1
use {
2
    jetstream_rpc::{Error, Frame, Protocol, ServiceTransport},
3
    jetstream_wireformat::WireFormat,
4
    std::pin::pin,
5
    tokio_util::{
6
        bytes::{self, Buf, BufMut},
7
        codec::{Decoder, Encoder},
8
    },
9
};
10

            
11
pub struct ServerCodec<P: Protocol> {
12
    _phantom: std::marker::PhantomData<P>,
13
}
14

            
15
impl<P: Protocol> ServerCodec<P> {
16
4
    pub fn new() -> Self {
17
4
        Self {
18
4
            _phantom: std::marker::PhantomData,
19
4
        }
20
4
    }
21
}
22

            
23
impl<P: Protocol> Default for ServerCodec<P> {
24
4
    fn default() -> Self {
25
4
        Self::new()
26
4
    }
27
}
28

            
29
impl<P> Decoder for ServerCodec<P>
30
where
31
    P: Protocol,
32
{
33
    type Error = Error;
34
    type Item = Frame<P::Request>;
35

            
36
404
    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
37
404
        // check to see if you have at least 4 bytes to figure out the size
38
404
        if src.len() < 4 {
39
202
            src.reserve(4);
40
202
            return Ok(None);
41
202
        }
42
202
        let Some(mut bytz) = src.get(..4) else {
43
            return Ok(None);
44
        };
45

            
46
202
        let byte_size: u32 = WireFormat::decode(&mut bytz)?;
47
202
        if src.len() < byte_size as usize {
48
            src.reserve(byte_size as usize);
49
            return Ok(None);
50
202
        }
51
202

            
52
202
        Frame::<P::Request>::decode(&mut src.reader())
53
202
            .map(Some)
54
202
            .map_err(|_| Error::Custom("()".to_string()))
55
404
    }
56
}
57

            
58
impl<P> Encoder<Frame<P::Response>> for ServerCodec<P>
59
where
60
    P: Protocol,
61
{
62
    type Error = Error;
63

            
64
202
    fn encode(
65
202
        &mut self,
66
202
        item: Frame<P::Response>,
67
202
        dst: &mut bytes::BytesMut,
68
202
    ) -> Result<(), Self::Error> {
69
202
        item.encode(&mut dst.writer())
70
202
            .map_err(|_| Error::Custom("()".to_string()))
71
202
            .map(|_| ())
72
202
    }
73
}
74

            
75
4
pub async fn run<T, P>(p: &mut P, mut stream: T) -> Result<(), P::Error>
76
4
where
77
4
    T: ServiceTransport<P>,
78
4
    P: Protocol,
79
4
{
80
    use futures::{SinkExt, StreamExt};
81
4
    let mut a = pin!(p);
82
206
    while let Some(Ok(frame)) = stream.next().await {
83
202
        stream.send(a.rpc(frame).await?).await?
84
    }
85
    Ok(())
86
}