melvin_ob/console_communication/
console_endpoint.rs

1use super::melvin_messages;
2use crate::{info, warn};
3use prost::Message;
4use std::{
5    io::{Cursor, ErrorKind},
6    sync::Arc,
7};
8use tokio::{
9    io::{AsyncReadExt, AsyncWriteExt},
10    net::{
11        TcpListener,
12        tcp::{ReadHalf, WriteHalf},
13    },
14    sync::{broadcast, oneshot},
15};
16
17/// Represents the different console endpoint event types.
18///
19/// # Variants
20/// - `Connected`: Indicates a new console connection.
21/// - `Disconnected`: Indicates when a console connection is closed.
22/// - `Message`: Represents an upstream message received from the console.
23#[derive(Debug, Clone)]
24pub enum ConsoleEvent {
25    Connected,
26    Disconnected,
27    Message(melvin_messages::UpstreamContent),
28}
29
30/// The `ConsoleEndpoint` handles communication with MELVINs operator console.
31pub(crate) struct ConsoleEndpoint {
32    /// Used to send downstream messages to connected consoles.
33    downstream: broadcast::Sender<Option<Arc<Vec<u8>>>>,
34    /// Used to broadcast upstream events from consoles.
35    upstream_event: broadcast::Sender<ConsoleEvent>,
36    /// A channel sender to trigger endpoint shutdown.
37    close_oneshot: Option<oneshot::Sender<()>>,
38}
39
40impl ConsoleEndpoint {
41    /// Handles incoming data from the connected console. It listens for messages
42    /// and broadcasts them as upstream events.
43    ///
44    /// # Arguments
45    /// - `socket`: The reading end of the connection.
46    /// - `upstream_event_sender`: The sender used to broadcast received upstream events.
47    ///
48    /// # Errors
49    /// Returns I/O errors if issues arise when reading data from the socket.
50    async fn handle_connection_rx(
51        socket: &mut ReadHalf<'_>,
52        upstream_event_sender: &broadcast::Sender<ConsoleEvent>,
53    ) -> Result<(), std::io::Error> {
54        loop {
55            let length = socket.read_u32().await?;
56
57            let mut buffer = vec![0u8; length as usize];
58            socket.read_exact(&mut buffer).await?;
59
60            if let Ok(melvin_messages::Upstream { content: Some(content) }) =
61                melvin_messages::Upstream::decode(&mut Cursor::new(buffer))
62            {
63                info!("Received upstream message: {content:?}");
64                upstream_event_sender.send(ConsoleEvent::Message(content)).unwrap();
65            }
66        }
67    }
68
69    /// Handles sending downstream messages to the connected console. It listens to a receiver
70    /// for messages and sends them to the console.
71    ///
72    /// # Arguments
73    /// - `socket`: The write end of the connection.
74    /// - `downstream_receiver`: A receiver to get downstream messages.
75    ///
76    /// # Errors
77    /// Returns I/O errors if issues arise when sending data to the socket.
78    #[allow(clippy::cast_possible_truncation)]
79    async fn handle_connection_tx(
80        socket: &mut WriteHalf<'_>,
81        downstream_receiver: &mut broadcast::Receiver<Option<Arc<Vec<u8>>>>,
82    ) -> Result<(), std::io::Error> {
83        while let Ok(Some(message_buffer)) = downstream_receiver.recv().await {
84            socket.write_u32(message_buffer.len() as u32).await?;
85            socket.write_all(&message_buffer).await?;
86        }
87
88        Ok(())
89    }
90
91    /// Starts the `ConsoleEndpoint`, binding to a TCP listener and handling new connections.
92    ///
93    /// # Returns
94    /// An instance of `ConsoleEndpoint`.
95    ///
96    /// # Notes
97    /// This method spawns an asynchronous task to listen for and handle incoming connections.
98    pub(crate) fn start() -> Self {
99        let downstream_sender = broadcast::Sender::new(5);
100        let upstream_event_sender = broadcast::Sender::new(5);
101        let (close_oneshot_sender, mut close_oneshot_receiver) = oneshot::channel();
102        let inst = Self {
103            downstream: downstream_sender.clone(),
104            upstream_event: upstream_event_sender.clone(),
105            close_oneshot: Some(close_oneshot_sender),
106        };
107        tokio::spawn(async move {
108            info!("Started Console Endpoint");
109            let listener = TcpListener::bind("0.0.0.0:1337").await.unwrap();
110            loop {
111                let accept = tokio::select! {
112                    accept = listener.accept() => accept,
113                    _ = &mut close_oneshot_receiver => break
114                };
115
116                if let Ok((mut socket, _)) = accept {
117                    let upstream_event_sender_local = upstream_event_sender.clone();
118                    upstream_event_sender_local.send(ConsoleEvent::Connected).unwrap();
119                    let mut downstream_receiver = downstream_sender.subscribe();
120
121                    tokio::spawn(async move {
122                        info!("New connection from console");
123                        let (mut rx_socket, mut tx_socket) = socket.split();
124
125                        let result = tokio::select! {
126                            res = ConsoleEndpoint::handle_connection_tx(&mut tx_socket, &mut downstream_receiver) => res,
127                            res = ConsoleEndpoint::handle_connection_rx(&mut rx_socket, &upstream_event_sender_local) => res
128                        };
129
130                        upstream_event_sender_local.send(ConsoleEvent::Disconnected).unwrap();
131                        match result {
132                            Err(e)
133                                if e.kind() == ErrorKind::UnexpectedEof
134                                    || e.kind() == ErrorKind::ConnectionReset
135                                    || e.kind() == ErrorKind::ConnectionAborted =>
136                            {
137                                return;
138                            }
139                            Err(e) => {
140                                warn!("Closing connection to console due to {e:?}");
141                            }
142                            _ => {}
143                        };
144                        let _ = socket.shutdown().await;
145                    });
146                } else {
147                    break;
148                }
149            }
150        });
151        inst
152    }
153
154    /// Sends a downstream message to the operator console.
155    ///
156    /// # Arguments
157    /// - `msg`: A `DownstreamContent` message to send.
158    pub(crate) fn send_downstream(&self, msg: melvin_messages::DownstreamContent) {
159        let _ = self.downstream.send(Some(Arc::new(
160            melvin_messages::Downstream { content: Some(msg) }.encode_to_vec(),
161        )));
162    }
163
164    /// Checks whether any console is currently connected to the endpoint.
165    ///
166    /// # Returns
167    /// `true` if at least one console is connected; otherwise, `false`.
168    pub(crate) fn is_console_connected(&self) -> bool {
169        self.downstream.receiver_count() > 0
170    }
171
172    /// Subscribes to upstream events from the connected console.
173    ///
174    /// # Returns
175    /// A broadcast receiver to listen for upstream events.
176    pub(crate) fn subscribe_upstream_events(&self) -> broadcast::Receiver<ConsoleEvent> {
177        self.upstream_event.subscribe()
178    }
179}
180
181impl Drop for ConsoleEndpoint {
182    /// Handles graceful shutdown of the `ConsoleEndpoint`. Signals the close channel
183    /// and notifies all downstream subscribers of disconnection.
184    fn drop(&mut self) {
185        self.close_oneshot.take().unwrap().send(()).unwrap();
186        self.downstream.send(None).unwrap();
187    }
188}