1
// Copyright (c) 2024, Sevki <s@sevki.io>
2
// Use of this source code is governed by a BSD-style license that can be
3
// found in the LICENSE file.
4
use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5
use std::io::{self};
6

            
7
use bytes::Bytes;
8

            
9
pub extern crate bytes;
10

            
11
use super::WireFormat;
12

            
13
pub trait AsyncWireFormat: std::marker::Sized {
14
    fn encode_async<W: AsyncWireFormat + Unpin + Send>(
15
        self,
16
        writer: &mut W,
17
    ) -> impl std::future::Future<Output = io::Result<()>> + Send;
18
    fn decode_async<R: AsyncWireFormat + Unpin + Send>(
19
        reader: &mut R,
20
    ) -> impl std::future::Future<Output = io::Result<Self>> + Send;
21
}
22

            
23
#[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
24
pub mod tokio {
25
    use std::{future::Future, io};
26

            
27
    use tokio::io::{AsyncRead, AsyncWrite};
28

            
29
    use crate::WireFormat;
30
    /// Extension trait for asynchronous wire format encoding and decoding.
31
    pub trait AsyncWireFormatExt
32
    where
33
        Self: WireFormat + Send,
34
    {
35
        /// Encodes the object asynchronously into the provided writer.
36
        ///
37
        /// # Arguments
38
        ///
39
        /// * `writer` - The writer to encode the object into.n
40
        ///
41
        /// # Returns
42
        ///
43
        /// A future that resolves to an `io::Result<()>` indicating the success or failure of the encoding operation.
44
        fn encode_async<W>(
45
            self,
46
            writer: W,
47
        ) -> impl Future<Output = io::Result<()>>
48
        where
49
            Self: Sync + Sized,
50
            W: AsyncWrite + Unpin + Send,
51
        {
52
            let mut writer = tokio_util::io::SyncIoBridge::new(writer);
53
            async {
54
                tokio::task::block_in_place(move || self.encode(&mut writer))
55
            }
56
        }
57

            
58
        /// Decodes an object asynchronously from the provided reader.
59
        ///
60
        /// # Arguments
61
        ///
62
        /// * `reader` - The reader to decode the object from.
63
        ///
64
        /// # Returns
65
        ///
66
        /// A future that resolves to an `io::Result<Self>` indicating the success or failure of the decoding operation.
67
        fn decode_async<R>(
68
            reader: R,
69
        ) -> impl Future<Output = io::Result<Self>> + Send
70
        where
71
            Self: Sync + Sized,
72
            R: AsyncRead + Unpin + Send,
73
        {
74
            let mut reader = tokio_util::io::SyncIoBridge::new(reader);
75
            async {
76
                tokio::task::block_in_place(move || Self::decode(&mut reader))
77
            }
78
        }
79
    }
80
    /// Implements the `AsyncWireFormatExt` trait for types that implement the `WireFormat` trait and can be sent across threads.
81
    impl<T: WireFormat + Send> AsyncWireFormatExt for T {}
82
}
83

            
84
/// A trait for converting types to and from a wire format.
85
pub trait ConvertWireFormat: WireFormat {
86
    /// Converts the type to a byte representation.
87
    ///
88
    /// # Returns
89
    ///
90
    /// A `Bytes` object representing the byte representation of the type.
91
    fn to_bytes(&self) -> Bytes;
92

            
93
    /// Converts a byte buffer to the type.
94
    ///
95
    /// # Arguments
96
    ///
97
    /// * `buf` - A mutable reference to a `Bytes` object containing the byte buffer.
98
    ///
99
    /// # Returns
100
    ///
101
    /// A `Result` containing the converted type or an `std::io::Error` if the conversion fails.
102
    fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error>
103
    where
104
        Self: Sized;
105

            
106
    /// AsRef<[u8]> for the type.
107
    ///
108
    /// # Returns
109
    ///
110
    /// A reference to the byte representation of the type.
111
    fn as_bytes(&self) -> Vec<u8> {
112
        self.to_bytes().to_vec()
113
    }
114
}
115

            
116
/// Implements the `ConvertWireFormat` trait for types that implement `jetstream_p9::WireFormat`.
117
/// This trait provides methods for converting the type to and from bytes.
118
impl<T> ConvertWireFormat for T
119
where
120
    T: WireFormat,
121
{
122
    /// Converts the type to bytes.
123
    /// Returns a `Bytes` object containing the encoded bytes.
124
    fn to_bytes(&self) -> Bytes {
125
        let mut buf = vec![];
126
        let res = self.encode(&mut buf);
127
        if let Err(e) = res {
128
            panic!("Failed to encode: {}", e);
129
        }
130
        Bytes::from(buf)
131
    }
132

            
133
    /// Converts bytes to the type.
134
    /// Returns a `Result` containing the decoded type or an `std::io::Error` if decoding fails.
135
    fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error> {
136
        let buf = buf.to_vec();
137
        T::decode(&mut buf.as_slice())
138
    }
139
}
140

            
141
impl WireFormat for Ipv4Addr {
142
    fn byte_size(&self) -> u32 {
143
        self.octets().len() as u32
144
    }
145

            
146
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
147
        writer.write_all(&self.octets())
148
    }
149

            
150
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
151
        let mut buf = [0u8; 4];
152
        reader.read_exact(&mut buf)?;
153
        Ok(Ipv4Addr::from(buf))
154
    }
155
}
156

            
157
impl WireFormat for Ipv6Addr {
158
    fn byte_size(&self) -> u32 {
159
        self.octets().len() as u32
160
    }
161

            
162
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
163
        writer.write_all(&self.octets())
164
    }
165

            
166
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
167
        let mut buf = [0u8; 16];
168
        reader.read_exact(&mut buf)?;
169
        Ok(Ipv6Addr::from(buf))
170
    }
171
}
172

            
173
impl WireFormat for SocketAddrV4 {
174
    fn byte_size(&self) -> u32 {
175
        self.ip().byte_size() + 2
176
    }
177

            
178
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
179
        self.ip().encode(writer)?;
180
        self.port().encode(writer)
181
    }
182

            
183
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
184
        self::Ipv4Addr::decode(reader).and_then(|ip| {
185
            u16::decode(reader).map(|port| SocketAddrV4::new(ip, port))
186
        })
187
    }
188
}
189

            
190
impl WireFormat for SocketAddrV6 {
191
    fn byte_size(&self) -> u32 {
192
        self.ip().byte_size() + 2
193
    }
194

            
195
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
196
        self.ip().encode(writer)?;
197
        self.port().encode(writer)
198
    }
199

            
200
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
201
        self::Ipv6Addr::decode(reader).and_then(|ip| {
202
            u16::decode(reader).map(|port| SocketAddrV6::new(ip, port, 0, 0))
203
        })
204
    }
205
}
206

            
207
impl WireFormat for SocketAddr {
208
    fn byte_size(&self) -> u32 {
209
        1 + match self {
210
            SocketAddr::V4(socket_addr_v4) => socket_addr_v4.byte_size(),
211
            SocketAddr::V6(socket_addr_v6) => socket_addr_v6.byte_size(),
212
        }
213
    }
214

            
215
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()>
216
    where
217
        Self: Sized,
218
    {
219
        match self {
220
            SocketAddr::V4(socket_addr_v4) => {
221
                writer.write_all(&[0])?;
222
                socket_addr_v4.encode(writer)
223
            }
224
            SocketAddr::V6(socket_addr_v6) => {
225
                writer.write_all(&[1])?;
226
                socket_addr_v6.encode(writer)
227
            }
228
        }
229
    }
230

            
231
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self>
232
    where
233
        Self: Sized,
234
    {
235
        let mut buf = [0u8; 1];
236
        reader.read_exact(&mut buf)?;
237
        match buf[0] {
238
            0 => Ok(SocketAddr::V4(SocketAddrV4::decode(reader)?)),
239
            1 => Ok(SocketAddr::V6(SocketAddrV6::decode(reader)?)),
240
            _ => Err(std::io::Error::new(
241
                std::io::ErrorKind::InvalidData,
242
                "Invalid address type",
243
            )),
244
        }
245
    }
246
}