1
// Copyright (c) 2024, Sevki <s@sevki.io>
2
// Use of this source code is governed by a BSD-style license that can be
3
// found in the LICENSE file.
4
use std::{
5
    io::{self},
6
    net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
7
};
8

            
9
use bytes::Bytes;
10

            
11
use super::WireFormat;
12

            
13
pub trait AsyncWireFormat: std::marker::Sized {
14
    fn encode_async<W: AsyncWireFormat + Unpin + Send>(
15
        self,
16
        writer: &mut W,
17
    ) -> impl std::future::Future<Output = io::Result<()>> + Send;
18
    fn decode_async<R: AsyncWireFormat + Unpin + Send>(
19
        reader: &mut R,
20
    ) -> impl std::future::Future<Output = io::Result<Self>> + Send;
21
}
22

            
23
#[cfg(not(target_arch = "wasm32"))]
24
pub mod tokio {
25
    use std::{future::Future, io};
26

            
27
    use tokio::io::{AsyncRead, AsyncWrite};
28

            
29
    use crate::WireFormat;
30
    /// Extension trait for asynchronous wire format encoding and decoding.
31
    pub trait AsyncWireFormatExt
32
    where
33
        Self: WireFormat + Send,
34
    {
35
        /// Encodes the object asynchronously into the provided writer.
36
        ///
37
        /// # Arguments
38
        ///
39
        /// * `writer` - The writer to encode the object into.n
40
        ///
41
        /// # Returns
42
        ///
43
        /// A future that resolves to an `io::Result<()>` indicating the success or failure of the encoding operation.
44
4
        fn encode_async<W>(self, writer: W) -> impl Future<Output = io::Result<()>>
45
4
        where
46
4
            Self: Sync,
47
4
            W: AsyncWrite + Unpin + Send,
48
4
        {
49
4
            let mut writer = tokio_util::io::SyncIoBridge::new(writer);
50
4
            async { tokio::task::block_in_place(move || self.encode(&mut writer)) }
51
4
        }
52

            
53
        /// Decodes an object asynchronously from the provided reader.
54
        ///
55
        /// # Arguments
56
        ///
57
        /// * `reader` - The reader to decode the object from.
58
        ///
59
        /// # Returns
60
        ///
61
        /// A future that resolves to an `io::Result<Self>` indicating the success or failure of the decoding operation.
62
4
        fn decode_async<R>(reader: R) -> impl Future<Output = io::Result<Self>> + Send
63
4
        where
64
4
            Self: Sync,
65
4
            R: AsyncRead + Unpin + Send,
66
4
        {
67
4
            let mut reader = tokio_util::io::SyncIoBridge::new(reader);
68
4
            async { tokio::task::block_in_place(move || Self::decode(&mut reader)) }
69
4
        }
70
    }
71
    /// Implements the `AsyncWireFormatExt` trait for types that implement the `WireFormat` trait and can be sent across threads.
72
    impl<T: WireFormat + Send> AsyncWireFormatExt for T {}
73
}
74

            
75
/// A trait for converting types to and from a wire format.
76
pub trait ConvertWireFormat: WireFormat {
77
    /// Converts the type to a byte representation.
78
    ///
79
    /// # Returns
80
    ///
81
    /// A `Bytes` object representing the byte representation of the type.
82
    fn to_bytes(&self) -> Bytes;
83

            
84
    /// Converts a byte buffer to the type.
85
    ///
86
    /// # Arguments
87
    ///
88
    /// * `buf` - A mutable reference to a `Bytes` object containing the byte buffer.
89
    ///
90
    /// # Returns
91
    ///
92
    /// A `Result` containing the converted type or an `std::io::Error` if the conversion fails.
93
    fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error>;
94

            
95
    /// AsRef<[u8]> for the type.
96
    ///
97
    /// # Returns
98
    ///
99
    /// A reference to the byte representation of the type.
100
    fn as_bytes(&self) -> Vec<u8> {
101
        self.to_bytes().to_vec()
102
    }
103
}
104

            
105
/// Implements the `ConvertWireFormat` trait for types that implement `jetstream_p9::WireFormat`.
106
/// This trait provides methods for converting the type to and from bytes.
107
impl<T> ConvertWireFormat for T
108
where
109
    T: WireFormat,
110
{
111
    /// Converts the type to bytes.
112
    /// Returns a `Bytes` object containing the encoded bytes.
113
2
    fn to_bytes(&self) -> Bytes {
114
2
        let mut buf = vec![];
115
2
        let res = self.encode(&mut buf);
116
2
        if let Err(e) = res {
117
            panic!("Failed to encode: {}", e);
118
2
        }
119
2
        Bytes::from(buf)
120
2
    }
121

            
122
    /// Converts bytes to the type.
123
    /// Returns a `Result` containing the decoded type or an `std::io::Error` if decoding fails.
124
2
    fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error> {
125
2
        let buf = buf.to_vec();
126
2
        T::decode(&mut buf.as_slice())
127
2
    }
128
}
129

            
130
#[cfg(feature = "std")]
131
impl WireFormat for Ipv4Addr {
132
    fn byte_size(&self) -> u32 {
133
        self.octets().len() as u32
134
    }
135

            
136
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
137
        writer.write_all(&self.octets())
138
    }
139

            
140
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
141
        let mut buf = [0u8; 4];
142
        reader.read_exact(&mut buf)?;
143
        Ok(Ipv4Addr::from(buf))
144
    }
145
}
146

            
147
#[cfg(feature = "std")]
148
impl WireFormat for Ipv6Addr {
149
    fn byte_size(&self) -> u32 {
150
        self.octets().len() as u32
151
    }
152

            
153
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
154
        writer.write_all(&self.octets())
155
    }
156

            
157
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
158
        let mut buf = [0u8; 16];
159
        reader.read_exact(&mut buf)?;
160
        Ok(Ipv6Addr::from(buf))
161
    }
162
}
163

            
164
#[cfg(feature = "std")]
165
impl WireFormat for SocketAddrV4 {
166
    fn byte_size(&self) -> u32 {
167
        self.ip().byte_size() + 2
168
    }
169

            
170
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
171
        self.ip().encode(writer)?;
172
        self.port().encode(writer)
173
    }
174

            
175
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
176
        self::Ipv4Addr::decode(reader)
177
            .and_then(|ip| u16::decode(reader).map(|port| SocketAddrV4::new(ip, port)))
178
    }
179
}
180

            
181
#[cfg(feature = "std")]
182
impl WireFormat for SocketAddrV6 {
183
    fn byte_size(&self) -> u32 {
184
        self.ip().byte_size() + 2
185
    }
186

            
187
    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
188
        self.ip().encode(writer)?;
189
        self.port().encode(writer)
190
    }
191

            
192
    fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
193
        self::Ipv6Addr::decode(reader)
194
            .and_then(|ip| u16::decode(reader).map(|port| SocketAddrV6::new(ip, port, 0, 0)))
195
    }
196
}