1
use {
2
    proc_macro2::{Span, TokenStream},
3
    quote::{quote, quote_spanned},
4
    syn::{spanned::Spanned, Data, DeriveInput, Fields, Ident, Meta},
5
};
6

            
7
1336
fn has_skip_attr(field: &syn::Field) -> bool {
8
1336
    field.attrs.iter().any(|attr| {
9
        if attr.path().is_ident("jetstream") {
10
            if let Ok(()) = attr.parse_nested_meta(|meta| {
11
                if meta.path.is_ident("skip") {
12
                    Ok(())
13
                } else {
14
                    Err(meta.error("expected `skip`"))
15
                }
16
            }) {
17
                return true;
18
            }
19
        }
20
        false
21
1336
    })
22
1336
}
23

            
24
120
fn extract_jetstream_type(input: &DeriveInput) -> Option<Ident> {
25
660
    for attr in &input.attrs {
26
540
        if attr.path().is_ident("jetstream_type") {
27
            if let Ok(Meta::Path(path)) = attr.parse_args() {
28
                if let Some(ident) = path.get_ident() {
29
                    return Some(ident.clone());
30
                }
31
            }
32
540
        }
33
    }
34
120
    None
35
120
}
36

            
37
120
pub(crate) fn wire_format_inner(input: DeriveInput) -> TokenStream {
38
120
    if !input.generics.params.is_empty() {
39
        return quote! {
40
            compile_error!("derive(JetStreamWireFormat) does not support generic parameters");
41
        };
42
120
    }
43
120
    let jetstream_type = extract_jetstream_type(&input);
44
120
    let container = input.ident;
45

            
46
    // Generate message type implementation
47
120
    let message_impl = if let Some(msg_type) = jetstream_type {
48
        quote! {
49
           impl jetstream_wireformat::Message for #container {
50
               const MESSAGE_TYPE: u8 = super::#msg_type;
51
           }
52
        }
53
    } else {
54
120
        quote! {}
55
    };
56

            
57
120
    let byte_size_impl = byte_size_sum(&input.data);
58
120
    let encode_impl = encode_wire_format(&input.data);
59
120
    let decode_impl = decode_wire_format(&input.data, &container);
60
120

            
61
120
    let scope = format!("wire_format_{}", container).to_lowercase();
62
120
    let scope = Ident::new(&scope, Span::call_site());
63
120
    quote! {
64
120
        mod #scope {
65
120
            extern crate std;
66
120
            use self::std::io;
67
120
            use self::std::result::Result::Ok;
68
120
            use super::#container;
69
120
            use jetstream_wireformat::WireFormat;
70
120

            
71
120
            impl WireFormat for #container {
72
120
                fn byte_size(&self) -> u32 {
73
120
                    #byte_size_impl
74
120
                }
75
120

            
76
120
                fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
77
120
                    #encode_impl
78
120
                }
79
120

            
80
120
                fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
81
120
                    #decode_impl
82
120
                }
83
120
            }
84
120
            #message_impl
85
120
        }
86
120
    }
87
120
}
88

            
89
120
fn byte_size_sum(data: &Data) -> TokenStream {
90
120
    if let Data::Struct(ref data) = *data {
91
120
        if let Fields::Named(ref fields) = data.fields {
92
382
            let fields = fields.named.iter().filter(|f| !has_skip_attr(f)).map(|f| {
93
326
                let field = &f.ident;
94
326
                let span = field.span();
95
326
                quote_spanned! {span=>
96
326
                    WireFormat::byte_size(&self.#field)
97
326
                }
98
382
            });
99
112

            
100
112
            quote! {
101
112
                0 #(+ #fields)*
102
112
            }
103
8
        } else if let Fields::Unnamed(unnamed) = &data.fields {
104
8
            let fields = unnamed
105
8
                .unnamed
106
8
                .iter()
107
8
                .enumerate()
108
12
                .filter(|(_, f)| !has_skip_attr(f))
109
12
                .map(|(i, _f)| {
110
8
                    let index = syn::Index::from(i);
111
8
                    quote! {
112
8
                        WireFormat::byte_size(&self.#index)
113
8
                    }
114
12
                });
115
8

            
116
8
            quote! {
117
8
                0 #(+ #fields)*
118
8
            }
119
        } else {
120
            unimplemented!();
121
        }
122
    } else if let Data::Enum(ref data) = *data {
123
        let variants = data.variants.iter().map(|variant| {
124
            let variant_ident = &variant.ident;
125
            match &variant.fields {
126
                Fields::Named(fields) => {
127
                    let field_idents = fields
128
                        .named
129
                        .iter()
130
                        .filter(|f| !has_skip_attr(f))
131
                        .map(|f| &f.ident)
132
                        .collect::<Vec<_>>();
133
                    quote! {
134
                        Self::#variant_ident { #(ref #field_idents),* } => {
135
                            1 #(+ WireFormat::byte_size(#field_idents))*
136
                        }
137
                    }
138
                }
139
                Fields::Unnamed(fields) => {
140
                    let refs = fields
141
                        .unnamed
142
                        .iter()
143
                        .enumerate()
144
                        .filter(|(_, f)| !has_skip_attr(f))
145
                        .map(|(i, _)| format!("__{}", i))
146
                        .map(|name| Ident::new(&name, Span::call_site()))
147
                        .collect::<Vec<_>>();
148
                    quote! {
149
                        Self::#variant_ident(#(ref #refs),*) => {
150
                            1 #(+ WireFormat::byte_size(#refs))*
151
                        }
152
                    }
153
                }
154
                Fields::Unit => {
155
                    quote! {
156
                        Self::#variant_ident => 1
157
                    }
158
                }
159
            }
160
        });
161

            
162
        quote! {
163
            match self {
164
                #(#variants),*
165
            }
166
        }
167
    } else {
168
        unimplemented!();
169
    }
170
120
}
171

            
172
120
fn encode_wire_format(data: &Data) -> TokenStream {
173
120
    if let Data::Struct(ref data) = *data {
174
120
        if let Fields::Named(ref fields) = data.fields {
175
382
            let fields = fields.named.iter().filter(|f| !has_skip_attr(f)).map(|f| {
176
326
                let field = &f.ident;
177
326
                let span = field.span();
178
326
                quote_spanned! {span=>
179
326
                    WireFormat::encode(&self.#field, _writer)?;
180
326
                }
181
382
            });
182
112

            
183
112
            quote! {
184
112
                #(#fields)*
185
112
                Ok(())
186
112
            }
187
8
        } else if let Fields::Unnamed(unnamed) = &data.fields {
188
8
            let fields = unnamed
189
8
                .unnamed
190
8
                .iter()
191
8
                .enumerate()
192
12
                .filter(|(_, f)| !has_skip_attr(f))
193
12
                .map(|(i, _f)| {
194
8
                    let index = syn::Index::from(i);
195
8
                    quote! {
196
8
                        WireFormat::encode(&self.#index, _writer)?;
197
8
                    }
198
12
                });
199
8

            
200
8
            quote! {
201
8
                #(#fields)*
202
8
                Ok(())
203
8
            }
204
        } else {
205
            unimplemented!();
206
        }
207
    } else if let Data::Enum(ref data) = *data {
208
        let variants = data.variants.iter().enumerate().map(|(idx, variant)| {
209
            let variant_ident = &variant.ident;
210
            let idx = idx as u8;
211

            
212
            match &variant.fields {
213
                Fields::Named(ref fields) => {
214
                    let field_idents = fields
215
                        .named
216
                        .iter()
217
                        .filter(|f| !has_skip_attr(f))
218
                        .map(|f| &f.ident)
219
                        .collect::<Vec<_>>();
220

            
221
                    quote! {
222
                        Self::#variant_ident { #(ref #field_idents),* } => {
223
                            WireFormat::encode(&(#idx), _writer)?;
224
                            #(WireFormat::encode(#field_idents, _writer)?;)*
225
                        }
226
                    }
227
                }
228
                Fields::Unnamed(ref fields) => {
229
                    let field_refs = fields
230
                        .unnamed
231
                        .iter()
232
                        .enumerate()
233
                        .filter(|(_, f)| !has_skip_attr(f))
234
                        .map(|(i, _)| format!("__{}", i))
235
                        .map(|name| Ident::new(&name, Span::call_site()))
236
                        .collect::<Vec<_>>();
237
                    quote! {
238
                        Self::#variant_ident(#(ref #field_refs),*) => {
239
                            WireFormat::encode(&(#idx), _writer)?;
240
                            #(WireFormat::encode(#field_refs, _writer)?;)*
241
                        }
242
                    }
243
                }
244
                Fields::Unit => {
245
                    quote! {
246
                        Self::#variant_ident => {
247
                            WireFormat::encode(&(#idx), _writer)?;
248
                        }
249
                    }
250
                }
251
            }
252
        });
253

            
254
        quote! {
255
            match self {
256
                #(#variants),*
257
            }
258
            Ok(())
259
        }
260
    } else {
261
        unimplemented!();
262
    }
263
120
}
264

            
265
120
fn decode_wire_format(data: &Data, container: &Ident) -> TokenStream {
266
120
    if let Data::Struct(ref data) = *data {
267
120
        if let Fields::Named(ref fields) = data.fields {
268
112
            let all_fields = fields.named.iter().collect::<Vec<_>>();
269
382
            let non_skipped_values = fields.named.iter().filter(|f| !has_skip_attr(f)).map(|f| {
270
326
                let field = &f.ident;
271
326
                let span = field.span();
272
326
                quote_spanned! {span=>
273
326
                    let #field = WireFormat::decode(_reader)?;
274
326
                }
275
382
            });
276
112

            
277
382
            let members = all_fields.iter().map(|f| {
278
326
                let field = &f.ident;
279
326
                if has_skip_attr(f) {
280
                    quote! {
281
                        #field: Default::default(),
282
                    }
283
                } else {
284
326
                    quote! {
285
326
                        #field: #field,
286
326
                    }
287
                }
288
382
            });
289
112

            
290
112
            quote! {
291
112
                #(#non_skipped_values)*
292
112
                Ok(#container {
293
112
                    #(#members)*
294
112
                })
295
112
            }
296
8
        } else if let Fields::Unnamed(unnamed) = &data.fields {
297
8
            let all_fields = unnamed
298
8
                .unnamed
299
8
                .iter()
300
8
                .enumerate()
301
12
                .map(|(i, f)| (i, has_skip_attr(f)))
302
8
                .collect::<Vec<_>>();
303
8

            
304
8
            let non_skipped_values = unnamed
305
8
                .unnamed
306
8
                .iter()
307
8
                .enumerate()
308
12
                .filter(|(_, f)| !has_skip_attr(f))
309
12
                .map(|(i, _f)| {
310
8
                    let ident = Ident::new(&format!("__{}", i), Span::call_site());
311
8
                    quote! {
312
8
                        let #ident = WireFormat::decode(_reader)?;
313
8
                    }
314
12
                });
315
8

            
316
12
            let members = all_fields.iter().map(|(i, is_skipped)| {
317
8
                let ident = if *is_skipped {
318
                    quote! { Default::default() }
319
                } else {
320
8
                    let ident = Ident::new(&format!("__{}", i), Span::call_site());
321
8
                    quote! { #ident }
322
                };
323
8
                quote! { #ident }
324
12
            });
325
8

            
326
8
            quote! {
327
8
                #(#non_skipped_values)*
328
8
                Ok(#container(
329
8
                    #(#members,)*
330
8
                ))
331
8
            }
332
        } else {
333
            unimplemented!();
334
        }
335
    } else if let Data::Enum(ref data) = *data {
336
        let mut variant_matches = data
337
            .variants
338
            .iter()
339
            .enumerate()
340
            .map(|(idx, variant)| {
341
                let variant_ident = &variant.ident;
342
                let idx = idx as u8;
343

            
344
                match &variant.fields {
345
                    Fields::Named(ref fields) => {
346
                        let field_decodes =
347
                            fields.named.iter().filter(|f| !has_skip_attr(f)).map(|f| {
348
                                let field_ident = &f.ident;
349
                                quote! { let #field_ident = WireFormat::decode(_reader)?; }
350
                            });
351
                        let field_names = fields.named.iter().map(|f| {
352
                            let field_ident = &f.ident;
353
                            if has_skip_attr(f) {
354
                                quote! { #field_ident: Default::default() }
355
                            } else {
356
                                // Just use the field name directly for the shorthand syntax
357
                                quote! { #field_ident }
358
                            }
359
                        });
360

            
361
                        quote! {
362
                            #idx => {
363
                                #(#field_decodes)*
364
                                Ok(Self::#variant_ident { #(#field_names),* })
365
                            }
366
                        }
367
                    }
368
                    Fields::Unnamed(ref fields) => {
369
                        let field_decodes = fields
370
                            .unnamed
371
                            .iter()
372
                            .enumerate()
373
                            .filter(|(_, f)| !has_skip_attr(f))
374
                            .map(|(i, _)| {
375
                                let field_name = Ident::new(&format!("__{}", i), Span::call_site());
376
                                quote! { let #field_name = WireFormat::decode(_reader)?; }
377
                            });
378
                        let field_names = fields.unnamed.iter().enumerate().map(|(i, f)| {
379
                            if has_skip_attr(f) {
380
                                quote! { Default::default() }
381
                            } else {
382
                                let field_name = Ident::new(&format!("__{}", i), Span::call_site());
383
                                quote! { #field_name }
384
                            }
385
                        });
386

            
387
                        quote! {
388
                            #idx => {
389
                                #(#field_decodes)*
390
                                Ok(Self::#variant_ident(#(#field_names),*))
391
                            }
392
                        }
393
                    }
394
                    Fields::Unit => {
395
                        quote! {
396
                            #idx => Ok(Self::#variant_ident)
397
                        }
398
                    }
399
                }
400
            })
401
            .collect::<Vec<_>>();
402

            
403
        variant_matches.push(quote! {
404
              _ => Err(::std::io::Error::new(::std::io::ErrorKind::InvalidData, "invalid variant index"))
405
          });
406

            
407
        quote! {
408
            let variant_index: u8 = WireFormat::decode(_reader)?;
409
            match variant_index {
410
                #(#variant_matches),*
411
            }
412
        }
413
    } else {
414
        unimplemented!();
415
    }
416
120
}
417

            
418
#[cfg(test)]
419
mod tests {
420
    extern crate pretty_assertions;
421
    use syn::parse_quote;
422

            
423
    use {self::pretty_assertions::assert_eq, super::*};
424

            
425
    #[test]
426
    fn byte_size() {
427
        let input: DeriveInput = parse_quote! {
428
            struct Item {
429
                ident: u32,
430
                with_underscores: String,
431
                other: u8,
432
            }
433
        };
434

            
435
        let expected = quote! {
436
            0
437
                + WireFormat::byte_size(&self.ident)
438
                + WireFormat::byte_size(&self.with_underscores)
439
                + WireFormat::byte_size(&self.other)
440
        };
441

            
442
        assert_eq!(byte_size_sum(&input.data).to_string(), expected.to_string());
443
    }
444

            
445
    #[test]
446
    fn encode() {
447
        let input: DeriveInput = parse_quote! {
448
            struct Item {
449
                ident: u32,
450
                with_underscores: String,
451
                other: u8,
452
            }
453
        };
454

            
455
        let expected = quote! {
456
            WireFormat::encode(&self.ident, _writer)?;
457
            WireFormat::encode(&self.with_underscores, _writer)?;
458
            WireFormat::encode(&self.other, _writer)?;
459
            Ok(())
460
        };
461

            
462
        assert_eq!(
463
            encode_wire_format(&input.data).to_string(),
464
            expected.to_string(),
465
        );
466
    }
467

            
468
    #[test]
469
    fn decode() {
470
        let input: DeriveInput = parse_quote! {
471
            struct Item {
472
                ident: u32,
473
                with_underscores: String,
474
                other: u8,
475
            }
476
        };
477

            
478
        let container = Ident::new("Item", Span::call_site());
479
        let expected = quote! {
480
            let ident = WireFormat::decode(_reader)?;
481
            let with_underscores = WireFormat::decode(_reader)?;
482
            let other = WireFormat::decode(_reader)?;
483
            Ok(Item {
484
                ident: ident,
485
                with_underscores: with_underscores,
486
                other: other,
487
            })
488
        };
489

            
490
        assert_eq!(
491
            decode_wire_format(&input.data, &container).to_string(),
492
            expected.to_string(),
493
        );
494
    }
495

            
496
    #[test]
497
    fn end_to_end() {
498
        let input: DeriveInput = parse_quote! {
499
            struct Niijima_先輩 {
500
                a: u8,
501
                b: u16,
502
                c: u32,
503
                d: u64,
504
                e: String,
505
                f: Vec<String>,
506
                g: Nested,
507
            }
508
        };
509
        let output = wire_format_inner(input);
510
        let syntax_tree: syn::File = syn::parse2(output).unwrap();
511
        let output_str = prettyplease::unparse(&syntax_tree);
512
        insta::assert_snapshot!(output_str, @r###"
513
        mod wire_format_niijima_先輩 {
514
            extern crate std;
515
            use self::std::io;
516
            use self::std::result::Result::Ok;
517
            use super::Niijima_先輩;
518
            use jetstream_wireformat::WireFormat;
519
            impl WireFormat for Niijima_先輩 {
520
                fn byte_size(&self) -> u32 {
521
                    0 + WireFormat::byte_size(&self.a) + WireFormat::byte_size(&self.b)
522
                        + WireFormat::byte_size(&self.c) + WireFormat::byte_size(&self.d)
523
                        + WireFormat::byte_size(&self.e) + WireFormat::byte_size(&self.f)
524
                        + WireFormat::byte_size(&self.g)
525
                }
526
                fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
527
                    WireFormat::encode(&self.a, _writer)?;
528
                    WireFormat::encode(&self.b, _writer)?;
529
                    WireFormat::encode(&self.c, _writer)?;
530
                    WireFormat::encode(&self.d, _writer)?;
531
                    WireFormat::encode(&self.e, _writer)?;
532
                    WireFormat::encode(&self.f, _writer)?;
533
                    WireFormat::encode(&self.g, _writer)?;
534
                    Ok(())
535
                }
536
                fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
537
                    let a = WireFormat::decode(_reader)?;
538
                    let b = WireFormat::decode(_reader)?;
539
                    let c = WireFormat::decode(_reader)?;
540
                    let d = WireFormat::decode(_reader)?;
541
                    let e = WireFormat::decode(_reader)?;
542
                    let f = WireFormat::decode(_reader)?;
543
                    let g = WireFormat::decode(_reader)?;
544
                    Ok(Niijima_先輩 {
545
                        a: a,
546
                        b: b,
547
                        c: c,
548
                        d: d,
549
                        e: e,
550
                        f: f,
551
                        g: g,
552
                    })
553
                }
554
            }
555
        }
556
        "###);
557
    }
558

            
559
    #[test]
560
    fn end_to_end_unnamed() {
561
        let input: DeriveInput = parse_quote! {
562
            struct Niijima_先輩(u8, u16, u32, u64, String, Vec<String>, Nested);
563
        };
564

            
565
        let output = wire_format_inner(input);
566
        let syntax_tree: syn::File = syn::parse2(output).unwrap();
567
        let output_str = prettyplease::unparse(&syntax_tree);
568
        insta::assert_snapshot!(output_str, @r###"
569
        mod wire_format_niijima_先輩 {
570
            extern crate std;
571
            use self::std::io;
572
            use self::std::result::Result::Ok;
573
            use super::Niijima_先輩;
574
            use jetstream_wireformat::WireFormat;
575
            impl WireFormat for Niijima_先輩 {
576
                fn byte_size(&self) -> u32 {
577
                    0 + WireFormat::byte_size(&self.0) + WireFormat::byte_size(&self.1)
578
                        + WireFormat::byte_size(&self.2) + WireFormat::byte_size(&self.3)
579
                        + WireFormat::byte_size(&self.4) + WireFormat::byte_size(&self.5)
580
                        + WireFormat::byte_size(&self.6)
581
                }
582
                fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
583
                    WireFormat::encode(&self.0, _writer)?;
584
                    WireFormat::encode(&self.1, _writer)?;
585
                    WireFormat::encode(&self.2, _writer)?;
586
                    WireFormat::encode(&self.3, _writer)?;
587
                    WireFormat::encode(&self.4, _writer)?;
588
                    WireFormat::encode(&self.5, _writer)?;
589
                    WireFormat::encode(&self.6, _writer)?;
590
                    Ok(())
591
                }
592
                fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
593
                    let __0 = WireFormat::decode(_reader)?;
594
                    let __1 = WireFormat::decode(_reader)?;
595
                    let __2 = WireFormat::decode(_reader)?;
596
                    let __3 = WireFormat::decode(_reader)?;
597
                    let __4 = WireFormat::decode(_reader)?;
598
                    let __5 = WireFormat::decode(_reader)?;
599
                    let __6 = WireFormat::decode(_reader)?;
600
                    Ok(Niijima_先輩(__0, __1, __2, __3, __4, __5, __6))
601
                }
602
            }
603
        }
604
        "###);
605
    }
606

            
607
    #[test]
608
    fn enum_byte_size() {
609
        let input: DeriveInput = parse_quote! {
610
            enum Message {
611
                Ping,
612
                Text { content: String },
613
                Binary(Vec<u8>),
614
            }
615
        };
616

            
617
        let expected = quote! {
618
            match self {
619
                Self::Ping => 1,
620
                Self::Text { ref content } => { 1 + WireFormat::byte_size(content) },
621
                Self::Binary(ref __0) => { 1 + WireFormat::byte_size(__0) }
622
            }
623
        };
624

            
625
        assert_eq!(byte_size_sum(&input.data).to_string(), expected.to_string());
626
    }
627

            
628
    #[test]
629
    fn enum_encode() {
630
        let input: DeriveInput = parse_quote! {
631
            enum Message {
632
                Ping,
633
                Text { content: String },
634
                Binary(Vec<u8>),
635
            }
636
        };
637

            
638
        let expected = quote! {
639
            match self {
640
                Self::Ping => {
641
                    WireFormat::encode(&(0u8), _writer)?;
642
                },
643
                Self::Text { ref content } => {
644
                    WireFormat::encode(&(1u8), _writer)?;
645
                    WireFormat::encode(content, _writer)?;
646
                },
647
                Self::Binary(ref __0) => {
648
                    WireFormat::encode(&(2u8), _writer)?;
649
                    WireFormat::encode(__0, _writer)?;
650
                }
651
            }
652
            Ok(())
653
        };
654

            
655
        assert_eq!(
656
            encode_wire_format(&input.data).to_string(),
657
            expected.to_string()
658
        );
659
    }
660

            
661
    #[test]
662
    fn enum_decode() {
663
        let input: DeriveInput = parse_quote! {
664
            enum Message {
665
                Ping,
666
                Text { content: String },
667
                Binary(Vec<u8>),
668
            }
669
        };
670

            
671
        let container = Ident::new("Message", Span::call_site());
672
        let expected = quote! {
673
            let variant_index: u8 = WireFormat::decode(_reader)?;
674
            match variant_index {
675
                0u8 => Ok(Self::Ping) ,
676
                1u8 => {
677
                    let content = WireFormat::decode(_reader)?;
678
                    Ok(Self::Text { content })
679
                },
680
                2u8 => {
681
                    let __0 = WireFormat::decode(_reader)?;
682
                    Ok(Self::Binary(__0))
683
                },
684
                _ => Err(::std::io::Error::new(::std::io::ErrorKind::InvalidData, "invalid variant index"))
685
            }
686
        };
687

            
688
        assert_eq!(
689
            decode_wire_format(&input.data, &container).to_string(),
690
            expected.to_string()
691
        );
692
    }
693

            
694
    #[test]
695
    fn enum_end_to_end() {
696
        let input: DeriveInput = parse_quote! {
697
            enum Message {
698
                Ping,
699
                Text { content: String },
700
                Binary(Vec<u8>),
701
            }
702
        };
703
        let output = wire_format_inner(input);
704
        let syntax_tree: syn::File = syn::parse2(output).unwrap();
705
        let output_str = prettyplease::unparse(&syntax_tree);
706
        insta::assert_snapshot!(output_str, @r###"
707
        mod wire_format_message {
708
            extern crate std;
709
            use self::std::io;
710
            use self::std::result::Result::Ok;
711
            use super::Message;
712
            use jetstream_wireformat::WireFormat;
713
            impl WireFormat for Message {
714
                fn byte_size(&self) -> u32 {
715
                    match self {
716
                        Self::Ping => 1,
717
                        Self::Text { ref content } => 1 + WireFormat::byte_size(content),
718
                        Self::Binary(ref __0) => 1 + WireFormat::byte_size(__0),
719
                    }
720
                }
721
                fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
722
                    match self {
723
                        Self::Ping => {
724
                            WireFormat::encode(&(0u8), _writer)?;
725
                        }
726
                        Self::Text { ref content } => {
727
                            WireFormat::encode(&(1u8), _writer)?;
728
                            WireFormat::encode(content, _writer)?;
729
                        }
730
                        Self::Binary(ref __0) => {
731
                            WireFormat::encode(&(2u8), _writer)?;
732
                            WireFormat::encode(__0, _writer)?;
733
                        }
734
                    }
735
                    Ok(())
736
                }
737
                fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
738
                    let variant_index: u8 = WireFormat::decode(_reader)?;
739
                    match variant_index {
740
                        0u8 => Ok(Self::Ping),
741
                        1u8 => {
742
                            let content = WireFormat::decode(_reader)?;
743
                            Ok(Self::Text { content })
744
                        }
745
                        2u8 => {
746
                            let __0 = WireFormat::decode(_reader)?;
747
                            Ok(Self::Binary(__0))
748
                        }
749
                        _ => {
750
                            Err(
751
                                ::std::io::Error::new(
752
                                    ::std::io::ErrorKind::InvalidData,
753
                                    "invalid variant index",
754
                                ),
755
                            )
756
                        }
757
                    }
758
                }
759
            }
760
        }
761
        "###);
762
    }
763
    #[test]
764
    fn test_struct_skip_field() {
765
        let input: DeriveInput = parse_quote! {
766
            struct Item {
767
                ident: u32,
768
                #[jetstream(skip)]
769
                skipped: String,
770
                other: u8,
771
            }
772
        };
773

            
774
        // Test byte_size
775
        let expected_size = quote! {
776
            0
777
                + WireFormat::byte_size(&self.ident)
778
                + WireFormat::byte_size(&self.other)
779
        };
780

            
781
        assert_eq!(
782
            byte_size_sum(&input.data).to_string(),
783
            expected_size.to_string()
784
        );
785

            
786
        // Test encode
787
        let expected_encode = quote! {
788
            WireFormat::encode(&self.ident, _writer)?;
789
            WireFormat::encode(&self.other, _writer)?;
790
            Ok(())
791
        };
792

            
793
        assert_eq!(
794
            encode_wire_format(&input.data).to_string(),
795
            expected_encode.to_string()
796
        );
797

            
798
        // Test decode
799
        let container = Ident::new("Item", Span::call_site());
800
        let expected_decode = quote! {
801
            let ident = WireFormat::decode(_reader)?;
802
            let other = WireFormat::decode(_reader)?;
803
            Ok(Item {
804
                ident: ident,
805
                skipped: Default::default(),
806
                other: other,
807
            })
808
        };
809

            
810
        assert_eq!(
811
            decode_wire_format(&input.data, &container).to_string(),
812
            expected_decode.to_string()
813
        );
814
    }
815

            
816
    #[test]
817
    fn test_tuple_struct_skip_field() {
818
        let input: DeriveInput = parse_quote! {
819
            struct Item(u32, #[jetstream(skip)] String, u8);
820
        };
821

            
822
        // Test byte_size
823
        let expected_size = quote! {
824
            0
825
                + WireFormat::byte_size(&self.0)
826
                + WireFormat::byte_size(&self.2)
827
        };
828

            
829
        assert_eq!(
830
            byte_size_sum(&input.data).to_string(),
831
            expected_size.to_string()
832
        );
833

            
834
        // Test encode
835
        let expected_encode = quote! {
836
            WireFormat::encode(&self.0, _writer)?;
837
            WireFormat::encode(&self.2, _writer)?;
838
            Ok(())
839
        };
840

            
841
        assert_eq!(
842
            encode_wire_format(&input.data).to_string(),
843
            expected_encode.to_string()
844
        );
845

            
846
        // Test decode
847
        let container = Ident::new("Item", Span::call_site());
848
        let expected_decode = quote! {
849
            let __0 = WireFormat::decode(_reader)?;
850
            let __2 = WireFormat::decode(_reader)?;
851
            Ok(Item(__0, Default::default(), __2,))
852
        };
853

            
854
        assert_eq!(
855
            decode_wire_format(&input.data, &container).to_string(),
856
            expected_decode.to_string()
857
        );
858
    }
859

            
860
    #[test]
861
    fn test_enum_skip_field() {
862
        let input: DeriveInput = parse_quote! {
863
            enum Message {
864
                Ping,
865
                Text {
866
                    content: String,
867
                    #[jetstream(skip)]
868
                    metadata: Vec<u8>
869
                },
870
                Binary(Vec<u8>, #[jetstream(skip)] String),
871
            }
872
        };
873

            
874
        // Test byte_size
875
        let expected_size = quote! {
876
            match self {
877
                Self::Ping => 1,
878
                Self::Text { ref content } => { 1 + WireFormat::byte_size(content) },
879
                Self::Binary(ref __0) => { 1 + WireFormat::byte_size(__0) }
880
            }
881
        };
882

            
883
        assert_eq!(
884
            byte_size_sum(&input.data).to_string(),
885
            expected_size.to_string()
886
        );
887

            
888
        // Test encode
889
        let expected_encode = quote! {
890
            match self {
891
                Self::Ping => {
892
                    WireFormat::encode(&(0u8), _writer)?;
893
                },
894
                Self::Text { ref content } => {
895
                    WireFormat::encode(&(1u8), _writer)?;
896
                    WireFormat::encode(content, _writer)?;
897
                },
898
                Self::Binary(ref __0) => {
899
                    WireFormat::encode(&(2u8), _writer)?;
900
                    WireFormat::encode(__0, _writer)?;
901
                }
902
            }
903
            Ok(())
904
        };
905

            
906
        assert_eq!(
907
            encode_wire_format(&input.data).to_string(),
908
            expected_encode.to_string()
909
        );
910

            
911
        // Test decode
912
        let container = Ident::new("Message", Span::call_site());
913
        let expected_decode = quote! {
914
            let variant_index: u8 = WireFormat::decode(_reader)?;
915
            match variant_index {
916
                0u8 => Ok(Self::Ping),
917
                1u8 => {
918
                    let content = WireFormat::decode(_reader)?;
919
                    Ok(Self::Text { content, metadata: Default::default() })
920
                },
921
                2u8 => {
922
                    let __0 = WireFormat::decode(_reader)?;
923
                    Ok(Self::Binary(__0, Default::default()))
924
                },
925
                _ => Err(::std::io::Error::new(::std::io::ErrorKind::InvalidData, "invalid variant index"))
926
            }
927
        };
928

            
929
        assert_eq!(
930
            decode_wire_format(&input.data, &container).to_string(),
931
            expected_decode.to_string()
932
        );
933
    }
934

            
935
    #[test]
936
    fn test_end_to_end_with_skip() {
937
        let input: DeriveInput = parse_quote! {
938
            struct Item {
939
                a: u8,
940
                #[jetstream(skip)]
941
                skip_this: String,
942
                b: u16,
943
                #[jetstream(skip)]
944
                also_skip: Vec<u8>,
945
                c: u32,
946
            }
947
        };
948

            
949
        let output = wire_format_inner(input);
950
        let syntax_tree: syn::File = syn::parse2(output).unwrap();
951
        let output_str = prettyplease::unparse(&syntax_tree);
952
        insta::assert_snapshot!(output_str, @r###"
953
        mod wire_format_item {
954
            extern crate std;
955
            use self::std::io;
956
            use self::std::result::Result::Ok;
957
            use super::Item;
958
            use jetstream_wireformat::WireFormat;
959
            impl WireFormat for Item {
960
                fn byte_size(&self) -> u32 {
961
                    0 + WireFormat::byte_size(&self.a) + WireFormat::byte_size(&self.b)
962
                        + WireFormat::byte_size(&self.c)
963
                }
964
                fn encode<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
965
                    WireFormat::encode(&self.a, _writer)?;
966
                    WireFormat::encode(&self.b, _writer)?;
967
                    WireFormat::encode(&self.c, _writer)?;
968
                    Ok(())
969
                }
970
                fn decode<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
971
                    let a = WireFormat::decode(_reader)?;
972
                    let b = WireFormat::decode(_reader)?;
973
                    let c = WireFormat::decode(_reader)?;
974
                    Ok(Item {
975
                        a: a,
976
                        skip_this: Default::default(),
977
                        b: b,
978
                        also_skip: Default::default(),
979
                        c: c,
980
                    })
981
                }
982
            }
983
        }
984
        "###);
985
    }
986
}