zernel_ebpf/
websocket_server.rs

1// Copyright (C) 2026 Dyber, Inc. — GPL-2.0
2
3use crate::aggregation::AggregatedMetrics;
4use anyhow::Result;
5use futures_util::{SinkExt, StreamExt};
6use std::net::SocketAddr;
7use std::sync::Arc;
8use tokio::net::TcpListener;
9use tokio::sync::RwLock;
10use tokio_tungstenite::accept_async;
11use tokio_tungstenite::tungstenite::Message;
12use tracing::{debug, error, info, warn};
13
14/// WebSocket server for real-time telemetry streaming to `zernel watch`.
15pub struct WebSocketServer {
16    metrics: Arc<RwLock<AggregatedMetrics>>,
17    port: u16,
18    push_interval_ms: u64,
19}
20
21impl WebSocketServer {
22    pub fn new(metrics: Arc<RwLock<AggregatedMetrics>>, port: u16, push_interval_ms: u64) -> Self {
23        Self {
24            metrics,
25            port,
26            push_interval_ms,
27        }
28    }
29
30    pub async fn serve(&self) -> Result<()> {
31        let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
32        let listener = TcpListener::bind(addr).await?;
33        info!(port = self.port, "WebSocket server listening");
34
35        let metrics = Arc::clone(&self.metrics);
36        let interval_ms = self.push_interval_ms;
37
38        loop {
39            let (stream, peer) = listener.accept().await?;
40            let metrics = Arc::clone(&metrics);
41            info!(?peer, "WebSocket client connected");
42
43            tokio::spawn(async move {
44                match accept_async(stream).await {
45                    Ok(ws) => {
46                        if let Err(e) = handle_connection(ws, metrics, interval_ms).await {
47                            debug!(?peer, error = %e, "WebSocket connection ended");
48                        }
49                    }
50                    Err(e) => {
51                        warn!(?peer, error = %e, "WebSocket handshake failed");
52                    }
53                }
54            });
55        }
56    }
57}
58
59async fn handle_connection(
60    ws: tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
61    metrics: Arc<RwLock<AggregatedMetrics>>,
62    interval_ms: u64,
63) -> Result<()> {
64    let (mut write, mut read) = ws.split();
65    let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(interval_ms));
66
67    loop {
68        tokio::select! {
69            _ = interval.tick() => {
70                let snapshot = {
71                    let m = metrics.read().await;
72                    m.to_ws_snapshot()
73                };
74                let msg = Message::Text(snapshot.to_string().into());
75                if write.send(msg).await.is_err() {
76                    break;
77                }
78            }
79            msg = read.next() => {
80                match msg {
81                    Some(Ok(Message::Close(_))) | None => break,
82                    Some(Ok(Message::Ping(data))) => {
83                        let _ = write.send(Message::Pong(data)).await;
84                    }
85                    Some(Err(e)) => {
86                        error!("WebSocket read error: {e}");
87                        break;
88                    }
89                    _ => {}
90                }
91            }
92        }
93    }
94
95    Ok(())
96}