1
use {
2
    proc_macro2::{Literal, TokenStream},
3
    quote::{format_ident, quote, ToTokens},
4
    syn::{Ident, ItemTrait, TraitItem},
5
};
6

            
7
struct IdentCased(Ident);
8

            
9
impl From<&Ident> for IdentCased {
10
80
    fn from(ident: &Ident) -> Self {
11
80
        IdentCased(ident.clone())
12
80
    }
13
}
14

            
15
impl IdentCased {
16
72
    fn remove_prefix(&self) -> Self {
17
72
        let s = self.0.to_string();
18
72
        IdentCased(Ident::new(&s[1..], self.0.span()))
19
72
    }
20
    #[allow(dead_code)]
21
    fn to_title_case(&self) -> Self {
22
        let converter = convert_case::Converter::new().to_case(convert_case::Case::Title);
23
        let converted = converter.convert(self.0.to_string());
24
        IdentCased(Ident::new(&converted, self.0.span()))
25
    }
26
    #[allow(dead_code)]
27
    fn to_upper_case(&self) -> Self {
28
        let converter = convert_case::Converter::new().to_case(convert_case::Case::Upper);
29
        let converted = converter.convert(self.0.to_string());
30
        IdentCased(Ident::new(&converted, self.0.span()))
31
    }
32
32
    fn to_screaming_snake_case(&self) -> Self {
33
32
        let converter = convert_case::Converter::new().to_case(convert_case::Case::ScreamingSnake);
34
32
        let converted = converter.convert(self.0.to_string());
35
32
        IdentCased(Ident::new(&converted, self.0.span()))
36
32
    }
37
88
    fn to_pascal_case(&self) -> Self {
38
88
        let converter = convert_case::Converter::new().to_case(convert_case::Case::Pascal);
39
88
        let converted = converter.convert(self.0.to_string());
40
88
        IdentCased(Ident::new(&converted, self.0.span()))
41
88
    }
42
    #[allow(dead_code)]
43
    fn to_upper_flat(&self) -> Self {
44
        let converter = convert_case::Converter::new().to_case(convert_case::Case::UpperFlat);
45
        let converted = converter.convert(self.0.to_string());
46
        IdentCased(Ident::new(&converted, self.0.span()))
47
    }
48
    #[allow(dead_code)]
49
    fn remove_whitespace(&self) -> Self {
50
        let s = self.0.to_string().split_whitespace().collect::<String>();
51
        IdentCased(Ident::new(&s, self.0.span()))
52
    }
53
}
54

            
55
impl From<IdentCased> for Ident {
56
120
    fn from(ident: IdentCased) -> Self {
57
120
        ident.0
58
120
    }
59
}
60

            
61
enum Direction {
62
    Rx,
63
    Tx,
64
}
65

            
66
12
fn generate_frame(
67
12
    direction: Direction,
68
12
    msgs: &[(Ident, proc_macro2::TokenStream)],
69
12
) -> proc_macro2::TokenStream {
70
12
    let enum_name = match direction {
71
6
        Direction::Rx => quote! { Rmessage },
72
6
        Direction::Tx => quote! { Tmessage },
73
    };
74

            
75
22
    let msg_variants = msgs.iter().map(|(ident, _p)| {
76
16
        let name: IdentCased = ident.into();
77
16
        let variant_name: Ident = name.remove_prefix().to_pascal_case().into();
78
16
        let constant_name: Ident = name.to_screaming_snake_case().into();
79
16
        quote! {
80
16
            #variant_name(#ident) = #constant_name,
81
16
        }
82
22
    });
83
22
    let cloned_byte_sizes = msgs.iter().map(|(ident, _)| {
84
16
        let name: IdentCased = ident.into();
85
16
        let variant_name: Ident = name.remove_prefix().to_pascal_case().into();
86
16
        quote! {
87
16
            #enum_name::#variant_name(msg) => msg.byte_size()
88
16
        }
89
22
    });
90
12

            
91
22
    let match_arms = msgs.iter().map(|(ident, _)| {
92
16
        let name: IdentCased = ident.into();
93
16
        let variant_name: Ident = name.remove_prefix().to_pascal_case().into();
94
16
        quote! {
95
16
            #enum_name::#variant_name(msg)
96
16
        }
97
22
    });
98
12

            
99
22
    let decode_bodies = msgs.iter().map(|(ident, _)| {
100
16
        let name: IdentCased = ident.into();
101
16
        let variant_name: Ident = name.remove_prefix().to_pascal_case().into();
102
16

            
103
16
        let const_name: Ident = name.to_screaming_snake_case().into();
104
16
        quote! {
105
16
                #const_name => Ok(#enum_name::#variant_name(WireFormat::decode(reader)?)),
106
16
        }
107
22
    });
108
12

            
109
22
    let encode_match_arms = match_arms.clone().map(|arm| {
110
16
        quote! {
111
16
            #arm => msg.encode(writer)?,
112
16
        }
113
22
    });
114
12

            
115
12
    quote! {
116
12
        #[derive(Debug)]
117
12
        #[repr(u8)]
118
12
        pub enum #enum_name {
119
12
            #( #msg_variants )*
120
12
        }
121
12

            
122
12
        impl Framer for #enum_name {
123
12
            fn byte_size(&self) -> u32 {
124
12
                match &self {
125
12
                    #(
126
12
                        #cloned_byte_sizes,
127
12
                     )*
128
12
                }
129
12
            }
130
12

            
131
12
            fn message_type(&self) -> u8 {
132
12
                // SAFETY: Because `Self` is marked `repr(u8)`, its layout is a `repr(C)` `union`
133
12
                // between `repr(C)` structs, each of which has the `u8` discriminant as its first
134
12
                // field, so we can read the discriminant without offsetting the pointer.
135
12
                unsafe { *<*const _>::from(self).cast::<u8>() }
136
12
            }
137
12

            
138
12
            fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
139
12
                match &self {
140
12
                    #(
141
12
                        #encode_match_arms
142
12
                     )*
143
12
                }
144
12

            
145
12
                Ok(())
146
12
            }
147
12

            
148
12
            fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<#enum_name> {
149
12
                match ty {
150
12
                    #(
151
12
                        #decode_bodies
152
12
                     )*
153
12
                    _ => Err(std::io::Error::new(
154
12
                        std::io::ErrorKind::InvalidData,
155
12
                        format!("unknown message type: {}", ty),
156
12
                    )),
157
12
                }
158
12
            }
159
12
        }
160
12
    }
161
12
}
162

            
163
6
fn generate_tframe(tmsgs: &[(Ident, proc_macro2::TokenStream)]) -> proc_macro2::TokenStream {
164
6
    generate_frame(Direction::Tx, tmsgs)
165
6
}
166

            
167
6
fn generate_rframe(rmsgs: &[(Ident, proc_macro2::TokenStream)]) -> proc_macro2::TokenStream {
168
6
    generate_frame(Direction::Rx, rmsgs)
169
6
}
170

            
171
8
fn generate_msg_id(index: usize, method_name: &Ident) -> proc_macro2::TokenStream {
172
8
    let upper_cased_method_name = method_name.to_string().to_uppercase();
173
8
    let tmsg_const_name = Ident::new(&format!("T{}", upper_cased_method_name), method_name.span());
174
8
    let rmsg_const_name = Ident::new(&format!("R{}", upper_cased_method_name), method_name.span());
175
8
    let offset = 2 * index as u8;
176
8

            
177
8
    quote! {
178
8
        pub const #tmsg_const_name: u8 = MESSAGE_ID_START + #offset;
179
8
        pub const #rmsg_const_name: u8 = MESSAGE_ID_START + #offset + 1;
180
8
    }
181
8
}
182

            
183
8
fn generate_input_struct(
184
8
    request_struct_ident: &Ident,
185
8
    method_sig: &syn::Signature,
186
8
) -> proc_macro2::TokenStream {
187
12
    let inputs = method_sig.inputs.iter().map(|arg| {
188
8
        match arg {
189
            syn::FnArg::Typed(pat) => {
190
                let name = pat.pat.clone();
191
                let ty = pat.ty.clone();
192
                quote! {
193
                    pub #name: #ty,
194
                }
195
            }
196
8
            syn::FnArg::Receiver(_) => quote! {},
197
        }
198
12
    });
199
8

            
200
8
    quote! {
201
8
        #[allow(non_camel_case_types)]
202
8
        #[derive(Debug, JetStreamWireFormat)]
203
8
        pub struct #request_struct_ident {
204
8
            #(#inputs)*
205
8
        }
206
8
    }
207
8
}
208
8
fn generate_return_struct(
209
8
    return_struct_ident: &Ident,
210
8
    method_sig: &syn::Signature,
211
8
) -> proc_macro2::TokenStream {
212
8
    match &method_sig.output {
213
8
        syn::ReturnType::Type(_, ty) => {
214
8
            match &**ty {
215
8
                syn::Type::Path(type_path) => {
216
                    // Check if it's a Result type
217
8
                    if let Some(segment) = type_path.path.segments.last() {
218
8
                        if segment.ident == "Result" {
219
                            // Extract the success type from Result<T, E>
220
8
                            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
221
8
                                if let Some(syn::GenericArgument::Type(success_type)) =
222
8
                                    args.args.first()
223
                                {
224
8
                                    return quote! {
225
8
                                        #[allow(non_camel_case_types)]
226
8
                                        #[derive(Debug, JetStreamWireFormat)]
227
8
                                        pub struct #return_struct_ident(pub #success_type);
228
8
                                    };
229
                                }
230
                            }
231
                        }
232
                    }
233
                    // If not a Result or couldn't extract type, use the whole type
234
                    quote! {
235
                        #[allow(non_camel_case_types)]
236
                        #[derive(Debug, JetStreamWireFormat)]
237
                        pub struct #return_struct_ident(pub #ty);
238
                    }
239
                }
240
                // Handle other return type variants if needed
241
                _ => {
242
                    quote! {
243
                        #[allow(non_camel_case_types)]
244
                        #[derive(Debug, JetStreamWireFormat)]
245
                        pub struct #return_struct_ident(pub #ty);
246
                    }
247
                }
248
            }
249
        }
250
        syn::ReturnType::Default => {
251
            quote! {
252
               #[allow(non_camel_case_types)]
253
               #[derive(Debug, JetStreamWireFormat)]
254
               pub struct #return_struct_ident;
255
            }
256
        }
257
    }
258
8
}
259

            
260
6
fn generate_match_arms(
261
6
    tmsgs: impl Iterator<Item = (Ident, proc_macro2::TokenStream)>,
262
6
) -> impl Iterator<Item = proc_macro2::TokenStream> {
263
8
    tmsgs.map(|(ident, _)| {
264
8
        let name: IdentCased = (&ident).into();
265
8
        let variant_name: Ident = name.remove_prefix().to_pascal_case().into();
266
8
        quote! {
267
8
            Tmessage::#variant_name(msg)
268
8
        }
269
8
    })
270
6
}
271
8
fn handle_receiver(recv: &syn::Receiver) -> proc_macro2::TokenStream {
272
8
    let mutability = &recv.mutability;
273
8
    let reference = &recv.reference;
274
8

            
275
8
    match (reference, mutability) {
276
8
        (Some(_), Some(_)) => quote! { &mut self.inner, },
277
        (Some(_), None) => quote! { &self.inner, },
278
        (None, _) => quote! { self.inner, },
279
    }
280
8
}
281
6
pub(crate) fn service_impl(item: ItemTrait, is_async_trait: bool) -> TokenStream {
282
6
    let trait_name = &item.ident;
283
6
    let trait_items = &item.items;
284
6
    let vis = &item.vis;
285
6

            
286
6
    // Generate message structs and enum variants
287
6
    // let mut message_structs = Vec::new();
288
6
    let mut tmsgs = Vec::new();
289
6
    let mut rmsgs = Vec::new();
290
6
    let mut msg_ids = Vec::new();
291
6
    let service_name = format_ident!("{}Service", trait_name);
292
6
    let channel_name = format_ident!("{}Channel", trait_name);
293
6
    let digest = sha256::digest(item.to_token_stream().to_string());
294
6

            
295
6
    #[allow(clippy::to_string_in_format_args)]
296
6
    let protocol_version = format!(
297
6
        "dev.branch.jetstream.proto/{}/{}.{}.{}-{}",
298
6
        trait_name.to_string().to_lowercase(),
299
6
        env!("CARGO_PKG_VERSION_MAJOR"),
300
6
        env!("CARGO_PKG_VERSION_MINOR"),
301
6
        env!("CARGO_PKG_VERSION_PATCH"),
302
6
        digest[0..8].to_string()
303
6
    );
304
6
    let protocol_version = Literal::string(protocol_version.as_str());
305
6
    let mut calls = vec![];
306
6
    let tag_name = format_ident!("{}_TAG", trait_name.to_string().to_uppercase());
307
6

            
308
6
    let mut server_calls = vec![];
309
6

            
310
6
    {
311
11
        let with_calls = item.items.iter().enumerate().map(|(index, item)| {
312
8
            if let TraitItem::Fn(method) = item {
313
8
                let method_name = &method.sig.ident;
314
8

            
315
8
                let request_struct_ident =
316
8
                    Ident::new(&format!("T{}", method_name), method_name.span());
317
8
                let return_struct_ident =
318
8
                    Ident::new(&format!("R{}", method_name), method_name.span());
319
8
                let _output_type = match &method.sig.output {
320
8
                    syn::ReturnType::Type(_, ty) => quote! { #ty },
321
                    syn::ReturnType::Default => quote! { () },
322
                };
323
8
                let msg_id = generate_msg_id(index, method_name);
324
8
                msg_ids.push(msg_id);
325
8
                let request_struct =
326
8
                    generate_input_struct(&request_struct_ident.clone(), &method.sig);
327
8
                let return_struct =
328
8
                    generate_return_struct(&return_struct_ident.clone(), &method.sig);
329
8

            
330
8
                tmsgs.insert(
331
8
                    index,
332
8
                    (request_struct_ident.clone(), request_struct.clone()),
333
8
                );
334
8
                rmsgs.insert(index, (return_struct_ident.clone(), return_struct.clone()));
335
            }
336
11
        });
337
6
        calls.extend(with_calls);
338
6
    }
339
6
    let mut client_calls = vec![];
340
6
    {
341
11
        item.items.iter().enumerate().for_each(|(index,item)|{
342
8
            let TraitItem::Fn(method) = item else {return;};
343
8
            let method_name = &method.sig.ident;
344
8
            let has_req = method.sig.inputs.iter().count() > 1;
345
8
            let is_async = method.sig.asyncness.is_some();
346
8
            let maybe_await = if is_async { quote! { .await } } else { quote! {} };
347

            
348
8
            let request_struct_ident = tmsgs.get(index).unwrap().0.clone();
349
8
            let return_struct_ident = rmsgs.get(index).unwrap().0.clone();
350
8
            let new = if has_req {
351
                let spread_req = method.sig.inputs.iter().map(|arg| match arg {
352
                    syn::FnArg::Typed(pat) => {
353
                        let name = pat.pat.clone();
354
                        quote! { req.#name, }
355
                    }
356
                    syn::FnArg::Receiver(_) => quote! {},
357
                });
358
                quote! {
359
                    fn #method_name(&mut self, tag: u16, req: #request_struct_ident) -> impl ::core::future::Future<
360
                        Output = Result<#return_struct_ident, Error,
361
                    > + Send + Sync {
362
                        Box::pin(async move {
363
                            #return_struct_ident(tag, #trait_name::#method_name(&mut self.inner,
364
                                #(#spread_req)*
365
                            )#maybe_await)
366
                        })
367
                    }
368
                }
369
            } else {
370
8
                quote! {
371
8
                    fn #method_name(&mut self, tag: u16, req: #request_struct_ident) -> impl ::core::future::Future<
372
8
                        Output = Result<#return_struct_ident, Error,
373
8
                    > + Send + Sync {
374
8
                        Box::pin(async move {
375
8
                            #return_struct_ident(tag, #trait_name::#method_name(&mut self.inner)#maybe_await)
376
8
                        })
377
8
                    }
378
8
                    }
379
                };
380
8
            server_calls.extend(new);
381
11
        });
382
6
    }
383
6
    {
384
11
        item.items.iter().enumerate().for_each(|(index, item)| {
385
8
            let TraitItem::Fn(method) = item else {
386
                return;
387
            };
388
8
            let method_name = &method.sig.ident;
389
8
            let variant_name: Ident = IdentCased(method_name.clone()).to_pascal_case().into();
390
8
            let retn = &method.sig.output;
391
8
            let is_async = method.sig.asyncness.is_some();
392
8
            let maybe_async = if is_async {
393
8
                quote! { async }
394
            } else {
395
                quote! {}
396
            };
397
8
            let request_struct_ident = tmsgs.get(index).unwrap().0.clone();
398
8
            let inputs = method.sig.inputs.iter().map(|arg| {
399
8
                match arg {
400
                    syn::FnArg::Typed(pat) => {
401
                        let name = pat.pat.clone();
402
                        let ty = pat.ty.clone();
403
                        quote! {
404
                             #name: #ty,
405
                        }
406
                    }
407
8
                    syn::FnArg::Receiver(_) => quote! {},
408
                }
409
8
            });
410
8
            let args = method.sig.inputs.iter().map(|arg| {
411
8
                match arg {
412
                    syn::FnArg::Typed(pat) => {
413
                        let name = pat.pat.clone();
414
                        quote! {
415
                             #name,
416
                        }
417
                    }
418
8
                    syn::FnArg::Receiver(_) => quote! {},
419
                }
420
8
            });
421
8
            let new = quote! {
422
8
                #maybe_async fn #method_name(&mut self, #(#inputs)*)  #retn {
423
8
                    let tag =#tag_name.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
424
8
                    let req = Tmessage::#variant_name(#request_struct_ident {
425
8
                        #(
426
8
                            #args
427
8
                        )*
428
8
                    });
429
8
                    let tframe= Frame::from((tag, req));
430
8
                    let rframe = self.rpc(tframe).await?;
431
8
                    let rmsg = rframe.msg;
432
8
                    match rmsg {
433
8
                        Rmessage::#variant_name(msg) => Ok(msg.0),
434
8
                        _ => Err(Error::InvalidResponse),
435
8
                    }
436
8
                }
437
8
            };
438

            
439
8
            client_calls.extend(new);
440
11
        });
441
6
    }
442
6

            
443
6
    // make a const with the digest
444
6
    let digest = Literal::string(digest.as_str());
445
11
    let tmsg_definitions = tmsgs.iter().map(|(_ident, def)| {
446
8
        quote! {
447
8
            #def
448
8
        }
449
11
    });
450
6

            
451
11
    let rmsg_definitions = rmsgs.iter().map(|(_ident, def)| {
452
8
        quote! {
453
8
            #def
454
8
        }
455
11
    });
456
6
    let tmessage = generate_tframe(&tmsgs);
457
6
    let rmessage = generate_rframe(&rmsgs);
458
6
    let proto_mod = format_ident!("{}_protocol", trait_name.to_string().to_lowercase());
459
6

            
460
6
    let match_arms = generate_match_arms(tmsgs.clone().into_iter());
461
6
    let match_arm_bodies: Vec<proc_macro2::TokenStream> = item
462
6
        .items
463
6
        .clone()
464
6
        .iter()
465
11
        .map(|item| {
466
8
            match item {
467
8
                TraitItem::Fn(method) => {
468
8
                    let method_name = &method.sig.ident;
469
8
                    let name: IdentCased = method_name.into();
470
8
                    let variant_name: Ident = name.to_pascal_case().into();
471
8
                    let return_struct_ident =
472
8
                        Ident::new(&format!("R{}", method_name), method_name.span());
473
8
                    let variables_spead = method.sig.inputs.iter().map(|arg| {
474
8
                        match arg {
475
                            syn::FnArg::Typed(pat) => {
476
                                let name = pat.pat.clone();
477
                                quote! { msg.#name, }
478
                            }
479
8
                            syn::FnArg::Receiver(recv) => handle_receiver(recv),
480
                        }
481
8
                    });
482
8
                    quote! {
483
8
                         {
484
8
                            let msg = #trait_name::#method_name(
485
8
                                #(
486
8
                                    #variables_spead
487
8
                                )*
488
8
                            ).await?;
489
8
                            let ret = #return_struct_ident(msg);
490
8
                            Ok(Rmessage::#variant_name(ret))
491
8
                        }
492
8
                    }
493
                }
494
                _ => quote! {},
495
            }
496
11
        })
497
6
        .collect();
498
11
    let matches = std::iter::zip(match_arms, match_arm_bodies.iter()).map(|(arm, body)| {
499
8
        quote! {
500
8
            #arm => #body
501
8
        }
502
11
    });
503

            
504
6
    let trait_attribute = if is_async_trait {
505
        quote! { #[jetstream::prelude::async_trait] }
506
    } else {
507
6
        quote! { #[jetstream::prelude::trait_variant::make(Send + Sync)] }
508
    };
509
6
    quote! {
510
6
        #vis mod #proto_mod{
511
6
            use jetstream::prelude::*;
512
6
            use std::io::{self,Read,Write};
513
6
            use std::mem;
514
6
            use super::#trait_name;
515
6
            const MESSAGE_ID_START: u8 = 101;
516
6
            pub const PROTOCOL_VERSION: &str = #protocol_version;
517
6
            const DIGEST: &str = #digest;
518
6

            
519
6
            #(#msg_ids)*
520
6

            
521
6
            #(#tmsg_definitions)*
522
6

            
523
6
            #(#rmsg_definitions)*
524
6

            
525
6
            #tmessage
526
6

            
527
6
            #rmessage
528
6

            
529
6
            #[derive(Clone)]
530
6
            pub struct #service_name<T: #trait_name> {
531
6
                pub inner: T,
532
6
            }
533
6

            
534
6
            impl<T> Protocol for #service_name<T>
535
6
            where
536
6
                T: #trait_name+ Send + Sync + Sized
537
6
            {
538
6
                type Request = Tmessage;
539
6
                type Response = Rmessage;
540
6
                type Error = Error;
541
6
                const VERSION: &'static str = PROTOCOL_VERSION;
542
6

            
543
6
                fn rpc(&mut self, frame: Frame<<Self as Protocol>::Request>) -> impl ::core::future::Future<
544
6
                    Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
545
6
                > + Send + Sync {
546
6
                    Box::pin(async move {
547
6
                        let req: <Self as Protocol>::Request = frame.msg;
548
6
                        let res: Result<<Self as Protocol>::Response, Self::Error> =match req {
549
6
                                #(
550
6
                                    #matches
551
6
                                )*
552
6
                        };
553
6
                        let rframe: Frame<<Self as Protocol>::Response> = Frame::from((frame.tag, res?));
554
6
                        Ok(rframe)
555
6
                    })
556
6
                }
557
6
            }
558
6
            pub struct #channel_name<'a> {
559
6
                pub inner: Box<&'a mut dyn ClientTransport<Self>>,
560
6
            }
561
6
            impl<'a> Protocol for #channel_name<'a>
562
6
            {
563
6
                type Request = Tmessage;
564
6
                type Response = Rmessage;
565
6
                type Error = Error;
566
6
                const VERSION: &'static str = PROTOCOL_VERSION;
567
6
                fn rpc(&mut self, frame: Frame<<Self as Protocol>::Request>) -> impl ::core::future::Future<
568
6
                    Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
569
6
                > + Send + Sync {
570
6
                    use futures::{SinkExt, StreamExt};
571
6
                    Box::pin(async move {
572
6
                        self.inner
573
6
                            .send(frame)
574
6
                            .await?;
575
6
                        let frame = self.inner.next().await.unwrap()?;
576
6
                        Ok(frame)
577
6
                    })
578
6
                }
579
6
            }
580
6
            lazy_static::lazy_static! {
581
6
                static ref #tag_name: std::sync::atomic::AtomicU16 = std::sync::atomic::AtomicU16::new(0);
582
6
            }
583
6
            impl<'a> #trait_name for #channel_name<'a>
584
6
            {
585
6
                #(#client_calls)*
586
6
            }
587
6

            
588
6
        }
589
6

            
590
6
        #trait_attribute
591
6
        #vis trait #trait_name {
592
6
            #(#trait_items)*
593
6
        }
594
6
    }
595
6
}
596

            
597
#[cfg(test)]
598
mod tests {
599
    use core::panic;
600

            
601
    use {super::*, syn::parse_quote};
602

            
603
    fn run_test_with_filters<F>(test_fn: F)
604
    where
605
        F: FnOnce() + panic::UnwindSafe,
606
    {
607
        let filters = vec![
608
            // Filter for protocol version strings
609
            (
610
                r"dev\.branch\.jetstream\.proto/\w+/\d+\.\d+\.\d+-[a-f0-9]{8}",
611
                "dev.branch.jetstream.proto/NAME/VERSION-HASH",
612
            ),
613
            // Filter for digest strings
614
            (r"[a-f0-9]{64}", "DIGEST_HASH"),
615
        ];
616

            
617
        insta::with_settings!({
618
            filters => filters,
619
        }, {
620
            test_fn();
621
        });
622
    }
623

            
624
    #[test]
625
    fn test_simple_service() {
626
        let input: ItemTrait = parse_quote! {
627
            pub trait Echo {
628
                async fn ping(&self) -> Result<(), std::io::Error>;
629
            }
630
        };
631
        let output = service_impl(input, false);
632
        let syntax_tree: syn::File = syn::parse2(output).unwrap();
633
        let output_str = prettyplease::unparse(&syntax_tree);
634
        run_test_with_filters(|| {
635
            insta::assert_snapshot!(output_str, @r###"
636
            pub mod echo_protocol {
637
                use jetstream::prelude::*;
638
                use std::io::{self, Read, Write};
639
                use std::mem;
640
                use super::Echo;
641
                const MESSAGE_ID_START: u8 = 101;
642
                pub const PROTOCOL_VERSION: &str = "dev.branch.jetstream.proto/NAME/VERSION-HASH";
643
                const DIGEST: &str = "DIGEST_HASH";
644
                pub const TPING: u8 = MESSAGE_ID_START + 0u8;
645
                pub const RPING: u8 = MESSAGE_ID_START + 0u8 + 1;
646
                #[allow(non_camel_case_types)]
647
                #[derive(Debug, JetStreamWireFormat)]
648
                pub struct Tping {}
649
                #[allow(non_camel_case_types)]
650
                #[derive(Debug, JetStreamWireFormat)]
651
                pub struct Rping(pub ());
652
                #[derive(Debug)]
653
                #[repr(u8)]
654
                pub enum Tmessage {
655
                    Ping(Tping) = TPING,
656
                }
657
                impl Framer for Tmessage {
658
                    fn byte_size(&self) -> u32 {
659
                        match &self {
660
                            Tmessage::Ping(msg) => msg.byte_size(),
661
                        }
662
                    }
663
                    fn message_type(&self) -> u8 {
664
                        unsafe { *<*const _>::from(self).cast::<u8>() }
665
                    }
666
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
667
                        match &self {
668
                            Tmessage::Ping(msg) => msg.encode(writer)?,
669
                        }
670
                        Ok(())
671
                    }
672
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Tmessage> {
673
                        match ty {
674
                            TPING => Ok(Tmessage::Ping(WireFormat::decode(reader)?)),
675
                            _ => {
676
                                Err(
677
                                    std::io::Error::new(
678
                                        std::io::ErrorKind::InvalidData,
679
                                        format!("unknown message type: {}", ty),
680
                                    ),
681
                                )
682
                            }
683
                        }
684
                    }
685
                }
686
                #[derive(Debug)]
687
                #[repr(u8)]
688
                pub enum Rmessage {
689
                    Ping(Rping) = RPING,
690
                }
691
                impl Framer for Rmessage {
692
                    fn byte_size(&self) -> u32 {
693
                        match &self {
694
                            Rmessage::Ping(msg) => msg.byte_size(),
695
                        }
696
                    }
697
                    fn message_type(&self) -> u8 {
698
                        unsafe { *<*const _>::from(self).cast::<u8>() }
699
                    }
700
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
701
                        match &self {
702
                            Rmessage::Ping(msg) => msg.encode(writer)?,
703
                        }
704
                        Ok(())
705
                    }
706
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Rmessage> {
707
                        match ty {
708
                            RPING => Ok(Rmessage::Ping(WireFormat::decode(reader)?)),
709
                            _ => {
710
                                Err(
711
                                    std::io::Error::new(
712
                                        std::io::ErrorKind::InvalidData,
713
                                        format!("unknown message type: {}", ty),
714
                                    ),
715
                                )
716
                            }
717
                        }
718
                    }
719
                }
720
                #[derive(Clone)]
721
                pub struct EchoService<T: Echo> {
722
                    pub inner: T,
723
                }
724
                impl<T> Protocol for EchoService<T>
725
                where
726
                    T: Echo + Send + Sync + Sized,
727
                {
728
                    type Request = Tmessage;
729
                    type Response = Rmessage;
730
                    type Error = Error;
731
                    const VERSION: &'static str = PROTOCOL_VERSION;
732
                    fn rpc(
733
                        &mut self,
734
                        frame: Frame<<Self as Protocol>::Request>,
735
                    ) -> impl ::core::future::Future<
736
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
737
                    > + Send + Sync {
738
                        Box::pin(async move {
739
                            let req: <Self as Protocol>::Request = frame.msg;
740
                            let res: Result<<Self as Protocol>::Response, Self::Error> = match req {
741
                                Tmessage::Ping(msg) => {
742
                                    let msg = Echo::ping(&self.inner).await?;
743
                                    let ret = Rping(msg);
744
                                    Ok(Rmessage::Ping(ret))
745
                                }
746
                            };
747
                            let rframe: Frame<<Self as Protocol>::Response> = Frame::from((
748
                                frame.tag,
749
                                res?,
750
                            ));
751
                            Ok(rframe)
752
                        })
753
                    }
754
                }
755
                pub struct EchoChannel<'a> {
756
                    pub inner: Box<&'a mut dyn ClientTransport<Self>>,
757
                }
758
                impl<'a> Protocol for EchoChannel<'a> {
759
                    type Request = Tmessage;
760
                    type Response = Rmessage;
761
                    type Error = Error;
762
                    const VERSION: &'static str = PROTOCOL_VERSION;
763
                    fn rpc(
764
                        &mut self,
765
                        frame: Frame<<Self as Protocol>::Request>,
766
                    ) -> impl ::core::future::Future<
767
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
768
                    > + Send + Sync {
769
                        use futures::{SinkExt, StreamExt};
770
                        Box::pin(async move {
771
                            self.inner.send(frame).await?;
772
                            let frame = self.inner.next().await.unwrap()?;
773
                            Ok(frame)
774
                        })
775
                    }
776
                }
777
                lazy_static::lazy_static! {
778
                    static ref ECHO_TAG : std::sync::atomic::AtomicU16 =
779
                    std::sync::atomic::AtomicU16::new(0);
780
                }
781
                impl<'a> Echo for EchoChannel<'a> {
782
                    async fn ping(&mut self) -> Result<(), std::io::Error> {
783
                        let tag = ECHO_TAG.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
784
                        let req = Tmessage::Ping(Tping {});
785
                        let tframe = Frame::from((tag, req));
786
                        let rframe = self.rpc(tframe).await?;
787
                        let rmsg = rframe.msg;
788
                        match rmsg {
789
                            Rmessage::Ping(msg) => Ok(msg.0),
790
                            _ => Err(Error::InvalidResponse),
791
                        }
792
                    }
793
                }
794
            }
795
            #[jetstream::prelude::trait_variant::make(Send+Sync)]
796
            pub trait Echo {
797
                async fn ping(&self) -> Result<(), std::io::Error>;
798
            }
799
            "###)
800
        })
801
    }
802

            
803
    #[test]
804
    fn test_service_with_args() {
805
        let input: ItemTrait = parse_quote! {
806
            pub trait Echo {
807
                async fn ping(&self, message: String) -> Result<String, std::io::Error>;
808
            }
809
        };
810
        let output = service_impl(input, false);
811
        let syntax_tree: syn::File = syn::parse2(output).unwrap();
812
        let output_str = prettyplease::unparse(&syntax_tree);
813
        run_test_with_filters(|| {
814
            insta::assert_snapshot!(output_str, @r###"
815
            pub mod echo_protocol {
816
                use jetstream::prelude::*;
817
                use std::io::{self, Read, Write};
818
                use std::mem;
819
                use super::Echo;
820
                const MESSAGE_ID_START: u8 = 101;
821
                pub const PROTOCOL_VERSION: &str = "dev.branch.jetstream.proto/NAME/VERSION-HASH";
822
                const DIGEST: &str = "DIGEST_HASH";
823
                pub const TPING: u8 = MESSAGE_ID_START + 0u8;
824
                pub const RPING: u8 = MESSAGE_ID_START + 0u8 + 1;
825
                #[allow(non_camel_case_types)]
826
                #[derive(Debug, JetStreamWireFormat)]
827
                pub struct Tping {
828
                    pub message: String,
829
                }
830
                #[allow(non_camel_case_types)]
831
                #[derive(Debug, JetStreamWireFormat)]
832
                pub struct Rping(pub String);
833
                #[derive(Debug)]
834
                #[repr(u8)]
835
                pub enum Tmessage {
836
                    Ping(Tping) = TPING,
837
                }
838
                impl Framer for Tmessage {
839
                    fn byte_size(&self) -> u32 {
840
                        match &self {
841
                            Tmessage::Ping(msg) => msg.byte_size(),
842
                        }
843
                    }
844
                    fn message_type(&self) -> u8 {
845
                        unsafe { *<*const _>::from(self).cast::<u8>() }
846
                    }
847
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
848
                        match &self {
849
                            Tmessage::Ping(msg) => msg.encode(writer)?,
850
                        }
851
                        Ok(())
852
                    }
853
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Tmessage> {
854
                        match ty {
855
                            TPING => Ok(Tmessage::Ping(WireFormat::decode(reader)?)),
856
                            _ => {
857
                                Err(
858
                                    std::io::Error::new(
859
                                        std::io::ErrorKind::InvalidData,
860
                                        format!("unknown message type: {}", ty),
861
                                    ),
862
                                )
863
                            }
864
                        }
865
                    }
866
                }
867
                #[derive(Debug)]
868
                #[repr(u8)]
869
                pub enum Rmessage {
870
                    Ping(Rping) = RPING,
871
                }
872
                impl Framer for Rmessage {
873
                    fn byte_size(&self) -> u32 {
874
                        match &self {
875
                            Rmessage::Ping(msg) => msg.byte_size(),
876
                        }
877
                    }
878
                    fn message_type(&self) -> u8 {
879
                        unsafe { *<*const _>::from(self).cast::<u8>() }
880
                    }
881
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
882
                        match &self {
883
                            Rmessage::Ping(msg) => msg.encode(writer)?,
884
                        }
885
                        Ok(())
886
                    }
887
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Rmessage> {
888
                        match ty {
889
                            RPING => Ok(Rmessage::Ping(WireFormat::decode(reader)?)),
890
                            _ => {
891
                                Err(
892
                                    std::io::Error::new(
893
                                        std::io::ErrorKind::InvalidData,
894
                                        format!("unknown message type: {}", ty),
895
                                    ),
896
                                )
897
                            }
898
                        }
899
                    }
900
                }
901
                #[derive(Clone)]
902
                pub struct EchoService<T: Echo> {
903
                    pub inner: T,
904
                }
905
                impl<T> Protocol for EchoService<T>
906
                where
907
                    T: Echo + Send + Sync + Sized,
908
                {
909
                    type Request = Tmessage;
910
                    type Response = Rmessage;
911
                    type Error = Error;
912
                    const VERSION: &'static str = PROTOCOL_VERSION;
913
                    fn rpc(
914
                        &mut self,
915
                        frame: Frame<<Self as Protocol>::Request>,
916
                    ) -> impl ::core::future::Future<
917
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
918
                    > + Send + Sync {
919
                        Box::pin(async move {
920
                            let req: <Self as Protocol>::Request = frame.msg;
921
                            let res: Result<<Self as Protocol>::Response, Self::Error> = match req {
922
                                Tmessage::Ping(msg) => {
923
                                    let msg = Echo::ping(&self.inner, msg.message).await?;
924
                                    let ret = Rping(msg);
925
                                    Ok(Rmessage::Ping(ret))
926
                                }
927
                            };
928
                            let rframe: Frame<<Self as Protocol>::Response> = Frame::from((
929
                                frame.tag,
930
                                res?,
931
                            ));
932
                            Ok(rframe)
933
                        })
934
                    }
935
                }
936
                pub struct EchoChannel<'a> {
937
                    pub inner: Box<&'a mut dyn ClientTransport<Self>>,
938
                }
939
                impl<'a> Protocol for EchoChannel<'a> {
940
                    type Request = Tmessage;
941
                    type Response = Rmessage;
942
                    type Error = Error;
943
                    const VERSION: &'static str = PROTOCOL_VERSION;
944
                    fn rpc(
945
                        &mut self,
946
                        frame: Frame<<Self as Protocol>::Request>,
947
                    ) -> impl ::core::future::Future<
948
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
949
                    > + Send + Sync {
950
                        use futures::{SinkExt, StreamExt};
951
                        Box::pin(async move {
952
                            self.inner.send(frame).await?;
953
                            let frame = self.inner.next().await.unwrap()?;
954
                            Ok(frame)
955
                        })
956
                    }
957
                }
958
                lazy_static::lazy_static! {
959
                    static ref ECHO_TAG : std::sync::atomic::AtomicU16 =
960
                    std::sync::atomic::AtomicU16::new(0);
961
                }
962
                impl<'a> Echo for EchoChannel<'a> {
963
                    async fn ping(&mut self, message: String) -> Result<String, std::io::Error> {
964
                        let tag = ECHO_TAG.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
965
                        let req = Tmessage::Ping(Tping { message });
966
                        let tframe = Frame::from((tag, req));
967
                        let rframe = self.rpc(tframe).await?;
968
                        let rmsg = rframe.msg;
969
                        match rmsg {
970
                            Rmessage::Ping(msg) => Ok(msg.0),
971
                            _ => Err(Error::InvalidResponse),
972
                        }
973
                    }
974
                }
975
            }
976
            #[jetstream::prelude::trait_variant::make(Send+Sync)]
977
            pub trait Echo {
978
                async fn ping(&self, message: String) -> Result<String, std::io::Error>;
979
            }
980
            "###)
981
        })
982
    }
983

            
984
    #[test]
985
    fn test_async_trait_service_with_args() {
986
        let input: ItemTrait = parse_quote! {
987
            pub trait Echo {
988
                async fn ping(&mut self, message: String) -> Result<String, std::io::Error>;
989
            }
990
        };
991
        let output = service_impl(input, true);
992
        let syntax_tree: syn::File = syn::parse2(output).unwrap();
993
        let output_str = prettyplease::unparse(&syntax_tree);
994
        run_test_with_filters(|| {
995
            insta::assert_snapshot!(output_str, @r###"
996
            pub mod echo_protocol {
997
                use jetstream::prelude::*;
998
                use std::io::{self, Read, Write};
999
                use std::mem;
                use super::Echo;
                const MESSAGE_ID_START: u8 = 101;
                pub const PROTOCOL_VERSION: &str = "dev.branch.jetstream.proto/NAME/VERSION-HASH";
                const DIGEST: &str = "DIGEST_HASH";
                pub const TPING: u8 = MESSAGE_ID_START + 0u8;
                pub const RPING: u8 = MESSAGE_ID_START + 0u8 + 1;
                #[allow(non_camel_case_types)]
                #[derive(Debug, JetStreamWireFormat)]
                pub struct Tping {
                    pub message: String,
                }
                #[allow(non_camel_case_types)]
                #[derive(Debug, JetStreamWireFormat)]
                pub struct Rping(pub String);
                #[derive(Debug)]
                #[repr(u8)]
                pub enum Tmessage {
                    Ping(Tping) = TPING,
                }
                impl Framer for Tmessage {
                    fn byte_size(&self) -> u32 {
                        match &self {
                            Tmessage::Ping(msg) => msg.byte_size(),
                        }
                    }
                    fn message_type(&self) -> u8 {
                        unsafe { *<*const _>::from(self).cast::<u8>() }
                    }
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
                        match &self {
                            Tmessage::Ping(msg) => msg.encode(writer)?,
                        }
                        Ok(())
                    }
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Tmessage> {
                        match ty {
                            TPING => Ok(Tmessage::Ping(WireFormat::decode(reader)?)),
                            _ => {
                                Err(
                                    std::io::Error::new(
                                        std::io::ErrorKind::InvalidData,
                                        format!("unknown message type: {}", ty),
                                    ),
                                )
                            }
                        }
                    }
                }
                #[derive(Debug)]
                #[repr(u8)]
                pub enum Rmessage {
                    Ping(Rping) = RPING,
                }
                impl Framer for Rmessage {
                    fn byte_size(&self) -> u32 {
                        match &self {
                            Rmessage::Ping(msg) => msg.byte_size(),
                        }
                    }
                    fn message_type(&self) -> u8 {
                        unsafe { *<*const _>::from(self).cast::<u8>() }
                    }
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
                        match &self {
                            Rmessage::Ping(msg) => msg.encode(writer)?,
                        }
                        Ok(())
                    }
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Rmessage> {
                        match ty {
                            RPING => Ok(Rmessage::Ping(WireFormat::decode(reader)?)),
                            _ => {
                                Err(
                                    std::io::Error::new(
                                        std::io::ErrorKind::InvalidData,
                                        format!("unknown message type: {}", ty),
                                    ),
                                )
                            }
                        }
                    }
                }
                #[derive(Clone)]
                pub struct EchoService<T: Echo> {
                    pub inner: T,
                }
                impl<T> Protocol for EchoService<T>
                where
                    T: Echo + Send + Sync + Sized,
                {
                    type Request = Tmessage;
                    type Response = Rmessage;
                    type Error = Error;
                    const VERSION: &'static str = PROTOCOL_VERSION;
                    fn rpc(
                        &mut self,
                        frame: Frame<<Self as Protocol>::Request>,
                    ) -> impl ::core::future::Future<
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
                    > + Send + Sync {
                        Box::pin(async move {
                            let req: <Self as Protocol>::Request = frame.msg;
                            let res: Result<<Self as Protocol>::Response, Self::Error> = match req {
                                Tmessage::Ping(msg) => {
                                    let msg = Echo::ping(&mut self.inner, msg.message).await?;
                                    let ret = Rping(msg);
                                    Ok(Rmessage::Ping(ret))
                                }
                            };
                            let rframe: Frame<<Self as Protocol>::Response> = Frame::from((
                                frame.tag,
                                res?,
                            ));
                            Ok(rframe)
                        })
                    }
                }
                pub struct EchoChannel<'a> {
                    pub inner: Box<&'a mut dyn ClientTransport<Self>>,
                }
                impl<'a> Protocol for EchoChannel<'a> {
                    type Request = Tmessage;
                    type Response = Rmessage;
                    type Error = Error;
                    const VERSION: &'static str = PROTOCOL_VERSION;
                    fn rpc(
                        &mut self,
                        frame: Frame<<Self as Protocol>::Request>,
                    ) -> impl ::core::future::Future<
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
                    > + Send + Sync {
                        use futures::{SinkExt, StreamExt};
                        Box::pin(async move {
                            self.inner.send(frame).await?;
                            let frame = self.inner.next().await.unwrap()?;
                            Ok(frame)
                        })
                    }
                }
                lazy_static::lazy_static! {
                    static ref ECHO_TAG : std::sync::atomic::AtomicU16 =
                    std::sync::atomic::AtomicU16::new(0);
                }
                impl<'a> Echo for EchoChannel<'a> {
                    async fn ping(&mut self, message: String) -> Result<String, std::io::Error> {
                        let tag = ECHO_TAG.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
                        let req = Tmessage::Ping(Tping { message });
                        let tframe = Frame::from((tag, req));
                        let rframe = self.rpc(tframe).await?;
                        let rmsg = rframe.msg;
                        match rmsg {
                            Rmessage::Ping(msg) => Ok(msg.0),
                            _ => Err(Error::InvalidResponse),
                        }
                    }
                }
            }
            #[jetstream::prelude::async_trait]
            pub trait Echo {
                async fn ping(&mut self, message: String) -> Result<String, std::io::Error>;
            }
            "###)
        })
    }
}