1
#![doc(html_logo_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png")]
2
#![doc(
3
    html_favicon_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png"
4
)]
5
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
6
// Copyright (c) 2024, Sevki <s@sevki.io>
7
// Copyright 2018 The ChromiumOS Authors
8
// Use of this source code is governed by a BSD-style license that can be
9
// found in the LICENSE file.
10
pub use jetstream_macros::JetStreamWireFormat;
11

            
12
use {
13
    bytes::Buf,
14
    std::{
15
        ffi::{CStr, CString, OsStr},
16
        fmt,
17
        io::{self, ErrorKind, Read, Write},
18
        mem,
19
        ops::{Deref, DerefMut},
20
        string::String,
21
        vec::Vec,
22
    },
23
    zerocopy::LittleEndian,
24
};
25
pub mod wire_format_extensions;
26

            
27
/// A type that can be encoded on the wire using the 9P protocol.
28
pub trait WireFormat: std::marker::Sized + Send {
29
    /// Returns the number of bytes necessary to fully encode `self`.
30
    fn byte_size(&self) -> u32;
31

            
32
    /// Encodes `self` into `writer`.
33
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()>;
34

            
35
    /// Decodes `Self` from `reader`.
36
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self>;
37
}
38

            
39
/// A 9P protocol string.
40
///
41
/// The string is always valid UTF-8 and 65535 bytes or less (enforced by `P9String::new()`).
42
///
43
/// It is represented as a C string with a terminating 0 (NUL) character to allow it to be passed
44
/// directly to libc functions.
45
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
46
pub struct P9String {
47
    cstr: CString,
48
}
49

            
50
impl P9String {
51
    pub fn new(string_bytes: impl Into<Vec<u8>>) -> io::Result<Self> {
52
        let string_bytes: Vec<u8> = string_bytes.into();
53

            
54
        if string_bytes.len() > u16::MAX as usize {
55
            return Err(io::Error::new(
56
                ErrorKind::InvalidInput,
57
                "string is too long",
58
            ));
59
        }
60

            
61
        // 9p strings must be valid UTF-8.
62
        let _check_utf8 = std::str::from_utf8(&string_bytes)
63
            .map_err(|e| io::Error::new(ErrorKind::InvalidInput, e))?;
64

            
65
        let cstr =
66
            CString::new(string_bytes).map_err(|e| io::Error::new(ErrorKind::InvalidInput, e))?;
67

            
68
        Ok(P9String { cstr })
69
    }
70

            
71
    pub fn len(&self) -> usize {
72
        self.cstr.as_bytes().len()
73
    }
74

            
75
    pub fn is_empty(&self) -> bool {
76
        self.cstr.as_bytes().is_empty()
77
    }
78

            
79
    pub fn as_c_str(&self) -> &CStr {
80
        self.cstr.as_c_str()
81
    }
82

            
83
    pub fn as_bytes(&self) -> &[u8] {
84
        self.cstr.as_bytes()
85
    }
86
    #[cfg(not(target_arch = "wasm32"))]
87
    /// Returns a raw pointer to the string's storage.
88
    ///
89
    /// The string bytes are always followed by a NUL terminator ('\0'), so the pointer can be
90
    /// passed directly to libc functions that expect a C string.
91
    pub fn as_ptr(&self) -> *const libc::c_char {
92
        self.cstr.as_ptr()
93
    }
94
}
95

            
96
impl PartialEq<&str> for P9String {
97
    fn eq(&self, other: &&str) -> bool {
98
        self.cstr.as_bytes() == other.as_bytes()
99
    }
100
}
101

            
102
impl TryFrom<&OsStr> for P9String {
103
    type Error = io::Error;
104

            
105
    fn try_from(value: &OsStr) -> io::Result<Self> {
106
        let string_bytes = value.as_encoded_bytes();
107
        Self::new(string_bytes)
108
    }
109
}
110

            
111
// The 9P protocol requires that strings are UTF-8 encoded.  The wire format is a u16
112
// count |N|, encoded in little endian, followed by |N| bytes of UTF-8 data.
113
impl WireFormat for P9String {
114
    fn byte_size(&self) -> u32 {
115
        (mem::size_of::<u16>() + self.len()) as u32
116
    }
117

            
118
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
119
        (self.len() as u16).encode(writer)?;
120
        writer.write_all(self.cstr.as_bytes())
121
    }
122

            
123
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
124
        let len: u16 = WireFormat::decode(reader)?;
125
        let mut string_bytes = vec![0u8; usize::from(len)];
126
        reader.read_exact(&mut string_bytes)?;
127
        Self::new(string_bytes)
128
    }
129
}
130

            
131
// This doesn't really _need_ to be a macro but unfortunately there is no trait bound to
132
// express "can be casted to another type", which means we can't write `T as u8` in a trait
133
// based implementation.  So instead we have this macro, which is implemented for all the
134
// stable unsigned types with the added benefit of not being implemented for the signed
135
// types which are not allowed by the protocol.
136
macro_rules! uint_wire_format_impl {
137
    ($Ty:ty) => {
138
        impl WireFormat for $Ty {
139
40
            fn byte_size(&self) -> u32 {
140
40
                mem::size_of::<$Ty>() as u32
141
40
            }
142

            
143
1114
            fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
144
1114
                writer.write_all(&self.to_le_bytes())
145
1114
            }
146

            
147
1542
            fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
148
1542
                let mut buf = [0; mem::size_of::<$Ty>()];
149
1542
                reader.read_exact(&mut buf)?;
150
                paste::expr! {
151
1542
                    let num: zerocopy::[<$Ty:snake:upper>]<LittleEndian> =  zerocopy::byteorder::[<$Ty:snake:upper>]::from_bytes(buf);
152
1542
                    Ok(num.get())
153
                }
154
1542
            }
155
        }
156
    };
157
}
158
// unsigned integers
159
uint_wire_format_impl!(u16);
160
uint_wire_format_impl!(u32);
161
uint_wire_format_impl!(u64);
162
uint_wire_format_impl!(u128);
163
// signed integers
164
uint_wire_format_impl!(i16);
165
uint_wire_format_impl!(i32);
166
uint_wire_format_impl!(i64);
167
uint_wire_format_impl!(i128);
168

            
169
macro_rules! float_wire_format_impl {
170
    ($Ty:ty) => {
171
        impl WireFormat for $Ty {
172
            fn byte_size(&self) -> u32 {
173
                mem::size_of::<$Ty>() as u32
174
            }
175

            
176
            fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
177
                paste::expr! {
178
                    writer.write_all(&self.to_le_bytes())
179
                }
180
            }
181

            
182
            fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
183
                let mut buf = [0; mem::size_of::<$Ty>()];
184
                reader.read_exact(&mut buf)?;
185
                paste::expr! {
186
                    let num: zerocopy::[<$Ty:snake:upper>]<LittleEndian> =  zerocopy::byteorder::[<$Ty:snake:upper>]::from_bytes(buf);
187
                    Ok(num.get())
188
                }
189
            }
190
        }
191
    };
192
}
193

            
194
float_wire_format_impl!(f32);
195
float_wire_format_impl!(f64);
196

            
197
impl WireFormat for u8 {
198
5
    fn byte_size(&self) -> u32 {
199
5
        1
200
5
    }
201

            
202
420
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
203
420
        writer.write_all(&[*self])
204
420
    }
205

            
206
12
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
207
12
        let mut byte = [0u8; 1];
208
12
        reader.read_exact(&mut byte)?;
209
12
        Ok(byte[0])
210
12
    }
211
}
212

            
213
impl WireFormat for usize {
214
    fn byte_size(&self) -> u32 {
215
        mem::size_of::<usize>() as u32
216
    }
217

            
218
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
219
        writer.write_all(&self.to_le_bytes())
220
    }
221

            
222
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
223
        let mut buf = [0; mem::size_of::<usize>()];
224
        reader.read_exact(&mut buf)?;
225
        Ok(usize::from_le_bytes(buf))
226
    }
227
}
228

            
229
impl WireFormat for isize {
230
    fn byte_size(&self) -> u32 {
231
        mem::size_of::<isize>() as u32
232
    }
233

            
234
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
235
        writer.write_all(&self.to_le_bytes())
236
    }
237

            
238
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
239
        let mut buf = [0; mem::size_of::<isize>()];
240
        reader.read_exact(&mut buf)?;
241
        Ok(isize::from_le_bytes(buf))
242
    }
243
}
244

            
245
// The 9P protocol requires that strings are UTF-8 encoded.  The wire format is a u16
246
// count |N|, encoded in little endian, followed by |N| bytes of UTF-8 data.
247
impl WireFormat for String {
248
327
    fn byte_size(&self) -> u32 {
249
327
        (mem::size_of::<u16>() + self.len()) as u32
250
327
    }
251

            
252
224
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
253
224
        if self.len() > u16::MAX as usize {
254
2
            return Err(io::Error::new(
255
2
                ErrorKind::InvalidInput,
256
2
                "string is too long",
257
2
            ));
258
222
        }
259
222

            
260
222
        (self.len() as u16).encode(writer)?;
261
222
        writer.write_all(self.as_bytes())
262
224
    }
263

            
264
232
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
265
232
        let len: u16 = WireFormat::decode(reader)?;
266
232
        let mut result = String::with_capacity(len as usize);
267
232
        reader.take(len as u64).read_to_string(&mut result)?;
268
222
        Ok(result)
269
232
    }
270
}
271

            
272
// The wire format for repeated types is similar to that of strings: a little endian
273
// encoded u16 |N|, followed by |N| instances of the given type.
274
impl<T: WireFormat> WireFormat for Vec<T> {
275
    fn byte_size(&self) -> u32 {
276
        mem::size_of::<u16>() as u32 + self.iter().map(|elem| elem.byte_size()).sum::<u32>()
277
    }
278

            
279
10
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
280
10
        if self.len() > u16::MAX as usize {
281
2
            return Err(std::io::Error::new(
282
2
                std::io::ErrorKind::InvalidInput,
283
2
                "too many elements in vector",
284
2
            ));
285
8
        }
286
8

            
287
8
        (self.len() as u16).encode(writer)?;
288
46
        for elem in self {
289
38
            elem.encode(writer)?;
290
        }
291

            
292
8
        Ok(())
293
10
    }
294

            
295
8
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
296
8
        let len: u16 = WireFormat::decode(reader)?;
297
8
        let mut result = Vec::with_capacity(len as usize);
298
8

            
299
8
        for _ in 0..len {
300
52
            result.push(WireFormat::decode(reader)?);
301
        }
302

            
303
8
        Ok(result)
304
8
    }
305
}
306

            
307
/// A type that encodes an arbitrary number of bytes of data.  Typically used for Rread
308
/// Twrite messages.  This differs from a `Vec<u8>` in that it encodes the number of bytes
309
/// using a `u32` instead of a `u16`.
310
#[derive(PartialEq, Eq, Clone)]
311
#[repr(transparent)]
312
#[cfg_attr(feature = "testing", derive(serde::Serialize, serde::Deserialize))]
313
pub struct Data(pub Vec<u8>);
314

            
315
// The maximum length of a data buffer that we support.  In practice the server's max message
316
// size should prevent us from reading too much data so this check is mainly to ensure a
317
// malicious client cannot trick us into allocating massive amounts of memory.
318
const MAX_DATA_LENGTH: u32 = 32 * 1024 * 1024;
319

            
320
impl fmt::Debug for Data {
321
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
322
        // There may be a lot of data and we don't want to spew it all out in a trace.  Instead
323
        // just print out the number of bytes in the buffer.
324
        write!(f, "Data({} bytes)", self.len())
325
    }
326
}
327

            
328
// Implement Deref and DerefMut so that we don't have to use self.0 everywhere.
329
impl Deref for Data {
330
    type Target = Vec<u8>;
331
175
    fn deref(&self) -> &Self::Target {
332
175
        &self.0
333
175
    }
334
}
335
impl DerefMut for Data {
336
    fn deref_mut(&mut self) -> &mut Self::Target {
337
        &mut self.0
338
    }
339
}
340

            
341
// Same as Vec<u8> except that it encodes the length as a u32 instead of a u16.
342
impl WireFormat for Data {
343
    fn byte_size(&self) -> u32 {
344
        mem::size_of::<u32>() as u32 + self.iter().map(|elem| elem.byte_size()).sum::<u32>()
345
    }
346

            
347
6
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
348
6
        if self.len() > u32::MAX as usize {
349
            return Err(std::io::Error::new(
350
                std::io::ErrorKind::InvalidInput,
351
                "data is too large",
352
            ));
353
6
        }
354
6
        (self.len() as u32).encode(writer)?;
355
6
        writer.write_all(self)
356
6
    }
357

            
358
6
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
359
6
        let len: u32 = WireFormat::decode(reader)?;
360
6
        if len > MAX_DATA_LENGTH {
361
            return Err(std::io::Error::new(
362
                std::io::ErrorKind::InvalidData,
363
                format!("data length ({} bytes) is too large", len),
364
            ));
365
6
        }
366
6

            
367
6
        let mut buf = Vec::with_capacity(len as usize);
368
6
        reader.take(len as u64).read_to_end(&mut buf)?;
369

            
370
6
        if buf.len() == len as usize {
371
6
            Ok(Data(buf))
372
        } else {
373
            Err(io::Error::new(
374
                std::io::ErrorKind::UnexpectedEof,
375
                format!(
376
                    "unexpected end of data: want: {} bytes, got: {} bytes",
377
                    len,
378
                    buf.len()
379
                ),
380
            ))
381
        }
382
6
    }
383
}
384

            
385
impl<T> WireFormat for Option<T>
386
where
387
    T: WireFormat,
388
{
389
8
    fn byte_size(&self) -> u32 {
390
8
        1 + match self {
391
2
            None => 0,
392
6
            Some(value) => value.byte_size(),
393
        }
394
8
    }
395

            
396
8
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
397
8
        match self {
398
2
            None => WireFormat::encode(&0u8, writer),
399
6
            Some(value) => {
400
6
                WireFormat::encode(&1u8, writer)?;
401
6
                WireFormat::encode(value, writer)
402
            }
403
        }
404
8
    }
405

            
406
10
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
407
10
        let tag: u8 = WireFormat::decode(reader)?;
408
10
        match tag {
409
2
            0 => Ok(None),
410
6
            1 => Ok(Some(WireFormat::decode(reader)?)),
411
            _ => {
412
2
                Err(io::Error::new(
413
2
                    io::ErrorKind::InvalidData,
414
2
                    format!("Invalid Option tag: {}", tag),
415
2
                ))
416
            }
417
        }
418
10
    }
419
}
420

            
421
impl WireFormat for () {
422
5
    fn byte_size(&self) -> u32 {
423
5
        0
424
5
    }
425

            
426
2
    fn encode<W: Write>(&self, _writer: &mut W) -> io::Result<()> {
427
2
        Ok(())
428
2
    }
429

            
430
2
    fn decode<R: Read>(_reader: &mut R) -> io::Result<Self> {
431
2
        Ok(())
432
2
    }
433
}
434

            
435
impl WireFormat for bool {
436
    fn byte_size(&self) -> u32 {
437
        1
438
    }
439

            
440
4
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
441
4
        writer.write_all(&[*self as u8])
442
4
    }
443

            
444
6
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
445
6
        let mut byte = [0u8; 1];
446
6
        reader.read_exact(&mut byte)?;
447
6
        match byte[0] {
448
2
            0 => Ok(false),
449
2
            1 => Ok(true),
450
            _ => {
451
2
                Err(io::Error::new(
452
2
                    io::ErrorKind::InvalidData,
453
2
                    "invalid byte for bool",
454
2
                ))
455
            }
456
        }
457
6
    }
458
}
459

            
460
impl io::Read for Data {
461
5
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
462
5
        self.0.reader().read(buf)
463
5
    }
464
}