1
#![doc(
2
    html_logo_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png"
3
)]
4
#![doc(
5
    html_favicon_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png"
6
)]
7
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
8
// Copyright (c) 2024, Sevki <s@sevki.io>
9
// Copyright 2018 The ChromiumOS Authors
10
// Use of this source code is governed by a BSD-style license that can be
11
// found in the LICENSE file.
12
use std::{
13
    ffi::{CStr, CString, OsStr},
14
    fmt,
15
    io::{self, ErrorKind, Read, Write},
16
    marker::PhantomData,
17
    mem,
18
    ops::{Deref, DerefMut},
19
    string::String,
20
    vec::Vec,
21
};
22

            
23
use bytes::Buf;
24
pub use jetstream_macros::JetStreamWireFormat;
25
use zerocopy::LittleEndian;
26
pub mod wire_format_extensions;
27

            
28
#[cfg(target_arch = "wasm32")]
29
pub mod wasm;
30

            
31
/// A type that can be encoded on the wire using the 9P protocol.
32
#[cfg(not(target_arch = "wasm32"))]
33
pub trait WireFormat: Send {
34
    /// Returns the number of bytes necessary to fully encode `self`.
35
    fn byte_size(&self) -> u32;
36

            
37
    /// Encodes `self` into `writer`.
38
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> where 
39
        Self: Sized;
40

            
41
    /// Decodes `Self` from `reader`.
42
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self>
43
    where
44
        Self: Sized;
45
}
46

            
47
/// A type that can be encoded on the wire using the 9P protocol.
48
/// WebAssembly doesn't fully support Send, so we don't require it.
49
#[cfg(target_arch = "wasm32")]
50
pub trait WireFormat: std::marker::Sized {
51
    /// Returns the number of bytes necessary to fully encode `self`.
52
    fn byte_size(&self) -> u32;
53

            
54
    /// Encodes `self` into `writer`.
55
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()>;
56

            
57
    /// Decodes `Self` from `reader`.
58
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self>;
59
}
60

            
61
/// A 9P protocol string.
62
///
63
/// The string is always valid UTF-8 and 65535 bytes or less (enforced by `P9String::new()`).
64
///
65
/// It is represented as a C string with a terminating 0 (NUL) character to allow it to be passed
66
/// directly to libc functions.
67
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
68
pub struct P9String {
69
    cstr: CString,
70
}
71

            
72
impl P9String {
73
    pub fn new(string_bytes: impl Into<Vec<u8>>) -> io::Result<Self> {
74
        let string_bytes: Vec<u8> = string_bytes.into();
75

            
76
        if string_bytes.len() > u16::MAX as usize {
77
            return Err(io::Error::new(
78
                ErrorKind::InvalidInput,
79
                "string is too long",
80
            ));
81
        }
82

            
83
        // 9p strings must be valid UTF-8.
84
        let _check_utf8 = std::str::from_utf8(&string_bytes)
85
            .map_err(|e| io::Error::new(ErrorKind::InvalidInput, e))?;
86

            
87
        let cstr = CString::new(string_bytes)
88
            .map_err(|e| io::Error::new(ErrorKind::InvalidInput, e))?;
89

            
90
        Ok(P9String { cstr })
91
    }
92

            
93
    pub fn len(&self) -> usize {
94
        self.cstr.as_bytes().len()
95
    }
96

            
97
    pub fn is_empty(&self) -> bool {
98
        self.cstr.as_bytes().is_empty()
99
    }
100

            
101
    pub fn as_c_str(&self) -> &CStr {
102
        self.cstr.as_c_str()
103
    }
104

            
105
    pub fn as_bytes(&self) -> &[u8] {
106
        self.cstr.as_bytes()
107
    }
108

            
109
    #[cfg(not(target_arch = "wasm32"))]
110
    /// Returns a raw pointer to the string's storage.
111
    ///
112
    /// The string bytes are always followed by a NUL terminator ('\0'), so the pointer can be
113
    /// passed directly to libc functions that expect a C string.
114
    pub fn as_ptr(&self) -> *const libc::c_char {
115
        self.cstr.as_ptr()
116
    }
117

            
118
    #[cfg(target_arch = "wasm32")]
119
    /// Returns a raw pointer to the string's storage.
120
    ///
121
    /// The string bytes are always followed by a NUL terminator ('\0').
122
    /// Note: In WebAssembly, returns a raw pointer but libc is not available.
123
    pub fn as_ptr(&self) -> *const std::os::raw::c_char {
124
        self.cstr.as_ptr()
125
    }
126
}
127

            
128
impl PartialEq<&str> for P9String {
129
    fn eq(&self, other: &&str) -> bool {
130
        self.cstr.as_bytes() == other.as_bytes()
131
    }
132
}
133

            
134
impl TryFrom<&OsStr> for P9String {
135
    type Error = io::Error;
136

            
137
    fn try_from(value: &OsStr) -> io::Result<Self> {
138
        let string_bytes = value.as_encoded_bytes();
139
        Self::new(string_bytes)
140
    }
141
}
142

            
143
// The 9P protocol requires that strings are UTF-8 encoded.  The wire format is a u16
144
// count |N|, encoded in little endian, followed by |N| bytes of UTF-8 data.
145
impl WireFormat for P9String {
146
    fn byte_size(&self) -> u32 {
147
        (mem::size_of::<u16>() + self.len()) as u32
148
    }
149

            
150
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
151
        (self.len() as u16).encode(writer)?;
152
        writer.write_all(self.cstr.as_bytes())
153
    }
154

            
155
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
156
        let len: u16 = WireFormat::decode(reader)?;
157
        let mut string_bytes = vec![0u8; usize::from(len)];
158
        reader.read_exact(&mut string_bytes)?;
159
        Self::new(string_bytes)
160
    }
161
}
162

            
163
// This doesn't really _need_ to be a macro but unfortunately there is no trait bound to
164
// express "can be casted to another type", which means we can't write `T as u8` in a trait
165
// based implementation.  So instead we have this macro, which is implemented for all the
166
// stable unsigned types with the added benefit of not being implemented for the signed
167
// types which are not allowed by the protocol.
168
macro_rules! uint_wire_format_impl {
169
    ($Ty:ty) => {
170
        impl WireFormat for $Ty {
171
56
            fn byte_size(&self) -> u32 {
172
56
                mem::size_of::<$Ty>() as u32
173
56
            }
174

            
175
1046
            fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
176
1046
                writer.write_all(&self.to_le_bytes())
177
1046
            }
178

            
179
1450
            fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
180
1450
                let mut buf = [0; mem::size_of::<$Ty>()];
181
1450
                reader.read_exact(&mut buf)?;
182
                paste::expr! {
183
1450
                    let num: zerocopy::[<$Ty:snake:upper>]<LittleEndian> =  zerocopy::byteorder::[<$Ty:snake:upper>]::from_bytes(buf);
184
1450
                    Ok(num.get())
185
                }
186
1450
            }
187
        }
188
    };
189
}
190
// unsigned integers
191
uint_wire_format_impl!(u16);
192
uint_wire_format_impl!(u32);
193
uint_wire_format_impl!(u64);
194
uint_wire_format_impl!(u128);
195
// signed integers
196
uint_wire_format_impl!(i16);
197
uint_wire_format_impl!(i32);
198
uint_wire_format_impl!(i64);
199
uint_wire_format_impl!(i128);
200

            
201
macro_rules! float_wire_format_impl {
202
    ($Ty:ty) => {
203
        impl WireFormat for $Ty {
204
            fn byte_size(&self) -> u32 {
205
                mem::size_of::<$Ty>() as u32
206
            }
207

            
208
            fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
209
                paste::expr! {
210
                    writer.write_all(&self.to_le_bytes())
211
                }
212
            }
213

            
214
            fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
215
                let mut buf = [0; mem::size_of::<$Ty>()];
216
                reader.read_exact(&mut buf)?;
217
                paste::expr! {
218
                    let num: zerocopy::[<$Ty:snake:upper>]<LittleEndian> =  zerocopy::byteorder::[<$Ty:snake:upper>]::from_bytes(buf);
219
                    Ok(num.get())
220
                }
221
            }
222
        }
223
    };
224
}
225

            
226
float_wire_format_impl!(f32);
227
float_wire_format_impl!(f64);
228

            
229
impl WireFormat for u8 {
230
    fn byte_size(&self) -> u32 {
231
        1
232
    }
233

            
234
432
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
235
432
        writer.write_all(&[*self])
236
432
    }
237

            
238
20
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
239
20
        let mut byte = [0u8; 1];
240
20
        reader.read_exact(&mut byte)?;
241
20
        Ok(byte[0])
242
20
    }
243
}
244

            
245
impl WireFormat for usize {
246
    fn byte_size(&self) -> u32 {
247
        mem::size_of::<usize>() as u32
248
    }
249

            
250
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
251
        writer.write_all(&self.to_le_bytes())
252
    }
253

            
254
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
255
        let mut buf = [0; mem::size_of::<usize>()];
256
        reader.read_exact(&mut buf)?;
257
        Ok(usize::from_le_bytes(buf))
258
    }
259
}
260

            
261
impl WireFormat for isize {
262
    fn byte_size(&self) -> u32 {
263
        mem::size_of::<isize>() as u32
264
    }
265

            
266
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
267
        writer.write_all(&self.to_le_bytes())
268
    }
269

            
270
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
271
        let mut buf = [0; mem::size_of::<isize>()];
272
        reader.read_exact(&mut buf)?;
273
        Ok(isize::from_le_bytes(buf))
274
    }
275
}
276

            
277
// The 9P protocol requires that strings are UTF-8 encoded.  The wire format is a u16
278
// count |N|, encoded in little endian, followed by |N| bytes of UTF-8 data.
279
impl WireFormat for String {
280
540
    fn byte_size(&self) -> u32 {
281
540
        (mem::size_of::<u16>() + self.len()) as u32
282
540
    }
283

            
284
212
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
285
212
        if self.len() > u16::MAX as usize {
286
            return Err(io::Error::new(
287
                ErrorKind::InvalidInput,
288
                "string is too long",
289
            ));
290
212
        }
291
212

            
292
212
        (self.len() as u16).encode(writer)?;
293
212
        writer.write_all(self.as_bytes())
294
212
    }
295

            
296
214
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
297
214
        let len: u16 = WireFormat::decode(reader)?;
298
214
        let mut result = String::with_capacity(len as usize);
299
214
        reader.take(len as u64).read_to_string(&mut result)?;
300
214
        Ok(result)
301
214
    }
302
}
303

            
304
// The wire format for repeated types is similar to that of strings: a little endian
305
// encoded u16 |N|, followed by |N| instances of the given type.
306
impl<T: WireFormat> WireFormat for Vec<T> {
307
    fn byte_size(&self) -> u32 {
308
        mem::size_of::<u16>() as u32
309
            + self.iter().map(|elem| elem.byte_size()).sum::<u32>()
310
    }
311

            
312
4
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
313
4
        if self.len() > u16::MAX as usize {
314
            return Err(std::io::Error::new(
315
                std::io::ErrorKind::InvalidInput,
316
                "too many elements in vector",
317
            ));
318
4
        }
319
4

            
320
4
        (self.len() as u16).encode(writer)?;
321
22
        for elem in self {
322
18
            elem.encode(writer)?;
323
        }
324

            
325
4
        Ok(())
326
4
    }
327

            
328
2
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
329
2
        let len: u16 = WireFormat::decode(reader)?;
330
2
        let mut result = Vec::with_capacity(len as usize);
331
2

            
332
2
        for _ in 0..len {
333
10
            result.push(WireFormat::decode(reader)?);
334
        }
335

            
336
2
        Ok(result)
337
2
    }
338
}
339

            
340
/// A type that encodes an arbitrary number of bytes of data.  Typically used for Rread
341
/// Twrite messages.  This differs from a `Vec<u8>` in that it encodes the number of bytes
342
/// using a `u32` instead of a `u16`.
343
#[derive(PartialEq, Eq, Clone)]
344
#[repr(transparent)]
345
#[cfg_attr(feature = "testing", derive(serde::Serialize, serde::Deserialize))]
346
pub struct Data(pub Vec<u8>);
347

            
348
// The maximum length of a data buffer that we support.  In practice the server's max message
349
// size should prevent us from reading too much data so this check is mainly to ensure a
350
// malicious client cannot trick us into allocating massive amounts of memory.
351
const MAX_DATA_LENGTH: u32 = 32 * 1024 * 1024;
352

            
353
impl fmt::Debug for Data {
354
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
355
        // There may be a lot of data and we don't want to spew it all out in a trace.  Instead
356
        // just print out the number of bytes in the buffer.
357
        write!(f, "Data({} bytes)", self.len())
358
    }
359
}
360

            
361
// Implement Deref and DerefMut so that we don't have to use self.0 everywhere.
362
impl Deref for Data {
363
    type Target = Vec<u8>;
364

            
365
    fn deref(&self) -> &Self::Target {
366
        &self.0
367
    }
368
}
369
impl DerefMut for Data {
370
    fn deref_mut(&mut self) -> &mut Self::Target {
371
        &mut self.0
372
    }
373
}
374

            
375
// Same as Vec<u8> except that it encodes the length as a u32 instead of a u16.
376
impl WireFormat for Data {
377
    fn byte_size(&self) -> u32 {
378
        mem::size_of::<u32>() as u32
379
            + self.iter().map(|elem| elem.byte_size()).sum::<u32>()
380
    }
381

            
382
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
383
        if self.len() > u32::MAX as usize {
384
            return Err(std::io::Error::new(
385
                std::io::ErrorKind::InvalidInput,
386
                "data is too large",
387
            ));
388
        }
389
        (self.len() as u32).encode(writer)?;
390
        writer.write_all(self)
391
    }
392

            
393
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
394
        let len: u32 = WireFormat::decode(reader)?;
395
        if len > MAX_DATA_LENGTH {
396
            return Err(std::io::Error::new(
397
                std::io::ErrorKind::InvalidData,
398
                format!("data length ({} bytes) is too large", len),
399
            ));
400
        }
401

            
402
        let mut buf = Vec::with_capacity(len as usize);
403
        reader.take(len as u64).read_to_end(&mut buf)?;
404

            
405
        if buf.len() == len as usize {
406
            Ok(Data(buf))
407
        } else {
408
            Err(io::Error::new(
409
                std::io::ErrorKind::UnexpectedEof,
410
                format!(
411
                    "unexpected end of data: want: {} bytes, got: {} bytes",
412
                    len,
413
                    buf.len()
414
                ),
415
            ))
416
        }
417
    }
418
}
419

            
420
impl<T> WireFormat for Option<T>
421
where
422
    T: WireFormat,
423
{
424
    fn byte_size(&self) -> u32 {
425
        1 + match self {
426
            None => 0,
427
            Some(value) => value.byte_size(),
428
        }
429
    }
430

            
431
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
432
        match self {
433
            None => WireFormat::encode(&0u8, writer),
434
            Some(value) => {
435
                WireFormat::encode(&1u8, writer)?;
436
                WireFormat::encode(value, writer)
437
            }
438
        }
439
    }
440

            
441
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
442
        let tag: u8 = WireFormat::decode(reader)?;
443
        match tag {
444
            0 => Ok(None),
445
            1 => Ok(Some(WireFormat::decode(reader)?)),
446
            _ => Err(io::Error::new(
447
                io::ErrorKind::InvalidData,
448
                format!("Invalid Option tag: {}", tag),
449
            )),
450
        }
451
    }
452
}
453

            
454
impl WireFormat for () {
455
7
    fn byte_size(&self) -> u32 {
456
7
        0
457
7
    }
458

            
459
2
    fn encode<W: Write>(&self, _writer: &mut W) -> io::Result<()> {
460
2
        Ok(())
461
2
    }
462

            
463
2
    fn decode<R: Read>(_reader: &mut R) -> io::Result<Self> {
464
2
        Ok(())
465
2
    }
466
}
467

            
468
impl WireFormat for bool {
469
    fn byte_size(&self) -> u32 {
470
        1
471
    }
472

            
473
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
474
        writer.write_all(&[*self as u8])
475
    }
476

            
477
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
478
        let mut byte = [0u8; 1];
479
        reader.read_exact(&mut byte)?;
480
        match byte[0] {
481
            0 => Ok(false),
482
            1 => Ok(true),
483
            _ => Err(io::Error::new(
484
                io::ErrorKind::InvalidData,
485
                "invalid byte for bool",
486
            )),
487
        }
488
    }
489
}
490

            
491
impl io::Read for Data {
492
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
493
        self.0.reader().read(buf)
494
    }
495
}
496

            
497
#[repr(transparent)]
498
pub struct Wrapped<T, I>(pub T, PhantomData<I>);
499

            
500
impl<T, I> Wrapped<T, I> {
501
    pub fn new(value: T) -> Self {
502
        Wrapped(value, PhantomData)
503
    }
504
}
505

            
506
#[cfg(not(target_arch = "wasm32"))]
507
impl<T, I> WireFormat for Wrapped<T, I>
508
where
509
    T: Send + std::convert::AsRef<I>,
510
    I: WireFormat + std::convert::Into<T>,
511
{
512
    fn byte_size(&self) -> u32 {
513
        AsRef::<I>::as_ref(&self.0).byte_size()
514
    }
515

            
516
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
517
        AsRef::<I>::as_ref(&self.0).encode(writer)
518
    }
519

            
520
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
521
        let inner = I::decode(reader)?;
522
        Ok(Wrapped(inner.into(), PhantomData))
523
    }
524
}
525

            
526
#[cfg(target_arch = "wasm32")]
527
impl<T, I> WireFormat for Wrapped<T, I>
528
where
529
    T: std::convert::AsRef<I>,
530
    I: WireFormat + std::convert::Into<T>,
531
{
532
    fn byte_size(&self) -> u32 {
533
        AsRef::<I>::as_ref(&self.0).byte_size()
534
    }
535

            
536
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
537
        AsRef::<I>::as_ref(&self.0).encode(writer)
538
    }
539

            
540
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
541
        let inner = I::decode(reader)?;
542
        Ok(Wrapped(inner.into(), PhantomData))
543
    }
544
}
545

            
546

            
547

            
548
impl<T:WireFormat> WireFormat for Box<T> {
549
    fn byte_size(&self) -> u32 {
550
        (**self).byte_size()
551
    }
552

            
553
    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> where 
554
        Self: Sized {
555
        (**self).encode(writer)
556
    }
557

            
558
    fn decode<R: Read>(reader: &mut R) -> io::Result<Self>
559
    where
560
        Self: Sized {
561
        let inner = T::decode(reader)?;
562
        Ok(Box::new(inner))
563
    }
564
}