1
use std::pin::pin;
2

            
3
use crate::{
4
    context::{Context, Contextual},
5
    Error, Frame, Protocol,
6
};
7
use futures::{Sink, Stream};
8
use jetstream_wireformat::WireFormat;
9
use tokio_util::{
10
    bytes::{self, Buf, BufMut},
11
    codec::{Decoder, Encoder},
12
};
13

            
14
pub struct ServerCodec<P: Protocol> {
15
    _phantom: std::marker::PhantomData<P>,
16
}
17

            
18
impl<P: Protocol> ServerCodec<P> {
19
8
    pub fn new() -> Self {
20
8
        Self {
21
8
            _phantom: std::marker::PhantomData,
22
8
        }
23
8
    }
24
}
25

            
26
impl<P: Protocol> Default for ServerCodec<P> {
27
4
    fn default() -> Self {
28
4
        Self::new()
29
4
    }
30
}
31

            
32
pub trait ServiceTransport<P: Protocol>:
33
    Sink<Frame<P::Response>, Error = P::Error>
34
    + Stream<Item = Result<Frame<P::Request>, P::Error>>
35
    + Send
36
    + Sync
37
    + Unpin
38
{
39
    fn context(&self) -> Context;
40
}
41

            
42
impl<P: Protocol, T> ServiceTransport<P> for T
43
where
44
    T: Sink<Frame<P::Response>, Error = P::Error>
45
        + Stream<Item = Result<Frame<P::Request>, P::Error>>
46
        + Send
47
        + Sync
48
        + Unpin,
49
    T: Contextual,
50
{
51
202
    fn context(&self) -> Context {
52
202
        <Self as Contextual>::context(self)
53
202
    }
54
}
55

            
56
impl<P> Decoder for ServerCodec<P>
57
where
58
    P: Protocol,
59
{
60
    type Error = Error;
61
    type Item = Frame<P::Request>;
62

            
63
444
    fn decode(
64
444
        &mut self,
65
444
        src: &mut bytes::BytesMut,
66
444
    ) -> Result<Option<Self::Item>, Self::Error> {
67
        // check to see if you have at least 4 bytes to figure out the size
68
444
        if src.len() < 4 {
69
222
            src.reserve(4);
70
222
            return Ok(None);
71
222
        }
72
222
        let Some(mut bytz) = src.get(..4) else {
73
            return Ok(None);
74
        };
75

            
76
222
        let byte_size: u32 = WireFormat::decode(&mut bytz)?;
77
222
        if src.len() < byte_size as usize {
78
            src.reserve(byte_size as usize);
79
            return Ok(None);
80
222
        }
81

            
82
222
        Frame::<P::Request>::decode(&mut src.reader())
83
222
            .map(Some)
84
222
            .map_err(|_| Error::Custom("()".to_string()))
85
444
    }
86
}
87

            
88
impl<P> Encoder<Frame<P::Response>> for ServerCodec<P>
89
where
90
    P: Protocol,
91
{
92
    type Error = Error;
93

            
94
222
    fn encode(
95
222
        &mut self,
96
222
        item: Frame<P::Response>,
97
222
        dst: &mut bytes::BytesMut,
98
222
    ) -> Result<(), Self::Error> {
99
222
        item.encode(&mut dst.writer())
100
222
            .map_err(|_| Error::Custom("()".to_string()))
101
222
            .map(|_| ())
102
222
    }
103
}
104

            
105
4
pub async fn run<T, P>(p: &mut P, mut stream: T) -> Result<(), P::Error>
106
4
where
107
4
    T: ServiceTransport<P>,
108
4
    P: Protocol,
109
4
{
110
    use futures::{SinkExt, StreamExt};
111
4
    let mut a = pin!(p);
112
206
    while let Some(Ok(frame)) = stream.next().await {
113
202
        stream.send(a.rpc(stream.context(), frame).await?).await?
114
    }
115
    Ok(())
116
}