A better Rust ATProto crate

fixed decode trait bound issues

Orual 0f033cfb 3d82b460

+187 -107
-91
crates/jacquard-common/src/websocket.rs
··· 352 352 } 353 353 } 354 354 355 - /// Extension trait for decoding typed messages from WebSocket streams 356 - pub trait WsStreamExt: Sized { 357 - /// Decode JSON text/binary frames into typed messages 358 - /// 359 - /// Deserializes borrowing from temporary frame bytes, then converts to owned. 360 - fn decode_json<T>(self) -> impl Stream<Item = Result<T::Output, StreamError>> 361 - where 362 - T: IntoStatic, 363 - for<'de> T: serde::Deserialize<'de>, 364 - T::Output: 'static; 365 - 366 - /// Decode DAG-CBOR binary frames into typed messages 367 - /// 368 - /// Deserializes borrowing from temporary frame bytes, then converts to owned. 369 - fn decode_cbor<T>(self) -> impl Stream<Item = Result<T::Output, StreamError>> 370 - where 371 - T: IntoStatic, 372 - for<'de> T: serde::Deserialize<'de>, 373 - T::Output: 'static; 374 - } 375 - 376 - impl WsStreamExt for WsStream { 377 - fn decode_json<T>(self) -> impl Stream<Item = Result<T::Output, StreamError>> 378 - where 379 - T: IntoStatic, 380 - for<'de> T: serde::Deserialize<'de>, 381 - T::Output: 'static, 382 - { 383 - use n0_future::StreamExt as _; 384 - 385 - // Helper to deserialize with concrete lifetime 386 - fn parse_json<'a, T>(bytes: &'a [u8]) -> Result<T, serde_json::Error> 387 - where 388 - T: serde::Deserialize<'a>, 389 - { 390 - serde_json::from_slice(bytes) 391 - } 392 - 393 - Box::pin(self.into_inner().filter_map(|msg_result| { 394 - match msg_result { 395 - Ok(WsMessage::Text(text)) => Some( 396 - parse_json::<T>(text.as_ref()) 397 - .map(|v| v.into_static()) 398 - .map_err(StreamError::decode), 399 - ), 400 - Ok(WsMessage::Binary(bytes)) => Some( 401 - parse_json::<T>(&bytes) 402 - .map(|v| v.into_static()) 403 - .map_err(StreamError::decode), 404 - ), 405 - Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 406 - Err(e) => Some(Err(e)), 407 - } 408 - })) 409 - } 410 - 411 - fn decode_cbor<T>(self) -> impl Stream<Item = Result<T::Output, StreamError>> 412 - where 413 - T: IntoStatic, 414 - for<'de> T: serde::Deserialize<'de>, 415 - T::Output: 'static, 416 - { 417 - use n0_future::StreamExt as _; 418 - 419 - // Helper to deserialize with concrete lifetime 420 - fn parse_cbor<'a, T>( 421 - bytes: &'a [u8], 422 - ) -> Result<T, serde_ipld_dagcbor::DecodeError<std::convert::Infallible>> 423 - where 424 - T: serde::Deserialize<'a>, 425 - { 426 - serde_ipld_dagcbor::from_slice(bytes) 427 - } 428 - 429 - Box::pin(self.into_inner().filter_map(|msg_result| { 430 - match msg_result { 431 - Ok(WsMessage::Binary(bytes)) => Some( 432 - parse_cbor::<T>(&bytes) 433 - .map(|v| v.into_static()) 434 - .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))), 435 - ), 436 - Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format( 437 - "expected binary frame for CBOR, got text", 438 - ))), 439 - Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 440 - Err(e) => Some(Err(e)), 441 - } 442 - })) 443 - } 444 - } 445 - 446 355 impl fmt::Debug for WsStream { 447 356 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 448 357 f.debug_struct("WsStream").finish_non_exhaustive()
+1 -1
crates/jacquard-common/src/xrpc.rs
··· 23 23 pub use subscription::{ 24 24 BasicSubscriptionClient, MessageEncoding, SubscriptionCall, SubscriptionClient, 25 25 SubscriptionEndpoint, SubscriptionExt, SubscriptionOptions, SubscriptionResp, 26 - TungsteniteSubscriptionClient, XrpcSubscription, 26 + SubscriptionStream, TungsteniteSubscriptionClient, XrpcSubscription, 27 27 }; 28 28 29 29 use bytes::Bytes;
+182 -11
crates/jacquard-common/src/xrpc/subscription.rs
··· 6 6 use serde::{Deserialize, Serialize}; 7 7 use std::error::Error; 8 8 use std::future::Future; 9 + use std::marker::PhantomData; 9 10 use url::Url; 10 11 12 + use crate::stream::StreamError; 11 13 use crate::websocket::{WebSocketClient, WebSocketConnection}; 12 14 use crate::{CowStr, IntoStatic}; 13 15 ··· 76 78 } 77 79 } 78 80 81 + /// Decode JSON messages from a WebSocket stream 82 + fn decode_json_msg<S: SubscriptionResp>( 83 + msg_result: Result<crate::websocket::WsMessage, StreamError>, 84 + ) -> Option<Result<StreamMessage<'static, S>, StreamError>> 85 + where 86 + for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>, 87 + { 88 + use crate::websocket::WsMessage; 89 + 90 + fn parse_msg<'a, S: SubscriptionResp>( 91 + bytes: &'a [u8], 92 + ) -> Result<S::Message<'a>, serde_json::Error> { 93 + serde_json::from_slice(bytes) 94 + } 95 + 96 + match msg_result { 97 + Ok(WsMessage::Text(text)) => Some( 98 + parse_msg::<S>(text.as_ref()) 99 + .map(|v| v.into_static()) 100 + .map_err(StreamError::decode), 101 + ), 102 + Ok(WsMessage::Binary(bytes)) => Some( 103 + parse_msg::<S>(&bytes) 104 + .map(|v| v.into_static()) 105 + .map_err(StreamError::decode), 106 + ), 107 + Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 108 + Err(e) => Some(Err(e)), 109 + } 110 + } 111 + 112 + /// Decode CBOR messages from a WebSocket stream 113 + fn decode_cbor_msg<S: SubscriptionResp>( 114 + msg_result: Result<crate::websocket::WsMessage, StreamError>, 115 + ) -> Option<Result<StreamMessage<'static, S>, StreamError>> 116 + where 117 + for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>, 118 + { 119 + use crate::websocket::WsMessage; 120 + 121 + fn parse_cbor<'a, S: SubscriptionResp>( 122 + bytes: &'a [u8], 123 + ) -> Result<S::Message<'a>, serde_ipld_dagcbor::DecodeError<std::convert::Infallible>> { 124 + serde_ipld_dagcbor::from_slice(bytes) 125 + } 126 + 127 + match msg_result { 128 + Ok(WsMessage::Binary(bytes)) => Some( 129 + parse_cbor::<S>(&bytes) 130 + .map(|v| v.into_static()) 131 + .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))), 132 + ), 133 + Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format( 134 + "expected binary frame for CBOR, got text", 135 + ))), 136 + Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 137 + Err(e) => Some(Err(e)), 138 + } 139 + } 140 + 141 + /// Typed subscription stream wrapping a WebSocket connection. 142 + /// 143 + /// Analogous to `Response<R>` for XRPC but for subscription streams. 144 + /// Automatically decodes messages based on the subscription's encoding format. 145 + pub struct SubscriptionStream<S: SubscriptionResp> { 146 + _marker: PhantomData<fn() -> S>, 147 + connection: WebSocketConnection, 148 + } 149 + 150 + impl<S: SubscriptionResp> SubscriptionStream<S> { 151 + /// Create a new subscription stream from a WebSocket connection. 152 + pub fn new(connection: WebSocketConnection) -> Self { 153 + Self { 154 + _marker: PhantomData, 155 + connection, 156 + } 157 + } 158 + 159 + /// Get a reference to the underlying WebSocket connection. 160 + pub fn connection(&self) -> &WebSocketConnection { 161 + &self.connection 162 + } 163 + 164 + /// Get a mutable reference to the underlying WebSocket connection. 165 + pub fn connection_mut(&mut self) -> &mut WebSocketConnection { 166 + &mut self.connection 167 + } 168 + 169 + /// Split the connection and decode messages into a typed stream. 170 + /// 171 + /// Returns a tuple of (sender, typed message stream). 172 + /// Messages are decoded according to the subscription's ENCODING. 173 + pub fn into_stream( 174 + self, 175 + ) -> ( 176 + crate::websocket::WsSink, 177 + n0_future::stream::Boxed<Result<StreamMessage<'static, S>, StreamError>>, 178 + ) 179 + where 180 + for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>, 181 + { 182 + use n0_future::StreamExt as _; 183 + 184 + let (tx, rx) = self.connection.split(); 185 + 186 + let stream: n0_future::stream::Boxed<_> = match S::ENCODING { 187 + MessageEncoding::Json => { 188 + Box::pin(rx.into_inner().filter_map(|msg| decode_json_msg::<S>(msg))) 189 + } 190 + MessageEncoding::DagCbor => { 191 + Box::pin(rx.into_inner().filter_map(|msg| decode_cbor_msg::<S>(msg))) 192 + } 193 + }; 194 + 195 + (tx, stream) 196 + } 197 + 198 + /// Consume the stream and return the underlying connection. 199 + pub fn into_connection(self) -> WebSocketConnection { 200 + self.connection 201 + } 202 + 203 + /// Tee the stream, keeping the raw stream in self and returning a typed stream. 204 + /// 205 + /// Replaces the internal WebSocket stream with one copy and returns a typed decoded 206 + /// stream. Both streams receive all messages. Useful for observing raw messages 207 + /// while also processing typed messages. 208 + pub fn tee( 209 + &mut self, 210 + ) -> n0_future::stream::Boxed<Result<StreamMessage<'static, S>, StreamError>> 211 + where 212 + for<'a> StreamMessage<'a, S>: IntoStatic<Output = StreamMessage<'static, S>>, 213 + { 214 + use n0_future::StreamExt as _; 215 + 216 + let rx = self.connection.receiver_mut(); 217 + let (raw_rx, typed_rx_source) = std::mem::replace( 218 + rx, 219 + crate::websocket::WsStream::new(futures::stream::empty()), 220 + ) 221 + .tee(); 222 + 223 + // Put the raw stream back 224 + *rx = raw_rx; 225 + 226 + match S::ENCODING { 227 + MessageEncoding::Json => Box::pin( 228 + typed_rx_source 229 + .into_inner() 230 + .filter_map(|msg| decode_json_msg::<S>(msg)), 231 + ), 232 + MessageEncoding::DagCbor => Box::pin( 233 + typed_rx_source 234 + .into_inner() 235 + .filter_map(|msg| decode_cbor_msg::<S>(msg)), 236 + ), 237 + } 238 + } 239 + } 240 + 241 + type StreamMessage<'a, R> = <R as SubscriptionResp>::Message<'a>; 242 + 79 243 /// XRPC subscription endpoint trait (server-side) 80 244 /// 81 245 /// Analogous to `XrpcEndpoint` but for WebSocket subscriptions. ··· 163 327 /// 164 328 /// Builds a WebSocket URL from the base, appends the NSID path, 165 329 /// encodes query parameters from the subscription type, and connects. 166 - pub async fn subscribe<Sub>(self, params: &Sub) -> Result<WebSocketConnection, C::Error> 330 + /// Returns a typed SubscriptionStream that automatically decodes messages. 331 + pub async fn subscribe<Sub>( 332 + self, 333 + params: &Sub, 334 + ) -> Result<SubscriptionStream<Sub::Stream>, C::Error> 167 335 where 168 336 Sub: XrpcSubscription, 169 337 { ··· 185 353 url.set_query(None); 186 354 } 187 355 188 - self.client 356 + let connection = self 357 + .client 189 358 .connect_with_headers(url, self.opts.headers) 190 - .await 359 + .await?; 360 + 361 + Ok(SubscriptionStream::new(connection)) 191 362 } 192 363 } 193 364 ··· 210 381 fn subscribe<Sub>( 211 382 &self, 212 383 params: &Sub, 213 - ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> 384 + ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>> 214 385 where 215 386 Sub: XrpcSubscription + Send + Sync, 216 387 Self: Sync; ··· 220 391 fn subscribe<Sub>( 221 392 &self, 222 393 params: &Sub, 223 - ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> 394 + ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>> 224 395 where 225 396 Sub: XrpcSubscription + Send + Sync; 226 397 ··· 230 401 &self, 231 402 params: &Sub, 232 403 opts: SubscriptionOptions<'_>, 233 - ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> 404 + ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>> 234 405 where 235 406 Sub: XrpcSubscription + Send + Sync, 236 407 Self: Sync; ··· 241 412 &self, 242 413 params: &Sub, 243 414 opts: SubscriptionOptions<'_>, 244 - ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> 415 + ) -> impl Future<Output = Result<SubscriptionStream<Sub::Stream>, Self::Error>> 245 416 where 246 417 Sub: XrpcSubscription + Send + Sync; 247 418 } ··· 308 479 async fn subscribe<Sub>( 309 480 &self, 310 481 params: &Sub, 311 - ) -> Result<WebSocketConnection, Self::Error> 482 + ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error> 312 483 where 313 484 Sub: XrpcSubscription + Send + Sync, 314 485 Self: Sync, ··· 321 492 async fn subscribe<Sub>( 322 493 &self, 323 494 params: &Sub, 324 - ) -> Result<WebSocketConnection, Self::Error> 495 + ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error> 325 496 where 326 497 Sub: XrpcSubscription + Send + Sync, 327 498 { ··· 334 505 &self, 335 506 params: &Sub, 336 507 opts: SubscriptionOptions<'_>, 337 - ) -> Result<WebSocketConnection, Self::Error> 508 + ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error> 338 509 where 339 510 Sub: XrpcSubscription + Send + Sync, 340 511 Self: Sync, ··· 351 522 &self, 352 523 params: &Sub, 353 524 opts: SubscriptionOptions<'_>, 354 - ) -> Result<WebSocketConnection, Self::Error> 525 + ) -> Result<SubscriptionStream<Sub::Stream>, Self::Error> 355 526 where 356 527 Sub: XrpcSubscription + Send + Sync, 357 528 {
+2 -2
crates/jacquard-oauth/src/client.rs
··· 618 618 async fn subscribe<Sub>( 619 619 &self, 620 620 params: &Sub, 621 - ) -> std::result::Result<WebSocketConnection, Self::Error> 621 + ) -> std::result::Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error> 622 622 where 623 623 Sub: XrpcSubscription + Send + Sync, 624 624 { ··· 630 630 &self, 631 631 params: &Sub, 632 632 opts: jacquard_common::xrpc::SubscriptionOptions<'_>, 633 - ) -> std::result::Result<WebSocketConnection, Self::Error> 633 + ) -> std::result::Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error> 634 634 where 635 635 Sub: XrpcSubscription + Send + Sync, 636 636 {
+2 -2
crates/jacquard/src/client/credential_session.rs
··· 607 607 async fn subscribe<Sub>( 608 608 &self, 609 609 params: &Sub, 610 - ) -> Result<WebSocketConnection, Self::Error> 610 + ) -> Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error> 611 611 where 612 612 Sub: XrpcSubscription + Send + Sync, 613 613 { ··· 619 619 &self, 620 620 params: &Sub, 621 621 opts: jacquard_common::xrpc::SubscriptionOptions<'_>, 622 - ) -> Result<WebSocketConnection, Self::Error> 622 + ) -> Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error> 623 623 where 624 624 Sub: XrpcSubscription + Send + Sync, 625 625 {