1
use proc_macro2::{Literal, TokenStream};
2
use quote::{format_ident, quote, ToTokens};
3
use syn::{Ident, ItemTrait, TraitItem};
4

            
5
struct IdentCased(Ident);
6

            
7
impl From<&Ident> for IdentCased {
8
80
    fn from(ident: &Ident) -> Self {
9
80
        IdentCased(ident.clone())
10
80
    }
11
}
12

            
13
impl IdentCased {
14
72
    fn remove_prefix(&self) -> Self {
15
72
        let s = self.0.to_string();
16
72
        IdentCased(Ident::new(&s[1..], self.0.span()))
17
72
    }
18

            
19
    #[allow(dead_code)]
20
    fn to_title_case(&self) -> Self {
21
        let converter =
22
            convert_case::Converter::new().to_case(convert_case::Case::Title);
23
        let converted = converter.convert(self.0.to_string());
24
        IdentCased(Ident::new(&converted, self.0.span()))
25
    }
26

            
27
    #[allow(dead_code)]
28
    fn to_upper_case(&self) -> Self {
29
        let converter =
30
            convert_case::Converter::new().to_case(convert_case::Case::Upper);
31
        let converted = converter.convert(self.0.to_string());
32
        IdentCased(Ident::new(&converted, self.0.span()))
33
    }
34

            
35
32
    fn to_screaming_snake_case(&self) -> Self {
36
32
        let converter = convert_case::Converter::new()
37
32
            .to_case(convert_case::Case::ScreamingSnake);
38
32
        let converted = converter.convert(self.0.to_string());
39
32
        IdentCased(Ident::new(&converted, self.0.span()))
40
32
    }
41

            
42
88
    fn to_pascal_case(&self) -> Self {
43
88
        let converter =
44
88
            convert_case::Converter::new().to_case(convert_case::Case::Pascal);
45
88
        let converted = converter.convert(self.0.to_string());
46
88
        IdentCased(Ident::new(&converted, self.0.span()))
47
88
    }
48

            
49
    #[allow(dead_code)]
50
    fn to_upper_flat(&self) -> Self {
51
        let converter = convert_case::Converter::new()
52
            .to_case(convert_case::Case::UpperFlat);
53
        let converted = converter.convert(self.0.to_string());
54
        IdentCased(Ident::new(&converted, self.0.span()))
55
    }
56

            
57
    #[allow(dead_code)]
58
    fn remove_whitespace(&self) -> Self {
59
        let s = self.0.to_string().split_whitespace().collect::<String>();
60
        IdentCased(Ident::new(&s, self.0.span()))
61
    }
62
}
63

            
64
impl From<IdentCased> for Ident {
65
120
    fn from(ident: IdentCased) -> Self {
66
120
        ident.0
67
120
    }
68
}
69

            
70
enum Direction {
71
    Rx,
72
    Tx,
73
}
74

            
75
12
fn generate_frame(
76
12
    direction: Direction,
77
12
    msgs: &[(Ident, proc_macro2::TokenStream)],
78
12
) -> proc_macro2::TokenStream {
79
12
    let enum_name = match direction {
80
6
        Direction::Rx => quote! { Rmessage },
81
6
        Direction::Tx => quote! { Tmessage },
82
    };
83

            
84
22
    let msg_variants = msgs.iter().map(|(ident, _p)| {
85
16
        let name: IdentCased = ident.into();
86
16
        let variant_name: Ident = name.remove_prefix().to_pascal_case().into();
87
16
        let constant_name: Ident = name.to_screaming_snake_case().into();
88
16
        quote! {
89
16
            #variant_name(#ident) = #constant_name,
90
16
        }
91
22
    });
92
22
    let cloned_byte_sizes = msgs.iter().map(|(ident, _)| {
93
16
        let name: IdentCased = ident.into();
94
16
        let variant_name: Ident = name.remove_prefix().to_pascal_case().into();
95
16
        quote! {
96
16
            #enum_name::#variant_name(msg) => msg.byte_size()
97
16
        }
98
22
    });
99
12

            
100
22
    let match_arms = msgs.iter().map(|(ident, _)| {
101
16
        let name: IdentCased = ident.into();
102
16
        let variant_name: Ident = name.remove_prefix().to_pascal_case().into();
103
16
        quote! {
104
16
            #enum_name::#variant_name(msg)
105
16
        }
106
22
    });
107
12

            
108
22
    let decode_bodies = msgs.iter().map(|(ident, _)| {
109
16
        let name: IdentCased = ident.into();
110
16
        let variant_name: Ident = name.remove_prefix().to_pascal_case().into();
111
16

            
112
16
        let const_name: Ident = name.to_screaming_snake_case().into();
113
16
        quote! {
114
16
                #const_name => Ok(#enum_name::#variant_name(WireFormat::decode(reader)?)),
115
16
        }
116
22
    });
117
12

            
118
22
    let encode_match_arms = match_arms.clone().map(|arm| {
119
16
        quote! {
120
16
            #arm => msg.encode(writer)?,
121
16
        }
122
22
    });
123
12

            
124
12
    quote! {
125
12
        #[derive(Debug)]
126
12
        #[repr(u8)]
127
12
        pub enum #enum_name {
128
12
            #( #msg_variants )*
129
12
        }
130
12

            
131
12
        impl Framer for #enum_name {
132
12
            fn byte_size(&self) -> u32 {
133
12
                match &self {
134
12
                    #(
135
12
                        #cloned_byte_sizes,
136
12
                     )*
137
12
                }
138
12
            }
139
12

            
140
12
            fn message_type(&self) -> u8 {
141
12
                // SAFETY: Because `Self` is marked `repr(u8)`, its layout is a `repr(C)` `union`
142
12
                // between `repr(C)` structs, each of which has the `u8` discriminant as its first
143
12
                // field, so we can read the discriminant without offsetting the pointer.
144
12
                unsafe { *<*const _>::from(self).cast::<u8>() }
145
12
            }
146
12

            
147
12
            fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
148
12
                match &self {
149
12
                    #(
150
12
                        #encode_match_arms
151
12
                     )*
152
12
                }
153
12

            
154
12
                Ok(())
155
12
            }
156
12

            
157
12
            fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<#enum_name> {
158
12
                match ty {
159
12
                    #(
160
12
                        #decode_bodies
161
12
                     )*
162
12
                    _ => Err(std::io::Error::new(
163
12
                        std::io::ErrorKind::InvalidData,
164
12
                        format!("unknown message type: {}", ty),
165
12
                    )),
166
12
                }
167
12
            }
168
12
        }
169
12
    }
170
12
}
171

            
172
6
fn generate_tframe(
173
6
    tmsgs: &[(Ident, proc_macro2::TokenStream)],
174
6
) -> proc_macro2::TokenStream {
175
6
    generate_frame(Direction::Tx, tmsgs)
176
6
}
177

            
178
6
fn generate_rframe(
179
6
    rmsgs: &[(Ident, proc_macro2::TokenStream)],
180
6
) -> proc_macro2::TokenStream {
181
6
    generate_frame(Direction::Rx, rmsgs)
182
6
}
183

            
184
8
fn generate_msg_id(
185
8
    index: usize,
186
8
    method_name: &Ident,
187
8
) -> proc_macro2::TokenStream {
188
8
    let upper_cased_method_name = method_name.to_string().to_uppercase();
189
8
    let tmsg_const_name = Ident::new(
190
8
        &format!("T{}", upper_cased_method_name),
191
8
        method_name.span(),
192
8
    );
193
8
    let rmsg_const_name = Ident::new(
194
8
        &format!("R{}", upper_cased_method_name),
195
8
        method_name.span(),
196
8
    );
197
8
    let offset = 2 * index as u8;
198
8

            
199
8
    quote! {
200
8
        pub const #tmsg_const_name: u8 = MESSAGE_ID_START + #offset;
201
8
        pub const #rmsg_const_name: u8 = MESSAGE_ID_START + #offset + 1;
202
8
    }
203
8
}
204

            
205
8
fn generate_input_struct(
206
8
    request_struct_ident: &Ident,
207
8
    method_sig: &syn::Signature,
208
8
) -> proc_macro2::TokenStream {
209
12
    let inputs = method_sig.inputs.iter().map(|arg| match arg {
210
        syn::FnArg::Typed(pat) => {
211
            let name = pat.pat.clone();
212
            let ty = pat.ty.clone();
213
            quote! {
214
                pub #name: #ty,
215
            }
216
        }
217
8
        syn::FnArg::Receiver(_) => quote! {},
218
12
    });
219
8

            
220
8
    quote! {
221
8
        #[allow(non_camel_case_types)]
222
8
        #[derive(Debug, JetStreamWireFormat)]
223
8
        pub struct #request_struct_ident {
224
8
            #(#inputs)*
225
8
        }
226
8
    }
227
8
}
228
8
fn generate_return_struct(
229
8
    return_struct_ident: &Ident,
230
8
    method_sig: &syn::Signature,
231
8
) -> proc_macro2::TokenStream {
232
8
    match &method_sig.output {
233
8
        syn::ReturnType::Type(_, ty) => {
234
8
            match &**ty {
235
8
                syn::Type::Path(type_path) => {
236
                    // Check if it's a Result type
237
8
                    if let Some(segment) = type_path.path.segments.last() {
238
8
                        if segment.ident == "Result" {
239
                            // Extract the success type from Result<T, E>
240
8
                            if let syn::PathArguments::AngleBracketed(args) =
241
8
                                &segment.arguments
242
                            {
243
                                if let Some(syn::GenericArgument::Type(
244
8
                                    success_type,
245
8
                                )) = args.args.first()
246
                                {
247
8
                                    return quote! {
248
8
                                        #[allow(non_camel_case_types)]
249
8
                                        #[derive(Debug, JetStreamWireFormat)]
250
8
                                        pub struct #return_struct_ident(pub #success_type);
251
8
                                    };
252
                                }
253
                            }
254
                        }
255
                    }
256
                    // If not a Result or couldn't extract type, use the whole type
257
                    quote! {
258
                        #[allow(non_camel_case_types)]
259
                        #[derive(Debug, JetStreamWireFormat)]
260
                        pub struct #return_struct_ident(pub #ty);
261
                    }
262
                }
263
                // Handle other return type variants if needed
264
                _ => {
265
                    quote! {
266
                        #[allow(non_camel_case_types)]
267
                        #[derive(Debug, JetStreamWireFormat)]
268
                        pub struct #return_struct_ident(pub #ty);
269
                    }
270
                }
271
            }
272
        }
273
        syn::ReturnType::Default => {
274
            quote! {
275
               #[allow(non_camel_case_types)]
276
               #[derive(Debug, JetStreamWireFormat)]
277
               pub struct #return_struct_ident;
278
            }
279
        }
280
    }
281
8
}
282

            
283
6
fn generate_match_arms(
284
6
    tmsgs: impl Iterator<Item = (Ident, proc_macro2::TokenStream)>,
285
6
) -> impl Iterator<Item = proc_macro2::TokenStream> {
286
8
    tmsgs.map(|(ident, _)| {
287
8
        let name: IdentCased = (&ident).into();
288
8
        let variant_name: Ident = name.remove_prefix().to_pascal_case().into();
289
8
        quote! {
290
8
            Tmessage::#variant_name(msg)
291
8
        }
292
8
    })
293
6
}
294
8
fn handle_receiver(recv: &syn::Receiver) -> proc_macro2::TokenStream {
295
8
    let mutability = &recv.mutability;
296
8
    let reference = &recv.reference;
297
8

            
298
8
    match (reference, mutability) {
299
8
        (Some(_), Some(_)) => quote! { &mut self.inner, },
300
        (Some(_), None) => quote! { &self.inner, },
301
        (None, _) => quote! { self.inner, },
302
    }
303
8
}
304
6
pub(crate) fn service_impl(
305
6
    item: ItemTrait,
306
6
    is_async_trait: bool,
307
6
) -> TokenStream {
308
6
    let trait_name = &item.ident;
309
6
    let trait_items = &item.items;
310
6
    let vis = &item.vis;
311
6

            
312
6
    // Generate message structs and enum variants
313
6
    // let mut message_structs = Vec::new();
314
6
    let mut tmsgs = Vec::new();
315
6
    let mut rmsgs = Vec::new();
316
6
    let mut msg_ids = Vec::new();
317
6
    let service_name = format_ident!("{}Service", trait_name);
318
6
    let channel_name = format_ident!("{}Channel", trait_name);
319
6
    let digest = sha256::digest(item.to_token_stream().to_string());
320
6

            
321
6
    #[allow(clippy::to_string_in_format_args)]
322
6
    let protocol_version = format!(
323
6
        "dev.branch.jetstream.proto/{}/{}.{}.{}-{}",
324
6
        trait_name.to_string().to_lowercase(),
325
6
        env!("CARGO_PKG_VERSION_MAJOR"),
326
6
        env!("CARGO_PKG_VERSION_MINOR"),
327
6
        env!("CARGO_PKG_VERSION_PATCH"),
328
6
        digest[0..8].to_string()
329
6
    );
330
6
    let protocol_version = Literal::string(protocol_version.as_str());
331
6
    let mut calls = vec![];
332
6
    let tag_name =
333
6
        format_ident!("{}_TAG", trait_name.to_string().to_uppercase());
334
6

            
335
6
    let mut server_calls = vec![];
336
6

            
337
6
    {
338
11
        let with_calls = item.items.iter().enumerate().map(|(index, item)| {
339
8
            if let TraitItem::Fn(method) = item {
340
8
                let method_name = &method.sig.ident;
341
8

            
342
8
                let request_struct_ident = Ident::new(
343
8
                    &format!("T{}", method_name),
344
8
                    method_name.span(),
345
8
                );
346
8
                let return_struct_ident = Ident::new(
347
8
                    &format!("R{}", method_name),
348
8
                    method_name.span(),
349
8
                );
350
8
                let _output_type = match &method.sig.output {
351
8
                    syn::ReturnType::Type(_, ty) => quote! { #ty },
352
                    syn::ReturnType::Default => quote! { () },
353
                };
354
8
                let msg_id = generate_msg_id(index, method_name);
355
8
                msg_ids.push(msg_id);
356
8
                let request_struct = generate_input_struct(
357
8
                    &request_struct_ident.clone(),
358
8
                    &method.sig,
359
8
                );
360
8
                let return_struct = generate_return_struct(
361
8
                    &return_struct_ident.clone(),
362
8
                    &method.sig,
363
8
                );
364
8

            
365
8
                tmsgs.insert(
366
8
                    index,
367
8
                    (request_struct_ident.clone(), request_struct.clone()),
368
8
                );
369
8
                rmsgs.insert(
370
8
                    index,
371
8
                    (return_struct_ident.clone(), return_struct.clone()),
372
8
                );
373
            }
374
11
        });
375
6
        calls.extend(with_calls);
376
6
    }
377
6
    let mut client_calls = vec![];
378
6
    {
379
11
        item.items.iter().enumerate().for_each(|(index,item)|{
380
8
            let TraitItem::Fn(method) = item else {return;};
381
8
            let method_name = &method.sig.ident;
382
8
            let has_req = method.sig.inputs.iter().count() > 1;
383
8
            let is_async = method.sig.asyncness.is_some();
384
8
            let maybe_await = if is_async { quote! { .await } } else { quote! {} };
385

            
386
8
            let request_struct_ident = tmsgs.get(index).unwrap().0.clone();
387
8
            let return_struct_ident = rmsgs.get(index).unwrap().0.clone();
388
8
            let new = if has_req {
389
                let spread_req = method.sig.inputs.iter().map(|arg| match arg {
390
                    syn::FnArg::Typed(pat) => {
391
                        let name = pat.pat.clone();
392
                        quote! { req.#name, }
393
                    }
394
                    syn::FnArg::Receiver(_) => quote! {},
395
                });
396
                quote! {
397
                    fn #method_name(&mut self, tag: u16, req: #request_struct_ident) -> impl ::core::future::Future<
398
                        Output = Result<#return_struct_ident, Error,
399
                    > + Send + Sync {
400
                        Box::pin(async move {
401
                            #return_struct_ident(tag, #trait_name::#method_name(&mut self.inner,
402
                                #(#spread_req)*
403
                            )#maybe_await)
404
                        })
405
                    }
406
                }
407
            } else {
408
8
                quote! {
409
8
                    fn #method_name(&mut self, tag: u16, req: #request_struct_ident) -> impl ::core::future::Future<
410
8
                        Output = Result<#return_struct_ident, Error,
411
8
                    > + Send + Sync {
412
8
                        Box::pin(async move {
413
8
                            #return_struct_ident(tag, #trait_name::#method_name(&mut self.inner)#maybe_await)
414
8
                        })
415
8
                    }
416
8
                    }
417
                };
418
8
            server_calls.extend(new);
419
11
        });
420
6
    }
421
6
    {
422
11
        item.items.iter().enumerate().for_each(|(index, item)| {
423
8
            let TraitItem::Fn(method) = item else {
424
                return;
425
            };
426
8
            let method_name = &method.sig.ident;
427
8
            let variant_name: Ident = IdentCased(method_name.clone()).to_pascal_case().into();
428
8
            let retn = &method.sig.output;
429
8
            let is_async = method.sig.asyncness.is_some();
430
8
            let maybe_async = if is_async {
431
8
                quote! { async }
432
            } else {
433
                quote! {}
434
            };
435
8
            let request_struct_ident = tmsgs.get(index).unwrap().0.clone();
436
8
            let inputs = method.sig.inputs.iter().map(|arg| {
437
8
                match arg {
438
                    syn::FnArg::Typed(pat) => {
439
                        let name = pat.pat.clone();
440
                        let ty = pat.ty.clone();
441
                        quote! {
442
                             #name: #ty,
443
                        }
444
                    }
445
8
                    syn::FnArg::Receiver(_) => quote! {},
446
                }
447
8
            });
448
8
            let args = method.sig.inputs.iter().map(|arg| {
449
8
                match arg {
450
                    syn::FnArg::Typed(pat) => {
451
                        let name = pat.pat.clone();
452
                        quote! {
453
                             #name,
454
                        }
455
                    }
456
8
                    syn::FnArg::Receiver(_) => quote! {},
457
                }
458
8
            });
459
8
            let new = quote! {
460
8
                #maybe_async fn #method_name(&mut self, #(#inputs)*)  #retn {
461
8
                    let tag =#tag_name.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
462
8
                    let req = Tmessage::#variant_name(#request_struct_ident {
463
8
                        #(
464
8
                            #args
465
8
                        )*
466
8
                    });
467
8
                    let tframe= Frame::from((tag, req));
468
8
                    let rframe = self.rpc(tframe).await?;
469
8
                    let rmsg = rframe.msg;
470
8
                    match rmsg {
471
8
                        Rmessage::#variant_name(msg) => Ok(msg.0),
472
8
                        _ => Err(Error::InvalidResponse),
473
8
                    }
474
8
                }
475
8
            };
476

            
477
8
            client_calls.extend(new);
478
11
        });
479
6
    }
480
6

            
481
6
    // make a const with the digest
482
6
    let digest = Literal::string(digest.as_str());
483
11
    let tmsg_definitions = tmsgs.iter().map(|(_ident, def)| {
484
8
        quote! {
485
8
            #def
486
8
        }
487
11
    });
488
6

            
489
11
    let rmsg_definitions = rmsgs.iter().map(|(_ident, def)| {
490
8
        quote! {
491
8
            #def
492
8
        }
493
11
    });
494
6
    let tmessage = generate_tframe(&tmsgs);
495
6
    let rmessage = generate_rframe(&rmsgs);
496
6
    let proto_mod =
497
6
        format_ident!("{}_protocol", trait_name.to_string().to_lowercase());
498
6

            
499
6
    let match_arms = generate_match_arms(tmsgs.clone().into_iter());
500
6
    let match_arm_bodies: Vec<proc_macro2::TokenStream> = item
501
6
        .items
502
6
        .clone()
503
6
        .iter()
504
11
        .map(|item| match item {
505
8
            TraitItem::Fn(method) => {
506
8
                let method_name = &method.sig.ident;
507
8
                let name: IdentCased = method_name.into();
508
8
                let variant_name: Ident = name.to_pascal_case().into();
509
8
                let return_struct_ident = Ident::new(
510
8
                    &format!("R{}", method_name),
511
8
                    method_name.span(),
512
8
                );
513
8
                let variables_spead =
514
8
                    method.sig.inputs.iter().map(|arg| match arg {
515
                        syn::FnArg::Typed(pat) => {
516
                            let name = pat.pat.clone();
517
                            quote! { msg.#name, }
518
                        }
519
8
                        syn::FnArg::Receiver(recv) => handle_receiver(recv),
520
8
                    });
521
8
                quote! {
522
8
                     {
523
8
                        let msg = #trait_name::#method_name(
524
8
                            #(
525
8
                                #variables_spead
526
8
                            )*
527
8
                        ).await?;
528
8
                        let ret = #return_struct_ident(msg);
529
8
                        Ok(Rmessage::#variant_name(ret))
530
8
                    }
531
8
                }
532
            }
533
            _ => quote! {},
534
11
        })
535
6
        .collect();
536
6
    let matches = std::iter::zip(match_arms, match_arm_bodies.iter()).map(
537
11
        |(arm, body)| {
538
8
            quote! {
539
8
                #arm => #body
540
8
            }
541
11
        },
542
6
    );
543

            
544
6
    let trait_attribute = if is_async_trait {
545
        quote! { #[jetstream::prelude::async_trait] }
546
    } else {
547
6
        quote! { #[jetstream::prelude::make(Send + Sync)] }
548
    };
549
6
    quote! {
550
6
        #vis mod #proto_mod{
551
6
            use jetstream::prelude::*;
552
6
            use std::io::{self,Read,Write};
553
6
            use std::mem;
554
6
            use super::#trait_name;
555
6
            const MESSAGE_ID_START: u8 = 101;
556
6
            pub const PROTOCOL_VERSION: &str = #protocol_version;
557
6
            const DIGEST: &str = #digest;
558
6

            
559
6
            #(#msg_ids)*
560
6

            
561
6
            #(#tmsg_definitions)*
562
6

            
563
6
            #(#rmsg_definitions)*
564
6

            
565
6
            #tmessage
566
6

            
567
6
            #rmessage
568
6

            
569
6
            #[derive(Clone)]
570
6
            pub struct #service_name<T: #trait_name> {
571
6
                pub inner: T,
572
6
            }
573
6

            
574
6
            impl<T> Protocol for #service_name<T>
575
6
            where
576
6
                T: #trait_name+ Send + Sync + Sized
577
6
            {
578
6
                type Request = Tmessage;
579
6
                type Response = Rmessage;
580
6
                type Error = Error;
581
6
                const VERSION: &'static str = PROTOCOL_VERSION;
582
6

            
583
6
                fn rpc(&mut self, frame: Frame<<Self as Protocol>::Request>) -> impl ::core::future::Future<
584
6
                    Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
585
6
                > + Send + Sync {
586
6
                    Box::pin(async move {
587
6
                        let req: <Self as Protocol>::Request = frame.msg;
588
6
                        let res: Result<<Self as Protocol>::Response, Self::Error> =match req {
589
6
                                #(
590
6
                                    #matches
591
6
                                )*
592
6
                        };
593
6
                        let rframe: Frame<<Self as Protocol>::Response> = Frame::from((frame.tag, res?));
594
6
                        Ok(rframe)
595
6
                    })
596
6
                }
597
6
            }
598
6
            pub struct #channel_name<'a> {
599
6
                pub inner: Box<&'a mut dyn ClientTransport<Self>>,
600
6
            }
601
6
            impl<'a> Protocol for #channel_name<'a>
602
6
            {
603
6
                type Request = Tmessage;
604
6
                type Response = Rmessage;
605
6
                type Error = Error;
606
6
                const VERSION: &'static str = PROTOCOL_VERSION;
607
6
                fn rpc(&mut self, frame: Frame<<Self as Protocol>::Request>) -> impl ::core::future::Future<
608
6
                    Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
609
6
                > + Send + Sync {
610
6
                    use futures::{SinkExt, StreamExt};
611
6
                    Box::pin(async move {
612
6
                        self.inner
613
6
                            .send(frame)
614
6
                            .await?;
615
6
                        let frame = self.inner.next().await.unwrap()?;
616
6
                        Ok(frame)
617
6
                    })
618
6
                }
619
6
            }
620
6
            lazy_static::lazy_static! {
621
6
                static ref #tag_name: std::sync::atomic::AtomicU16 = std::sync::atomic::AtomicU16::new(0);
622
6
            }
623
6
            impl<'a> #trait_name for #channel_name<'a>
624
6
            {
625
6
                #(#client_calls)*
626
6
            }
627
6

            
628
6
        }
629
6

            
630
6
        #trait_attribute
631
6
        #vis trait #trait_name {
632
6
            #(#trait_items)*
633
6
        }
634
6
    }
635
6
}
636

            
637
#[cfg(test)]
638
mod tests {
639
    use core::panic;
640

            
641
    use syn::parse_quote;
642

            
643
    use super::*;
644

            
645
    fn run_test_with_filters<F>(test_fn: F)
646
    where
647
        F: FnOnce() + panic::UnwindSafe,
648
    {
649
        let filters = vec![
650
            // Filter for protocol version strings
651
            (
652
                r"dev\.branch\.jetstream\.proto/\w+/\d+\.\d+\.\d+-[a-f0-9]{8}",
653
                "dev.branch.jetstream.proto/NAME/VERSION-HASH",
654
            ),
655
            // Filter for digest strings
656
            (r"[a-f0-9]{64}", "DIGEST_HASH"),
657
        ];
658

            
659
        insta::with_settings!({
660
            filters => filters,
661
        }, {
662
            test_fn();
663
        });
664
    }
665

            
666
    #[test]
667
    fn test_simple_service() {
668
        let input: ItemTrait = parse_quote! {
669
            pub trait Echo {
670
                async fn ping(&self) -> Result<(), std::io::Error>;
671
            }
672
        };
673
        let output = service_impl(input, false);
674
        let syntax_tree: syn::File = syn::parse2(output).unwrap();
675
        let output_str = prettyplease::unparse(&syntax_tree);
676
        run_test_with_filters(|| {
677
            insta::assert_snapshot!(output_str, @r#"
678
            pub mod echo_protocol {
679
                use jetstream::prelude::*;
680
                use std::io::{self, Read, Write};
681
                use std::mem;
682
                use super::Echo;
683
                const MESSAGE_ID_START: u8 = 101;
684
                pub const PROTOCOL_VERSION: &str = "dev.branch.jetstream.proto/NAME/VERSION-HASH";
685
                const DIGEST: &str = "DIGEST_HASH";
686
                pub const TPING: u8 = MESSAGE_ID_START + 0u8;
687
                pub const RPING: u8 = MESSAGE_ID_START + 0u8 + 1;
688
                #[allow(non_camel_case_types)]
689
                #[derive(Debug, JetStreamWireFormat)]
690
                pub struct Tping {}
691
                #[allow(non_camel_case_types)]
692
                #[derive(Debug, JetStreamWireFormat)]
693
                pub struct Rping(pub ());
694
                #[derive(Debug)]
695
                #[repr(u8)]
696
                pub enum Tmessage {
697
                    Ping(Tping) = TPING,
698
                }
699
                impl Framer for Tmessage {
700
                    fn byte_size(&self) -> u32 {
701
                        match &self {
702
                            Tmessage::Ping(msg) => msg.byte_size(),
703
                        }
704
                    }
705
                    fn message_type(&self) -> u8 {
706
                        unsafe { *<*const _>::from(self).cast::<u8>() }
707
                    }
708
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
709
                        match &self {
710
                            Tmessage::Ping(msg) => msg.encode(writer)?,
711
                        }
712
                        Ok(())
713
                    }
714
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Tmessage> {
715
                        match ty {
716
                            TPING => Ok(Tmessage::Ping(WireFormat::decode(reader)?)),
717
                            _ => {
718
                                Err(
719
                                    std::io::Error::new(
720
                                        std::io::ErrorKind::InvalidData,
721
                                        format!("unknown message type: {}", ty),
722
                                    ),
723
                                )
724
                            }
725
                        }
726
                    }
727
                }
728
                #[derive(Debug)]
729
                #[repr(u8)]
730
                pub enum Rmessage {
731
                    Ping(Rping) = RPING,
732
                }
733
                impl Framer for Rmessage {
734
                    fn byte_size(&self) -> u32 {
735
                        match &self {
736
                            Rmessage::Ping(msg) => msg.byte_size(),
737
                        }
738
                    }
739
                    fn message_type(&self) -> u8 {
740
                        unsafe { *<*const _>::from(self).cast::<u8>() }
741
                    }
742
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
743
                        match &self {
744
                            Rmessage::Ping(msg) => msg.encode(writer)?,
745
                        }
746
                        Ok(())
747
                    }
748
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Rmessage> {
749
                        match ty {
750
                            RPING => Ok(Rmessage::Ping(WireFormat::decode(reader)?)),
751
                            _ => {
752
                                Err(
753
                                    std::io::Error::new(
754
                                        std::io::ErrorKind::InvalidData,
755
                                        format!("unknown message type: {}", ty),
756
                                    ),
757
                                )
758
                            }
759
                        }
760
                    }
761
                }
762
                #[derive(Clone)]
763
                pub struct EchoService<T: Echo> {
764
                    pub inner: T,
765
                }
766
                impl<T> Protocol for EchoService<T>
767
                where
768
                    T: Echo + Send + Sync + Sized,
769
                {
770
                    type Request = Tmessage;
771
                    type Response = Rmessage;
772
                    type Error = Error;
773
                    const VERSION: &'static str = PROTOCOL_VERSION;
774
                    fn rpc(
775
                        &mut self,
776
                        frame: Frame<<Self as Protocol>::Request>,
777
                    ) -> impl ::core::future::Future<
778
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
779
                    > + Send + Sync {
780
                        Box::pin(async move {
781
                            let req: <Self as Protocol>::Request = frame.msg;
782
                            let res: Result<<Self as Protocol>::Response, Self::Error> = match req {
783
                                Tmessage::Ping(msg) => {
784
                                    let msg = Echo::ping(&self.inner).await?;
785
                                    let ret = Rping(msg);
786
                                    Ok(Rmessage::Ping(ret))
787
                                }
788
                            };
789
                            let rframe: Frame<<Self as Protocol>::Response> = Frame::from((
790
                                frame.tag,
791
                                res?,
792
                            ));
793
                            Ok(rframe)
794
                        })
795
                    }
796
                }
797
                pub struct EchoChannel<'a> {
798
                    pub inner: Box<&'a mut dyn ClientTransport<Self>>,
799
                }
800
                impl<'a> Protocol for EchoChannel<'a> {
801
                    type Request = Tmessage;
802
                    type Response = Rmessage;
803
                    type Error = Error;
804
                    const VERSION: &'static str = PROTOCOL_VERSION;
805
                    fn rpc(
806
                        &mut self,
807
                        frame: Frame<<Self as Protocol>::Request>,
808
                    ) -> impl ::core::future::Future<
809
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
810
                    > + Send + Sync {
811
                        use futures::{SinkExt, StreamExt};
812
                        Box::pin(async move {
813
                            self.inner.send(frame).await?;
814
                            let frame = self.inner.next().await.unwrap()?;
815
                            Ok(frame)
816
                        })
817
                    }
818
                }
819
                lazy_static::lazy_static! {
820
                    static ref ECHO_TAG : std::sync::atomic::AtomicU16 =
821
                    std::sync::atomic::AtomicU16::new(0);
822
                }
823
                impl<'a> Echo for EchoChannel<'a> {
824
                    async fn ping(&mut self) -> Result<(), std::io::Error> {
825
                        let tag = ECHO_TAG.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
826
                        let req = Tmessage::Ping(Tping {});
827
                        let tframe = Frame::from((tag, req));
828
                        let rframe = self.rpc(tframe).await?;
829
                        let rmsg = rframe.msg;
830
                        match rmsg {
831
                            Rmessage::Ping(msg) => Ok(msg.0),
832
                            _ => Err(Error::InvalidResponse),
833
                        }
834
                    }
835
                }
836
            }
837
            #[jetstream::prelude::make(Send+Sync)]
838
            pub trait Echo {
839
                async fn ping(&self) -> Result<(), std::io::Error>;
840
            }
841
            "#)
842
        })
843
    }
844

            
845
    #[test]
846
    fn test_service_with_args() {
847
        let input: ItemTrait = parse_quote! {
848
            pub trait Echo {
849
                async fn ping(&self, message: String) -> Result<String, std::io::Error>;
850
            }
851
        };
852
        let output = service_impl(input, false);
853
        let syntax_tree: syn::File = syn::parse2(output).unwrap();
854
        let output_str = prettyplease::unparse(&syntax_tree);
855
        run_test_with_filters(|| {
856
            insta::assert_snapshot!(output_str, @r#"
857
            pub mod echo_protocol {
858
                use jetstream::prelude::*;
859
                use std::io::{self, Read, Write};
860
                use std::mem;
861
                use super::Echo;
862
                const MESSAGE_ID_START: u8 = 101;
863
                pub const PROTOCOL_VERSION: &str = "dev.branch.jetstream.proto/NAME/VERSION-HASH";
864
                const DIGEST: &str = "DIGEST_HASH";
865
                pub const TPING: u8 = MESSAGE_ID_START + 0u8;
866
                pub const RPING: u8 = MESSAGE_ID_START + 0u8 + 1;
867
                #[allow(non_camel_case_types)]
868
                #[derive(Debug, JetStreamWireFormat)]
869
                pub struct Tping {
870
                    pub message: String,
871
                }
872
                #[allow(non_camel_case_types)]
873
                #[derive(Debug, JetStreamWireFormat)]
874
                pub struct Rping(pub String);
875
                #[derive(Debug)]
876
                #[repr(u8)]
877
                pub enum Tmessage {
878
                    Ping(Tping) = TPING,
879
                }
880
                impl Framer for Tmessage {
881
                    fn byte_size(&self) -> u32 {
882
                        match &self {
883
                            Tmessage::Ping(msg) => msg.byte_size(),
884
                        }
885
                    }
886
                    fn message_type(&self) -> u8 {
887
                        unsafe { *<*const _>::from(self).cast::<u8>() }
888
                    }
889
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
890
                        match &self {
891
                            Tmessage::Ping(msg) => msg.encode(writer)?,
892
                        }
893
                        Ok(())
894
                    }
895
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Tmessage> {
896
                        match ty {
897
                            TPING => Ok(Tmessage::Ping(WireFormat::decode(reader)?)),
898
                            _ => {
899
                                Err(
900
                                    std::io::Error::new(
901
                                        std::io::ErrorKind::InvalidData,
902
                                        format!("unknown message type: {}", ty),
903
                                    ),
904
                                )
905
                            }
906
                        }
907
                    }
908
                }
909
                #[derive(Debug)]
910
                #[repr(u8)]
911
                pub enum Rmessage {
912
                    Ping(Rping) = RPING,
913
                }
914
                impl Framer for Rmessage {
915
                    fn byte_size(&self) -> u32 {
916
                        match &self {
917
                            Rmessage::Ping(msg) => msg.byte_size(),
918
                        }
919
                    }
920
                    fn message_type(&self) -> u8 {
921
                        unsafe { *<*const _>::from(self).cast::<u8>() }
922
                    }
923
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
924
                        match &self {
925
                            Rmessage::Ping(msg) => msg.encode(writer)?,
926
                        }
927
                        Ok(())
928
                    }
929
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Rmessage> {
930
                        match ty {
931
                            RPING => Ok(Rmessage::Ping(WireFormat::decode(reader)?)),
932
                            _ => {
933
                                Err(
934
                                    std::io::Error::new(
935
                                        std::io::ErrorKind::InvalidData,
936
                                        format!("unknown message type: {}", ty),
937
                                    ),
938
                                )
939
                            }
940
                        }
941
                    }
942
                }
943
                #[derive(Clone)]
944
                pub struct EchoService<T: Echo> {
945
                    pub inner: T,
946
                }
947
                impl<T> Protocol for EchoService<T>
948
                where
949
                    T: Echo + Send + Sync + Sized,
950
                {
951
                    type Request = Tmessage;
952
                    type Response = Rmessage;
953
                    type Error = Error;
954
                    const VERSION: &'static str = PROTOCOL_VERSION;
955
                    fn rpc(
956
                        &mut self,
957
                        frame: Frame<<Self as Protocol>::Request>,
958
                    ) -> impl ::core::future::Future<
959
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
960
                    > + Send + Sync {
961
                        Box::pin(async move {
962
                            let req: <Self as Protocol>::Request = frame.msg;
963
                            let res: Result<<Self as Protocol>::Response, Self::Error> = match req {
964
                                Tmessage::Ping(msg) => {
965
                                    let msg = Echo::ping(&self.inner, msg.message).await?;
966
                                    let ret = Rping(msg);
967
                                    Ok(Rmessage::Ping(ret))
968
                                }
969
                            };
970
                            let rframe: Frame<<Self as Protocol>::Response> = Frame::from((
971
                                frame.tag,
972
                                res?,
973
                            ));
974
                            Ok(rframe)
975
                        })
976
                    }
977
                }
978
                pub struct EchoChannel<'a> {
979
                    pub inner: Box<&'a mut dyn ClientTransport<Self>>,
980
                }
981
                impl<'a> Protocol for EchoChannel<'a> {
982
                    type Request = Tmessage;
983
                    type Response = Rmessage;
984
                    type Error = Error;
985
                    const VERSION: &'static str = PROTOCOL_VERSION;
986
                    fn rpc(
987
                        &mut self,
988
                        frame: Frame<<Self as Protocol>::Request>,
989
                    ) -> impl ::core::future::Future<
990
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
991
                    > + Send + Sync {
992
                        use futures::{SinkExt, StreamExt};
993
                        Box::pin(async move {
994
                            self.inner.send(frame).await?;
995
                            let frame = self.inner.next().await.unwrap()?;
996
                            Ok(frame)
997
                        })
998
                    }
999
                }
                lazy_static::lazy_static! {
                    static ref ECHO_TAG : std::sync::atomic::AtomicU16 =
                    std::sync::atomic::AtomicU16::new(0);
                }
                impl<'a> Echo for EchoChannel<'a> {
                    async fn ping(&mut self, message: String) -> Result<String, std::io::Error> {
                        let tag = ECHO_TAG.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
                        let req = Tmessage::Ping(Tping { message });
                        let tframe = Frame::from((tag, req));
                        let rframe = self.rpc(tframe).await?;
                        let rmsg = rframe.msg;
                        match rmsg {
                            Rmessage::Ping(msg) => Ok(msg.0),
                            _ => Err(Error::InvalidResponse),
                        }
                    }
                }
            }
            #[jetstream::prelude::make(Send+Sync)]
            pub trait Echo {
                async fn ping(&self, message: String) -> Result<String, std::io::Error>;
            }
            "#)
        })
    }
    #[test]
    fn test_async_trait_service_with_args() {
        let input: ItemTrait = parse_quote! {
            pub trait Echo {
                async fn ping(&mut self, message: String) -> Result<String, std::io::Error>;
            }
        };
        let output = service_impl(input, true);
        let syntax_tree: syn::File = syn::parse2(output).unwrap();
        let output_str = prettyplease::unparse(&syntax_tree);
        run_test_with_filters(|| {
            insta::assert_snapshot!(output_str, @r###"
            pub mod echo_protocol {
                use jetstream::prelude::*;
                use std::io::{self, Read, Write};
                use std::mem;
                use super::Echo;
                const MESSAGE_ID_START: u8 = 101;
                pub const PROTOCOL_VERSION: &str = "dev.branch.jetstream.proto/NAME/VERSION-HASH";
                const DIGEST: &str = "DIGEST_HASH";
                pub const TPING: u8 = MESSAGE_ID_START + 0u8;
                pub const RPING: u8 = MESSAGE_ID_START + 0u8 + 1;
                #[allow(non_camel_case_types)]
                #[derive(Debug, JetStreamWireFormat)]
                pub struct Tping {
                    pub message: String,
                }
                #[allow(non_camel_case_types)]
                #[derive(Debug, JetStreamWireFormat)]
                pub struct Rping(pub String);
                #[derive(Debug)]
                #[repr(u8)]
                pub enum Tmessage {
                    Ping(Tping) = TPING,
                }
                impl Framer for Tmessage {
                    fn byte_size(&self) -> u32 {
                        match &self {
                            Tmessage::Ping(msg) => msg.byte_size(),
                        }
                    }
                    fn message_type(&self) -> u8 {
                        unsafe { *<*const _>::from(self).cast::<u8>() }
                    }
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
                        match &self {
                            Tmessage::Ping(msg) => msg.encode(writer)?,
                        }
                        Ok(())
                    }
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Tmessage> {
                        match ty {
                            TPING => Ok(Tmessage::Ping(WireFormat::decode(reader)?)),
                            _ => {
                                Err(
                                    std::io::Error::new(
                                        std::io::ErrorKind::InvalidData,
                                        format!("unknown message type: {}", ty),
                                    ),
                                )
                            }
                        }
                    }
                }
                #[derive(Debug)]
                #[repr(u8)]
                pub enum Rmessage {
                    Ping(Rping) = RPING,
                }
                impl Framer for Rmessage {
                    fn byte_size(&self) -> u32 {
                        match &self {
                            Rmessage::Ping(msg) => msg.byte_size(),
                        }
                    }
                    fn message_type(&self) -> u8 {
                        unsafe { *<*const _>::from(self).cast::<u8>() }
                    }
                    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
                        match &self {
                            Rmessage::Ping(msg) => msg.encode(writer)?,
                        }
                        Ok(())
                    }
                    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Rmessage> {
                        match ty {
                            RPING => Ok(Rmessage::Ping(WireFormat::decode(reader)?)),
                            _ => {
                                Err(
                                    std::io::Error::new(
                                        std::io::ErrorKind::InvalidData,
                                        format!("unknown message type: {}", ty),
                                    ),
                                )
                            }
                        }
                    }
                }
                #[derive(Clone)]
                pub struct EchoService<T: Echo> {
                    pub inner: T,
                }
                impl<T> Protocol for EchoService<T>
                where
                    T: Echo + Send + Sync + Sized,
                {
                    type Request = Tmessage;
                    type Response = Rmessage;
                    type Error = Error;
                    const VERSION: &'static str = PROTOCOL_VERSION;
                    fn rpc(
                        &mut self,
                        frame: Frame<<Self as Protocol>::Request>,
                    ) -> impl ::core::future::Future<
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
                    > + Send + Sync {
                        Box::pin(async move {
                            let req: <Self as Protocol>::Request = frame.msg;
                            let res: Result<<Self as Protocol>::Response, Self::Error> = match req {
                                Tmessage::Ping(msg) => {
                                    let msg = Echo::ping(&mut self.inner, msg.message).await?;
                                    let ret = Rping(msg);
                                    Ok(Rmessage::Ping(ret))
                                }
                            };
                            let rframe: Frame<<Self as Protocol>::Response> = Frame::from((
                                frame.tag,
                                res?,
                            ));
                            Ok(rframe)
                        })
                    }
                }
                pub struct EchoChannel<'a> {
                    pub inner: Box<&'a mut dyn ClientTransport<Self>>,
                }
                impl<'a> Protocol for EchoChannel<'a> {
                    type Request = Tmessage;
                    type Response = Rmessage;
                    type Error = Error;
                    const VERSION: &'static str = PROTOCOL_VERSION;
                    fn rpc(
                        &mut self,
                        frame: Frame<<Self as Protocol>::Request>,
                    ) -> impl ::core::future::Future<
                        Output = Result<Frame<<Self as Protocol>::Response>, Self::Error>,
                    > + Send + Sync {
                        use futures::{SinkExt, StreamExt};
                        Box::pin(async move {
                            self.inner.send(frame).await?;
                            let frame = self.inner.next().await.unwrap()?;
                            Ok(frame)
                        })
                    }
                }
                lazy_static::lazy_static! {
                    static ref ECHO_TAG : std::sync::atomic::AtomicU16 =
                    std::sync::atomic::AtomicU16::new(0);
                }
                impl<'a> Echo for EchoChannel<'a> {
                    async fn ping(&mut self, message: String) -> Result<String, std::io::Error> {
                        let tag = ECHO_TAG.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
                        let req = Tmessage::Ping(Tping { message });
                        let tframe = Frame::from((tag, req));
                        let rframe = self.rpc(tframe).await?;
                        let rmsg = rframe.msg;
                        match rmsg {
                            Rmessage::Ping(msg) => Ok(msg.0),
                            _ => Err(Error::InvalidResponse),
                        }
                    }
                }
            }
            #[jetstream::prelude::async_trait]
            pub trait Echo {
                async fn ping(&mut self, message: String) -> Result<String, std::io::Error>;
            }
            "###)
        })
    }
}