1
use std::{io, marker::PhantomData};
2

            
3
use crate::{context::NodeAddr, Frame, Protocol};
4
use futures::{
5
    stream::{SplitSink, SplitStream},
6
    Sink, Stream, StreamExt,
7
};
8
use jetstream_wireformat::WireFormat;
9
use tokio_util::{
10
    bytes::{self, Buf, BufMut},
11
    codec::{Decoder, Encoder},
12
};
13

            
14
pub struct ClientCodec<P>
15
where
16
    P: Protocol,
17
{
18
    _p: std::marker::PhantomData<P>,
19
}
20

            
21
impl<P: Protocol> Encoder<Frame<P::Request>> for ClientCodec<P> {
22
    type Error = std::io::Error;
23

            
24
222
    fn encode(
25
222
        &mut self,
26
222
        item: Frame<P::Request>,
27
222
        dst: &mut bytes::BytesMut,
28
222
    ) -> Result<(), Self::Error> {
29
222
        WireFormat::encode(&item, &mut dst.writer())
30
222
    }
31
}
32

            
33
impl<P: Protocol> Decoder for ClientCodec<P> {
34
    type Error = std::io::Error;
35
    type Item = Frame<P::Response>;
36

            
37
438
    fn decode(
38
438
        &mut self,
39
438
        src: &mut bytes::BytesMut,
40
438
    ) -> Result<Option<Self::Item>, Self::Error> {
41
        // check to see if you have at least 4 bytes to figure out the size
42
438
        if src.len() < 4 {
43
216
            src.reserve(4);
44
216
            return Ok(None);
45
222
        }
46
222
        let Some(mut bytz) = src.get(..4) else {
47
            return Ok(None);
48
        };
49

            
50
222
        let byte_size: u32 = WireFormat::decode(&mut bytz)?;
51
222
        if src.len() < byte_size as usize {
52
            src.reserve(byte_size as usize);
53
            return Ok(None);
54
222
        }
55
222
        Frame::<P::Response>::decode(&mut src.reader()).map(Some)
56
438
    }
57
}
58

            
59
impl<P> Default for ClientCodec<P>
60
where
61
    P: Protocol,
62
{
63
8
    fn default() -> Self {
64
8
        Self {
65
8
            _p: std::marker::PhantomData,
66
8
        }
67
8
    }
68
}
69

            
70
pub trait ClientTransport<P: Protocol>:
71
    Sink<Frame<P::Request>, Error = std::io::Error>
72
    + Stream<Item = Result<Frame<P::Response>, std::io::Error>>
73
    + Send
74
    + Sync
75
    + Unpin
76
{
77
}
78

            
79
impl<P: Protocol, T> ClientTransport<P> for T
80
where
81
    Self: Sized,
82
    T: Sink<Frame<P::Request>, Error = std::io::Error>
83
        + Stream<Item = Result<Frame<P::Response>, std::io::Error>>
84
        + Send
85
        + Sync
86
        + Unpin,
87
{
88
}
89

            
90
pub trait Channel<P: Protocol>: Unpin + Sized {
91
    fn split(self) -> (SplitSink<Self, Frame<P::Request>>, SplitStream<Self>);
92
}
93

            
94
impl<P, T> Channel<P> for T
95
where
96
    P: Protocol,
97
    T: ClientTransport<P> + Unpin + Sized,
98
{
99
    fn split(
100
        self,
101
    ) -> (
102
        SplitSink<Self, Frame<<P as Protocol>::Request>>,
103
        SplitStream<Self>,
104
    ) {
105
        StreamExt::split(self)
106
    }
107
}
108

            
109
#[derive(Debug)]
110
pub struct ClientBuilder<P: Protocol> {
111
    node_addr: NodeAddr,
112
    _phantom: PhantomData<P>,
113
}
114

            
115
impl<P: Protocol> WireFormat for ClientBuilder<P> {
116
    fn byte_size(&self) -> u32 {
117
        P::VERSION.to_string().byte_size() + self.node_addr.byte_size()
118
    }
119

            
120
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()>
121
    where
122
        Self: Sized,
123
    {
124
        P::VERSION.to_string().encode(writer)?;
125
        self.node_addr.encode(writer)
126
    }
127

            
128
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self>
129
    where
130
        Self: Sized,
131
    {
132
        let version = String::decode(reader)?;
133
        if version != P::VERSION {
134
            return Err(io::Error::new(
135
                io::ErrorKind::InvalidData,
136
                "version mismatch",
137
            ));
138
        }
139
        let node_addr = NodeAddr::decode(reader)?;
140
        Ok(ClientBuilder {
141
            node_addr,
142
            _phantom: PhantomData,
143
        })
144
    }
145
}
146

            
147
pub fn client_builder<P: Protocol>(
148
    addr: impl Into<NodeAddr>,
149
) -> ClientBuilder<P> {
150
    ClientBuilder {
151
        node_addr: addr.into(),
152
        _phantom: PhantomData,
153
    }
154
}
155

            
156
impl<P: Protocol> From<(P, NodeAddr)> for ClientBuilder<P> {
157
    fn from(value: (P, NodeAddr)) -> Self {
158
        ClientBuilder {
159
            node_addr: value.1,
160
            _phantom: PhantomData,
161
        }
162
    }
163
}