1
use std::pin::pin;
2

            
3
use jetstream_rpc::{Error, Frame, Protocol, ServiceTransport};
4
use jetstream_wireformat::WireFormat;
5
use tokio_util::{
6
    bytes::{self, Buf, BufMut},
7
    codec::{Decoder, Encoder},
8
};
9

            
10
pub struct ServerCodec<P: Protocol> {
11
    _phantom: std::marker::PhantomData<P>,
12
}
13

            
14
impl<P: Protocol> ServerCodec<P> {
15
4
    pub fn new() -> Self {
16
4
        Self {
17
4
            _phantom: std::marker::PhantomData,
18
4
        }
19
4
    }
20
}
21

            
22
impl<P: Protocol> Default for ServerCodec<P> {
23
4
    fn default() -> Self {
24
4
        Self::new()
25
4
    }
26
}
27

            
28
impl<P> Decoder for ServerCodec<P>
29
where
30
    P: Protocol,
31
{
32
    type Error = Error;
33
    type Item = Frame<P::Request>;
34

            
35
404
    fn decode(
36
404
        &mut self,
37
404
        src: &mut bytes::BytesMut,
38
404
    ) -> Result<Option<Self::Item>, Self::Error> {
39
404
        // check to see if you have at least 4 bytes to figure out the size
40
404
        if src.len() < 4 {
41
202
            src.reserve(4);
42
202
            return Ok(None);
43
202
        }
44
202
        let Some(mut bytz) = src.get(..4) else {
45
            return Ok(None);
46
        };
47

            
48
202
        let byte_size: u32 = WireFormat::decode(&mut bytz)?;
49
202
        if src.len() < byte_size as usize {
50
            src.reserve(byte_size as usize);
51
            return Ok(None);
52
202
        }
53
202

            
54
202
        Frame::<P::Request>::decode(&mut src.reader())
55
202
            .map(Some)
56
202
            .map_err(|_| Error::Custom("()".to_string()))
57
404
    }
58
}
59

            
60
impl<P> Encoder<Frame<P::Response>> for ServerCodec<P>
61
where
62
    P: Protocol,
63
{
64
    type Error = Error;
65

            
66
202
    fn encode(
67
202
        &mut self,
68
202
        item: Frame<P::Response>,
69
202
        dst: &mut bytes::BytesMut,
70
202
    ) -> Result<(), Self::Error> {
71
202
        item.encode(&mut dst.writer())
72
202
            .map_err(|_| Error::Custom("()".to_string()))
73
202
            .map(|_| ())
74
202
    }
75
}
76

            
77
4
pub async fn run<T, P>(p: &mut P, mut stream: T) -> Result<(), P::Error>
78
4
where
79
4
    T: ServiceTransport<P>,
80
4
    P: Protocol,
81
4
{
82
    use futures::{SinkExt, StreamExt};
83
4
    let mut a = pin!(p);
84
206
    while let Some(Ok(frame)) = stream.next().await {
85
202
        stream.send(a.rpc(frame).await?).await?
86
    }
87
    Ok(())
88
}