High-performance implementation of plcbundle written in Rust
at main 254 lines 7.8 kB view raw
1// WebSocket handler for streaming operations 2 3use crate::server::ServerState; 4use axum::extract::ws::Message; 5use axum::{ 6 body::Bytes, 7 extract::{Query, State, ws::WebSocketUpgrade}, 8 response::Response, 9}; 10use futures_util::{SinkExt, StreamExt}; 11use serde::Deserialize; 12use std::sync::Arc; 13use tokio::time::{Duration, interval}; 14 15#[derive(Deserialize)] 16pub struct CursorQuery { 17 pub cursor: Option<u64>, 18} 19 20pub async fn handle_websocket( 21 State(state): State<ServerState>, 22 ws: WebSocketUpgrade, 23 Query(params): Query<CursorQuery>, 24) -> Response { 25 let start_cursor = params 26 .cursor 27 .unwrap_or_else(|| state.manager.get_current_cursor()); 28 ws.on_upgrade(move |socket| handle_websocket_connection(socket, state, start_cursor)) 29} 30 31async fn handle_websocket_connection( 32 socket: axum::extract::ws::WebSocket, 33 state: ServerState, 34 start_cursor: u64, 35) { 36 let (mut sender, mut receiver) = socket.split(); 37 38 // Spawn task to handle incoming messages (for close/pong) 39 let receiver_task = tokio::spawn(async move { 40 while let Some(msg) = receiver.next().await { 41 match msg { 42 Ok(axum::extract::ws::Message::Close(_)) 43 | Ok(axum::extract::ws::Message::Pong(_)) => { 44 // Normal close or pong, continue 45 } 46 Err(e) => { 47 eprintln!("WebSocket receive error: {}", e); 48 break; 49 } 50 _ => { 51 // Ignore other messages 52 } 53 } 54 } 55 }); 56 57 // Stream operations (this will handle pings internally) 58 let stream_result = stream_live_operations(state, start_cursor, &mut sender).await; 59 60 // Close receiver task 61 receiver_task.abort(); 62 63 if let Err(e) = stream_result { 64 eprintln!("WebSocket stream error: {}", e); 65 } 66} 67 68async fn stream_live_operations( 69 state: ServerState, 70 start_cursor: u64, 71 sender: &mut futures_util::stream::SplitSink< 72 axum::extract::ws::WebSocket, 73 axum::extract::ws::Message, 74 >, 75) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 76 let index = state.manager.get_index(); 77 let bundles = &index.bundles; 78 let mut current_record = start_cursor; 79 80 // Stream existing bundles 81 if !bundles.is_empty() { 82 let (start_bundle_num, start_position) = 83 crate::constants::global_to_bundle_position(start_cursor); 84 let start_bundle_idx = (start_bundle_num - 1) as usize; 85 86 if start_bundle_idx < bundles.len() { 87 for (i, bundle) in bundles.iter().enumerate().skip(start_bundle_idx) { 88 let bundle_num = bundle.bundle_number; 89 let skip_until = if i == start_bundle_idx { 90 start_position 91 } else { 92 0 93 }; 94 95 let streamed = 96 stream_bundle(&state.manager, bundle_num, skip_until, sender).await?; 97 current_record += streamed as u64; 98 } 99 } 100 } 101 102 // Stream mempool operations 103 let mut bundle_record_base = crate::constants::total_operations_from_bundles(bundles.len() as u32); 104 let mut last_seen_mempool_count = 0; 105 106 stream_mempool( 107 &state.manager, 108 start_cursor, 109 bundle_record_base, 110 &mut current_record, 111 &mut last_seen_mempool_count, 112 sender, 113 ) 114 .await?; 115 116 // Poll for new bundles and mempool updates 117 let mut ticker = interval(Duration::from_millis(500)); 118 let mut last_bundle_count = bundles.len(); 119 120 loop { 121 ticker.tick().await; 122 123 // Check for new bundles 124 let current_bundle_count = state.manager.bundle_count(); 125 126 if current_bundle_count > last_bundle_count { 127 let new_bundle_count = current_bundle_count - last_bundle_count; 128 current_record += 129 crate::constants::total_operations_from_bundles(new_bundle_count as u32); 130 last_bundle_count = current_bundle_count; 131 bundle_record_base = crate::constants::total_operations_from_bundles(last_bundle_count as u32); 132 last_seen_mempool_count = 0; 133 } 134 135 // Stream new mempool operations 136 stream_mempool( 137 &state.manager, 138 start_cursor, 139 bundle_record_base, 140 &mut current_record, 141 &mut last_seen_mempool_count, 142 sender, 143 ) 144 .await?; 145 146 // Send ping 147 if let Err(e) = sender.send(Message::Ping(Bytes::new())).await { 148 return Err(Box::new(std::io::Error::new( 149 std::io::ErrorKind::ConnectionAborted, 150 format!("WebSocket ping failed: {}", e), 151 ))); 152 } 153 } 154} 155 156async fn stream_bundle( 157 manager: &Arc<crate::manager::BundleManager>, 158 bundle_num: u32, 159 skip_until: usize, 160 sender: &mut futures_util::stream::SplitSink< 161 axum::extract::ws::WebSocket, 162 axum::extract::ws::Message, 163 >, 164) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> { 165 // Get decompressed bundle reader 166 let reader = match manager.stream_bundle_decompressed(bundle_num) { 167 Ok(r) => r, 168 Err(_) => return Ok(0), // Bundle not found, skip 169 }; 170 171 // Read bundle in blocking task 172 let lines = tokio::task::spawn_blocking(move || { 173 use std::io::{BufRead, BufReader}; 174 let mut buf_reader = BufReader::new(reader); 175 let mut lines = Vec::new(); 176 let mut line = String::new(); 177 let mut position = 0; 178 179 while buf_reader.read_line(&mut line).unwrap_or(0) > 0 { 180 if position >= skip_until && !line.trim().is_empty() { 181 lines.push(line.clone()); 182 } 183 line.clear(); 184 position += 1; 185 } 186 lines 187 }) 188 .await?; 189 190 // Send lines 191 let mut streamed = 0; 192 for line in lines.iter() { 193 if let Err(e) = sender.send(Message::Text(line.clone().into())).await { 194 return Err(Box::new(std::io::Error::new( 195 std::io::ErrorKind::BrokenPipe, 196 format!("WebSocket write error: {}", e), 197 ))); 198 } 199 streamed += 1; 200 201 // Send ping every 1000 operations 202 if streamed % 1000 == 0 && sender.send(Message::Ping(Bytes::new())).await.is_err() { 203 break; 204 } 205 } 206 207 Ok(streamed) 208} 209 210async fn stream_mempool( 211 manager: &Arc<crate::manager::BundleManager>, 212 start_cursor: u64, 213 bundle_record_base: u64, 214 current_record: &mut u64, 215 last_seen_count: &mut usize, 216 sender: &mut futures_util::stream::SplitSink< 217 axum::extract::ws::WebSocket, 218 axum::extract::ws::Message, 219 >, 220) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 221 let mempool_ops = match manager.get_mempool_operations() { 222 Ok(ops) => ops, 223 Err(_) => return Ok(()), // No mempool or error 224 }; 225 226 if mempool_ops.len() <= *last_seen_count { 227 return Ok(()); 228 } 229 230 for (i, op) in mempool_ops.iter().enumerate().skip(*last_seen_count) { 231 let record_num = bundle_record_base + i as u64; 232 if record_num < start_cursor { 233 continue; 234 } 235 236 // Send operation as JSON 237 let json = match sonic_rs::to_string(op) { 238 Ok(j) => j, 239 Err(_) => continue, // Skip invalid operations 240 }; 241 242 if let Err(e) = sender.send(Message::Text(json.into())).await { 243 return Err(Box::new(std::io::Error::new( 244 std::io::ErrorKind::BrokenPipe, 245 format!("WebSocket write error: {}", e), 246 ))); 247 } 248 249 *current_record += 1; 250 } 251 252 *last_seen_count = mempool_ops.len(); 253 Ok(()) 254}