A better Rust ATProto crate

websocket impl in progress

Orual c6d2a31d f7db2a76

+1159 -89
+66
Cargo.lock
··· 2005 2005 "smol_str", 2006 2006 "thiserror 2.0.17", 2007 2007 "tokio", 2008 + "tokio-tungstenite-wasm", 2008 2009 "tracing", 2009 2010 "trait-variant", 2010 2011 "url", ··· 3577 3578 ] 3578 3579 3579 3580 [[package]] 3581 + name = "sha1" 3582 + version = "0.10.6" 3583 + source = "registry+https://github.com/rust-lang/crates.io-index" 3584 + checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" 3585 + dependencies = [ 3586 + "cfg-if", 3587 + "cpufeatures", 3588 + "digest", 3589 + ] 3590 + 3591 + [[package]] 3580 3592 name = "sha1_smol" 3581 3593 version = "1.0.1" 3582 3594 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 4021 4033 ] 4022 4034 4023 4035 [[package]] 4036 + name = "tokio-tungstenite" 4037 + version = "0.24.0" 4038 + source = "registry+https://github.com/rust-lang/crates.io-index" 4039 + checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" 4040 + dependencies = [ 4041 + "futures-util", 4042 + "log", 4043 + "tokio", 4044 + "tungstenite", 4045 + ] 4046 + 4047 + [[package]] 4048 + name = "tokio-tungstenite-wasm" 4049 + version = "0.4.0" 4050 + source = "registry+https://github.com/rust-lang/crates.io-index" 4051 + checksum = "e21a5c399399c3db9f08d8297ac12b500e86bca82e930253fdc62eaf9c0de6ae" 4052 + dependencies = [ 4053 + "futures-channel", 4054 + "futures-util", 4055 + "http", 4056 + "httparse", 4057 + "js-sys", 4058 + "thiserror 1.0.69", 4059 + "tokio", 4060 + "tokio-tungstenite", 4061 + "wasm-bindgen", 4062 + "web-sys", 4063 + ] 4064 + 4065 + [[package]] 4024 4066 name = "tokio-util" 4025 4067 version = "0.7.16" 4026 4068 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 4162 4204 checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" 4163 4205 4164 4206 [[package]] 4207 + name = "tungstenite" 4208 + version = "0.24.0" 4209 + source = "registry+https://github.com/rust-lang/crates.io-index" 4210 + checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" 4211 + dependencies = [ 4212 + "byteorder", 4213 + "bytes", 4214 + "data-encoding", 4215 + "http", 4216 + "httparse", 4217 + "log", 4218 + "rand 0.8.5", 4219 + "sha1", 4220 + "thiserror 1.0.69", 4221 + "utf-8", 4222 + ] 4223 + 4224 + [[package]] 4165 4225 name = "twoway" 4166 4226 version = "0.1.8" 4167 4227 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 4271 4331 version = "2.1.3" 4272 4332 source = "registry+https://github.com/rust-lang/crates.io-index" 4273 4333 checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" 4334 + 4335 + [[package]] 4336 + name = "utf-8" 4337 + version = "0.7.6" 4338 + source = "registry+https://github.com/rust-lang/crates.io-index" 4339 + checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" 4274 4340 4275 4341 [[package]] 4276 4342 name = "utf8_iter"
+19 -7
crates/jacquard-api/src/app_bsky/feed/post.rs
··· 15 15 PartialEq, 16 16 Eq, 17 17 jacquard_derive::IntoStatic, 18 - bon::Builder, 18 + bon::Builder 19 19 )] 20 20 #[serde(rename_all = "camelCase")] 21 21 pub struct Entity<'a> { ··· 40 40 PartialEq, 41 41 Eq, 42 42 jacquard_derive::IntoStatic, 43 - bon::Builder, 43 + bon::Builder 44 44 )] 45 45 #[serde(rename_all = "camelCase")] 46 46 pub struct Post<'a> { ··· 99 99 100 100 #[jacquard_derive::open_union] 101 101 #[derive( 102 - serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq, jacquard_derive::IntoStatic, 102 + serde::Serialize, 103 + serde::Deserialize, 104 + Debug, 105 + Clone, 106 + PartialEq, 107 + Eq, 108 + jacquard_derive::IntoStatic 103 109 )] 104 110 #[serde(tag = "$type")] 105 111 #[serde(bound(deserialize = "'de: 'a"))] ··· 118 124 119 125 /// Typed wrapper for GetRecord response with this collection's record type. 120 126 #[derive( 121 - serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq, jacquard_derive::IntoStatic, 127 + serde::Serialize, 128 + serde::Deserialize, 129 + Debug, 130 + Clone, 131 + PartialEq, 132 + Eq, 133 + jacquard_derive::IntoStatic 122 134 )] 123 135 #[serde(rename_all = "camelCase")] 124 136 pub struct PostGetRecordOutput<'a> { ··· 167 179 PartialEq, 168 180 Eq, 169 181 jacquard_derive::IntoStatic, 170 - bon::Builder, 182 + bon::Builder 171 183 )] 172 184 #[serde(rename_all = "camelCase")] 173 185 pub struct ReplyRef<'a> { ··· 187 199 PartialEq, 188 200 Eq, 189 201 jacquard_derive::IntoStatic, 190 - bon::Builder, 202 + bon::Builder 191 203 )] 192 204 #[serde(rename_all = "camelCase")] 193 205 pub struct TextSlice<'a> { 194 206 pub end: i64, 195 207 pub start: i64, 196 - } 208 + }
+26 -2
crates/jacquard-api/src/com_atproto/label/subscribe_labels.rs
··· 74 74 #[serde(bound(deserialize = "'de: 'a"))] 75 75 pub enum SubscribeLabelsMessage<'a> { 76 76 #[serde(rename = "#labels")] 77 - Labels(Box<jacquard_common::types::value::Data<'a>>), 77 + Labels(Box<crate::com_atproto::label::subscribe_labels::Labels<'a>>), 78 78 #[serde(rename = "#info")] 79 - Info(Box<jacquard_common::types::value::Data<'a>>), 79 + Info(Box<crate::com_atproto::label::subscribe_labels::Info<'a>>), 80 80 } 81 81 82 82 #[jacquard_derive::open_union] ··· 111 111 Self::Unknown(err) => write!(f, "Unknown error: {:?}", err), 112 112 } 113 113 } 114 + } 115 + 116 + ///Stream response type for 117 + ///com.atproto.label.subscribeLabels 118 + pub struct SubscribeLabelsStream; 119 + impl jacquard_common::xrpc::SubscriptionResp for SubscribeLabelsStream { 120 + const NSID: &'static str = "com.atproto.label.subscribeLabels"; 121 + const ENCODING: jacquard_common::xrpc::MessageEncoding = jacquard_common::xrpc::MessageEncoding::DagCbor; 122 + type Message<'de> = SubscribeLabelsMessage<'de>; 123 + type Error<'de> = SubscribeLabelsError<'de>; 124 + } 125 + 126 + impl jacquard_common::xrpc::XrpcSubscription for SubscribeLabels { 127 + const NSID: &'static str = "com.atproto.label.subscribeLabels"; 128 + const ENCODING: jacquard_common::xrpc::MessageEncoding = jacquard_common::xrpc::MessageEncoding::DagCbor; 129 + type Stream = SubscribeLabelsStream; 130 + } 131 + 132 + pub struct SubscribeLabelsEndpoint; 133 + impl jacquard_common::xrpc::SubscriptionEndpoint for SubscribeLabelsEndpoint { 134 + const PATH: &'static str = "/xrpc/com.atproto.label.subscribeLabels"; 135 + const ENCODING: jacquard_common::xrpc::MessageEncoding = jacquard_common::xrpc::MessageEncoding::DagCbor; 136 + type Params<'de> = SubscribeLabels; 137 + type Stream = SubscribeLabelsStream; 114 138 }
+42 -21
crates/jacquard-api/src/com_atproto/sync/subscribe_repos.rs
··· 15 15 PartialEq, 16 16 Eq, 17 17 jacquard_derive::IntoStatic, 18 - bon::Builder 18 + bon::Builder, 19 19 )] 20 20 #[serde(rename_all = "camelCase")] 21 21 pub struct Account<'a> { ··· 42 42 PartialEq, 43 43 Eq, 44 44 jacquard_derive::IntoStatic, 45 - bon::Builder 45 + bon::Builder, 46 46 )] 47 47 #[serde(rename_all = "camelCase")] 48 48 pub struct Commit<'a> { ··· 87 87 PartialEq, 88 88 Eq, 89 89 jacquard_derive::IntoStatic, 90 - bon::Builder 90 + bon::Builder, 91 91 )] 92 92 #[serde(rename_all = "camelCase")] 93 93 pub struct Identity<'a> { ··· 111 111 PartialEq, 112 112 Eq, 113 113 jacquard_derive::IntoStatic, 114 - Default 114 + Default, 115 115 )] 116 116 #[serde(rename_all = "camelCase")] 117 117 pub struct Info<'a> { ··· 130 130 PartialEq, 131 131 Eq, 132 132 bon::Builder, 133 - jacquard_derive::IntoStatic 133 + jacquard_derive::IntoStatic, 134 134 )] 135 135 #[builder(start_fn = new)] 136 136 #[serde(rename_all = "camelCase")] ··· 141 141 142 142 #[jacquard_derive::open_union] 143 143 #[derive( 144 - serde::Serialize, 145 - serde::Deserialize, 146 - Debug, 147 - Clone, 148 - PartialEq, 149 - Eq, 150 - jacquard_derive::IntoStatic 144 + serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq, jacquard_derive::IntoStatic, 151 145 )] 152 146 #[serde(tag = "$type")] 153 147 #[serde(bound(deserialize = "'de: 'a"))] 154 148 pub enum SubscribeReposMessage<'a> { 155 149 #[serde(rename = "#commit")] 156 - Commit(Box<jacquard_common::types::value::Data<'a>>), 150 + Commit(Box<crate::com_atproto::sync::subscribe_repos::Commit<'a>>), 157 151 #[serde(rename = "#sync")] 158 - Sync(Box<jacquard_common::types::value::Data<'a>>), 152 + Sync(Box<crate::com_atproto::sync::subscribe_repos::Sync<'a>>), 159 153 #[serde(rename = "#identity")] 160 - Identity(Box<jacquard_common::types::value::Data<'a>>), 154 + Identity(Box<crate::com_atproto::sync::subscribe_repos::Identity<'a>>), 161 155 #[serde(rename = "#account")] 162 - Account(Box<jacquard_common::types::value::Data<'a>>), 156 + Account(Box<crate::com_atproto::sync::subscribe_repos::Account<'a>>), 163 157 #[serde(rename = "#info")] 164 - Info(Box<jacquard_common::types::value::Data<'a>>), 158 + Info(Box<crate::com_atproto::sync::subscribe_repos::Info<'a>>), 165 159 } 166 160 167 161 #[jacquard_derive::open_union] ··· 174 168 Eq, 175 169 thiserror::Error, 176 170 miette::Diagnostic, 177 - jacquard_derive::IntoStatic 171 + jacquard_derive::IntoStatic, 178 172 )] 179 173 #[serde(tag = "error", content = "message")] 180 174 #[serde(bound(deserialize = "'de: 'a"))] ··· 208 202 } 209 203 } 210 204 205 + ///Stream response type for 206 + ///com.atproto.sync.subscribeRepos 207 + pub struct SubscribeReposStream; 208 + impl jacquard_common::xrpc::SubscriptionResp for SubscribeReposStream { 209 + const NSID: &'static str = "com.atproto.sync.subscribeRepos"; 210 + const ENCODING: jacquard_common::xrpc::MessageEncoding = 211 + jacquard_common::xrpc::MessageEncoding::DagCbor; 212 + type Message<'de> = SubscribeReposMessage<'de>; 213 + type Error<'de> = SubscribeReposError<'de>; 214 + } 215 + 216 + impl jacquard_common::xrpc::XrpcSubscription for SubscribeRepos { 217 + const NSID: &'static str = "com.atproto.sync.subscribeRepos"; 218 + const ENCODING: jacquard_common::xrpc::MessageEncoding = 219 + jacquard_common::xrpc::MessageEncoding::DagCbor; 220 + type Stream = SubscribeReposStream; 221 + } 222 + 223 + pub struct SubscribeReposEndpoint; 224 + impl jacquard_common::xrpc::SubscriptionEndpoint for SubscribeReposEndpoint { 225 + const PATH: &'static str = "/xrpc/com.atproto.sync.subscribeRepos"; 226 + const ENCODING: jacquard_common::xrpc::MessageEncoding = 227 + jacquard_common::xrpc::MessageEncoding::DagCbor; 228 + type Params<'de> = SubscribeRepos; 229 + type Stream = SubscribeReposStream; 230 + } 231 + 211 232 /// A repo operation, ie a mutation of a single record. 212 233 #[jacquard_derive::lexicon] 213 234 #[derive( ··· 218 239 PartialEq, 219 240 Eq, 220 241 jacquard_derive::IntoStatic, 221 - bon::Builder 242 + bon::Builder, 222 243 )] 223 244 #[serde(rename_all = "camelCase")] 224 245 pub struct RepoOp<'a> { ··· 248 269 PartialEq, 249 270 Eq, 250 271 jacquard_derive::IntoStatic, 251 - bon::Builder 272 + bon::Builder, 252 273 )] 253 274 #[serde(rename_all = "camelCase")] 254 275 pub struct Sync<'a> { ··· 265 286 pub seq: i64, 266 287 /// Timestamp of when this message was originally broadcast. 267 288 pub time: jacquard_common::types::string::Datetime, 268 - } 289 + }
+3 -2
crates/jacquard-common/Cargo.toml
··· 44 44 # Streaming support (optional) 45 45 n0-future = { version = "0.1", optional = true } 46 46 futures = { version = "0.3", optional = true } 47 + tokio-tungstenite-wasm = { version = "0.4", optional = true } 47 48 48 49 [target.'cfg(target_family = "wasm")'.dependencies] 49 50 getrandom = { version = "0.3.4", features = ["wasm_js"] } ··· 52 53 reqwest = { workspace = true, optional = true, features = [ "http2", "system-proxy", "rustls-tls"] } 53 54 54 55 [features] 55 - default = ["service-auth", "reqwest-client", "crypto"] 56 + default = ["service-auth", "reqwest-client", "crypto", "websocket"] 56 57 crypto = [] 57 58 crypto-ed25519 = ["crypto", "dep:ed25519-dalek"] 58 59 crypto-k256 = ["crypto", "dep:k256", "k256/ecdsa"] ··· 61 62 reqwest-client = ["dep:reqwest"] 62 63 tracing = ["dep:tracing"] 63 64 streaming = ["n0-future", "futures"] 64 - websocket = ["streaming"] 65 + websocket = ["streaming", "tokio-tungstenite-wasm"] 65 66 66 67 [dependencies.ed25519-dalek] 67 68 version = "2"
+7
crates/jacquard-common/src/error.rs
··· 91 91 #[source] 92 92 serde_ipld_dagcbor::DecodeError<HttpError>, 93 93 ), 94 + /// DAG-CBOR deserialization failed (in-memory, e.g., WebSocket frames) 95 + #[error("Failed to deserialize DAG-CBOR: {0}")] 96 + DagCborInfallible( 97 + #[from] 98 + #[source] 99 + serde_ipld_dagcbor::DecodeError<std::convert::Infallible>, 100 + ), 94 101 } 95 102 96 103 /// HTTP error response (non-200 status codes outside of XRPC error handling)
+4 -1
crates/jacquard-common/src/lib.rs
··· 234 234 pub mod websocket; 235 235 236 236 #[cfg(feature = "websocket")] 237 - pub use websocket::{WebSocketClient, WebSocketConnection}; 237 + pub use websocket::{ 238 + tungstenite_client::TungsteniteClient, CloseCode, CloseFrame, WebSocketClient, 239 + WebSocketConnection, WsMessage, WsSink, WsStream, WsText, 240 + }; 238 241 239 242 pub use types::value::*; 240 243
+29 -6
crates/jacquard-common/src/stream.rs
··· 64 64 Closed, 65 65 /// Protocol violation or framing error 66 66 Protocol, 67 + /// Message deserialization failed 68 + Decode, 69 + /// Wrong message format (e.g., text frame when expecting binary) 70 + WrongMessageFormat, 67 71 } 68 72 69 73 impl StreamError { ··· 105 109 source: Some(msg.into().into()), 106 110 } 107 111 } 112 + 113 + /// Create a decode error with source 114 + pub fn decode(source: impl Error + Send + Sync + 'static) -> Self { 115 + Self { 116 + kind: StreamErrorKind::Decode, 117 + source: Some(Box::new(source)), 118 + } 119 + } 120 + 121 + /// Create a wrong message format error 122 + pub fn wrong_message_format(msg: impl Into<String>) -> Self { 123 + Self { 124 + kind: StreamErrorKind::WrongMessageFormat, 125 + source: Some(msg.into().into()), 126 + } 127 + } 108 128 } 109 129 110 130 impl fmt::Display for StreamError { ··· 113 133 StreamErrorKind::Transport => write!(f, "Transport error"), 114 134 StreamErrorKind::Closed => write!(f, "Stream closed"), 115 135 StreamErrorKind::Protocol => write!(f, "Protocol error"), 136 + StreamErrorKind::Decode => write!(f, "Decode error"), 137 + StreamErrorKind::WrongMessageFormat => write!(f, "Wrong message format"), 116 138 }?; 117 139 118 140 if let Some(source) = &self.source { ··· 125 147 126 148 impl Error for StreamError { 127 149 fn source(&self) -> Option<&(dyn Error + 'static)> { 128 - self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static)) 150 + self.source 151 + .as_ref() 152 + .map(|e| e.as_ref() as &(dyn Error + 'static)) 129 153 } 130 154 } 131 155 ··· 153 177 } 154 178 155 179 /// Convert into the inner boxed stream 156 - pub fn into_inner(self) -> Box<dyn n0_future::Stream<Item = Result<Bytes, StreamError>> + Unpin> { 180 + pub fn into_inner( 181 + self, 182 + ) -> Box<dyn n0_future::Stream<Item = Result<Bytes, StreamError>> + Unpin> { 157 183 self.inner 158 184 } 159 185 } ··· 219 245 async fn byte_stream_can_be_created() { 220 246 use futures::stream; 221 247 222 - let data = vec![ 223 - Ok(Bytes::from("hello")), 224 - Ok(Bytes::from(" world")), 225 - ]; 248 + let data = vec![Ok(Bytes::from("hello")), Ok(Bytes::from(" world"))]; 226 249 let stream = stream::iter(data); 227 250 228 251 let byte_stream = ByteStream::new(stream);
+15 -1
crates/jacquard-common/src/types/uri.rs
··· 6 6 }; 7 7 use serde::{Deserialize, Deserializer, Serialize, Serializer}; 8 8 use smol_str::ToSmolStr; 9 - use std::{fmt::Display, marker::PhantomData, str::FromStr}; 9 + use std::{fmt::Display, marker::PhantomData, ops::Deref, str::FromStr}; 10 10 use url::Url; 11 11 12 12 /// Generic URI with type-specific parsing ··· 202 202 impl<R: Collection> Display for RecordUri<'_, R> { 203 203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 204 204 self.0.fmt(f) 205 + } 206 + } 207 + 208 + impl<'a, R: Collection> AsRef<AtUri<'a>> for RecordUri<'a, R> { 209 + fn as_ref(&self) -> &AtUri<'a> { 210 + &self.0 211 + } 212 + } 213 + 214 + impl<'a, R: Collection> Deref for RecordUri<'a, R> { 215 + type Target = AtUri<'a>; 216 + 217 + fn deref(&self) -> &Self::Target { 218 + &self.0 205 219 } 206 220 } 207 221
+593 -23
crates/jacquard-common/src/websocket.rs
··· 1 1 //! WebSocket client abstraction 2 2 3 - use crate::stream::{ByteStream, ByteSink}; 3 + use bytes::Bytes; 4 + use n0_future::Stream; 5 + use n0_future::stream::Boxed; 6 + use std::borrow::Borrow; 7 + use std::fmt::{self, Display}; 4 8 use std::future::Future; 9 + use std::ops::Deref; 5 10 use url::Url; 6 11 12 + use crate::CowStr; 13 + use crate::stream::StreamError; 14 + 15 + /// UTF-8 validated bytes for WebSocket text messages 16 + #[repr(transparent)] 17 + #[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)] 18 + pub struct WsText(Bytes); 19 + 20 + impl WsText { 21 + /// Create from static string 22 + pub const fn from_static(s: &'static str) -> Self { 23 + Self(Bytes::from_static(s.as_bytes())) 24 + } 25 + 26 + /// Get as string slice 27 + pub fn as_str(&self) -> &str { 28 + unsafe { std::str::from_utf8_unchecked(&self.0) } 29 + } 30 + 31 + /// Create from bytes without validation (caller must ensure UTF-8) 32 + /// 33 + /// # Safety 34 + /// Bytes must be valid UTF-8 35 + pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self { 36 + Self(bytes) 37 + } 38 + 39 + /// Convert into underlying bytes 40 + pub fn into_bytes(self) -> Bytes { 41 + self.0 42 + } 43 + } 44 + 45 + impl Deref for WsText { 46 + type Target = str; 47 + fn deref(&self) -> &str { 48 + self.as_str() 49 + } 50 + } 51 + 52 + impl AsRef<str> for WsText { 53 + fn as_ref(&self) -> &str { 54 + self.as_str() 55 + } 56 + } 57 + 58 + impl AsRef<[u8]> for WsText { 59 + fn as_ref(&self) -> &[u8] { 60 + &self.0 61 + } 62 + } 63 + 64 + impl AsRef<Bytes> for WsText { 65 + fn as_ref(&self) -> &Bytes { 66 + &self.0 67 + } 68 + } 69 + 70 + impl Borrow<str> for WsText { 71 + fn borrow(&self) -> &str { 72 + self.as_str() 73 + } 74 + } 75 + 76 + impl Display for WsText { 77 + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 78 + Display::fmt(self.as_str(), f) 79 + } 80 + } 81 + 82 + impl From<String> for WsText { 83 + fn from(s: String) -> Self { 84 + Self(Bytes::from(s)) 85 + } 86 + } 87 + 88 + impl From<&str> for WsText { 89 + fn from(s: &str) -> Self { 90 + Self(Bytes::copy_from_slice(s.as_bytes())) 91 + } 92 + } 93 + 94 + impl From<&String> for WsText { 95 + fn from(s: &String) -> Self { 96 + Self::from(s.as_str()) 97 + } 98 + } 99 + 100 + impl TryFrom<Bytes> for WsText { 101 + type Error = std::str::Utf8Error; 102 + fn try_from(bytes: Bytes) -> Result<Self, Self::Error> { 103 + std::str::from_utf8(&bytes)?; 104 + Ok(Self(bytes)) 105 + } 106 + } 107 + 108 + impl TryFrom<Vec<u8>> for WsText { 109 + type Error = std::str::Utf8Error; 110 + fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> { 111 + Self::try_from(Bytes::from(vec)) 112 + } 113 + } 114 + 115 + impl From<WsText> for Bytes { 116 + fn from(t: WsText) -> Bytes { 117 + t.0 118 + } 119 + } 120 + 121 + impl Default for WsText { 122 + fn default() -> Self { 123 + Self(Bytes::new()) 124 + } 125 + } 126 + 127 + /// WebSocket close code 128 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 129 + #[repr(u16)] 130 + pub enum CloseCode { 131 + /// Normal closure 132 + Normal = 1000, 133 + /// Endpoint going away 134 + Away = 1001, 135 + /// Protocol error 136 + Protocol = 1002, 137 + /// Unsupported data 138 + Unsupported = 1003, 139 + /// Invalid frame payload data 140 + Invalid = 1007, 141 + /// Policy violation 142 + Policy = 1008, 143 + /// Message too big 144 + Size = 1009, 145 + /// Extension negotiation failure 146 + Extension = 1010, 147 + /// Unexpected condition 148 + Error = 1011, 149 + /// TLS handshake failure 150 + Tls = 1015, 151 + /// Other code 152 + Other(u16), 153 + } 154 + 155 + impl From<u16> for CloseCode { 156 + fn from(code: u16) -> Self { 157 + match code { 158 + 1000 => CloseCode::Normal, 159 + 1001 => CloseCode::Away, 160 + 1002 => CloseCode::Protocol, 161 + 1003 => CloseCode::Unsupported, 162 + 1007 => CloseCode::Invalid, 163 + 1008 => CloseCode::Policy, 164 + 1009 => CloseCode::Size, 165 + 1010 => CloseCode::Extension, 166 + 1011 => CloseCode::Error, 167 + 1015 => CloseCode::Tls, 168 + other => CloseCode::Other(other), 169 + } 170 + } 171 + } 172 + 173 + impl From<CloseCode> for u16 { 174 + fn from(code: CloseCode) -> u16 { 175 + match code { 176 + CloseCode::Normal => 1000, 177 + CloseCode::Away => 1001, 178 + CloseCode::Protocol => 1002, 179 + CloseCode::Unsupported => 1003, 180 + CloseCode::Invalid => 1007, 181 + CloseCode::Policy => 1008, 182 + CloseCode::Size => 1009, 183 + CloseCode::Extension => 1010, 184 + CloseCode::Error => 1011, 185 + CloseCode::Tls => 1015, 186 + CloseCode::Other(code) => code, 187 + } 188 + } 189 + } 190 + 191 + /// WebSocket close frame 192 + #[derive(Debug, Clone, PartialEq, Eq)] 193 + pub struct CloseFrame<'a> { 194 + /// Close code 195 + pub code: CloseCode, 196 + /// Close reason text 197 + pub reason: CowStr<'a>, 198 + } 199 + 200 + impl<'a> CloseFrame<'a> { 201 + /// Create a new close frame 202 + pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self { 203 + Self { 204 + code, 205 + reason: reason.into(), 206 + } 207 + } 208 + } 209 + 210 + /// WebSocket message 211 + #[derive(Debug, Clone, PartialEq, Eq)] 212 + pub enum WsMessage { 213 + /// Text message (UTF-8) 214 + Text(WsText), 215 + /// Binary message 216 + Binary(Bytes), 217 + /// Close frame 218 + Close(Option<CloseFrame<'static>>), 219 + } 220 + 221 + impl WsMessage { 222 + /// Check if this is a text message 223 + pub fn is_text(&self) -> bool { 224 + matches!(self, WsMessage::Text(_)) 225 + } 226 + 227 + /// Check if this is a binary message 228 + pub fn is_binary(&self) -> bool { 229 + matches!(self, WsMessage::Binary(_)) 230 + } 231 + 232 + /// Check if this is a close message 233 + pub fn is_close(&self) -> bool { 234 + matches!(self, WsMessage::Close(_)) 235 + } 236 + 237 + /// Get as text, if this is a text message 238 + pub fn as_text(&self) -> Option<&str> { 239 + match self { 240 + WsMessage::Text(t) => Some(t.as_str()), 241 + _ => None, 242 + } 243 + } 244 + 245 + /// Get as bytes 246 + pub fn as_bytes(&self) -> Option<&[u8]> { 247 + match self { 248 + WsMessage::Text(t) => Some(t.as_ref()), 249 + WsMessage::Binary(b) => Some(b), 250 + WsMessage::Close(_) => None, 251 + } 252 + } 253 + } 254 + 255 + impl From<WsText> for WsMessage { 256 + fn from(text: WsText) -> Self { 257 + WsMessage::Text(text) 258 + } 259 + } 260 + 261 + impl From<String> for WsMessage { 262 + fn from(s: String) -> Self { 263 + WsMessage::Text(WsText::from(s)) 264 + } 265 + } 266 + 267 + impl From<&str> for WsMessage { 268 + fn from(s: &str) -> Self { 269 + WsMessage::Text(WsText::from(s)) 270 + } 271 + } 272 + 273 + impl From<Bytes> for WsMessage { 274 + fn from(bytes: Bytes) -> Self { 275 + WsMessage::Binary(bytes) 276 + } 277 + } 278 + 279 + impl From<Vec<u8>> for WsMessage { 280 + fn from(vec: Vec<u8>) -> Self { 281 + WsMessage::Binary(Bytes::from(vec)) 282 + } 283 + } 284 + 285 + /// WebSocket message stream 286 + pub struct WsStream(Boxed<Result<WsMessage, StreamError>>); 287 + 288 + impl WsStream { 289 + /// Create a new message stream 290 + #[cfg(not(target_arch = "wasm32"))] 291 + pub fn new<S>(stream: S) -> Self 292 + where 293 + S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static, 294 + { 295 + Self(Box::pin(stream)) 296 + } 297 + 298 + /// Create a new message stream 299 + #[cfg(target_arch = "wasm32")] 300 + pub fn new<S>(stream: S) -> Self 301 + where 302 + S: Stream<Item = Result<WsMessage, StreamError>> + 'static, 303 + { 304 + Self(Box::pin(stream)) 305 + } 306 + 307 + /// Convert into the inner pinned boxed stream 308 + pub fn into_inner(self) -> Boxed<Result<WsMessage, StreamError>> { 309 + self.0 310 + } 311 + } 312 + 313 + /// Extension trait for decoding typed messages from WebSocket streams 314 + pub trait WsStreamExt: Sized { 315 + /// Decode JSON text/binary frames into typed messages 316 + /// 317 + /// Deserializes messages but does not automatically convert to owned. 318 + /// The caller is responsible for calling `.into_static()` if needed. 319 + fn decode_json<T>(self) -> impl Stream<Item = Result<T, StreamError>> 320 + where 321 + T: for<'de> serde::Deserialize<'de>; 322 + 323 + /// Decode DAG-CBOR binary frames into typed messages 324 + /// 325 + /// Deserializes messages but does not automatically convert to owned. 326 + /// The caller is responsible for calling `.into_static()` if needed. 327 + fn decode_cbor<T>(self) -> impl Stream<Item = Result<T, StreamError>> 328 + where 329 + T: for<'de> serde::Deserialize<'de>; 330 + } 331 + 332 + impl WsStreamExt for WsStream { 333 + fn decode_json<T>(self) -> impl Stream<Item = Result<T, StreamError>> 334 + where 335 + T: for<'de> serde::Deserialize<'de>, 336 + { 337 + use n0_future::StreamExt as _; 338 + 339 + Box::pin(self.into_inner().filter_map(|msg_result| match msg_result { 340 + Ok(WsMessage::Text(text)) => { 341 + Some(serde_json::from_slice(text.as_ref()).map_err(StreamError::decode)) 342 + } 343 + Ok(WsMessage::Binary(bytes)) => { 344 + Some(serde_json::from_slice(&bytes).map_err(StreamError::decode)) 345 + } 346 + Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 347 + Err(e) => Some(Err(e)), 348 + })) 349 + } 350 + 351 + fn decode_cbor<T>(self) -> impl Stream<Item = Result<T, StreamError>> 352 + where 353 + T: for<'de> serde::Deserialize<'de>, 354 + { 355 + use n0_future::StreamExt as _; 356 + 357 + Box::pin(self.into_inner().filter_map(|msg_result| { 358 + match msg_result { 359 + Ok(WsMessage::Binary(bytes)) => Some( 360 + serde_ipld_dagcbor::from_slice(&bytes) 361 + .map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))), 362 + ), 363 + Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format( 364 + "expected binary frame for CBOR, got text", 365 + ))), 366 + Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())), 367 + Err(e) => Some(Err(e)), 368 + } 369 + })) 370 + } 371 + } 372 + 373 + impl fmt::Debug for WsStream { 374 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 375 + f.debug_struct("WsStream").finish_non_exhaustive() 376 + } 377 + } 378 + 379 + /// WebSocket message sink 380 + pub struct WsSink(Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>); 381 + 382 + impl WsSink { 383 + /// Create a new message sink 384 + pub fn new<S>(sink: S) -> Self 385 + where 386 + S: n0_future::Sink<WsMessage, Error = StreamError> + 'static, 387 + { 388 + Self(Box::new(sink)) 389 + } 390 + 391 + /// Convert into the inner boxed sink 392 + pub fn into_inner(self) -> Box<dyn n0_future::Sink<WsMessage, Error = StreamError>> { 393 + self.0 394 + } 395 + } 396 + 397 + impl fmt::Debug for WsSink { 398 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 399 + f.debug_struct("WsSink").finish_non_exhaustive() 400 + } 401 + } 402 + 7 403 /// WebSocket client trait 8 404 #[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] 9 405 pub trait WebSocketClient { ··· 11 407 type Error: std::error::Error + Send + Sync + 'static; 12 408 13 409 /// Connect to a WebSocket endpoint 14 - fn connect( 15 - &self, 16 - url: Url, 17 - ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>; 410 + fn connect(&self, url: Url) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>; 18 411 } 19 412 20 413 /// WebSocket connection with bidirectional streams 21 414 pub struct WebSocketConnection { 22 - tx: ByteSink, 23 - rx: ByteStream, 415 + tx: WsSink, 416 + rx: WsStream, 24 417 } 25 418 26 419 impl WebSocketConnection { 27 420 /// Create a new WebSocket connection 28 - pub fn new(tx: ByteSink, rx: ByteStream) -> Self { 421 + pub fn new(tx: WsSink, rx: WsStream) -> Self { 29 422 Self { tx, rx } 30 423 } 31 424 32 425 /// Get mutable access to the sender 33 - pub fn sender_mut(&mut self) -> &mut ByteSink { 426 + pub fn sender_mut(&mut self) -> &mut WsSink { 34 427 &mut self.tx 35 428 } 36 429 37 430 /// Get mutable access to the receiver 38 - pub fn receiver_mut(&mut self) -> &mut ByteStream { 431 + pub fn receiver_mut(&mut self) -> &mut WsStream { 39 432 &mut self.rx 40 433 } 41 434 435 + /// Get a reference to the receiver 436 + pub fn receiver(&self) -> &WsStream { 437 + &self.rx 438 + } 439 + 440 + /// Get a reference to the sender 441 + pub fn sender(&self) -> &WsSink { 442 + &self.tx 443 + } 444 + 42 445 /// Split into sender and receiver 43 - pub fn split(self) -> (ByteSink, ByteStream) { 446 + pub fn split(self) -> (WsSink, WsStream) { 44 447 (self.tx, self.rx) 45 448 } 46 449 ··· 50 453 } 51 454 } 52 455 53 - impl std::fmt::Debug for WebSocketConnection { 54 - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 456 + impl fmt::Debug for WebSocketConnection { 457 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 55 458 f.debug_struct("WebSocketConnection") 56 459 .finish_non_exhaustive() 57 460 } 58 461 } 59 462 463 + /// Concrete WebSocket client implementation using tokio-tungstenite-wasm 464 + pub mod tungstenite_client { 465 + use super::*; 466 + use crate::IntoStatic; 467 + use futures::{SinkExt, StreamExt}; 468 + 469 + /// WebSocket client backed by tokio-tungstenite-wasm 470 + #[derive(Debug, Clone, Default)] 471 + pub struct TungsteniteClient; 472 + 473 + impl TungsteniteClient { 474 + /// Create a new tungstenite WebSocket client 475 + pub fn new() -> Self { 476 + Self 477 + } 478 + } 479 + 480 + impl WebSocketClient for TungsteniteClient { 481 + type Error = tokio_tungstenite_wasm::Error; 482 + 483 + async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> { 484 + let ws_stream = tokio_tungstenite_wasm::connect(url.as_str()).await?; 485 + 486 + let (sink, stream) = ws_stream.split(); 487 + 488 + // Convert tungstenite messages to our WsMessage 489 + let rx_stream = stream.filter_map(|result| async move { 490 + match result { 491 + Ok(msg) => match convert_message(msg) { 492 + Some(ws_msg) => Some(Ok(ws_msg)), 493 + None => None, // Skip ping/pong 494 + }, 495 + Err(e) => Some(Err(StreamError::transport(e))), 496 + } 497 + }); 498 + 499 + let rx = WsStream::new(rx_stream); 500 + 501 + // Convert our WsMessage to tungstenite messages 502 + let tx_sink = sink.with(|msg: WsMessage| async move { 503 + Ok::<_, tokio_tungstenite_wasm::Error>(msg.into()) 504 + }); 505 + 506 + let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e)); 507 + let tx = WsSink::new(tx_sink_mapped); 508 + 509 + Ok(WebSocketConnection::new(tx, rx)) 510 + } 511 + } 512 + 513 + /// Convert tokio-tungstenite-wasm Message to our WsMessage 514 + /// Returns None for Ping/Pong which we auto-handle 515 + fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> { 516 + use tokio_tungstenite_wasm::Message; 517 + 518 + match msg { 519 + Message::Text(vec) => { 520 + // tokio-tungstenite-wasm Text contains Vec<u8> (UTF-8 validated) 521 + let bytes = Bytes::from(vec); 522 + Some(WsMessage::Text(unsafe { 523 + WsText::from_bytes_unchecked(bytes) 524 + })) 525 + } 526 + Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))), 527 + Message::Close(frame) => { 528 + let close_frame = frame.map(|f| { 529 + let code = convert_close_code(f.code); 530 + CloseFrame::new(code, CowStr::from(f.reason.into_owned())) 531 + }); 532 + Some(WsMessage::Close(close_frame)) 533 + } 534 + } 535 + } 536 + 537 + /// Convert tokio-tungstenite-wasm CloseCode to our CloseCode 538 + fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode { 539 + use tokio_tungstenite_wasm::CloseCode as TungsteniteCode; 540 + 541 + match code { 542 + TungsteniteCode::Normal => CloseCode::Normal, 543 + TungsteniteCode::Away => CloseCode::Away, 544 + TungsteniteCode::Protocol => CloseCode::Protocol, 545 + TungsteniteCode::Unsupported => CloseCode::Unsupported, 546 + TungsteniteCode::Invalid => CloseCode::Invalid, 547 + TungsteniteCode::Policy => CloseCode::Policy, 548 + TungsteniteCode::Size => CloseCode::Size, 549 + TungsteniteCode::Extension => CloseCode::Extension, 550 + TungsteniteCode::Error => CloseCode::Error, 551 + TungsteniteCode::Tls => CloseCode::Tls, 552 + // For other variants, extract raw code 553 + other => { 554 + let raw: u16 = other.into(); 555 + CloseCode::from(raw) 556 + } 557 + } 558 + } 559 + 560 + impl From<WsMessage> for tokio_tungstenite_wasm::Message { 561 + fn from(msg: WsMessage) -> Self { 562 + use tokio_tungstenite_wasm::Message; 563 + 564 + match msg { 565 + WsMessage::Text(text) => { 566 + // tokio-tungstenite-wasm Text expects String 567 + let bytes = text.into_bytes(); 568 + // Safe: WsText is already UTF-8 validated 569 + let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) }; 570 + Message::Text(string) 571 + } 572 + WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()), 573 + WsMessage::Close(frame) => { 574 + let close_frame = frame.map(|f| { 575 + let code = u16::from(f.code).into(); 576 + tokio_tungstenite_wasm::CloseFrame { 577 + code, 578 + reason: f.reason.into_static().to_string().into(), 579 + } 580 + }); 581 + Message::Close(close_frame) 582 + } 583 + } 584 + } 585 + } 586 + } 587 + 60 588 #[cfg(test)] 61 589 mod tests { 62 590 use super::*; 63 - use crate::stream::StreamError; 591 + 592 + #[test] 593 + fn ws_text_from_string() { 594 + let text = WsText::from("hello"); 595 + assert_eq!(text.as_str(), "hello"); 596 + } 597 + 598 + #[test] 599 + fn ws_text_deref() { 600 + let text = WsText::from(String::from("world")); 601 + assert_eq!(&*text, "world"); 602 + } 603 + 604 + #[test] 605 + fn ws_text_try_from_bytes() { 606 + let bytes = Bytes::from("test"); 607 + let text = WsText::try_from(bytes).unwrap(); 608 + assert_eq!(text.as_str(), "test"); 609 + } 610 + 611 + #[test] 612 + fn ws_text_invalid_utf8() { 613 + let bytes = Bytes::from(vec![0xFF, 0xFE]); 614 + assert!(WsText::try_from(bytes).is_err()); 615 + } 616 + 617 + #[test] 618 + fn ws_message_text() { 619 + let msg = WsMessage::from("hello"); 620 + assert!(msg.is_text()); 621 + assert_eq!(msg.as_text(), Some("hello")); 622 + } 623 + 624 + #[test] 625 + fn ws_message_binary() { 626 + let msg = WsMessage::from(vec![1, 2, 3]); 627 + assert!(msg.is_binary()); 628 + assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..])); 629 + } 630 + 631 + #[test] 632 + fn close_code_conversion() { 633 + assert_eq!(u16::from(CloseCode::Normal), 1000); 634 + assert_eq!(CloseCode::from(1000), CloseCode::Normal); 635 + assert_eq!(CloseCode::from(9999), CloseCode::Other(9999)); 636 + } 64 637 65 638 #[test] 66 639 fn websocket_connection_has_tx_and_rx() { 67 - use futures::stream; 68 640 use futures::sink::SinkExt; 69 - use bytes::Bytes; 641 + use futures::stream; 70 642 71 - let rx_stream = stream::iter(vec![Ok(Bytes::from("test"))]); 72 - let rx = ByteStream::new(rx_stream); 643 + let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]); 644 + let rx = WsStream::new(rx_stream); 73 645 74 - // Create a sink that converts Infallible to StreamError 75 - let drain_sink = futures::sink::drain().sink_map_err(|_: std::convert::Infallible| { 76 - StreamError::closed() 77 - }); 78 - let tx = ByteSink::new(drain_sink); 646 + let drain_sink = futures::sink::drain() 647 + .sink_map_err(|_: std::convert::Infallible| StreamError::closed()); 648 + let tx = WsSink::new(drain_sink); 79 649 80 650 let conn = WebSocketConnection::new(tx, rx); 81 651 assert!(conn.is_open());
+6
crates/jacquard-common/src/xrpc.rs
··· 16 16 #[cfg(feature = "streaming")] 17 17 pub use streaming::StreamingResponse; 18 18 19 + #[cfg(feature = "websocket")] 20 + pub mod subscription; 21 + 22 + #[cfg(feature = "websocket")] 23 + pub use subscription::{MessageEncoding, SubscriptionEndpoint, SubscriptionResp, XrpcSubscription}; 24 + 19 25 use bytes::Bytes; 20 26 use http::{ 21 27 HeaderName, HeaderValue, Request, StatusCode,
+95
crates/jacquard-common/src/xrpc/subscription.rs
··· 1 + //! WebSocket subscription support for XRPC 2 + //! 3 + //! This module defines traits and types for typed WebSocket subscriptions, 4 + //! mirroring the request/response pattern used for HTTP XRPC endpoints. 5 + 6 + use serde::{Deserialize, Serialize}; 7 + use std::error::Error; 8 + 9 + use crate::IntoStatic; 10 + 11 + /// Encoding format for subscription messages 12 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 13 + pub enum MessageEncoding { 14 + /// JSON text frames 15 + Json, 16 + /// DAG-CBOR binary frames 17 + DagCbor, 18 + } 19 + 20 + /// XRPC subscription stream response trait 21 + /// 22 + /// Analogous to `XrpcResp` but for WebSocket subscriptions. 23 + /// Defines the message and error types for a subscription stream. 24 + /// 25 + /// This trait is implemented on a marker struct to keep it lifetime-free 26 + /// while using GATs for the message/error types. 27 + pub trait SubscriptionResp { 28 + /// The NSID for this subscription 29 + const NSID: &'static str; 30 + 31 + /// Message encoding (JSON or DAG-CBOR) 32 + const ENCODING: MessageEncoding; 33 + 34 + /// Message union type 35 + type Message<'de>: Deserialize<'de> + IntoStatic; 36 + 37 + /// Error union type 38 + type Error<'de>: Error + Deserialize<'de> + IntoStatic; 39 + } 40 + 41 + /// XRPC subscription (WebSocket) 42 + /// 43 + /// This trait is analogous to `XrpcRequest` but for WebSocket subscriptions. 44 + /// It defines the NSID and associated stream response type. 45 + /// 46 + /// The trait is implemented on the subscription parameters type. 47 + pub trait XrpcSubscription: Serialize { 48 + /// The NSID for this XRPC subscription 49 + const NSID: &'static str; 50 + 51 + /// Message encoding (JSON or DAG-CBOR) 52 + const ENCODING: MessageEncoding; 53 + 54 + /// Stream response type (marker struct) 55 + type Stream: SubscriptionResp; 56 + 57 + /// Encode query params for WebSocket URL 58 + /// 59 + /// Default implementation uses serde_html_form to encode the struct as query parameters. 60 + fn query_params(&self) -> Vec<(String, String)> { 61 + // Default: use serde_html_form to encode self 62 + serde_html_form::to_string(self) 63 + .ok() 64 + .map(|s| { 65 + s.split('&') 66 + .filter_map(|pair| { 67 + let mut parts = pair.splitn(2, '='); 68 + Some((parts.next()?.to_string(), parts.next()?.to_string())) 69 + }) 70 + .collect() 71 + }) 72 + .unwrap_or_default() 73 + } 74 + } 75 + 76 + /// XRPC subscription endpoint trait (server-side) 77 + /// 78 + /// Analogous to `XrpcEndpoint` but for WebSocket subscriptions. 79 + /// Defines the fully-qualified path and associated parameter/stream types. 80 + /// 81 + /// This exists primarily for server-side frameworks (like Axum) to extract 82 + /// typed subscription parameters without lifetime issues. 83 + pub trait SubscriptionEndpoint { 84 + /// Fully-qualified path ('/xrpc/[nsid]') where this subscription endpoint lives 85 + const PATH: &'static str; 86 + 87 + /// Message encoding (JSON or DAG-CBOR) 88 + const ENCODING: MessageEncoding; 89 + 90 + /// Subscription parameters type 91 + type Params<'de>: XrpcSubscription + Deserialize<'de> + IntoStatic; 92 + 93 + /// Stream response type 94 + type Stream: SubscriptionResp; 95 + }
+254 -26
crates/jacquard-lexicon/src/codegen/xrpc.rs
··· 7 7 use proc_macro2::TokenStream; 8 8 use quote::quote; 9 9 10 - use super::utils::make_ident; 11 10 use super::CodeGenerator; 11 + use super::utils::make_ident; 12 12 13 13 impl<'c> CodeGenerator<'c> { 14 14 /// Generate query type ··· 178 178 output.push(error_enum); 179 179 } 180 180 181 + // Generate XrpcSubscription trait impl 182 + let params_has_lifetime = sub 183 + .parameters 184 + .as_ref() 185 + .map(|p| match p { 186 + crate::lexicon::LexXrpcSubscriptionParameter::Params(params) => { 187 + self.params_need_lifetime(params) 188 + } 189 + }) 190 + .unwrap_or(false); 191 + 192 + let has_params = sub.parameters.is_some(); 193 + let has_message = sub.message.is_some(); 194 + let has_errors = sub.errors.is_some(); 195 + 196 + let subscription_impl = self.generate_xrpc_subscription_impl( 197 + nsid, 198 + &type_base, 199 + has_params, 200 + params_has_lifetime, 201 + has_message, 202 + has_errors, 203 + )?; 204 + output.push(subscription_impl); 205 + 181 206 Ok(quote! { 182 207 #(#output)* 183 208 }) ··· 200 225 let mut variants = Vec::new(); 201 226 for ref_str in &union.refs { 202 227 let ref_str_s = ref_str.as_ref(); 228 + 229 + // Normalize local refs (starting with #) by prepending current NSID 230 + let normalized_ref = if ref_str.starts_with('#') { 231 + format!("{}{}", nsid, ref_str) 232 + } else { 233 + ref_str.to_string() 234 + }; 235 + 203 236 // Parse ref to get NSID and def name 204 237 let (ref_nsid, ref_def) = 205 - if let Some((nsid, fragment)) = ref_str.split_once('#') { 206 - (nsid, fragment) 238 + if let Some((nsid_part, fragment)) = normalized_ref.split_once('#') { 239 + (nsid_part, fragment) 207 240 } else { 208 - (ref_str.as_ref(), "main") 241 + (normalized_ref.as_str(), "main") 209 242 }; 210 243 211 244 let variant_name = if ref_def == "main" { ··· 215 248 }; 216 249 let variant_ident = 217 250 syn::Ident::new(&variant_name, proc_macro2::Span::call_site()); 218 - let type_path = self.ref_to_rust_type(ref_str)?; 251 + let type_path = self.ref_to_rust_type(&normalized_ref)?; 219 252 220 253 variants.push(quote! { 221 254 #[serde(rename = #ref_str_s)] ··· 262 295 match field_type { 263 296 LexObjectProperty::Union(union) => { 264 297 // Skip empty, single-variant unions unless they're self-referential 265 - if !union.refs.is_empty() && (union.refs.len() > 1 || self.is_self_referential_union(nsid, &struct_name, union)) { 266 - let union_name = self.generate_field_type_name(nsid, &struct_name, field_name, ""); 298 + if !union.refs.is_empty() 299 + && (union.refs.len() > 1 300 + || self.is_self_referential_union(nsid, &struct_name, union)) 301 + { 302 + let union_name = self.generate_field_type_name( 303 + nsid, 304 + &struct_name, 305 + field_name, 306 + "", 307 + ); 267 308 let refs: Vec<_> = union.refs.iter().cloned().collect(); 268 - let union_def = 269 - self.generate_union(nsid, &union_name, &refs, None, union.closed)?; 309 + let union_def = self.generate_union( 310 + nsid, 311 + &union_name, 312 + &refs, 313 + None, 314 + union.closed, 315 + )?; 270 316 unions.push(union_def); 271 317 } 272 318 } ··· 274 320 if let LexArrayItem::Union(union) = &array.items { 275 321 // Skip single-variant array unions 276 322 if union.refs.len() > 1 { 277 - let union_name = self.generate_field_type_name(nsid, &struct_name, field_name, "Item"); 323 + let union_name = self.generate_field_type_name( 324 + nsid, 325 + &struct_name, 326 + field_name, 327 + "Item", 328 + ); 278 329 let refs: Vec<_> = union.refs.iter().cloned().collect(); 279 - let union_def = self.generate_union(nsid, &union_name, &refs, None, union.closed)?; 330 + let union_def = self.generate_union( 331 + nsid, 332 + &union_name, 333 + &refs, 334 + None, 335 + union.closed, 336 + )?; 280 337 unions.push(union_def); 281 338 } 282 339 } ··· 415 472 let (has_default, has_builder) = if is_binary_body { 416 473 (false, true) 417 474 } else if let Some(crate::lexicon::LexXrpcBodySchema::Object(obj)) = &body.schema { 418 - use crate::codegen::structs::{count_required_fields, all_required_are_defaultable_strings, conflicts_with_builder_macro}; 475 + use crate::codegen::structs::{ 476 + all_required_are_defaultable_strings, conflicts_with_builder_macro, 477 + count_required_fields, 478 + }; 419 479 let required_count = count_required_fields(obj); 420 480 let can_default = required_count == 0 || all_required_are_defaultable_strings(obj); 421 - let can_builder = required_count >= 1 && !can_default && !conflicts_with_builder_macro(type_base); 481 + let can_builder = 482 + required_count >= 1 && !can_default && !conflicts_with_builder_macro(type_base); 422 483 (can_default, can_builder) 423 484 } else { 424 485 (false, false) ··· 495 556 match field_type { 496 557 LexObjectProperty::Union(union) => { 497 558 // Skip empty, single-variant unions unless they're self-referential 498 - if !union.refs.is_empty() && (union.refs.len() > 1 || self.is_self_referential_union(nsid, type_base, union)) { 499 - let union_name = self.generate_field_type_name(nsid, type_base, field_name, ""); 559 + if !union.refs.is_empty() 560 + && (union.refs.len() > 1 561 + || self.is_self_referential_union(nsid, type_base, union)) 562 + { 563 + let union_name = 564 + self.generate_field_type_name(nsid, type_base, field_name, ""); 500 565 let refs: Vec<_> = union.refs.iter().cloned().collect(); 501 566 let union_def = 502 567 self.generate_union(nsid, &union_name, &refs, None, union.closed)?; ··· 507 572 if let LexArrayItem::Union(union) = &array.items { 508 573 // Skip single-variant array unions 509 574 if union.refs.len() > 1 { 510 - let union_name = self.generate_field_type_name(nsid, type_base, field_name, "Item"); 575 + let union_name = self 576 + .generate_field_type_name(nsid, type_base, field_name, "Item"); 511 577 let refs: Vec<_> = union.refs.iter().cloned().collect(); 512 - let union_def = self.generate_union(nsid, &union_name, &refs, None, union.closed)?; 578 + let union_def = self.generate_union( 579 + nsid, 580 + &union_name, 581 + &refs, 582 + None, 583 + union.closed, 584 + )?; 513 585 unions.push(union_def); 514 586 } 515 587 } ··· 545 617 546 618 // Determine if we should derive Default 547 619 // Check if schema is an Object and apply heuristics 548 - let has_default = if let Some(crate::lexicon::LexXrpcBodySchema::Object(obj)) = &body.schema { 549 - use crate::codegen::structs::{count_required_fields, all_required_are_defaultable_strings}; 620 + let has_default = if let Some(crate::lexicon::LexXrpcBodySchema::Object(obj)) = &body.schema 621 + { 622 + use crate::codegen::structs::{ 623 + all_required_are_defaultable_strings, count_required_fields, 624 + }; 550 625 let required_count = count_required_fields(obj); 551 626 required_count == 0 || all_required_are_defaultable_strings(obj) 552 627 } else { ··· 584 659 match field_type { 585 660 LexObjectProperty::Union(union) => { 586 661 // Skip single-variant unions unless they're self-referential 587 - if union.refs.len() > 1 || self.is_self_referential_union(nsid, &struct_name, union) { 588 - let union_name = self.generate_field_type_name(nsid, &struct_name, field_name, ""); 662 + if union.refs.len() > 1 663 + || self.is_self_referential_union(nsid, &struct_name, union) 664 + { 665 + let union_name = 666 + self.generate_field_type_name(nsid, &struct_name, field_name, ""); 589 667 let refs: Vec<_> = union.refs.iter().cloned().collect(); 590 668 let union_def = 591 669 self.generate_union(nsid, &union_name, &refs, None, union.closed)?; ··· 596 674 if let LexArrayItem::Union(union) = &array.items { 597 675 // Skip single-variant array unions 598 676 if union.refs.len() > 1 { 599 - let union_name = self.generate_field_type_name(nsid, &struct_name, field_name, "Item"); 677 + let union_name = self.generate_field_type_name( 678 + nsid, 679 + &struct_name, 680 + field_name, 681 + "Item", 682 + ); 600 683 let refs: Vec<_> = union.refs.iter().cloned().collect(); 601 - let union_def = self.generate_union(nsid, &union_name, &refs, None, union.closed)?; 684 + let union_def = self.generate_union( 685 + nsid, 686 + &union_name, 687 + &refs, 688 + None, 689 + union.closed, 690 + )?; 602 691 unions.push(union_def); 603 692 } 604 693 } ··· 953 1042 ); 954 1043 955 1044 let response_type = quote! { 956 - #[doc = "Response type for "] 1045 + #[doc = " Response type for "] 957 1046 #[doc = #nsid] 958 1047 pub struct #response_ident; 959 1048 ··· 1027 1116 #decode_body_method 1028 1117 } 1029 1118 1030 - #[doc = "Endpoint type for "] 1119 + #[doc = " Endpoint type for "] 1031 1120 #[doc = #nsid] 1032 1121 pub struct #endpoint_ident; 1033 1122 ··· 1057 1146 type Response = #response_ident; 1058 1147 } 1059 1148 1060 - #[doc = "Endpoint type for "] 1149 + #[doc = " Endpoint type for "] 1061 1150 #[doc = #nsid] 1062 1151 pub struct #endpoint_ident; 1063 1152 ··· 1070 1159 } 1071 1160 }) 1072 1161 } 1162 + } 1163 + 1164 + /// Generate XrpcSubscription trait impl for a subscription endpoint 1165 + pub(super) fn generate_xrpc_subscription_impl( 1166 + &self, 1167 + nsid: &str, 1168 + type_base: &str, 1169 + has_params: bool, 1170 + params_has_lifetime: bool, 1171 + has_message: bool, 1172 + has_errors: bool, 1173 + ) -> Result<TokenStream> { 1174 + // Generate stream response marker struct 1175 + let stream_ident = syn::Ident::new( 1176 + &format!("{}Stream", type_base), 1177 + proc_macro2::Span::call_site(), 1178 + ); 1179 + 1180 + let message_type = if has_message { 1181 + let msg_ident = syn::Ident::new( 1182 + &format!("{}Message", type_base), 1183 + proc_macro2::Span::call_site(), 1184 + ); 1185 + quote! { #msg_ident<'de> } 1186 + } else { 1187 + quote! { () } 1188 + }; 1189 + 1190 + let error_type = if has_errors { 1191 + let err_ident = syn::Ident::new( 1192 + &format!("{}Error", type_base), 1193 + proc_macro2::Span::call_site(), 1194 + ); 1195 + quote! { #err_ident<'de> } 1196 + } else { 1197 + quote! { jacquard_common::xrpc::GenericError<'de> } 1198 + }; 1199 + 1200 + // Determine encoding from nsid convention 1201 + // ATProto subscriptions use DAG-CBOR, community ones might use JSON 1202 + let encoding = if nsid.starts_with("com.atproto") { 1203 + quote! { jacquard_common::xrpc::MessageEncoding::DagCbor } 1204 + } else { 1205 + quote! { jacquard_common::xrpc::MessageEncoding::Json } 1206 + }; 1207 + 1208 + // Generate SubscriptionResp impl 1209 + let stream_resp_impl = quote! { 1210 + #[doc = "Stream response type for "] 1211 + #[doc = #nsid] 1212 + pub struct #stream_ident; 1213 + 1214 + impl jacquard_common::xrpc::SubscriptionResp for #stream_ident { 1215 + const NSID: &'static str = #nsid; 1216 + const ENCODING: jacquard_common::xrpc::MessageEncoding = #encoding; 1217 + 1218 + type Message<'de> = #message_type; 1219 + type Error<'de> = #error_type; 1220 + } 1221 + }; 1222 + 1223 + let params_ident = if has_params { 1224 + syn::Ident::new(type_base, proc_macro2::Span::call_site()) 1225 + } else { 1226 + // Generate marker struct if no params 1227 + let marker = syn::Ident::new(type_base, proc_macro2::Span::call_site()); 1228 + let endpoint_ident = syn::Ident::new( 1229 + &format!("{}Endpoint", type_base), 1230 + proc_macro2::Span::call_site(), 1231 + ); 1232 + let endpoint_path = format!("/xrpc/{}", nsid); 1233 + 1234 + return Ok(quote! { 1235 + #stream_resp_impl 1236 + 1237 + #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)] 1238 + pub struct #marker; 1239 + 1240 + impl jacquard_common::xrpc::XrpcSubscription for #marker { 1241 + const NSID: &'static str = #nsid; 1242 + const ENCODING: jacquard_common::xrpc::MessageEncoding = #encoding; 1243 + 1244 + type Stream = #stream_ident; 1245 + } 1246 + 1247 + pub struct #endpoint_ident; 1248 + 1249 + impl jacquard_common::xrpc::SubscriptionEndpoint for #endpoint_ident { 1250 + const PATH: &'static str = #endpoint_path; 1251 + const ENCODING: jacquard_common::xrpc::MessageEncoding = #encoding; 1252 + 1253 + type Params<'de> = #marker; 1254 + type Stream = #stream_ident; 1255 + } 1256 + }); 1257 + }; 1258 + 1259 + let (impl_generics, impl_target, endpoint_params_type) = 1260 + if has_params && params_has_lifetime { 1261 + ( 1262 + quote! { <'a> }, 1263 + quote! { #params_ident<'a> }, 1264 + quote! { #params_ident<'de> }, 1265 + ) 1266 + } else { 1267 + ( 1268 + quote! {}, 1269 + quote! { #params_ident }, 1270 + quote! { #params_ident }, 1271 + ) 1272 + }; 1273 + 1274 + let endpoint_ident = syn::Ident::new( 1275 + &format!("{}Endpoint", type_base), 1276 + proc_macro2::Span::call_site(), 1277 + ); 1278 + 1279 + let endpoint_path = format!("/xrpc/{}", nsid); 1280 + 1281 + Ok(quote! { 1282 + #stream_resp_impl 1283 + 1284 + impl #impl_generics jacquard_common::xrpc::XrpcSubscription for #impl_target { 1285 + const NSID: &'static str = #nsid; 1286 + const ENCODING: jacquard_common::xrpc::MessageEncoding = #encoding; 1287 + 1288 + type Stream = #stream_ident; 1289 + } 1290 + 1291 + pub struct #endpoint_ident; 1292 + 1293 + impl jacquard_common::xrpc::SubscriptionEndpoint for #endpoint_ident { 1294 + const PATH: &'static str = #endpoint_path; 1295 + const ENCODING: jacquard_common::xrpc::MessageEncoding = #encoding; 1296 + 1297 + type Params<'de> = #endpoint_params_type; 1298 + type Stream = #stream_ident; 1299 + } 1300 + }) 1073 1301 } 1074 1302 }