A better Rust ATProto crate

further work on websocket stream impl

Orual 96bc86e1 c6d2a31d

+95 -22
+95 -22
crates/jacquard-common/src/websocket.rs
··· 9 9 use std::ops::Deref; 10 10 use url::Url; 11 11 12 - use crate::CowStr; 13 12 use crate::stream::StreamError; 13 + use crate::{CowStr, IntoStatic}; 14 14 15 15 /// UTF-8 validated bytes for WebSocket text messages 16 16 #[repr(transparent)] ··· 308 308 pub fn into_inner(self) -> Boxed<Result<WsMessage, StreamError>> { 309 309 self.0 310 310 } 311 + 312 + /// Split this stream into two streams that both receive all messages 313 + /// 314 + /// Messages are cloned (cheaply via Bytes rc). Spawns a forwarder task. 315 + /// Both returned streams will receive all messages from the original stream. 316 + /// The forwarder continues as long as at least one stream is alive. 317 + /// If the underlying stream errors, both teed streams will end. 318 + pub fn tee(self) -> (WsStream, WsStream) { 319 + use futures::channel::mpsc; 320 + use n0_future::StreamExt as _; 321 + 322 + let (tx1, rx1) = mpsc::unbounded(); 323 + let (tx2, rx2) = mpsc::unbounded(); 324 + 325 + n0_future::task::spawn(async move { 326 + let mut stream = self.0; 327 + while let Some(result) = stream.next().await { 328 + match result { 329 + Ok(msg) => { 330 + // Clone message (cheap - Bytes is rc'd) 331 + let msg2 = msg.clone(); 332 + 333 + // Send to both channels, continue if at least one succeeds 334 + let send1 = tx1.unbounded_send(Ok(msg)); 335 + let send2 = tx2.unbounded_send(Ok(msg2)); 336 + 337 + // Only stop if both channels are closed 338 + if send1.is_err() && send2.is_err() { 339 + break; 340 + } 341 + } 342 + Err(_e) => { 343 + // Underlying stream errored, stop forwarding. 344 + // Both channels will close, ending both streams. 345 + break; 346 + } 347 + } 348 + } 349 + }); 350 + 351 + (WsStream::new(rx1), WsStream::new(rx2)) 352 + } 311 353 } 312 354 313 355 /// Extension trait for decoding typed messages from WebSocket streams 314 356 pub trait WsStreamExt: Sized { 315 357 /// Decode JSON text/binary frames into typed messages 316 358 /// 317 - /// Deserializes messages but does not automatically convert to owned. 318 - /// The caller is responsible for calling `.into_static()` if needed. 319 - fn decode_json<T>(self) -> impl Stream<Item = Result<T, StreamError>> 359 + /// Deserializes borrowing from temporary frame bytes, then converts to owned. 360 + fn decode_json<T>(self) -> impl Stream<Item = Result<T::Output, StreamError>> 320 361 where 321 - T: for<'de> serde::Deserialize<'de>; 362 + T: IntoStatic, 363 + for<'de> T: serde::Deserialize<'de>, 364 + T::Output: 'static; 322 365 323 366 /// Decode DAG-CBOR binary frames into typed messages 324 367 /// 325 - /// Deserializes messages but does not automatically convert to owned. 326 - /// The caller is responsible for calling `.into_static()` if needed. 327 - fn decode_cbor<T>(self) -> impl Stream<Item = Result<T, StreamError>> 368 + /// Deserializes borrowing from temporary frame bytes, then converts to owned. 369 + fn decode_cbor<T>(self) -> impl Stream<Item = Result<T::Output, StreamError>> 328 370 where 329 - T: for<'de> serde::Deserialize<'de>; 371 + T: IntoStatic, 372 + for<'de> T: serde::Deserialize<'de>, 373 + T::Output: 'static; 330 374 } 331 375 332 376 impl WsStreamExt for WsStream { 333 - fn decode_json<T>(self) -> impl Stream<Item = Result<T, StreamError>> 377 + fn decode_json<T>(self) -> impl Stream<Item = Result<T::Output, StreamError>> 334 378 where 335 - T: for<'de> serde::Deserialize<'de>, 379 + T: IntoStatic, 380 + for<'de> T: serde::Deserialize<'de>, 381 + T::Output: 'static, 336 382 { 337 383 use n0_future::StreamExt as _; 338 384 339 - Box::pin(self.into_inner().filter_map(|msg_result| match msg_result { 340 - Ok(WsMessage::Text(text)) => { 341 - Some(serde_json::from_slice(text.as_ref()).map_err(StreamError::decode)) 342 - } 343 - Ok(WsMessage::Binary(bytes)) => { 344 - Some(serde_json::from_slice(&bytes).map_err(StreamError::decode)) 385 + // Helper to deserialize with concrete lifetime 386 + fn parse_json<'a, T>(bytes: &'a [u8]) -> Result<T, serde_json::Error> 387 + where 388 + T: serde::Deserialize<'a>, 389 + { 390 + serde_json::from_slice(bytes) 391 + } 392 + 393 + Box::pin(self.into_inner().filter_map(|msg_result| { 394 + match msg_result { 395 + Ok(WsMessage::Text(text)) => Some( 396 + parse_json::<T>(text.as_ref()) 397 + .map(|v| v.into_static()) 398 + .map_err(StreamError::decode), 399 + ), 400 + Ok(WsMessage::Binary(bytes)) => Some( 401 + parse_json::<T>(&bytes) 402 + .map(|v| v.into_static()) 403 + .map_err(StreamError::decode), 404 + ), 405 + Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 406 + Err(e) => Some(Err(e)), 345 407 } 346 - Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 347 - Err(e) => Some(Err(e)), 348 408 })) 349 409 } 350 410 351 - fn decode_cbor<T>(self) -> impl Stream<Item = Result<T, StreamError>> 411 + fn decode_cbor<T>(self) -> impl Stream<Item = Result<T::Output, StreamError>> 352 412 where 353 - T: for<'de> serde::Deserialize<'de>, 413 + T: IntoStatic, 414 + for<'de> T: serde::Deserialize<'de>, 415 + T::Output: 'static, 354 416 { 355 417 use n0_future::StreamExt as _; 356 418 419 + // Helper to deserialize with concrete lifetime 420 + fn parse_cbor<'a, T>( 421 + bytes: &'a [u8], 422 + ) -> Result<T, serde_ipld_dagcbor::DecodeError<std::convert::Infallible>> 423 + where 424 + T: serde::Deserialize<'a>, 425 + { 426 + serde_ipld_dagcbor::from_slice(bytes) 427 + } 428 + 357 429 Box::pin(self.into_inner().filter_map(|msg_result| { 358 430 match msg_result { 359 431 Ok(WsMessage::Binary(bytes)) => Some( 360 - serde_ipld_dagcbor::from_slice(&bytes) 432 + parse_cbor::<T>(&bytes) 433 + .map(|v| v.into_static()) 361 434 .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))), 362 435 ), 363 436 Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(