1
#![doc(html_logo_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png")]
2
#![doc(
3
    html_favicon_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png"
4
)]
5
//! # JetStream Rpc
6
//! Defines Rpc primitives for JetStream.
7
//! Of note is the `Protocol` trait which is meant to be used with the `service` attribute macro.
8
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
9

            
10
use {
11
    futures::{
12
        stream::{SplitSink, SplitStream},
13
        Sink,
14
        Stream,
15
        StreamExt,
16
    },
17
    jetstream_wireformat::WireFormat,
18
    std::{
19
        io::{self, ErrorKind, Read, Write},
20
        mem,
21
    },
22
};
23

            
24
// Re-export codecs
25
pub use tokio_util::codec::{Decoder, Encoder, Framed};
26

            
27
/// A trait representing a message that can be encoded and decoded.
28
pub trait Message: WireFormat + Send + Sync {}
29

            
30
#[repr(transparent)]
31
pub struct Tag(u16);
32

            
33
impl From<u16> for Tag {
34
    fn from(tag: u16) -> Self {
35
        Self(tag)
36
    }
37
}
38

            
39
pub struct Context<T: WireFormat> {
40
    pub tag: Tag,
41
    pub msg: T,
42
}
43

            
44
pub trait FromContext<T: WireFormat> {
45
    fn from_context(ctx: Context<T>) -> Self;
46
}
47

            
48
impl<T: WireFormat> FromContext<T> for T {
49
    fn from_context(ctx: Context<T>) -> Self {
50
        ctx.msg
51
    }
52
}
53

            
54
impl<T: WireFormat> FromContext<T> for Tag {
55
    fn from_context(ctx: Context<T>) -> Self {
56
        ctx.tag
57
    }
58
}
59

            
60
pub trait Handler<T: WireFormat> {
61
    fn call(self, context: Context<T>);
62
}
63

            
64
/// Defines the request and response types for the JetStream protocol.
65
#[trait_variant::make(Send + Sync + Sized)]
66
pub trait Protocol: Send + Sync {
67
    type Request: Framer;
68
    type Response: Framer;
69
    type Error: std::error::Error + Send + Sync + 'static;
70
    const VERSION: &'static str;
71
    async fn rpc(
72
        &mut self,
73
        frame: Frame<Self::Request>,
74
    ) -> Result<Frame<Self::Response>, Self::Error>;
75
}
76

            
77
#[derive(Debug, thiserror::Error)]
78
pub enum Error {
79
    #[error("io error: {0}")]
80
    Io(#[from] io::Error),
81
    #[error("generic error: {0}")]
82
    Generic(#[from] Box<dyn std::error::Error + Send + Sync>),
83
    #[error("{0}")]
84
    Custom(String),
85
    #[error("invalid response")]
86
    InvalidResponse,
87
}
88

            
89
pub struct Frame<T: Framer> {
90
    pub tag: u16,
91
    pub msg: T,
92
}
93

            
94
impl<T: Framer> From<(u16, T)> for Frame<T> {
95
404
    fn from((tag, msg): (u16, T)) -> Self {
96
404
        Self { tag, msg }
97
404
    }
98
}
99

            
100
impl<T: Framer> WireFormat for Frame<T> {
101
404
    fn byte_size(&self) -> u32 {
102
404
        let msg_size = self.msg.byte_size();
103
404
        // size + type + tag + message size
104
404
        (mem::size_of::<u32>() + mem::size_of::<u8>() + mem::size_of::<u16>()) as u32 + msg_size
105
404
    }
106

            
107
404
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
108
404
        self.byte_size().encode(writer)?;
109

            
110
404
        let ty = self.msg.message_type();
111
404

            
112
404
        ty.encode(writer)?;
113
404
        self.tag.encode(writer)?;
114

            
115
404
        self.msg.encode(writer)?;
116

            
117
404
        Ok(())
118
404
    }
119

            
120
404
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
121
404
        let byte_size: u32 = WireFormat::decode(reader)?;
122

            
123
        // byte_size includes the size of byte_size so remove that from the
124
        // expected length of the message.  Also make sure that byte_size is at least
125
        // that long to begin with.
126
404
        if byte_size < mem::size_of::<u32>() as u32 {
127
            return Err(io::Error::new(
128
                ErrorKind::InvalidData,
129
                format!("byte_size(= {}) is less than 4 bytes", byte_size),
130
            ));
131
404
        }
132
404
        let reader = &mut reader.take((byte_size - mem::size_of::<u32>() as u32) as u64);
133
404

            
134
404
        let mut ty = [0u8];
135
404
        reader.read_exact(&mut ty)?;
136

            
137
404
        let tag: u16 = WireFormat::decode(reader)?;
138
404
        let msg = T::decode(reader, ty[0])?;
139

            
140
404
        Ok(Frame { tag, msg })
141
404
    }
142
}
143

            
144
pub trait Framer: Sized + Send + Sync {
145
    fn message_type(&self) -> u8;
146
    /// Returns the number of bytes necessary to fully encode `self`.
147
    fn byte_size(&self) -> u32;
148

            
149
    /// Encodes `self` into `writer`.
150
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()>;
151

            
152
    /// Decodes `Self` from `reader`.
153
    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Self>;
154
}
155

            
156
pub trait ServiceTransport<P: Protocol>:
157
    Sink<Frame<P::Response>, Error = P::Error>
158
    + Stream<Item = Result<Frame<P::Request>, P::Error>>
159
    + Send
160
    + Sync
161
    + Unpin
162
{
163
}
164

            
165
impl<P: Protocol, T> ServiceTransport<P> for T where
166
    T: Sink<Frame<P::Response>, Error = P::Error>
167
        + Stream<Item = Result<Frame<P::Request>, P::Error>>
168
        + Send
169
        + Sync
170
        + Unpin
171
{
172
}
173

            
174
pub trait ClientTransport<P: Protocol>:
175
    Sink<Frame<P::Request>, Error = std::io::Error>
176
    + Stream<Item = Result<Frame<P::Response>, std::io::Error>>
177
    + Send
178
    + Sync
179
    + Unpin
180
{
181
}
182

            
183
impl<P: Protocol, T> ClientTransport<P> for T
184
where
185
    Self: Sized,
186
    T: Sink<Frame<P::Request>, Error = std::io::Error>
187
        + Stream<Item = Result<Frame<P::Response>, std::io::Error>>
188
        + Send
189
        + Sync
190
        + Unpin,
191
{
192
}
193

            
194
pub trait Channel<P: Protocol>: Unpin + Sized {
195
    fn split(self) -> (SplitSink<Self, Frame<P::Request>>, SplitStream<Self>);
196
}
197

            
198
impl<P, T> Channel<P> for T
199
where
200
    P: Protocol,
201
    T: ClientTransport<P> + Unpin + Sized,
202
{
203
    fn split(
204
        self,
205
    ) -> (
206
        SplitSink<Self, Frame<<P as Protocol>::Request>>,
207
        SplitStream<Self>,
208
    ) {
209
        StreamExt::split(self)
210
    }
211
}