1
use std::{net::SocketAddr, path::Path, sync::Arc};
2

            
3
use echo_protocol::EchoChannel;
4
use jetstream::prelude::*;
5
use jetstream_macros::service;
6
use jetstream_quic::{Client, QuicTransport, Router, Server};
7
use jetstream_rpc::Protocol;
8

            
9
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
10

            
11
#[service]
12
pub trait Echo {
13
    async fn ping(&mut self) -> Result<()>;
14
}
15

            
16
#[derive(Clone)]
17
struct EchoImpl {}
18

            
19
impl Echo for EchoImpl {
20
    async fn ping(&mut self) -> Result<()> {
21
        eprintln!("Ping received");
22
        eprintln!("Pong sent");
23
        Ok(())
24
    }
25
}
26

            
27
pub static CA_CERT_PEM: &str =
28
    concat!(env!("CARGO_MANIFEST_DIR"), "/certs/ca.pem");
29
pub static CLIENT_CERT_PEM: &str =
30
    concat!(env!("CARGO_MANIFEST_DIR"), "/certs/client.pem");
31
pub static CLIENT_KEY_PEM: &str =
32
    concat!(env!("CARGO_MANIFEST_DIR"), "/certs/client.key");
33
pub static SERVER_CERT_PEM: &str =
34
    concat!(env!("CARGO_MANIFEST_DIR"), "/certs/server.pem");
35
pub static SERVER_KEY_PEM: &str =
36
    concat!(env!("CARGO_MANIFEST_DIR"), "/certs/server.key");
37

            
38
fn load_certs(path: &str) -> Vec<CertificateDer<'static>> {
39
    let data = std::fs::read(Path::new(path)).expect("Failed to read cert");
40
    rustls_pemfile::certs(&mut &*data)
41
        .filter_map(|r| r.ok())
42
        .collect()
43
}
44

            
45
fn load_key(path: &str) -> PrivateKeyDer<'static> {
46
    let data = std::fs::read(Path::new(path)).expect("Failed to read key");
47
    rustls_pemfile::private_key(&mut &*data)
48
        .expect("Failed to parse key")
49
        .expect("No key found")
50
}
51

            
52
async fn server(
53
    addr: SocketAddr,
54
) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
55
    let server_cert = load_certs(SERVER_CERT_PEM).pop().unwrap();
56
    let server_key = load_key(SERVER_KEY_PEM);
57
    let ca_cert = load_certs(CA_CERT_PEM).pop().unwrap();
58

            
59
    // Register the EchoService as a protocol handler
60
    // jetstream_quic's ProtocolHandler is auto-implemented for jetstream_rpc::Server
61
    let echo_service = echo_protocol::EchoService { inner: EchoImpl {} };
62

            
63
    let mut router = Router::new();
64
    router.register(Arc::new(echo_service));
65

            
66
    let server =
67
        Server::new_with_mtls(server_cert, server_key, ca_cert, addr, router);
68

            
69
    eprintln!("Server listening on {}", addr);
70
    server.run().await;
71

            
72
    Ok(())
73
}
74

            
75
async fn client(
76
    addr: SocketAddr,
77
) -> std::result::Result<(), Box<dyn std::error::Error + Send + Sync>> {
78
    // Wait for server to start
79
    tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
80

            
81
    let ca_cert = load_certs(CA_CERT_PEM).pop().unwrap();
82
    let client_cert = load_certs(CLIENT_CERT_PEM).pop().unwrap();
83
    let client_key = load_key(CLIENT_KEY_PEM);
84

            
85
    // Use the protocol version as ALPN
86
    let alpn = vec![EchoChannel::VERSION.as_bytes().to_vec()];
87
    let client = Client::new_with_mtls(ca_cert, client_cert, client_key, alpn)?;
88

            
89
    let connection = client.connect(addr, "localhost").await?;
90

            
91
    // Open a bidirectional stream and wrap it in QuicTransport
92
    let (send, recv) = connection.open_bi().await?;
93
    let transport: QuicTransport<EchoChannel> = (send, recv).into();
94
    let mut chan = EchoChannel::new(10, Box::new(transport));
95

            
96
    eprintln!("Ping sent");
97
    chan.ping().await?;
98
    eprintln!("Pong received");
99

            
100
    Ok(())
101
}
102

            
103
#[tokio::main]
104
async fn main() {
105
    // Install the ring crypto provider for rustls
106
    rustls::crypto::ring::default_provider()
107
        .install_default()
108
        .ok();
109

            
110
    let addr: SocketAddr = "127.0.0.1:4433".parse().unwrap();
111
    tokio::select! {
112
      _ = server(addr) => {},
113
      _ = client(addr) => {},
114
    }
115
}