1
use jetstream_wireformat::WireFormat;
2
use std::io;
3
use std::io::ErrorKind;
4
use std::io::Read;
5
use std::io::Write;
6
use std::mem;
7

            
8
pub struct Frame<T: Framer> {
9
    pub tag: u16,
10
    pub msg: T,
11
}
12

            
13
impl<T: Framer> From<(u16, T)> for Frame<T> {
14
444
    fn from((tag, msg): (u16, T)) -> Self {
15
444
        Self { tag, msg }
16
444
    }
17
}
18

            
19
impl<T: Framer> WireFormat for Frame<T> {
20
444
    fn byte_size(&self) -> u32 {
21
444
        let msg_size = self.msg.byte_size();
22
        // size + type + tag + message size
23
444
        (mem::size_of::<u32>() + mem::size_of::<u8>() + mem::size_of::<u16>())
24
444
            as u32
25
444
            + msg_size
26
444
    }
27

            
28
444
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
29
444
        self.byte_size().encode(writer)?;
30

            
31
444
        let ty = self.msg.message_type();
32

            
33
444
        ty.encode(writer)?;
34
444
        self.tag.encode(writer)?;
35

            
36
444
        self.msg.encode(writer)?;
37

            
38
444
        Ok(())
39
444
    }
40

            
41
444
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
42
444
        let byte_size: u32 = WireFormat::decode(reader)?;
43

            
44
        // byte_size includes the size of byte_size so remove that from the
45
        // expected length of the message.  Also make sure that byte_size is at least
46
        // that long to begin with.
47
444
        if byte_size < mem::size_of::<u32>() as u32 {
48
            return Err(io::Error::new(
49
                ErrorKind::InvalidData,
50
                format!("byte_size(= {byte_size}) is less than 4 bytes"),
51
            ));
52
444
        }
53
444
        let reader =
54
444
            &mut reader.take((byte_size - mem::size_of::<u32>() as u32) as u64);
55

            
56
444
        let mut ty = [0u8];
57
444
        reader.read_exact(&mut ty)?;
58

            
59
444
        let tag: u16 = WireFormat::decode(reader)?;
60
444
        let msg = T::decode(reader, ty[0])?;
61

            
62
444
        Ok(Frame { tag, msg })
63
444
    }
64
}
65

            
66
pub trait Framer: Sized + Send + Sync {
67
    fn message_type(&self) -> u8;
68
    /// Returns the number of bytes necessary to fully encode `self`.
69
    fn byte_size(&self) -> u32;
70

            
71
    /// Encodes `self` into `writer`.
72
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()>;
73

            
74
    /// Decodes `Self` from `reader`.
75
    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Self>;
76
}