tangled
alpha
login
or
join now
nonbinary.computer
/
jacquard
80
fork
atom
A better Rust ATProto crate
80
fork
atom
overview
issues
9
pulls
pipelines
websocket impl in progress
Orual
5 months ago
c6d2a31d
f7db2a76
+1159
-89
13 changed files
expand all
collapse all
unified
split
Cargo.lock
crates
jacquard-api
src
app_bsky
feed
post.rs
com_atproto
label
subscribe_labels.rs
sync
subscribe_repos.rs
jacquard-common
Cargo.toml
src
error.rs
lib.rs
stream.rs
types
uri.rs
websocket.rs
xrpc
subscription.rs
xrpc.rs
jacquard-lexicon
src
codegen
xrpc.rs
+66
Cargo.lock
···
2005
2005
"smol_str",
2006
2006
"thiserror 2.0.17",
2007
2007
"tokio",
2008
2008
+
"tokio-tungstenite-wasm",
2008
2009
"tracing",
2009
2010
"trait-variant",
2010
2011
"url",
···
3577
3578
]
3578
3579
3579
3580
[[package]]
3581
3581
+
name = "sha1"
3582
3582
+
version = "0.10.6"
3583
3583
+
source = "registry+https://github.com/rust-lang/crates.io-index"
3584
3584
+
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
3585
3585
+
dependencies = [
3586
3586
+
"cfg-if",
3587
3587
+
"cpufeatures",
3588
3588
+
"digest",
3589
3589
+
]
3590
3590
+
3591
3591
+
[[package]]
3580
3592
name = "sha1_smol"
3581
3593
version = "1.0.1"
3582
3594
source = "registry+https://github.com/rust-lang/crates.io-index"
···
4021
4033
]
4022
4034
4023
4035
[[package]]
4036
4036
+
name = "tokio-tungstenite"
4037
4037
+
version = "0.24.0"
4038
4038
+
source = "registry+https://github.com/rust-lang/crates.io-index"
4039
4039
+
checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9"
4040
4040
+
dependencies = [
4041
4041
+
"futures-util",
4042
4042
+
"log",
4043
4043
+
"tokio",
4044
4044
+
"tungstenite",
4045
4045
+
]
4046
4046
+
4047
4047
+
[[package]]
4048
4048
+
name = "tokio-tungstenite-wasm"
4049
4049
+
version = "0.4.0"
4050
4050
+
source = "registry+https://github.com/rust-lang/crates.io-index"
4051
4051
+
checksum = "e21a5c399399c3db9f08d8297ac12b500e86bca82e930253fdc62eaf9c0de6ae"
4052
4052
+
dependencies = [
4053
4053
+
"futures-channel",
4054
4054
+
"futures-util",
4055
4055
+
"http",
4056
4056
+
"httparse",
4057
4057
+
"js-sys",
4058
4058
+
"thiserror 1.0.69",
4059
4059
+
"tokio",
4060
4060
+
"tokio-tungstenite",
4061
4061
+
"wasm-bindgen",
4062
4062
+
"web-sys",
4063
4063
+
]
4064
4064
+
4065
4065
+
[[package]]
4024
4066
name = "tokio-util"
4025
4067
version = "0.7.16"
4026
4068
source = "registry+https://github.com/rust-lang/crates.io-index"
···
4162
4204
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
4163
4205
4164
4206
[[package]]
4207
4207
+
name = "tungstenite"
4208
4208
+
version = "0.24.0"
4209
4209
+
source = "registry+https://github.com/rust-lang/crates.io-index"
4210
4210
+
checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a"
4211
4211
+
dependencies = [
4212
4212
+
"byteorder",
4213
4213
+
"bytes",
4214
4214
+
"data-encoding",
4215
4215
+
"http",
4216
4216
+
"httparse",
4217
4217
+
"log",
4218
4218
+
"rand 0.8.5",
4219
4219
+
"sha1",
4220
4220
+
"thiserror 1.0.69",
4221
4221
+
"utf-8",
4222
4222
+
]
4223
4223
+
4224
4224
+
[[package]]
4165
4225
name = "twoway"
4166
4226
version = "0.1.8"
4167
4227
source = "registry+https://github.com/rust-lang/crates.io-index"
···
4271
4331
version = "2.1.3"
4272
4332
source = "registry+https://github.com/rust-lang/crates.io-index"
4273
4333
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
4334
4334
+
4335
4335
+
[[package]]
4336
4336
+
name = "utf-8"
4337
4337
+
version = "0.7.6"
4338
4338
+
source = "registry+https://github.com/rust-lang/crates.io-index"
4339
4339
+
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
4274
4340
4275
4341
[[package]]
4276
4342
name = "utf8_iter"
+19
-7
crates/jacquard-api/src/app_bsky/feed/post.rs
···
15
15
PartialEq,
16
16
Eq,
17
17
jacquard_derive::IntoStatic,
18
18
-
bon::Builder,
18
18
+
bon::Builder
19
19
)]
20
20
#[serde(rename_all = "camelCase")]
21
21
pub struct Entity<'a> {
···
40
40
PartialEq,
41
41
Eq,
42
42
jacquard_derive::IntoStatic,
43
43
-
bon::Builder,
43
43
+
bon::Builder
44
44
)]
45
45
#[serde(rename_all = "camelCase")]
46
46
pub struct Post<'a> {
···
99
99
100
100
#[jacquard_derive::open_union]
101
101
#[derive(
102
102
-
serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq, jacquard_derive::IntoStatic,
102
102
+
serde::Serialize,
103
103
+
serde::Deserialize,
104
104
+
Debug,
105
105
+
Clone,
106
106
+
PartialEq,
107
107
+
Eq,
108
108
+
jacquard_derive::IntoStatic
103
109
)]
104
110
#[serde(tag = "$type")]
105
111
#[serde(bound(deserialize = "'de: 'a"))]
···
118
124
119
125
/// Typed wrapper for GetRecord response with this collection's record type.
120
126
#[derive(
121
121
-
serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq, jacquard_derive::IntoStatic,
127
127
+
serde::Serialize,
128
128
+
serde::Deserialize,
129
129
+
Debug,
130
130
+
Clone,
131
131
+
PartialEq,
132
132
+
Eq,
133
133
+
jacquard_derive::IntoStatic
122
134
)]
123
135
#[serde(rename_all = "camelCase")]
124
136
pub struct PostGetRecordOutput<'a> {
···
167
179
PartialEq,
168
180
Eq,
169
181
jacquard_derive::IntoStatic,
170
170
-
bon::Builder,
182
182
+
bon::Builder
171
183
)]
172
184
#[serde(rename_all = "camelCase")]
173
185
pub struct ReplyRef<'a> {
···
187
199
PartialEq,
188
200
Eq,
189
201
jacquard_derive::IntoStatic,
190
190
-
bon::Builder,
202
202
+
bon::Builder
191
203
)]
192
204
#[serde(rename_all = "camelCase")]
193
205
pub struct TextSlice<'a> {
194
206
pub end: i64,
195
207
pub start: i64,
196
196
-
}
208
208
+
}
+26
-2
crates/jacquard-api/src/com_atproto/label/subscribe_labels.rs
···
74
74
#[serde(bound(deserialize = "'de: 'a"))]
75
75
pub enum SubscribeLabelsMessage<'a> {
76
76
#[serde(rename = "#labels")]
77
77
-
Labels(Box<jacquard_common::types::value::Data<'a>>),
77
77
+
Labels(Box<crate::com_atproto::label::subscribe_labels::Labels<'a>>),
78
78
#[serde(rename = "#info")]
79
79
-
Info(Box<jacquard_common::types::value::Data<'a>>),
79
79
+
Info(Box<crate::com_atproto::label::subscribe_labels::Info<'a>>),
80
80
}
81
81
82
82
#[jacquard_derive::open_union]
···
111
111
Self::Unknown(err) => write!(f, "Unknown error: {:?}", err),
112
112
}
113
113
}
114
114
+
}
115
115
+
116
116
+
///Stream response type for
117
117
+
///com.atproto.label.subscribeLabels
118
118
+
pub struct SubscribeLabelsStream;
119
119
+
impl jacquard_common::xrpc::SubscriptionResp for SubscribeLabelsStream {
120
120
+
const NSID: &'static str = "com.atproto.label.subscribeLabels";
121
121
+
const ENCODING: jacquard_common::xrpc::MessageEncoding = jacquard_common::xrpc::MessageEncoding::DagCbor;
122
122
+
type Message<'de> = SubscribeLabelsMessage<'de>;
123
123
+
type Error<'de> = SubscribeLabelsError<'de>;
124
124
+
}
125
125
+
126
126
+
impl jacquard_common::xrpc::XrpcSubscription for SubscribeLabels {
127
127
+
const NSID: &'static str = "com.atproto.label.subscribeLabels";
128
128
+
const ENCODING: jacquard_common::xrpc::MessageEncoding = jacquard_common::xrpc::MessageEncoding::DagCbor;
129
129
+
type Stream = SubscribeLabelsStream;
130
130
+
}
131
131
+
132
132
+
pub struct SubscribeLabelsEndpoint;
133
133
+
impl jacquard_common::xrpc::SubscriptionEndpoint for SubscribeLabelsEndpoint {
134
134
+
const PATH: &'static str = "/xrpc/com.atproto.label.subscribeLabels";
135
135
+
const ENCODING: jacquard_common::xrpc::MessageEncoding = jacquard_common::xrpc::MessageEncoding::DagCbor;
136
136
+
type Params<'de> = SubscribeLabels;
137
137
+
type Stream = SubscribeLabelsStream;
114
138
}
+42
-21
crates/jacquard-api/src/com_atproto/sync/subscribe_repos.rs
···
15
15
PartialEq,
16
16
Eq,
17
17
jacquard_derive::IntoStatic,
18
18
-
bon::Builder
18
18
+
bon::Builder,
19
19
)]
20
20
#[serde(rename_all = "camelCase")]
21
21
pub struct Account<'a> {
···
42
42
PartialEq,
43
43
Eq,
44
44
jacquard_derive::IntoStatic,
45
45
-
bon::Builder
45
45
+
bon::Builder,
46
46
)]
47
47
#[serde(rename_all = "camelCase")]
48
48
pub struct Commit<'a> {
···
87
87
PartialEq,
88
88
Eq,
89
89
jacquard_derive::IntoStatic,
90
90
-
bon::Builder
90
90
+
bon::Builder,
91
91
)]
92
92
#[serde(rename_all = "camelCase")]
93
93
pub struct Identity<'a> {
···
111
111
PartialEq,
112
112
Eq,
113
113
jacquard_derive::IntoStatic,
114
114
-
Default
114
114
+
Default,
115
115
)]
116
116
#[serde(rename_all = "camelCase")]
117
117
pub struct Info<'a> {
···
130
130
PartialEq,
131
131
Eq,
132
132
bon::Builder,
133
133
-
jacquard_derive::IntoStatic
133
133
+
jacquard_derive::IntoStatic,
134
134
)]
135
135
#[builder(start_fn = new)]
136
136
#[serde(rename_all = "camelCase")]
···
141
141
142
142
#[jacquard_derive::open_union]
143
143
#[derive(
144
144
-
serde::Serialize,
145
145
-
serde::Deserialize,
146
146
-
Debug,
147
147
-
Clone,
148
148
-
PartialEq,
149
149
-
Eq,
150
150
-
jacquard_derive::IntoStatic
144
144
+
serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq, jacquard_derive::IntoStatic,
151
145
)]
152
146
#[serde(tag = "$type")]
153
147
#[serde(bound(deserialize = "'de: 'a"))]
154
148
pub enum SubscribeReposMessage<'a> {
155
149
#[serde(rename = "#commit")]
156
156
-
Commit(Box<jacquard_common::types::value::Data<'a>>),
150
150
+
Commit(Box<crate::com_atproto::sync::subscribe_repos::Commit<'a>>),
157
151
#[serde(rename = "#sync")]
158
158
-
Sync(Box<jacquard_common::types::value::Data<'a>>),
152
152
+
Sync(Box<crate::com_atproto::sync::subscribe_repos::Sync<'a>>),
159
153
#[serde(rename = "#identity")]
160
160
-
Identity(Box<jacquard_common::types::value::Data<'a>>),
154
154
+
Identity(Box<crate::com_atproto::sync::subscribe_repos::Identity<'a>>),
161
155
#[serde(rename = "#account")]
162
162
-
Account(Box<jacquard_common::types::value::Data<'a>>),
156
156
+
Account(Box<crate::com_atproto::sync::subscribe_repos::Account<'a>>),
163
157
#[serde(rename = "#info")]
164
164
-
Info(Box<jacquard_common::types::value::Data<'a>>),
158
158
+
Info(Box<crate::com_atproto::sync::subscribe_repos::Info<'a>>),
165
159
}
166
160
167
161
#[jacquard_derive::open_union]
···
174
168
Eq,
175
169
thiserror::Error,
176
170
miette::Diagnostic,
177
177
-
jacquard_derive::IntoStatic
171
171
+
jacquard_derive::IntoStatic,
178
172
)]
179
173
#[serde(tag = "error", content = "message")]
180
174
#[serde(bound(deserialize = "'de: 'a"))]
···
208
202
}
209
203
}
210
204
205
205
+
///Stream response type for
206
206
+
///com.atproto.sync.subscribeRepos
207
207
+
pub struct SubscribeReposStream;
208
208
+
impl jacquard_common::xrpc::SubscriptionResp for SubscribeReposStream {
209
209
+
const NSID: &'static str = "com.atproto.sync.subscribeRepos";
210
210
+
const ENCODING: jacquard_common::xrpc::MessageEncoding =
211
211
+
jacquard_common::xrpc::MessageEncoding::DagCbor;
212
212
+
type Message<'de> = SubscribeReposMessage<'de>;
213
213
+
type Error<'de> = SubscribeReposError<'de>;
214
214
+
}
215
215
+
216
216
+
impl jacquard_common::xrpc::XrpcSubscription for SubscribeRepos {
217
217
+
const NSID: &'static str = "com.atproto.sync.subscribeRepos";
218
218
+
const ENCODING: jacquard_common::xrpc::MessageEncoding =
219
219
+
jacquard_common::xrpc::MessageEncoding::DagCbor;
220
220
+
type Stream = SubscribeReposStream;
221
221
+
}
222
222
+
223
223
+
pub struct SubscribeReposEndpoint;
224
224
+
impl jacquard_common::xrpc::SubscriptionEndpoint for SubscribeReposEndpoint {
225
225
+
const PATH: &'static str = "/xrpc/com.atproto.sync.subscribeRepos";
226
226
+
const ENCODING: jacquard_common::xrpc::MessageEncoding =
227
227
+
jacquard_common::xrpc::MessageEncoding::DagCbor;
228
228
+
type Params<'de> = SubscribeRepos;
229
229
+
type Stream = SubscribeReposStream;
230
230
+
}
231
231
+
211
232
/// A repo operation, ie a mutation of a single record.
212
233
#[jacquard_derive::lexicon]
213
234
#[derive(
···
218
239
PartialEq,
219
240
Eq,
220
241
jacquard_derive::IntoStatic,
221
221
-
bon::Builder
242
242
+
bon::Builder,
222
243
)]
223
244
#[serde(rename_all = "camelCase")]
224
245
pub struct RepoOp<'a> {
···
248
269
PartialEq,
249
270
Eq,
250
271
jacquard_derive::IntoStatic,
251
251
-
bon::Builder
272
272
+
bon::Builder,
252
273
)]
253
274
#[serde(rename_all = "camelCase")]
254
275
pub struct Sync<'a> {
···
265
286
pub seq: i64,
266
287
/// Timestamp of when this message was originally broadcast.
267
288
pub time: jacquard_common::types::string::Datetime,
268
268
-
}
289
289
+
}
+3
-2
crates/jacquard-common/Cargo.toml
···
44
44
# Streaming support (optional)
45
45
n0-future = { version = "0.1", optional = true }
46
46
futures = { version = "0.3", optional = true }
47
47
+
tokio-tungstenite-wasm = { version = "0.4", optional = true }
47
48
48
49
[target.'cfg(target_family = "wasm")'.dependencies]
49
50
getrandom = { version = "0.3.4", features = ["wasm_js"] }
···
52
53
reqwest = { workspace = true, optional = true, features = [ "http2", "system-proxy", "rustls-tls"] }
53
54
54
55
[features]
55
55
-
default = ["service-auth", "reqwest-client", "crypto"]
56
56
+
default = ["service-auth", "reqwest-client", "crypto", "websocket"]
56
57
crypto = []
57
58
crypto-ed25519 = ["crypto", "dep:ed25519-dalek"]
58
59
crypto-k256 = ["crypto", "dep:k256", "k256/ecdsa"]
···
61
62
reqwest-client = ["dep:reqwest"]
62
63
tracing = ["dep:tracing"]
63
64
streaming = ["n0-future", "futures"]
64
64
-
websocket = ["streaming"]
65
65
+
websocket = ["streaming", "tokio-tungstenite-wasm"]
65
66
66
67
[dependencies.ed25519-dalek]
67
68
version = "2"
+7
crates/jacquard-common/src/error.rs
···
91
91
#[source]
92
92
serde_ipld_dagcbor::DecodeError<HttpError>,
93
93
),
94
94
+
/// DAG-CBOR deserialization failed (in-memory, e.g., WebSocket frames)
95
95
+
#[error("Failed to deserialize DAG-CBOR: {0}")]
96
96
+
DagCborInfallible(
97
97
+
#[from]
98
98
+
#[source]
99
99
+
serde_ipld_dagcbor::DecodeError<std::convert::Infallible>,
100
100
+
),
94
101
}
95
102
96
103
/// HTTP error response (non-200 status codes outside of XRPC error handling)
+4
-1
crates/jacquard-common/src/lib.rs
···
234
234
pub mod websocket;
235
235
236
236
#[cfg(feature = "websocket")]
237
237
-
pub use websocket::{WebSocketClient, WebSocketConnection};
237
237
+
pub use websocket::{
238
238
+
tungstenite_client::TungsteniteClient, CloseCode, CloseFrame, WebSocketClient,
239
239
+
WebSocketConnection, WsMessage, WsSink, WsStream, WsText,
240
240
+
};
238
241
239
242
pub use types::value::*;
240
243
+29
-6
crates/jacquard-common/src/stream.rs
···
64
64
Closed,
65
65
/// Protocol violation or framing error
66
66
Protocol,
67
67
+
/// Message deserialization failed
68
68
+
Decode,
69
69
+
/// Wrong message format (e.g., text frame when expecting binary)
70
70
+
WrongMessageFormat,
67
71
}
68
72
69
73
impl StreamError {
···
105
109
source: Some(msg.into().into()),
106
110
}
107
111
}
112
112
+
113
113
+
/// Create a decode error with source
114
114
+
pub fn decode(source: impl Error + Send + Sync + 'static) -> Self {
115
115
+
Self {
116
116
+
kind: StreamErrorKind::Decode,
117
117
+
source: Some(Box::new(source)),
118
118
+
}
119
119
+
}
120
120
+
121
121
+
/// Create a wrong message format error
122
122
+
pub fn wrong_message_format(msg: impl Into<String>) -> Self {
123
123
+
Self {
124
124
+
kind: StreamErrorKind::WrongMessageFormat,
125
125
+
source: Some(msg.into().into()),
126
126
+
}
127
127
+
}
108
128
}
109
129
110
130
impl fmt::Display for StreamError {
···
113
133
StreamErrorKind::Transport => write!(f, "Transport error"),
114
134
StreamErrorKind::Closed => write!(f, "Stream closed"),
115
135
StreamErrorKind::Protocol => write!(f, "Protocol error"),
136
136
+
StreamErrorKind::Decode => write!(f, "Decode error"),
137
137
+
StreamErrorKind::WrongMessageFormat => write!(f, "Wrong message format"),
116
138
}?;
117
139
118
140
if let Some(source) = &self.source {
···
125
147
126
148
impl Error for StreamError {
127
149
fn source(&self) -> Option<&(dyn Error + 'static)> {
128
128
-
self.source.as_ref().map(|e| e.as_ref() as &(dyn Error + 'static))
150
150
+
self.source
151
151
+
.as_ref()
152
152
+
.map(|e| e.as_ref() as &(dyn Error + 'static))
129
153
}
130
154
}
131
155
···
153
177
}
154
178
155
179
/// Convert into the inner boxed stream
156
156
-
pub fn into_inner(self) -> Box<dyn n0_future::Stream<Item = Result<Bytes, StreamError>> + Unpin> {
180
180
+
pub fn into_inner(
181
181
+
self,
182
182
+
) -> Box<dyn n0_future::Stream<Item = Result<Bytes, StreamError>> + Unpin> {
157
183
self.inner
158
184
}
159
185
}
···
219
245
async fn byte_stream_can_be_created() {
220
246
use futures::stream;
221
247
222
222
-
let data = vec![
223
223
-
Ok(Bytes::from("hello")),
224
224
-
Ok(Bytes::from(" world")),
225
225
-
];
248
248
+
let data = vec![Ok(Bytes::from("hello")), Ok(Bytes::from(" world"))];
226
249
let stream = stream::iter(data);
227
250
228
251
let byte_stream = ByteStream::new(stream);
+15
-1
crates/jacquard-common/src/types/uri.rs
···
6
6
};
7
7
use serde::{Deserialize, Deserializer, Serialize, Serializer};
8
8
use smol_str::ToSmolStr;
9
9
-
use std::{fmt::Display, marker::PhantomData, str::FromStr};
9
9
+
use std::{fmt::Display, marker::PhantomData, ops::Deref, str::FromStr};
10
10
use url::Url;
11
11
12
12
/// Generic URI with type-specific parsing
···
202
202
impl<R: Collection> Display for RecordUri<'_, R> {
203
203
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204
204
self.0.fmt(f)
205
205
+
}
206
206
+
}
207
207
+
208
208
+
impl<'a, R: Collection> AsRef<AtUri<'a>> for RecordUri<'a, R> {
209
209
+
fn as_ref(&self) -> &AtUri<'a> {
210
210
+
&self.0
211
211
+
}
212
212
+
}
213
213
+
214
214
+
impl<'a, R: Collection> Deref for RecordUri<'a, R> {
215
215
+
type Target = AtUri<'a>;
216
216
+
217
217
+
fn deref(&self) -> &Self::Target {
218
218
+
&self.0
205
219
}
206
220
}
207
221
+593
-23
crates/jacquard-common/src/websocket.rs
···
1
1
//! WebSocket client abstraction
2
2
3
3
-
use crate::stream::{ByteStream, ByteSink};
3
3
+
use bytes::Bytes;
4
4
+
use n0_future::Stream;
5
5
+
use n0_future::stream::Boxed;
6
6
+
use std::borrow::Borrow;
7
7
+
use std::fmt::{self, Display};
4
8
use std::future::Future;
9
9
+
use std::ops::Deref;
5
10
use url::Url;
6
11
12
12
+
use crate::CowStr;
13
13
+
use crate::stream::StreamError;
14
14
+
15
15
+
/// UTF-8 validated bytes for WebSocket text messages
16
16
+
#[repr(transparent)]
17
17
+
#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
18
18
+
pub struct WsText(Bytes);
19
19
+
20
20
+
impl WsText {
21
21
+
/// Create from static string
22
22
+
pub const fn from_static(s: &'static str) -> Self {
23
23
+
Self(Bytes::from_static(s.as_bytes()))
24
24
+
}
25
25
+
26
26
+
/// Get as string slice
27
27
+
pub fn as_str(&self) -> &str {
28
28
+
unsafe { std::str::from_utf8_unchecked(&self.0) }
29
29
+
}
30
30
+
31
31
+
/// Create from bytes without validation (caller must ensure UTF-8)
32
32
+
///
33
33
+
/// # Safety
34
34
+
/// Bytes must be valid UTF-8
35
35
+
pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self {
36
36
+
Self(bytes)
37
37
+
}
38
38
+
39
39
+
/// Convert into underlying bytes
40
40
+
pub fn into_bytes(self) -> Bytes {
41
41
+
self.0
42
42
+
}
43
43
+
}
44
44
+
45
45
+
impl Deref for WsText {
46
46
+
type Target = str;
47
47
+
fn deref(&self) -> &str {
48
48
+
self.as_str()
49
49
+
}
50
50
+
}
51
51
+
52
52
+
impl AsRef<str> for WsText {
53
53
+
fn as_ref(&self) -> &str {
54
54
+
self.as_str()
55
55
+
}
56
56
+
}
57
57
+
58
58
+
impl AsRef<[u8]> for WsText {
59
59
+
fn as_ref(&self) -> &[u8] {
60
60
+
&self.0
61
61
+
}
62
62
+
}
63
63
+
64
64
+
impl AsRef<Bytes> for WsText {
65
65
+
fn as_ref(&self) -> &Bytes {
66
66
+
&self.0
67
67
+
}
68
68
+
}
69
69
+
70
70
+
impl Borrow<str> for WsText {
71
71
+
fn borrow(&self) -> &str {
72
72
+
self.as_str()
73
73
+
}
74
74
+
}
75
75
+
76
76
+
impl Display for WsText {
77
77
+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
78
78
+
Display::fmt(self.as_str(), f)
79
79
+
}
80
80
+
}
81
81
+
82
82
+
impl From<String> for WsText {
83
83
+
fn from(s: String) -> Self {
84
84
+
Self(Bytes::from(s))
85
85
+
}
86
86
+
}
87
87
+
88
88
+
impl From<&str> for WsText {
89
89
+
fn from(s: &str) -> Self {
90
90
+
Self(Bytes::copy_from_slice(s.as_bytes()))
91
91
+
}
92
92
+
}
93
93
+
94
94
+
impl From<&String> for WsText {
95
95
+
fn from(s: &String) -> Self {
96
96
+
Self::from(s.as_str())
97
97
+
}
98
98
+
}
99
99
+
100
100
+
impl TryFrom<Bytes> for WsText {
101
101
+
type Error = std::str::Utf8Error;
102
102
+
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
103
103
+
std::str::from_utf8(&bytes)?;
104
104
+
Ok(Self(bytes))
105
105
+
}
106
106
+
}
107
107
+
108
108
+
impl TryFrom<Vec<u8>> for WsText {
109
109
+
type Error = std::str::Utf8Error;
110
110
+
fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> {
111
111
+
Self::try_from(Bytes::from(vec))
112
112
+
}
113
113
+
}
114
114
+
115
115
+
impl From<WsText> for Bytes {
116
116
+
fn from(t: WsText) -> Bytes {
117
117
+
t.0
118
118
+
}
119
119
+
}
120
120
+
121
121
+
impl Default for WsText {
122
122
+
fn default() -> Self {
123
123
+
Self(Bytes::new())
124
124
+
}
125
125
+
}
126
126
+
127
127
+
/// WebSocket close code
128
128
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
129
129
+
#[repr(u16)]
130
130
+
pub enum CloseCode {
131
131
+
/// Normal closure
132
132
+
Normal = 1000,
133
133
+
/// Endpoint going away
134
134
+
Away = 1001,
135
135
+
/// Protocol error
136
136
+
Protocol = 1002,
137
137
+
/// Unsupported data
138
138
+
Unsupported = 1003,
139
139
+
/// Invalid frame payload data
140
140
+
Invalid = 1007,
141
141
+
/// Policy violation
142
142
+
Policy = 1008,
143
143
+
/// Message too big
144
144
+
Size = 1009,
145
145
+
/// Extension negotiation failure
146
146
+
Extension = 1010,
147
147
+
/// Unexpected condition
148
148
+
Error = 1011,
149
149
+
/// TLS handshake failure
150
150
+
Tls = 1015,
151
151
+
/// Other code
152
152
+
Other(u16),
153
153
+
}
154
154
+
155
155
+
impl From<u16> for CloseCode {
156
156
+
fn from(code: u16) -> Self {
157
157
+
match code {
158
158
+
1000 => CloseCode::Normal,
159
159
+
1001 => CloseCode::Away,
160
160
+
1002 => CloseCode::Protocol,
161
161
+
1003 => CloseCode::Unsupported,
162
162
+
1007 => CloseCode::Invalid,
163
163
+
1008 => CloseCode::Policy,
164
164
+
1009 => CloseCode::Size,
165
165
+
1010 => CloseCode::Extension,
166
166
+
1011 => CloseCode::Error,
167
167
+
1015 => CloseCode::Tls,
168
168
+
other => CloseCode::Other(other),
169
169
+
}
170
170
+
}
171
171
+
}
172
172
+
173
173
+
impl From<CloseCode> for u16 {
174
174
+
fn from(code: CloseCode) -> u16 {
175
175
+
match code {
176
176
+
CloseCode::Normal => 1000,
177
177
+
CloseCode::Away => 1001,
178
178
+
CloseCode::Protocol => 1002,
179
179
+
CloseCode::Unsupported => 1003,
180
180
+
CloseCode::Invalid => 1007,
181
181
+
CloseCode::Policy => 1008,
182
182
+
CloseCode::Size => 1009,
183
183
+
CloseCode::Extension => 1010,
184
184
+
CloseCode::Error => 1011,
185
185
+
CloseCode::Tls => 1015,
186
186
+
CloseCode::Other(code) => code,
187
187
+
}
188
188
+
}
189
189
+
}
190
190
+
191
191
+
/// WebSocket close frame
192
192
+
#[derive(Debug, Clone, PartialEq, Eq)]
193
193
+
pub struct CloseFrame<'a> {
194
194
+
/// Close code
195
195
+
pub code: CloseCode,
196
196
+
/// Close reason text
197
197
+
pub reason: CowStr<'a>,
198
198
+
}
199
199
+
200
200
+
impl<'a> CloseFrame<'a> {
201
201
+
/// Create a new close frame
202
202
+
pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self {
203
203
+
Self {
204
204
+
code,
205
205
+
reason: reason.into(),
206
206
+
}
207
207
+
}
208
208
+
}
209
209
+
210
210
+
/// WebSocket message
211
211
+
#[derive(Debug, Clone, PartialEq, Eq)]
212
212
+
pub enum WsMessage {
213
213
+
/// Text message (UTF-8)
214
214
+
Text(WsText),
215
215
+
/// Binary message
216
216
+
Binary(Bytes),
217
217
+
/// Close frame
218
218
+
Close(Option<CloseFrame<'static>>),
219
219
+
}
220
220
+
221
221
+
impl WsMessage {
222
222
+
/// Check if this is a text message
223
223
+
pub fn is_text(&self) -> bool {
224
224
+
matches!(self, WsMessage::Text(_))
225
225
+
}
226
226
+
227
227
+
/// Check if this is a binary message
228
228
+
pub fn is_binary(&self) -> bool {
229
229
+
matches!(self, WsMessage::Binary(_))
230
230
+
}
231
231
+
232
232
+
/// Check if this is a close message
233
233
+
pub fn is_close(&self) -> bool {
234
234
+
matches!(self, WsMessage::Close(_))
235
235
+
}
236
236
+
237
237
+
/// Get as text, if this is a text message
238
238
+
pub fn as_text(&self) -> Option<&str> {
239
239
+
match self {
240
240
+
WsMessage::Text(t) => Some(t.as_str()),
241
241
+
_ => None,
242
242
+
}
243
243
+
}
244
244
+
245
245
+
/// Get as bytes
246
246
+
pub fn as_bytes(&self) -> Option<&[u8]> {
247
247
+
match self {
248
248
+
WsMessage::Text(t) => Some(t.as_ref()),
249
249
+
WsMessage::Binary(b) => Some(b),
250
250
+
WsMessage::Close(_) => None,
251
251
+
}
252
252
+
}
253
253
+
}
254
254
+
255
255
+
impl From<WsText> for WsMessage {
256
256
+
fn from(text: WsText) -> Self {
257
257
+
WsMessage::Text(text)
258
258
+
}
259
259
+
}
260
260
+
261
261
+
impl From<String> for WsMessage {
262
262
+
fn from(s: String) -> Self {
263
263
+
WsMessage::Text(WsText::from(s))
264
264
+
}
265
265
+
}
266
266
+
267
267
+
impl From<&str> for WsMessage {
268
268
+
fn from(s: &str) -> Self {
269
269
+
WsMessage::Text(WsText::from(s))
270
270
+
}
271
271
+
}
272
272
+
273
273
+
impl From<Bytes> for WsMessage {
274
274
+
fn from(bytes: Bytes) -> Self {
275
275
+
WsMessage::Binary(bytes)
276
276
+
}
277
277
+
}
278
278
+
279
279
+
impl From<Vec<u8>> for WsMessage {
280
280
+
fn from(vec: Vec<u8>) -> Self {
281
281
+
WsMessage::Binary(Bytes::from(vec))
282
282
+
}
283
283
+
}
284
284
+
285
285
+
/// WebSocket message stream
286
286
+
pub struct WsStream(Boxed<Result<WsMessage, StreamError>>);
287
287
+
288
288
+
impl WsStream {
289
289
+
/// Create a new message stream
290
290
+
#[cfg(not(target_arch = "wasm32"))]
291
291
+
pub fn new<S>(stream: S) -> Self
292
292
+
where
293
293
+
S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static,
294
294
+
{
295
295
+
Self(Box::pin(stream))
296
296
+
}
297
297
+
298
298
+
/// Create a new message stream
299
299
+
#[cfg(target_arch = "wasm32")]
300
300
+
pub fn new<S>(stream: S) -> Self
301
301
+
where
302
302
+
S: Stream<Item = Result<WsMessage, StreamError>> + 'static,
303
303
+
{
304
304
+
Self(Box::pin(stream))
305
305
+
}
306
306
+
307
307
+
/// Convert into the inner pinned boxed stream
308
308
+
pub fn into_inner(self) -> Boxed<Result<WsMessage, StreamError>> {
309
309
+
self.0
310
310
+
}
311
311
+
}
312
312
+
313
313
+
/// Extension trait for decoding typed messages from WebSocket streams
314
314
+
pub trait WsStreamExt: Sized {
315
315
+
/// Decode JSON text/binary frames into typed messages
316
316
+
///
317
317
+
/// Deserializes messages but does not automatically convert to owned.
318
318
+
/// The caller is responsible for calling `.into_static()` if needed.
319
319
+
fn decode_json<T>(self) -> impl Stream<Item = Result<T, StreamError>>
320
320
+
where
321
321
+
T: for<'de> serde::Deserialize<'de>;
322
322
+
323
323
+
/// Decode DAG-CBOR binary frames into typed messages
324
324
+
///
325
325
+
/// Deserializes messages but does not automatically convert to owned.
326
326
+
/// The caller is responsible for calling `.into_static()` if needed.
327
327
+
fn decode_cbor<T>(self) -> impl Stream<Item = Result<T, StreamError>>
328
328
+
where
329
329
+
T: for<'de> serde::Deserialize<'de>;
330
330
+
}
331
331
+
332
332
+
impl WsStreamExt for WsStream {
333
333
+
fn decode_json<T>(self) -> impl Stream<Item = Result<T, StreamError>>
334
334
+
where
335
335
+
T: for<'de> serde::Deserialize<'de>,
336
336
+
{
337
337
+
use n0_future::StreamExt as _;
338
338
+
339
339
+
Box::pin(self.into_inner().filter_map(|msg_result| match msg_result {
340
340
+
Ok(WsMessage::Text(text)) => {
341
341
+
Some(serde_json::from_slice(text.as_ref()).map_err(StreamError::decode))
342
342
+
}
343
343
+
Ok(WsMessage::Binary(bytes)) => {
344
344
+
Some(serde_json::from_slice(&bytes).map_err(StreamError::decode))
345
345
+
}
346
346
+
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
347
347
+
Err(e) => Some(Err(e)),
348
348
+
}))
349
349
+
}
350
350
+
351
351
+
fn decode_cbor<T>(self) -> impl Stream<Item = Result<T, StreamError>>
352
352
+
where
353
353
+
T: for<'de> serde::Deserialize<'de>,
354
354
+
{
355
355
+
use n0_future::StreamExt as _;
356
356
+
357
357
+
Box::pin(self.into_inner().filter_map(|msg_result| {
358
358
+
match msg_result {
359
359
+
Ok(WsMessage::Binary(bytes)) => Some(
360
360
+
serde_ipld_dagcbor::from_slice(&bytes)
361
361
+
.map_err(|e| StreamError::decode(crate::error::DecodeError::from(e))),
362
362
+
),
363
363
+
Ok(WsMessage::Text(_)) => Some(Err(StreamError::wrong_message_format(
364
364
+
"expected binary frame for CBOR, got text",
365
365
+
))),
366
366
+
Ok(WsMessage::Close(_)) => Some(Err(StreamError::closed())),
367
367
+
Err(e) => Some(Err(e)),
368
368
+
}
369
369
+
}))
370
370
+
}
371
371
+
}
372
372
+
373
373
+
impl fmt::Debug for WsStream {
374
374
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
375
375
+
f.debug_struct("WsStream").finish_non_exhaustive()
376
376
+
}
377
377
+
}
378
378
+
379
379
+
/// WebSocket message sink
380
380
+
pub struct WsSink(Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>);
381
381
+
382
382
+
impl WsSink {
383
383
+
/// Create a new message sink
384
384
+
pub fn new<S>(sink: S) -> Self
385
385
+
where
386
386
+
S: n0_future::Sink<WsMessage, Error = StreamError> + 'static,
387
387
+
{
388
388
+
Self(Box::new(sink))
389
389
+
}
390
390
+
391
391
+
/// Convert into the inner boxed sink
392
392
+
pub fn into_inner(self) -> Box<dyn n0_future::Sink<WsMessage, Error = StreamError>> {
393
393
+
self.0
394
394
+
}
395
395
+
}
396
396
+
397
397
+
impl fmt::Debug for WsSink {
398
398
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399
399
+
f.debug_struct("WsSink").finish_non_exhaustive()
400
400
+
}
401
401
+
}
402
402
+
7
403
/// WebSocket client trait
8
404
#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))]
9
405
pub trait WebSocketClient {
···
11
407
type Error: std::error::Error + Send + Sync + 'static;
12
408
13
409
/// Connect to a WebSocket endpoint
14
14
-
fn connect(
15
15
-
&self,
16
16
-
url: Url,
17
17
-
) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>;
410
410
+
fn connect(&self, url: Url) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>;
18
411
}
19
412
20
413
/// WebSocket connection with bidirectional streams
21
414
pub struct WebSocketConnection {
22
22
-
tx: ByteSink,
23
23
-
rx: ByteStream,
415
415
+
tx: WsSink,
416
416
+
rx: WsStream,
24
417
}
25
418
26
419
impl WebSocketConnection {
27
420
/// Create a new WebSocket connection
28
28
-
pub fn new(tx: ByteSink, rx: ByteStream) -> Self {
421
421
+
pub fn new(tx: WsSink, rx: WsStream) -> Self {
29
422
Self { tx, rx }
30
423
}
31
424
32
425
/// Get mutable access to the sender
33
33
-
pub fn sender_mut(&mut self) -> &mut ByteSink {
426
426
+
pub fn sender_mut(&mut self) -> &mut WsSink {
34
427
&mut self.tx
35
428
}
36
429
37
430
/// Get mutable access to the receiver
38
38
-
pub fn receiver_mut(&mut self) -> &mut ByteStream {
431
431
+
pub fn receiver_mut(&mut self) -> &mut WsStream {
39
432
&mut self.rx
40
433
}
41
434
435
435
+
/// Get a reference to the receiver
436
436
+
pub fn receiver(&self) -> &WsStream {
437
437
+
&self.rx
438
438
+
}
439
439
+
440
440
+
/// Get a reference to the sender
441
441
+
pub fn sender(&self) -> &WsSink {
442
442
+
&self.tx
443
443
+
}
444
444
+
42
445
/// Split into sender and receiver
43
43
-
pub fn split(self) -> (ByteSink, ByteStream) {
446
446
+
pub fn split(self) -> (WsSink, WsStream) {
44
447
(self.tx, self.rx)
45
448
}
46
449
···
50
453
}
51
454
}
52
455
53
53
-
impl std::fmt::Debug for WebSocketConnection {
54
54
-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456
456
+
impl fmt::Debug for WebSocketConnection {
457
457
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55
458
f.debug_struct("WebSocketConnection")
56
459
.finish_non_exhaustive()
57
460
}
58
461
}
59
462
463
463
+
/// Concrete WebSocket client implementation using tokio-tungstenite-wasm
464
464
+
pub mod tungstenite_client {
465
465
+
use super::*;
466
466
+
use crate::IntoStatic;
467
467
+
use futures::{SinkExt, StreamExt};
468
468
+
469
469
+
/// WebSocket client backed by tokio-tungstenite-wasm
470
470
+
#[derive(Debug, Clone, Default)]
471
471
+
pub struct TungsteniteClient;
472
472
+
473
473
+
impl TungsteniteClient {
474
474
+
/// Create a new tungstenite WebSocket client
475
475
+
pub fn new() -> Self {
476
476
+
Self
477
477
+
}
478
478
+
}
479
479
+
480
480
+
impl WebSocketClient for TungsteniteClient {
481
481
+
type Error = tokio_tungstenite_wasm::Error;
482
482
+
483
483
+
async fn connect(&self, url: Url) -> Result<WebSocketConnection, Self::Error> {
484
484
+
let ws_stream = tokio_tungstenite_wasm::connect(url.as_str()).await?;
485
485
+
486
486
+
let (sink, stream) = ws_stream.split();
487
487
+
488
488
+
// Convert tungstenite messages to our WsMessage
489
489
+
let rx_stream = stream.filter_map(|result| async move {
490
490
+
match result {
491
491
+
Ok(msg) => match convert_message(msg) {
492
492
+
Some(ws_msg) => Some(Ok(ws_msg)),
493
493
+
None => None, // Skip ping/pong
494
494
+
},
495
495
+
Err(e) => Some(Err(StreamError::transport(e))),
496
496
+
}
497
497
+
});
498
498
+
499
499
+
let rx = WsStream::new(rx_stream);
500
500
+
501
501
+
// Convert our WsMessage to tungstenite messages
502
502
+
let tx_sink = sink.with(|msg: WsMessage| async move {
503
503
+
Ok::<_, tokio_tungstenite_wasm::Error>(msg.into())
504
504
+
});
505
505
+
506
506
+
let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e));
507
507
+
let tx = WsSink::new(tx_sink_mapped);
508
508
+
509
509
+
Ok(WebSocketConnection::new(tx, rx))
510
510
+
}
511
511
+
}
512
512
+
513
513
+
/// Convert tokio-tungstenite-wasm Message to our WsMessage
514
514
+
/// Returns None for Ping/Pong which we auto-handle
515
515
+
fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> {
516
516
+
use tokio_tungstenite_wasm::Message;
517
517
+
518
518
+
match msg {
519
519
+
Message::Text(vec) => {
520
520
+
// tokio-tungstenite-wasm Text contains Vec<u8> (UTF-8 validated)
521
521
+
let bytes = Bytes::from(vec);
522
522
+
Some(WsMessage::Text(unsafe {
523
523
+
WsText::from_bytes_unchecked(bytes)
524
524
+
}))
525
525
+
}
526
526
+
Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))),
527
527
+
Message::Close(frame) => {
528
528
+
let close_frame = frame.map(|f| {
529
529
+
let code = convert_close_code(f.code);
530
530
+
CloseFrame::new(code, CowStr::from(f.reason.into_owned()))
531
531
+
});
532
532
+
Some(WsMessage::Close(close_frame))
533
533
+
}
534
534
+
}
535
535
+
}
536
536
+
537
537
+
/// Convert tokio-tungstenite-wasm CloseCode to our CloseCode
538
538
+
fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode {
539
539
+
use tokio_tungstenite_wasm::CloseCode as TungsteniteCode;
540
540
+
541
541
+
match code {
542
542
+
TungsteniteCode::Normal => CloseCode::Normal,
543
543
+
TungsteniteCode::Away => CloseCode::Away,
544
544
+
TungsteniteCode::Protocol => CloseCode::Protocol,
545
545
+
TungsteniteCode::Unsupported => CloseCode::Unsupported,
546
546
+
TungsteniteCode::Invalid => CloseCode::Invalid,
547
547
+
TungsteniteCode::Policy => CloseCode::Policy,
548
548
+
TungsteniteCode::Size => CloseCode::Size,
549
549
+
TungsteniteCode::Extension => CloseCode::Extension,
550
550
+
TungsteniteCode::Error => CloseCode::Error,
551
551
+
TungsteniteCode::Tls => CloseCode::Tls,
552
552
+
// For other variants, extract raw code
553
553
+
other => {
554
554
+
let raw: u16 = other.into();
555
555
+
CloseCode::from(raw)
556
556
+
}
557
557
+
}
558
558
+
}
559
559
+
560
560
+
impl From<WsMessage> for tokio_tungstenite_wasm::Message {
561
561
+
fn from(msg: WsMessage) -> Self {
562
562
+
use tokio_tungstenite_wasm::Message;
563
563
+
564
564
+
match msg {
565
565
+
WsMessage::Text(text) => {
566
566
+
// tokio-tungstenite-wasm Text expects String
567
567
+
let bytes = text.into_bytes();
568
568
+
// Safe: WsText is already UTF-8 validated
569
569
+
let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) };
570
570
+
Message::Text(string)
571
571
+
}
572
572
+
WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()),
573
573
+
WsMessage::Close(frame) => {
574
574
+
let close_frame = frame.map(|f| {
575
575
+
let code = u16::from(f.code).into();
576
576
+
tokio_tungstenite_wasm::CloseFrame {
577
577
+
code,
578
578
+
reason: f.reason.into_static().to_string().into(),
579
579
+
}
580
580
+
});
581
581
+
Message::Close(close_frame)
582
582
+
}
583
583
+
}
584
584
+
}
585
585
+
}
586
586
+
}
587
587
+
60
588
#[cfg(test)]
61
589
mod tests {
62
590
use super::*;
63
63
-
use crate::stream::StreamError;
591
591
+
592
592
+
#[test]
593
593
+
fn ws_text_from_string() {
594
594
+
let text = WsText::from("hello");
595
595
+
assert_eq!(text.as_str(), "hello");
596
596
+
}
597
597
+
598
598
+
#[test]
599
599
+
fn ws_text_deref() {
600
600
+
let text = WsText::from(String::from("world"));
601
601
+
assert_eq!(&*text, "world");
602
602
+
}
603
603
+
604
604
+
#[test]
605
605
+
fn ws_text_try_from_bytes() {
606
606
+
let bytes = Bytes::from("test");
607
607
+
let text = WsText::try_from(bytes).unwrap();
608
608
+
assert_eq!(text.as_str(), "test");
609
609
+
}
610
610
+
611
611
+
#[test]
612
612
+
fn ws_text_invalid_utf8() {
613
613
+
let bytes = Bytes::from(vec![0xFF, 0xFE]);
614
614
+
assert!(WsText::try_from(bytes).is_err());
615
615
+
}
616
616
+
617
617
+
#[test]
618
618
+
fn ws_message_text() {
619
619
+
let msg = WsMessage::from("hello");
620
620
+
assert!(msg.is_text());
621
621
+
assert_eq!(msg.as_text(), Some("hello"));
622
622
+
}
623
623
+
624
624
+
#[test]
625
625
+
fn ws_message_binary() {
626
626
+
let msg = WsMessage::from(vec![1, 2, 3]);
627
627
+
assert!(msg.is_binary());
628
628
+
assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..]));
629
629
+
}
630
630
+
631
631
+
#[test]
632
632
+
fn close_code_conversion() {
633
633
+
assert_eq!(u16::from(CloseCode::Normal), 1000);
634
634
+
assert_eq!(CloseCode::from(1000), CloseCode::Normal);
635
635
+
assert_eq!(CloseCode::from(9999), CloseCode::Other(9999));
636
636
+
}
64
637
65
638
#[test]
66
639
fn websocket_connection_has_tx_and_rx() {
67
67
-
use futures::stream;
68
640
use futures::sink::SinkExt;
69
69
-
use bytes::Bytes;
641
641
+
use futures::stream;
70
642
71
71
-
let rx_stream = stream::iter(vec![Ok(Bytes::from("test"))]);
72
72
-
let rx = ByteStream::new(rx_stream);
643
643
+
let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]);
644
644
+
let rx = WsStream::new(rx_stream);
73
645
74
74
-
// Create a sink that converts Infallible to StreamError
75
75
-
let drain_sink = futures::sink::drain().sink_map_err(|_: std::convert::Infallible| {
76
76
-
StreamError::closed()
77
77
-
});
78
78
-
let tx = ByteSink::new(drain_sink);
646
646
+
let drain_sink = futures::sink::drain()
647
647
+
.sink_map_err(|_: std::convert::Infallible| StreamError::closed());
648
648
+
let tx = WsSink::new(drain_sink);
79
649
80
650
let conn = WebSocketConnection::new(tx, rx);
81
651
assert!(conn.is_open());
+6
crates/jacquard-common/src/xrpc.rs
···
16
16
#[cfg(feature = "streaming")]
17
17
pub use streaming::StreamingResponse;
18
18
19
19
+
#[cfg(feature = "websocket")]
20
20
+
pub mod subscription;
21
21
+
22
22
+
#[cfg(feature = "websocket")]
23
23
+
pub use subscription::{MessageEncoding, SubscriptionEndpoint, SubscriptionResp, XrpcSubscription};
24
24
+
19
25
use bytes::Bytes;
20
26
use http::{
21
27
HeaderName, HeaderValue, Request, StatusCode,
+95
crates/jacquard-common/src/xrpc/subscription.rs
···
1
1
+
//! WebSocket subscription support for XRPC
2
2
+
//!
3
3
+
//! This module defines traits and types for typed WebSocket subscriptions,
4
4
+
//! mirroring the request/response pattern used for HTTP XRPC endpoints.
5
5
+
6
6
+
use serde::{Deserialize, Serialize};
7
7
+
use std::error::Error;
8
8
+
9
9
+
use crate::IntoStatic;
10
10
+
11
11
+
/// Encoding format for subscription messages
12
12
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13
13
+
pub enum MessageEncoding {
14
14
+
/// JSON text frames
15
15
+
Json,
16
16
+
/// DAG-CBOR binary frames
17
17
+
DagCbor,
18
18
+
}
19
19
+
20
20
+
/// XRPC subscription stream response trait
21
21
+
///
22
22
+
/// Analogous to `XrpcResp` but for WebSocket subscriptions.
23
23
+
/// Defines the message and error types for a subscription stream.
24
24
+
///
25
25
+
/// This trait is implemented on a marker struct to keep it lifetime-free
26
26
+
/// while using GATs for the message/error types.
27
27
+
pub trait SubscriptionResp {
28
28
+
/// The NSID for this subscription
29
29
+
const NSID: &'static str;
30
30
+
31
31
+
/// Message encoding (JSON or DAG-CBOR)
32
32
+
const ENCODING: MessageEncoding;
33
33
+
34
34
+
/// Message union type
35
35
+
type Message<'de>: Deserialize<'de> + IntoStatic;
36
36
+
37
37
+
/// Error union type
38
38
+
type Error<'de>: Error + Deserialize<'de> + IntoStatic;
39
39
+
}
40
40
+
41
41
+
/// XRPC subscription (WebSocket)
42
42
+
///
43
43
+
/// This trait is analogous to `XrpcRequest` but for WebSocket subscriptions.
44
44
+
/// It defines the NSID and associated stream response type.
45
45
+
///
46
46
+
/// The trait is implemented on the subscription parameters type.
47
47
+
pub trait XrpcSubscription: Serialize {
48
48
+
/// The NSID for this XRPC subscription
49
49
+
const NSID: &'static str;
50
50
+
51
51
+
/// Message encoding (JSON or DAG-CBOR)
52
52
+
const ENCODING: MessageEncoding;
53
53
+
54
54
+
/// Stream response type (marker struct)
55
55
+
type Stream: SubscriptionResp;
56
56
+
57
57
+
/// Encode query params for WebSocket URL
58
58
+
///
59
59
+
/// Default implementation uses serde_html_form to encode the struct as query parameters.
60
60
+
fn query_params(&self) -> Vec<(String, String)> {
61
61
+
// Default: use serde_html_form to encode self
62
62
+
serde_html_form::to_string(self)
63
63
+
.ok()
64
64
+
.map(|s| {
65
65
+
s.split('&')
66
66
+
.filter_map(|pair| {
67
67
+
let mut parts = pair.splitn(2, '=');
68
68
+
Some((parts.next()?.to_string(), parts.next()?.to_string()))
69
69
+
})
70
70
+
.collect()
71
71
+
})
72
72
+
.unwrap_or_default()
73
73
+
}
74
74
+
}
75
75
+
76
76
+
/// XRPC subscription endpoint trait (server-side)
77
77
+
///
78
78
+
/// Analogous to `XrpcEndpoint` but for WebSocket subscriptions.
79
79
+
/// Defines the fully-qualified path and associated parameter/stream types.
80
80
+
///
81
81
+
/// This exists primarily for server-side frameworks (like Axum) to extract
82
82
+
/// typed subscription parameters without lifetime issues.
83
83
+
pub trait SubscriptionEndpoint {
84
84
+
/// Fully-qualified path ('/xrpc/[nsid]') where this subscription endpoint lives
85
85
+
const PATH: &'static str;
86
86
+
87
87
+
/// Message encoding (JSON or DAG-CBOR)
88
88
+
const ENCODING: MessageEncoding;
89
89
+
90
90
+
/// Subscription parameters type
91
91
+
type Params<'de>: XrpcSubscription + Deserialize<'de> + IntoStatic;
92
92
+
93
93
+
/// Stream response type
94
94
+
type Stream: SubscriptionResp;
95
95
+
}
+254
-26
crates/jacquard-lexicon/src/codegen/xrpc.rs
···
7
7
use proc_macro2::TokenStream;
8
8
use quote::quote;
9
9
10
10
-
use super::utils::make_ident;
11
10
use super::CodeGenerator;
11
11
+
use super::utils::make_ident;
12
12
13
13
impl<'c> CodeGenerator<'c> {
14
14
/// Generate query type
···
178
178
output.push(error_enum);
179
179
}
180
180
181
181
+
// Generate XrpcSubscription trait impl
182
182
+
let params_has_lifetime = sub
183
183
+
.parameters
184
184
+
.as_ref()
185
185
+
.map(|p| match p {
186
186
+
crate::lexicon::LexXrpcSubscriptionParameter::Params(params) => {
187
187
+
self.params_need_lifetime(params)
188
188
+
}
189
189
+
})
190
190
+
.unwrap_or(false);
191
191
+
192
192
+
let has_params = sub.parameters.is_some();
193
193
+
let has_message = sub.message.is_some();
194
194
+
let has_errors = sub.errors.is_some();
195
195
+
196
196
+
let subscription_impl = self.generate_xrpc_subscription_impl(
197
197
+
nsid,
198
198
+
&type_base,
199
199
+
has_params,
200
200
+
params_has_lifetime,
201
201
+
has_message,
202
202
+
has_errors,
203
203
+
)?;
204
204
+
output.push(subscription_impl);
205
205
+
181
206
Ok(quote! {
182
207
#(#output)*
183
208
})
···
200
225
let mut variants = Vec::new();
201
226
for ref_str in &union.refs {
202
227
let ref_str_s = ref_str.as_ref();
228
228
+
229
229
+
// Normalize local refs (starting with #) by prepending current NSID
230
230
+
let normalized_ref = if ref_str.starts_with('#') {
231
231
+
format!("{}{}", nsid, ref_str)
232
232
+
} else {
233
233
+
ref_str.to_string()
234
234
+
};
235
235
+
203
236
// Parse ref to get NSID and def name
204
237
let (ref_nsid, ref_def) =
205
205
-
if let Some((nsid, fragment)) = ref_str.split_once('#') {
206
206
-
(nsid, fragment)
238
238
+
if let Some((nsid_part, fragment)) = normalized_ref.split_once('#') {
239
239
+
(nsid_part, fragment)
207
240
} else {
208
208
-
(ref_str.as_ref(), "main")
241
241
+
(normalized_ref.as_str(), "main")
209
242
};
210
243
211
244
let variant_name = if ref_def == "main" {
···
215
248
};
216
249
let variant_ident =
217
250
syn::Ident::new(&variant_name, proc_macro2::Span::call_site());
218
218
-
let type_path = self.ref_to_rust_type(ref_str)?;
251
251
+
let type_path = self.ref_to_rust_type(&normalized_ref)?;
219
252
220
253
variants.push(quote! {
221
254
#[serde(rename = #ref_str_s)]
···
262
295
match field_type {
263
296
LexObjectProperty::Union(union) => {
264
297
// Skip empty, single-variant unions unless they're self-referential
265
265
-
if !union.refs.is_empty() && (union.refs.len() > 1 || self.is_self_referential_union(nsid, &struct_name, union)) {
266
266
-
let union_name = self.generate_field_type_name(nsid, &struct_name, field_name, "");
298
298
+
if !union.refs.is_empty()
299
299
+
&& (union.refs.len() > 1
300
300
+
|| self.is_self_referential_union(nsid, &struct_name, union))
301
301
+
{
302
302
+
let union_name = self.generate_field_type_name(
303
303
+
nsid,
304
304
+
&struct_name,
305
305
+
field_name,
306
306
+
"",
307
307
+
);
267
308
let refs: Vec<_> = union.refs.iter().cloned().collect();
268
268
-
let union_def =
269
269
-
self.generate_union(nsid, &union_name, &refs, None, union.closed)?;
309
309
+
let union_def = self.generate_union(
310
310
+
nsid,
311
311
+
&union_name,
312
312
+
&refs,
313
313
+
None,
314
314
+
union.closed,
315
315
+
)?;
270
316
unions.push(union_def);
271
317
}
272
318
}
···
274
320
if let LexArrayItem::Union(union) = &array.items {
275
321
// Skip single-variant array unions
276
322
if union.refs.len() > 1 {
277
277
-
let union_name = self.generate_field_type_name(nsid, &struct_name, field_name, "Item");
323
323
+
let union_name = self.generate_field_type_name(
324
324
+
nsid,
325
325
+
&struct_name,
326
326
+
field_name,
327
327
+
"Item",
328
328
+
);
278
329
let refs: Vec<_> = union.refs.iter().cloned().collect();
279
279
-
let union_def = self.generate_union(nsid, &union_name, &refs, None, union.closed)?;
330
330
+
let union_def = self.generate_union(
331
331
+
nsid,
332
332
+
&union_name,
333
333
+
&refs,
334
334
+
None,
335
335
+
union.closed,
336
336
+
)?;
280
337
unions.push(union_def);
281
338
}
282
339
}
···
415
472
let (has_default, has_builder) = if is_binary_body {
416
473
(false, true)
417
474
} else if let Some(crate::lexicon::LexXrpcBodySchema::Object(obj)) = &body.schema {
418
418
-
use crate::codegen::structs::{count_required_fields, all_required_are_defaultable_strings, conflicts_with_builder_macro};
475
475
+
use crate::codegen::structs::{
476
476
+
all_required_are_defaultable_strings, conflicts_with_builder_macro,
477
477
+
count_required_fields,
478
478
+
};
419
479
let required_count = count_required_fields(obj);
420
480
let can_default = required_count == 0 || all_required_are_defaultable_strings(obj);
421
421
-
let can_builder = required_count >= 1 && !can_default && !conflicts_with_builder_macro(type_base);
481
481
+
let can_builder =
482
482
+
required_count >= 1 && !can_default && !conflicts_with_builder_macro(type_base);
422
483
(can_default, can_builder)
423
484
} else {
424
485
(false, false)
···
495
556
match field_type {
496
557
LexObjectProperty::Union(union) => {
497
558
// Skip empty, single-variant unions unless they're self-referential
498
498
-
if !union.refs.is_empty() && (union.refs.len() > 1 || self.is_self_referential_union(nsid, type_base, union)) {
499
499
-
let union_name = self.generate_field_type_name(nsid, type_base, field_name, "");
559
559
+
if !union.refs.is_empty()
560
560
+
&& (union.refs.len() > 1
561
561
+
|| self.is_self_referential_union(nsid, type_base, union))
562
562
+
{
563
563
+
let union_name =
564
564
+
self.generate_field_type_name(nsid, type_base, field_name, "");
500
565
let refs: Vec<_> = union.refs.iter().cloned().collect();
501
566
let union_def =
502
567
self.generate_union(nsid, &union_name, &refs, None, union.closed)?;
···
507
572
if let LexArrayItem::Union(union) = &array.items {
508
573
// Skip single-variant array unions
509
574
if union.refs.len() > 1 {
510
510
-
let union_name = self.generate_field_type_name(nsid, type_base, field_name, "Item");
575
575
+
let union_name = self
576
576
+
.generate_field_type_name(nsid, type_base, field_name, "Item");
511
577
let refs: Vec<_> = union.refs.iter().cloned().collect();
512
512
-
let union_def = self.generate_union(nsid, &union_name, &refs, None, union.closed)?;
578
578
+
let union_def = self.generate_union(
579
579
+
nsid,
580
580
+
&union_name,
581
581
+
&refs,
582
582
+
None,
583
583
+
union.closed,
584
584
+
)?;
513
585
unions.push(union_def);
514
586
}
515
587
}
···
545
617
546
618
// Determine if we should derive Default
547
619
// Check if schema is an Object and apply heuristics
548
548
-
let has_default = if let Some(crate::lexicon::LexXrpcBodySchema::Object(obj)) = &body.schema {
549
549
-
use crate::codegen::structs::{count_required_fields, all_required_are_defaultable_strings};
620
620
+
let has_default = if let Some(crate::lexicon::LexXrpcBodySchema::Object(obj)) = &body.schema
621
621
+
{
622
622
+
use crate::codegen::structs::{
623
623
+
all_required_are_defaultable_strings, count_required_fields,
624
624
+
};
550
625
let required_count = count_required_fields(obj);
551
626
required_count == 0 || all_required_are_defaultable_strings(obj)
552
627
} else {
···
584
659
match field_type {
585
660
LexObjectProperty::Union(union) => {
586
661
// Skip single-variant unions unless they're self-referential
587
587
-
if union.refs.len() > 1 || self.is_self_referential_union(nsid, &struct_name, union) {
588
588
-
let union_name = self.generate_field_type_name(nsid, &struct_name, field_name, "");
662
662
+
if union.refs.len() > 1
663
663
+
|| self.is_self_referential_union(nsid, &struct_name, union)
664
664
+
{
665
665
+
let union_name =
666
666
+
self.generate_field_type_name(nsid, &struct_name, field_name, "");
589
667
let refs: Vec<_> = union.refs.iter().cloned().collect();
590
668
let union_def =
591
669
self.generate_union(nsid, &union_name, &refs, None, union.closed)?;
···
596
674
if let LexArrayItem::Union(union) = &array.items {
597
675
// Skip single-variant array unions
598
676
if union.refs.len() > 1 {
599
599
-
let union_name = self.generate_field_type_name(nsid, &struct_name, field_name, "Item");
677
677
+
let union_name = self.generate_field_type_name(
678
678
+
nsid,
679
679
+
&struct_name,
680
680
+
field_name,
681
681
+
"Item",
682
682
+
);
600
683
let refs: Vec<_> = union.refs.iter().cloned().collect();
601
601
-
let union_def = self.generate_union(nsid, &union_name, &refs, None, union.closed)?;
684
684
+
let union_def = self.generate_union(
685
685
+
nsid,
686
686
+
&union_name,
687
687
+
&refs,
688
688
+
None,
689
689
+
union.closed,
690
690
+
)?;
602
691
unions.push(union_def);
603
692
}
604
693
}
···
953
1042
);
954
1043
955
1044
let response_type = quote! {
956
956
-
#[doc = "Response type for "]
1045
1045
+
#[doc = " Response type for "]
957
1046
#[doc = #nsid]
958
1047
pub struct #response_ident;
959
1048
···
1027
1116
#decode_body_method
1028
1117
}
1029
1118
1030
1030
-
#[doc = "Endpoint type for "]
1119
1119
+
#[doc = " Endpoint type for "]
1031
1120
#[doc = #nsid]
1032
1121
pub struct #endpoint_ident;
1033
1122
···
1057
1146
type Response = #response_ident;
1058
1147
}
1059
1148
1060
1060
-
#[doc = "Endpoint type for "]
1149
1149
+
#[doc = " Endpoint type for "]
1061
1150
#[doc = #nsid]
1062
1151
pub struct #endpoint_ident;
1063
1152
···
1070
1159
}
1071
1160
})
1072
1161
}
1162
1162
+
}
1163
1163
+
1164
1164
+
/// Generate XrpcSubscription trait impl for a subscription endpoint
1165
1165
+
pub(super) fn generate_xrpc_subscription_impl(
1166
1166
+
&self,
1167
1167
+
nsid: &str,
1168
1168
+
type_base: &str,
1169
1169
+
has_params: bool,
1170
1170
+
params_has_lifetime: bool,
1171
1171
+
has_message: bool,
1172
1172
+
has_errors: bool,
1173
1173
+
) -> Result<TokenStream> {
1174
1174
+
// Generate stream response marker struct
1175
1175
+
let stream_ident = syn::Ident::new(
1176
1176
+
&format!("{}Stream", type_base),
1177
1177
+
proc_macro2::Span::call_site(),
1178
1178
+
);
1179
1179
+
1180
1180
+
let message_type = if has_message {
1181
1181
+
let msg_ident = syn::Ident::new(
1182
1182
+
&format!("{}Message", type_base),
1183
1183
+
proc_macro2::Span::call_site(),
1184
1184
+
);
1185
1185
+
quote! { #msg_ident<'de> }
1186
1186
+
} else {
1187
1187
+
quote! { () }
1188
1188
+
};
1189
1189
+
1190
1190
+
let error_type = if has_errors {
1191
1191
+
let err_ident = syn::Ident::new(
1192
1192
+
&format!("{}Error", type_base),
1193
1193
+
proc_macro2::Span::call_site(),
1194
1194
+
);
1195
1195
+
quote! { #err_ident<'de> }
1196
1196
+
} else {
1197
1197
+
quote! { jacquard_common::xrpc::GenericError<'de> }
1198
1198
+
};
1199
1199
+
1200
1200
+
// Determine encoding from nsid convention
1201
1201
+
// ATProto subscriptions use DAG-CBOR, community ones might use JSON
1202
1202
+
let encoding = if nsid.starts_with("com.atproto") {
1203
1203
+
quote! { jacquard_common::xrpc::MessageEncoding::DagCbor }
1204
1204
+
} else {
1205
1205
+
quote! { jacquard_common::xrpc::MessageEncoding::Json }
1206
1206
+
};
1207
1207
+
1208
1208
+
// Generate SubscriptionResp impl
1209
1209
+
let stream_resp_impl = quote! {
1210
1210
+
#[doc = "Stream response type for "]
1211
1211
+
#[doc = #nsid]
1212
1212
+
pub struct #stream_ident;
1213
1213
+
1214
1214
+
impl jacquard_common::xrpc::SubscriptionResp for #stream_ident {
1215
1215
+
const NSID: &'static str = #nsid;
1216
1216
+
const ENCODING: jacquard_common::xrpc::MessageEncoding = #encoding;
1217
1217
+
1218
1218
+
type Message<'de> = #message_type;
1219
1219
+
type Error<'de> = #error_type;
1220
1220
+
}
1221
1221
+
};
1222
1222
+
1223
1223
+
let params_ident = if has_params {
1224
1224
+
syn::Ident::new(type_base, proc_macro2::Span::call_site())
1225
1225
+
} else {
1226
1226
+
// Generate marker struct if no params
1227
1227
+
let marker = syn::Ident::new(type_base, proc_macro2::Span::call_site());
1228
1228
+
let endpoint_ident = syn::Ident::new(
1229
1229
+
&format!("{}Endpoint", type_base),
1230
1230
+
proc_macro2::Span::call_site(),
1231
1231
+
);
1232
1232
+
let endpoint_path = format!("/xrpc/{}", nsid);
1233
1233
+
1234
1234
+
return Ok(quote! {
1235
1235
+
#stream_resp_impl
1236
1236
+
1237
1237
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize)]
1238
1238
+
pub struct #marker;
1239
1239
+
1240
1240
+
impl jacquard_common::xrpc::XrpcSubscription for #marker {
1241
1241
+
const NSID: &'static str = #nsid;
1242
1242
+
const ENCODING: jacquard_common::xrpc::MessageEncoding = #encoding;
1243
1243
+
1244
1244
+
type Stream = #stream_ident;
1245
1245
+
}
1246
1246
+
1247
1247
+
pub struct #endpoint_ident;
1248
1248
+
1249
1249
+
impl jacquard_common::xrpc::SubscriptionEndpoint for #endpoint_ident {
1250
1250
+
const PATH: &'static str = #endpoint_path;
1251
1251
+
const ENCODING: jacquard_common::xrpc::MessageEncoding = #encoding;
1252
1252
+
1253
1253
+
type Params<'de> = #marker;
1254
1254
+
type Stream = #stream_ident;
1255
1255
+
}
1256
1256
+
});
1257
1257
+
};
1258
1258
+
1259
1259
+
let (impl_generics, impl_target, endpoint_params_type) =
1260
1260
+
if has_params && params_has_lifetime {
1261
1261
+
(
1262
1262
+
quote! { <'a> },
1263
1263
+
quote! { #params_ident<'a> },
1264
1264
+
quote! { #params_ident<'de> },
1265
1265
+
)
1266
1266
+
} else {
1267
1267
+
(
1268
1268
+
quote! {},
1269
1269
+
quote! { #params_ident },
1270
1270
+
quote! { #params_ident },
1271
1271
+
)
1272
1272
+
};
1273
1273
+
1274
1274
+
let endpoint_ident = syn::Ident::new(
1275
1275
+
&format!("{}Endpoint", type_base),
1276
1276
+
proc_macro2::Span::call_site(),
1277
1277
+
);
1278
1278
+
1279
1279
+
let endpoint_path = format!("/xrpc/{}", nsid);
1280
1280
+
1281
1281
+
Ok(quote! {
1282
1282
+
#stream_resp_impl
1283
1283
+
1284
1284
+
impl #impl_generics jacquard_common::xrpc::XrpcSubscription for #impl_target {
1285
1285
+
const NSID: &'static str = #nsid;
1286
1286
+
const ENCODING: jacquard_common::xrpc::MessageEncoding = #encoding;
1287
1287
+
1288
1288
+
type Stream = #stream_ident;
1289
1289
+
}
1290
1290
+
1291
1291
+
pub struct #endpoint_ident;
1292
1292
+
1293
1293
+
impl jacquard_common::xrpc::SubscriptionEndpoint for #endpoint_ident {
1294
1294
+
const PATH: &'static str = #endpoint_path;
1295
1295
+
const ENCODING: jacquard_common::xrpc::MessageEncoding = #encoding;
1296
1296
+
1297
1297
+
type Params<'de> = #endpoint_params_type;
1298
1298
+
type Stream = #stream_ident;
1299
1299
+
}
1300
1300
+
})
1073
1301
}
1074
1302
}