1
#![doc(
2
    html_logo_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png"
3
)]
4
#![doc(
5
    html_favicon_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png"
6
)]
7
//! # JetStream Rpc
8
//! Defines Rpc primitives for JetStream.
9
//! Of note is the `Protocol` trait which is meant to be used with the `service` attribute macro.
10
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
11

            
12
use std::{
13
    io::{self, ErrorKind, Read, Write},
14
    mem,
15
};
16

            
17
use futures::{
18
    stream::{SplitSink, SplitStream},
19
    Sink, Stream, StreamExt,
20
};
21
use jetstream_wireformat::WireFormat;
22
// Re-export codecs
23
pub use tokio_util::codec::{Decoder, Encoder, Framed};
24

            
25
/// A trait representing a message that can be encoded and decoded.
26
#[cfg(not(target_arch = "wasm32"))]
27
pub trait Message: WireFormat + Sync {}
28

            
29
/// A trait representing a message that can be encoded and decoded.
30
/// WebAssembly doesn't fully support Send+Sync, so we don't require those.
31
#[cfg(target_arch = "wasm32")]
32
pub trait Message: WireFormat {}
33

            
34
#[repr(transparent)]
35
pub struct Tag(u16);
36

            
37
impl From<u16> for Tag {
38
    fn from(tag: u16) -> Self {
39
        Self(tag)
40
    }
41
}
42

            
43
pub struct Context<T: WireFormat> {
44
    pub tag: Tag,
45
    pub msg: T,
46
}
47

            
48
pub trait FromContext<T: WireFormat> {
49
    fn from_context(ctx: Context<T>) -> Self;
50
}
51

            
52
impl<T: WireFormat> FromContext<T> for T {
53
    fn from_context(ctx: Context<T>) -> Self {
54
        ctx.msg
55
    }
56
}
57

            
58
impl<T: WireFormat> FromContext<T> for Tag {
59
    fn from_context(ctx: Context<T>) -> Self {
60
        ctx.tag
61
    }
62
}
63

            
64
pub trait Handler<T: WireFormat> {
65
    fn call(self, context: Context<T>);
66
}
67

            
68
/// Defines the request and response types for the JetStream protocol.
69
#[trait_variant::make(Send + Sync + Sized)]
70
pub trait Protocol: Send + Sync {
71
    type Request: Framer;
72
    type Response: Framer;
73
    type Error: std::error::Error + Send + Sync + 'static;
74
    const VERSION: &'static str;
75
    async fn rpc(
76
        &mut self,
77
        frame: Frame<Self::Request>,
78
    ) -> Result<Frame<Self::Response>, Self::Error>;
79
}
80

            
81
#[derive(Debug, thiserror::Error)]
82
pub enum Error {
83
    #[error("io error: {0}")]
84
    Io(#[from] io::Error),
85
    #[error("generic error: {0}")]
86
    Generic(#[from] Box<dyn std::error::Error + Send + Sync>),
87
    #[error("{0}")]
88
    Custom(String),
89
    #[error("invalid response")]
90
    InvalidResponse,
91
}
92

            
93
pub struct Frame<T: Framer> {
94
    pub tag: u16,
95
    pub msg: T,
96
}
97

            
98
impl<T: Framer> From<(u16, T)> for Frame<T> {
99
404
    fn from((tag, msg): (u16, T)) -> Self {
100
404
        Self { tag, msg }
101
404
    }
102
}
103

            
104
impl<T: Framer> WireFormat for Frame<T> {
105
404
    fn byte_size(&self) -> u32 {
106
404
        let msg_size = self.msg.byte_size();
107
404
        // size + type + tag + message size
108
404
        (mem::size_of::<u32>() + mem::size_of::<u8>() + mem::size_of::<u16>())
109
404
            as u32
110
404
            + msg_size
111
404
    }
112

            
113
404
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
114
404
        self.byte_size().encode(writer)?;
115

            
116
404
        let ty = self.msg.message_type();
117
404

            
118
404
        ty.encode(writer)?;
119
404
        self.tag.encode(writer)?;
120

            
121
404
        self.msg.encode(writer)?;
122

            
123
404
        Ok(())
124
404
    }
125

            
126
404
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
127
404
        let byte_size: u32 = WireFormat::decode(reader)?;
128

            
129
        // byte_size includes the size of byte_size so remove that from the
130
        // expected length of the message.  Also make sure that byte_size is at least
131
        // that long to begin with.
132
404
        if byte_size < mem::size_of::<u32>() as u32 {
133
            return Err(io::Error::new(
134
                ErrorKind::InvalidData,
135
                format!("byte_size(= {}) is less than 4 bytes", byte_size),
136
            ));
137
404
        }
138
404
        let reader =
139
404
            &mut reader.take((byte_size - mem::size_of::<u32>() as u32) as u64);
140
404

            
141
404
        let mut ty = [0u8];
142
404
        reader.read_exact(&mut ty)?;
143

            
144
404
        let tag: u16 = WireFormat::decode(reader)?;
145
404
        let msg = T::decode(reader, ty[0])?;
146

            
147
404
        Ok(Frame { tag, msg })
148
404
    }
149
}
150

            
151
pub trait Framer: Sized + Send + Sync {
152
    fn message_type(&self) -> u8;
153
    /// Returns the number of bytes necessary to fully encode `self`.
154
    fn byte_size(&self) -> u32;
155

            
156
    /// Encodes `self` into `writer`.
157
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()>;
158

            
159
    /// Decodes `Self` from `reader`.
160
    fn decode<R: Read>(reader: &mut R, ty: u8) -> io::Result<Self>;
161
}
162

            
163
pub trait ServiceTransport<P: Protocol>:
164
    Sink<Frame<P::Response>, Error = P::Error>
165
    + Stream<Item = Result<Frame<P::Request>, P::Error>>
166
    + Send
167
    + Sync
168
    + Unpin
169
{
170
}
171

            
172
impl<P: Protocol, T> ServiceTransport<P> for T where
173
    T: Sink<Frame<P::Response>, Error = P::Error>
174
        + Stream<Item = Result<Frame<P::Request>, P::Error>>
175
        + Send
176
        + Sync
177
        + Unpin
178
{
179
}
180

            
181
pub trait ClientTransport<P: Protocol>:
182
    Sink<Frame<P::Request>, Error = std::io::Error>
183
    + Stream<Item = Result<Frame<P::Response>, std::io::Error>>
184
    + Send
185
    + Sync
186
    + Unpin
187
{
188
}
189

            
190
impl<P: Protocol, T> ClientTransport<P> for T
191
where
192
    Self: Sized,
193
    T: Sink<Frame<P::Request>, Error = std::io::Error>
194
        + Stream<Item = Result<Frame<P::Response>, std::io::Error>>
195
        + Send
196
        + Sync
197
        + Unpin,
198
{
199
}
200

            
201
pub trait Channel<P: Protocol>: Unpin + Sized {
202
    fn split(self) -> (SplitSink<Self, Frame<P::Request>>, SplitStream<Self>);
203
}
204

            
205
impl<P, T> Channel<P> for T
206
where
207
    P: Protocol,
208
    T: ClientTransport<P> + Unpin + Sized,
209
{
210
    fn split(
211
        self,
212
    ) -> (
213
        SplitSink<Self, Frame<<P as Protocol>::Request>>,
214
        SplitStream<Self>,
215
    ) {
216
        StreamExt::split(self)
217
    }
218
}