1
use std::{
2
    io,
3
    marker::PhantomData,
4
    pin::Pin,
5
    task::{Context, Poll},
6
};
7

            
8
use futures::{Sink, Stream};
9
use jetstream_rpc::{Frame, Framer, Protocol};
10
use jetstream_wireformat::{
11
    wire_format_extensions::ConvertWireFormat, WireFormat,
12
};
13
use tungstenite::{Message, WebSocket};
14

            
15
pub struct WebSocketTransport<P: Protocol>(
16
    WebSocket<tungstenite::stream::MaybeTlsStream<std::net::TcpStream>>,
17
    PhantomData<P>,
18
);
19

            
20
impl<P: Protocol>
21
    From<WebSocket<tungstenite::stream::MaybeTlsStream<std::net::TcpStream>>>
22
    for WebSocketTransport<P>
23
{
24
    fn from(
25
        value: WebSocket<
26
            tungstenite::stream::MaybeTlsStream<std::net::TcpStream>,
27
        >,
28
    ) -> Self {
29
        Self(value, PhantomData)
30
    }
31
}
32

            
33
impl<P: Protocol> Sink<jetstream_rpc::Frame<P::Request>>
34
    for WebSocketTransport<P>
35
where
36
    Self: Unpin,
37
{
38
    type Error = io::Error;
39

            
40
    fn poll_ready(
41
        self: Pin<&mut Self>,
42
        _cx: &mut Context<'_>,
43
    ) -> Poll<Result<(), Self::Error>> {
44
        match self.0.can_write() {
45
            true => Poll::Ready(Ok(())),
46
            false => Poll::Pending,
47
        }
48
    }
49

            
50
    fn start_send(
51
        self: Pin<&mut Self>,
52
        item: jetstream_rpc::Frame<P::Request>,
53
    ) -> Result<(), Self::Error> {
54
        self.get_mut()
55
            .0
56
            .send(WebsocketFrame(item).into())
57
            .map_err(io::Error::other)?;
58
        Ok(())
59
    }
60

            
61
    fn poll_flush(
62
        self: Pin<&mut Self>,
63
        _cx: &mut Context<'_>,
64
    ) -> Poll<Result<(), Self::Error>> {
65
        Poll::Ready(self.get_mut().0.flush().map_err(io::Error::other))
66
    }
67

            
68
    fn poll_close(
69
        self: Pin<&mut Self>,
70
        _cx: &mut Context<'_>,
71
    ) -> Poll<Result<(), Self::Error>> {
72
        Poll::Ready(self.get_mut().0.close(None).map_err(io::Error::other))
73
    }
74
}
75

            
76
impl<P: Protocol> Stream for WebSocketTransport<P>
77
where
78
    Self: Unpin,
79
{
80
    type Item = Result<jetstream_rpc::Frame<P::Response>, io::Error>;
81

            
82
    fn poll_next(
83
        self: Pin<&mut Self>,
84
        _cx: &mut Context<'_>,
85
    ) -> Poll<Option<Self::Item>> {
86
        match self.get_mut().0.read() {
87
            Ok(Message::Binary(bytes)) => {
88
                let mut reader = io::Cursor::new(bytes);
89
                let frame = Frame::<P::Response>::decode(&mut reader).unwrap();
90

            
91
                Poll::Ready(Some(Ok(frame)))
92
            }
93
            Err(e) => {
94
                eprintln!("Error reading from websocket: {:?}", e);
95
                Poll::Ready(None)
96
            }
97
            _ => {
98
                eprintln!("Unexpected message type from websocket");
99
                Poll::Ready(None)
100
            }
101
        }
102
    }
103
}
104

            
105
pub struct WebsocketFrame<F: Framer>(Frame<F>);
106

            
107
impl<F: Framer> From<WebsocketFrame<F>> for tungstenite::protocol::Message {
108
    fn from(value: WebsocketFrame<F>) -> Self {
109
        Message::Binary(value.0.to_bytes())
110
    }
111
}