forked from
atscan.net/plcbundle-rs
High-performance implementation of plcbundle written in Rust
1// WebSocket handler for streaming operations
2
3use crate::server::ServerState;
4use axum::extract::ws::Message;
5use axum::{
6 body::Bytes,
7 extract::{Query, State, ws::WebSocketUpgrade},
8 response::Response,
9};
10use futures_util::{SinkExt, StreamExt};
11use serde::Deserialize;
12use std::sync::Arc;
13use tokio::time::{Duration, interval};
14
15#[derive(Deserialize)]
16pub struct CursorQuery {
17 pub cursor: Option<u64>,
18}
19
20pub async fn handle_websocket(
21 State(state): State<ServerState>,
22 ws: WebSocketUpgrade,
23 Query(params): Query<CursorQuery>,
24) -> Response {
25 let start_cursor = params
26 .cursor
27 .unwrap_or_else(|| state.manager.get_current_cursor());
28 ws.on_upgrade(move |socket| handle_websocket_connection(socket, state, start_cursor))
29}
30
31async fn handle_websocket_connection(
32 socket: axum::extract::ws::WebSocket,
33 state: ServerState,
34 start_cursor: u64,
35) {
36 let (mut sender, mut receiver) = socket.split();
37
38 // Spawn task to handle incoming messages (for close/pong)
39 let receiver_task = tokio::spawn(async move {
40 while let Some(msg) = receiver.next().await {
41 match msg {
42 Ok(axum::extract::ws::Message::Close(_))
43 | Ok(axum::extract::ws::Message::Pong(_)) => {
44 // Normal close or pong, continue
45 }
46 Err(e) => {
47 eprintln!("WebSocket receive error: {}", e);
48 break;
49 }
50 _ => {
51 // Ignore other messages
52 }
53 }
54 }
55 });
56
57 // Stream operations (this will handle pings internally)
58 let stream_result = stream_live_operations(state, start_cursor, &mut sender).await;
59
60 // Close receiver task
61 receiver_task.abort();
62
63 if let Err(e) = stream_result {
64 eprintln!("WebSocket stream error: {}", e);
65 }
66}
67
68async fn stream_live_operations(
69 state: ServerState,
70 start_cursor: u64,
71 sender: &mut futures_util::stream::SplitSink<
72 axum::extract::ws::WebSocket,
73 axum::extract::ws::Message,
74 >,
75) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
76 let index = state.manager.get_index();
77 let bundles = &index.bundles;
78 let mut current_record = start_cursor;
79
80 // Stream existing bundles
81 if !bundles.is_empty() {
82 let (start_bundle_num, start_position) =
83 crate::constants::global_to_bundle_position(start_cursor);
84 let start_bundle_idx = (start_bundle_num - 1) as usize;
85
86 if start_bundle_idx < bundles.len() {
87 for (i, bundle) in bundles.iter().enumerate().skip(start_bundle_idx) {
88 let bundle_num = bundle.bundle_number;
89 let skip_until = if i == start_bundle_idx {
90 start_position
91 } else {
92 0
93 };
94
95 let streamed =
96 stream_bundle(&state.manager, bundle_num, skip_until, sender).await?;
97 current_record += streamed as u64;
98 }
99 }
100 }
101
102 // Stream mempool operations
103 let mut bundle_record_base = crate::constants::total_operations_from_bundles(bundles.len() as u32);
104 let mut last_seen_mempool_count = 0;
105
106 stream_mempool(
107 &state.manager,
108 start_cursor,
109 bundle_record_base,
110 &mut current_record,
111 &mut last_seen_mempool_count,
112 sender,
113 )
114 .await?;
115
116 // Poll for new bundles and mempool updates
117 let mut ticker = interval(Duration::from_millis(500));
118 let mut last_bundle_count = bundles.len();
119
120 loop {
121 ticker.tick().await;
122
123 // Check for new bundles
124 let current_bundle_count = state.manager.bundle_count();
125
126 if current_bundle_count > last_bundle_count {
127 let new_bundle_count = current_bundle_count - last_bundle_count;
128 current_record +=
129 crate::constants::total_operations_from_bundles(new_bundle_count as u32);
130 last_bundle_count = current_bundle_count;
131 bundle_record_base = crate::constants::total_operations_from_bundles(last_bundle_count as u32);
132 last_seen_mempool_count = 0;
133 }
134
135 // Stream new mempool operations
136 stream_mempool(
137 &state.manager,
138 start_cursor,
139 bundle_record_base,
140 &mut current_record,
141 &mut last_seen_mempool_count,
142 sender,
143 )
144 .await?;
145
146 // Send ping
147 if let Err(e) = sender.send(Message::Ping(Bytes::new())).await {
148 return Err(Box::new(std::io::Error::new(
149 std::io::ErrorKind::ConnectionAborted,
150 format!("WebSocket ping failed: {}", e),
151 )));
152 }
153 }
154}
155
156async fn stream_bundle(
157 manager: &Arc<crate::manager::BundleManager>,
158 bundle_num: u32,
159 skip_until: usize,
160 sender: &mut futures_util::stream::SplitSink<
161 axum::extract::ws::WebSocket,
162 axum::extract::ws::Message,
163 >,
164) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
165 // Get decompressed bundle reader
166 let reader = match manager.stream_bundle_decompressed(bundle_num) {
167 Ok(r) => r,
168 Err(_) => return Ok(0), // Bundle not found, skip
169 };
170
171 // Read bundle in blocking task
172 let lines = tokio::task::spawn_blocking(move || {
173 use std::io::{BufRead, BufReader};
174 let mut buf_reader = BufReader::new(reader);
175 let mut lines = Vec::new();
176 let mut line = String::new();
177 let mut position = 0;
178
179 while buf_reader.read_line(&mut line).unwrap_or(0) > 0 {
180 if position >= skip_until && !line.trim().is_empty() {
181 lines.push(line.clone());
182 }
183 line.clear();
184 position += 1;
185 }
186 lines
187 })
188 .await?;
189
190 // Send lines
191 let mut streamed = 0;
192 for line in lines.iter() {
193 if let Err(e) = sender.send(Message::Text(line.clone().into())).await {
194 return Err(Box::new(std::io::Error::new(
195 std::io::ErrorKind::BrokenPipe,
196 format!("WebSocket write error: {}", e),
197 )));
198 }
199 streamed += 1;
200
201 // Send ping every 1000 operations
202 if streamed % 1000 == 0 && sender.send(Message::Ping(Bytes::new())).await.is_err() {
203 break;
204 }
205 }
206
207 Ok(streamed)
208}
209
210async fn stream_mempool(
211 manager: &Arc<crate::manager::BundleManager>,
212 start_cursor: u64,
213 bundle_record_base: u64,
214 current_record: &mut u64,
215 last_seen_count: &mut usize,
216 sender: &mut futures_util::stream::SplitSink<
217 axum::extract::ws::WebSocket,
218 axum::extract::ws::Message,
219 >,
220) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
221 let mempool_ops = match manager.get_mempool_operations() {
222 Ok(ops) => ops,
223 Err(_) => return Ok(()), // No mempool or error
224 };
225
226 if mempool_ops.len() <= *last_seen_count {
227 return Ok(());
228 }
229
230 for (i, op) in mempool_ops.iter().enumerate().skip(*last_seen_count) {
231 let record_num = bundle_record_base + i as u64;
232 if record_num < start_cursor {
233 continue;
234 }
235
236 // Send operation as JSON
237 let json = match sonic_rs::to_string(op) {
238 Ok(j) => j,
239 Err(_) => continue, // Skip invalid operations
240 };
241
242 if let Err(e) = sender.send(Message::Text(json.into())).await {
243 return Err(Box::new(std::io::Error::new(
244 std::io::ErrorKind::BrokenPipe,
245 format!("WebSocket write error: {}", e),
246 )));
247 }
248
249 *current_record += 1;
250 }
251
252 *last_seen_count = mempool_ops.len();
253 Ok(())
254}