1
// Copyright (c) 2024, Sevki <s@sevki.io>
2
// Use of this source code is governed by a BSD-style license that can be
3
// found in the LICENSE file.
4
use std::{
5
    io::{self},
6
    net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
7
};
8

            
9
use bytes::Bytes;
10

            
11
use super::WireFormat;
12

            
13
pub trait AsyncWireFormat: std::marker::Sized {
14
    fn encode_async<W: AsyncWireFormat + Unpin + Send>(
15
        self,
16
        writer: &mut W,
17
    ) -> impl std::future::Future<Output = io::Result<()>> + Send;
18
    fn decode_async<R: AsyncWireFormat + Unpin + Send>(
19
        reader: &mut R,
20
    ) -> impl std::future::Future<Output = io::Result<Self>> + Send;
21
}
22

            
23
#[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
24
pub mod tokio {
25
    use std::{future::Future, io};
26

            
27
    use tokio::io::{AsyncRead, AsyncWrite};
28

            
29
    use crate::WireFormat;
30
    /// Extension trait for asynchronous wire format encoding and decoding.
31
    pub trait AsyncWireFormatExt
32
    where
33
        Self: WireFormat + Send,
34
    {
35
        /// Encodes the object asynchronously into the provided writer.
36
        ///
37
        /// # Arguments
38
        ///
39
        /// * `writer` - The writer to encode the object into.n
40
        ///
41
        /// # Returns
42
        ///
43
        /// A future that resolves to an `io::Result<()>` indicating the success or failure of the encoding operation.
44
        fn encode_async<W>(
45
            self,
46
            writer: W,
47
        ) -> impl Future<Output = io::Result<()>>
48
        where
49
            Self: Sync + Sized,
50
            W: AsyncWrite + Unpin + Send,
51
        {
52
            let mut writer = tokio_util::io::SyncIoBridge::new(writer);
53
            async {
54
                tokio::task::block_in_place(move || self.encode(&mut writer))
55
            }
56
        }
57

            
58
        /// Decodes an object asynchronously from the provided reader.
59
        ///
60
        /// # Arguments
61
        ///
62
        /// * `reader` - The reader to decode the object from.
63
        ///
64
        /// # Returns
65
        ///
66
        /// A future that resolves to an `io::Result<Self>` indicating the success or failure of the decoding operation.
67
        fn decode_async<R>(
68
            reader: R,
69
        ) -> impl Future<Output = io::Result<Self>> + Send
70
        where
71
            Self: Sync + Sized,
72
            R: AsyncRead + Unpin + Send,
73
        {
74
            let mut reader = tokio_util::io::SyncIoBridge::new(reader);
75
            async {
76
                tokio::task::block_in_place(move || Self::decode(&mut reader))
77
            }
78
        }
79
    }
80
    /// Implements the `AsyncWireFormatExt` trait for types that implement the `WireFormat` trait and can be sent across threads.
81
    impl<T: WireFormat + Send> AsyncWireFormatExt for T {}
82
}
83

            
84
/// A trait for converting types to and from a wire format.
85
pub trait ConvertWireFormat: WireFormat {
86
    /// Converts the type to a byte representation.
87
    ///
88
    /// # Returns
89
    ///
90
    /// A `Bytes` object representing the byte representation of the type.
91
    fn to_bytes(&self) -> Bytes;
92

            
93
    /// Converts a byte buffer to the type.
94
    ///
95
    /// # Arguments
96
    ///
97
    /// * `buf` - A mutable reference to a `Bytes` object containing the byte buffer.
98
    ///
99
    /// # Returns
100
    ///
101
    /// A `Result` containing the converted type or an `std::io::Error` if the conversion fails.
102
    fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error>
103
    where
104
        Self: Sized;
105

            
106
    /// AsRef<[u8]> for the type.
107
    ///
108
    /// # Returns
109
    ///
110
    /// A reference to the byte representation of the type.
111
    fn as_bytes(&self) -> Vec<u8> {
112
        self.to_bytes().to_vec()
113
    }
114
}
115

            
116
/// Implements the `ConvertWireFormat` trait for types that implement `jetstream_p9::WireFormat`.
117
/// This trait provides methods for converting the type to and from bytes.
118
impl<T> ConvertWireFormat for T
119
where
120
    T: WireFormat,
121
{
122
    /// Converts the type to bytes.
123
    /// Returns a `Bytes` object containing the encoded bytes.
124
    fn to_bytes(&self) -> Bytes {
125
        let mut buf = vec![];
126
        let res = self.encode(&mut buf);
127
        if let Err(e) = res {
128
            panic!("Failed to encode: {}", e);
129
        }
130
        Bytes::from(buf)
131
    }
132

            
133
    /// Converts bytes to the type.
134
    /// Returns a `Result` containing the decoded type or an `std::io::Error` if decoding fails.
135
    fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error> {
136
        let buf = buf.to_vec();
137
        T::decode(&mut buf.as_slice())
138
    }
139
}
140

            
141
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
142
impl WireFormat for Ipv4Addr {
143
    fn byte_size(&self) -> u32 {
144
        self.octets().len() as u32
145
    }
146

            
147
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
148
        writer.write_all(&self.octets())
149
    }
150

            
151
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
152
        let mut buf = [0u8; 4];
153
        reader.read_exact(&mut buf)?;
154
        Ok(Ipv4Addr::from(buf))
155
    }
156
}
157

            
158
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
159
impl WireFormat for Ipv6Addr {
160
    fn byte_size(&self) -> u32 {
161
        self.octets().len() as u32
162
    }
163

            
164
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
165
        writer.write_all(&self.octets())
166
    }
167

            
168
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
169
        let mut buf = [0u8; 16];
170
        reader.read_exact(&mut buf)?;
171
        Ok(Ipv6Addr::from(buf))
172
    }
173
}
174

            
175
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
176
impl WireFormat for SocketAddrV4 {
177
    fn byte_size(&self) -> u32 {
178
        self.ip().byte_size() + 2
179
    }
180

            
181
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
182
        self.ip().encode(writer)?;
183
        self.port().encode(writer)
184
    }
185

            
186
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
187
        self::Ipv4Addr::decode(reader).and_then(|ip| {
188
            u16::decode(reader).map(|port| SocketAddrV4::new(ip, port))
189
        })
190
    }
191
}
192

            
193
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
194
impl WireFormat for SocketAddrV6 {
195
    fn byte_size(&self) -> u32 {
196
        self.ip().byte_size() + 2
197
    }
198

            
199
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
200
        self.ip().encode(writer)?;
201
        self.port().encode(writer)
202
    }
203

            
204
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
205
        self::Ipv6Addr::decode(reader).and_then(|ip| {
206
            u16::decode(reader).map(|port| SocketAddrV6::new(ip, port, 0, 0))
207
        })
208
    }
209
}