1
#[cfg(tokio_unix)]
2
use std::ops::{Deref, DerefMut};
3
#[cfg(tokio_unix)]
4
use std::path::PathBuf;
5
use std::{collections::BTreeSet, fmt::Display, net::IpAddr};
6

            
7
use jetstream_wireformat::JetStreamWireFormat;
8
#[cfg(feature = "s2n-quic")]
9
use s2n_quic::stream::BidirectionalStream;
10
#[cfg(tokio_unix)]
11
use tokio::net::{unix::UCred, UnixStream};
12
#[cfg(any(feature = "s2n-quic", feature = "turmoil", tokio_unix))]
13
use tokio_util::codec::Framed;
14
use url::Url;
15

            
16
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
17
pub struct Context {
18
    remote: Option<RemoteAddr>,
19
    peer: Option<Peer>,
20
}
21

            
22
impl Display for Context {
23
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24
        match self.peer {
25
            Some(Peer::NodeId(ref id)) => write!(f, "{}", id.0),
26
            #[cfg(tokio_unix)]
27
            Some(Peer::Unix(ref cred)) => write!(
28
                f,
29
                "{}",
30
                cred.process_path()
31
                    .expect("Failed to get process path")
32
                    .to_string_lossy()
33
            ),
34
            None => write!(f, "None"),
35
        }
36
    }
37
}
38

            
39
impl From<NodeId> for Context {
40
66
    fn from(value: NodeId) -> Self {
41
66
        Context {
42
66
            remote: None,
43
66
            peer: Some(Peer::NodeId(value)),
44
66
        }
45
66
    }
46
}
47

            
48
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49
pub enum RemoteAddr {
50
    #[cfg(tokio_unix)]
51
    UnixAddr(PathBuf),
52
    NodeAddr(NodeAddr),
53
    IpAddr(IpAddr),
54
}
55

            
56
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57
pub enum Peer {
58
    #[cfg(tokio_unix)]
59
    Unix(Unix),
60
    NodeId(NodeId),
61
}
62

            
63
#[derive(Debug, Clone, PartialEq, Eq, Hash, JetStreamWireFormat)]
64
pub struct NodeId(okid::OkId);
65

            
66
#[derive(Debug, Clone, PartialEq, Eq, Hash, JetStreamWireFormat)]
67
pub struct NodeAddr {
68
    id: NodeId,
69
    relay_url: Option<Url>,
70
    direct_addresses: BTreeSet<std::net::SocketAddr>,
71
}
72

            
73
#[cfg(feature = "iroh")]
74
impl From<iroh::PublicKey> for NodeId {
75
66
    fn from(value: iroh::PublicKey) -> Self {
76
66
        NodeId(value.into())
77
66
    }
78
}
79

            
80
#[cfg(feature = "iroh")]
81
impl From<NodeAddr> for iroh::NodeAddr {
82
    fn from(value: NodeAddr) -> Self {
83
        iroh::NodeAddr {
84
            node_id: value
85
                .id
86
                .0
87
                .try_into()
88
                .expect("Failed to convert NodeId to iroh::NodeId"),
89
            relay_url: if value.relay_url.is_some() {
90
                use iroh::RelayUrl;
91
                Some(RelayUrl::from(value.relay_url.unwrap()))
92
            } else {
93
                None
94
            },
95
            direct_addresses: value.direct_addresses.clone(),
96
        }
97
    }
98
}
99

            
100
#[cfg(feature = "iroh")]
101
impl From<iroh::NodeAddr> for NodeAddr {
102
    fn from(value: iroh::NodeAddr) -> Self {
103
        NodeAddr {
104
            id: NodeId(value.node_id.into()),
105
            relay_url: value.relay_url.map(|url| url.into()),
106
            direct_addresses: value.direct_addresses,
107
        }
108
    }
109
}
110

            
111
#[derive(Debug, Clone)]
112
#[cfg(tokio_unix)]
113
pub struct Unix(UCred);
114

            
115
#[cfg(tokio_unix)]
116
impl std::hash::Hash for Unix {
117
    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
118
        if let Some(pid) = self.0.pid() {
119
            H::write_i32(state, pid);
120
        }
121
        H::write_u32(state, self.0.uid());
122
        H::write_u32(state, self.0.gid());
123
    }
124
}
125

            
126
#[cfg(tokio_unix)]
127
impl PartialEq for Unix {
128
    fn eq(&self, other: &Self) -> bool {
129
        self.0.pid() == other.0.pid()
130
            && self.0.uid() == other.0.uid()
131
            && self.0.gid() == other.0.gid()
132
    }
133
}
134

            
135
#[cfg(tokio_unix)]
136
impl Eq for Unix {}
137

            
138
#[cfg(tokio_unix)]
139
impl Deref for Unix {
140
    type Target = UCred;
141

            
142
    fn deref(&self) -> &Self::Target {
143
        &self.0
144
    }
145
}
146

            
147
#[cfg(tokio_unix)]
148
impl DerefMut for Unix {
149
    fn deref_mut(&mut self) -> &mut Self::Target {
150
        &mut self.0
151
    }
152
}
153

            
154
#[cfg(tokio_unix)]
155
impl Unix {
156
    /// returns the process' path
157
    pub fn process_path(&self) -> Result<PathBuf, std::io::Error> {
158
        use std::fs::read_link;
159
        if let Some(pid) = self.pid() {
160
            read_link(format!("/proc/{}/exe", pid))
161
        } else {
162
            Err(std::io::Error::new(
163
                std::io::ErrorKind::NotFound,
164
                "PID not found",
165
            ))
166
        }
167
    }
168
}
169

            
170
pub trait Contextual {
171
    fn context(&self) -> Context;
172
}
173

            
174
#[cfg(tokio_unix)]
175
impl<U> Contextual for Framed<UnixStream, U> {
176
    fn context(&self) -> Context {
177
        let addr = self.get_ref().peer_addr().unwrap();
178
        let ucred = self.get_ref().peer_cred().unwrap();
179
        Context {
180
            remote: Some(RemoteAddr::UnixAddr(
181
                addr.as_pathname()
182
                    .expect("Failed to get path")
183
                    .to_path_buf(),
184
            )),
185
            peer: Some(Peer::Unix(Unix(ucred))),
186
        }
187
    }
188
}
189

            
190
#[cfg(feature = "s2n-quic")]
191
impl<U> Contextual for Framed<BidirectionalStream, U> {
192
200
    fn context(&self) -> Context {
193
200
        let addr = self
194
200
            .get_ref()
195
200
            .connection()
196
200
            .remote_addr()
197
200
            .expect("Failed to get remote address");
198
200
        Context {
199
200
            remote: Some(RemoteAddr::IpAddr(addr.ip())),
200
200
            peer: None,
201
200
        }
202
200
    }
203
}
204

            
205
#[cfg(feature = "turmoil")]
206
impl<U> Contextual for Framed<turmoil::net::TcpStream, U> {
207
2
    fn context(&self) -> Context {
208
2
        let addr = self.get_ref().peer_addr().unwrap();
209
2
        Context {
210
2
            remote: Some(RemoteAddr::IpAddr(addr.ip())),
211
2
            peer: None,
212
2
        }
213
2
    }
214
}
215

            
216
#[cfg(cloudflare)]
217
impl Contextual for worker::Request {
218
    fn context(&self) -> Context {
219
        Context {
220
            remote: None,
221
            peer: None,
222
        }
223
    }
224
}
225

            
226
impl Context {
227
    pub fn new(remote: Option<RemoteAddr>, peer: Option<Peer>) -> Self {
228
        Context { remote, peer }
229
    }
230
}