this repo has no description
1mod common;
2use common::*;
3use cid::Cid;
4use futures::{stream::StreamExt, SinkExt};
5use iroh_car::CarReader;
6use reqwest::StatusCode;
7use serde::Deserialize;
8use serde_json::{json, Value};
9use std::io::Cursor;
10use tokio_tungstenite::{connect_async, tungstenite};
11
12#[derive(Debug, Deserialize)]
13struct FrameHeader {
14 op: i64,
15 t: String,
16}
17
18#[derive(Debug, Deserialize)]
19struct CommitFrame {
20 seq: i64,
21 rebase: bool,
22 #[serde(rename = "tooBig")]
23 too_big: bool,
24 repo: String,
25 commit: Cid,
26 rev: String,
27 since: Option<String>,
28 #[serde(with = "serde_bytes")]
29 blocks: Vec<u8>,
30 ops: Vec<RepoOp>,
31 blobs: Vec<Cid>,
32 time: String,
33}
34
35#[derive(Debug, Deserialize)]
36struct RepoOp {
37 action: String,
38 path: String,
39 cid: Option<Cid>,
40}
41
42fn find_cbor_map_end(bytes: &[u8]) -> Result<usize, String> {
43 let mut pos = 0;
44 fn read_uint(bytes: &[u8], pos: &mut usize, additional: u8) -> Result<u64, String> {
45 match additional {
46 0..=23 => Ok(additional as u64),
47 24 => {
48 if *pos >= bytes.len() { return Err("Unexpected end".into()); }
49 let val = bytes[*pos] as u64;
50 *pos += 1;
51 Ok(val)
52 }
53 25 => {
54 if *pos + 2 > bytes.len() { return Err("Unexpected end".into()); }
55 let val = u16::from_be_bytes([bytes[*pos], bytes[*pos + 1]]) as u64;
56 *pos += 2;
57 Ok(val)
58 }
59 26 => {
60 if *pos + 4 > bytes.len() { return Err("Unexpected end".into()); }
61 let val = u32::from_be_bytes([bytes[*pos], bytes[*pos + 1], bytes[*pos + 2], bytes[*pos + 3]]) as u64;
62 *pos += 4;
63 Ok(val)
64 }
65 27 => {
66 if *pos + 8 > bytes.len() { return Err("Unexpected end".into()); }
67 let val = u64::from_be_bytes([bytes[*pos], bytes[*pos + 1], bytes[*pos + 2], bytes[*pos + 3], bytes[*pos + 4], bytes[*pos + 5], bytes[*pos + 6], bytes[*pos + 7]]);
68 *pos += 8;
69 Ok(val)
70 }
71 _ => Err(format!("Invalid additional info: {}", additional)),
72 }
73 }
74 fn skip_value(bytes: &[u8], pos: &mut usize) -> Result<(), String> {
75 if *pos >= bytes.len() { return Err("Unexpected end".into()); }
76 let initial = bytes[*pos];
77 *pos += 1;
78 let major = initial >> 5;
79 let additional = initial & 0x1f;
80 match major {
81 0 | 1 => { read_uint(bytes, pos, additional)?; Ok(()) }
82 2 | 3 => {
83 let len = read_uint(bytes, pos, additional)? as usize;
84 *pos += len;
85 Ok(())
86 }
87 4 => {
88 let len = read_uint(bytes, pos, additional)?;
89 for _ in 0..len { skip_value(bytes, pos)?; }
90 Ok(())
91 }
92 5 => {
93 let len = read_uint(bytes, pos, additional)?;
94 for _ in 0..len {
95 skip_value(bytes, pos)?;
96 skip_value(bytes, pos)?;
97 }
98 Ok(())
99 }
100 6 => {
101 read_uint(bytes, pos, additional)?;
102 skip_value(bytes, pos)
103 }
104 7 => Ok(()),
105 _ => Err(format!("Unknown major type: {}", major)),
106 }
107 }
108 skip_value(bytes, &mut pos)?;
109 Ok(pos)
110}
111
112fn parse_frame(bytes: &[u8]) -> Result<(FrameHeader, CommitFrame), String> {
113 let header_len = find_cbor_map_end(bytes)?;
114 let header: FrameHeader = serde_ipld_dagcbor::from_slice(&bytes[..header_len])
115 .map_err(|e| format!("Failed to parse header: {:?}", e))?;
116 let remaining = &bytes[header_len..];
117 let frame: CommitFrame = serde_ipld_dagcbor::from_slice(remaining)
118 .map_err(|e| format!("Failed to parse commit frame: {:?}", e))?;
119 Ok((header, frame))
120}
121
122#[tokio::test]
123async fn test_firehose_subscription() {
124 let client = client();
125 let (token, did) = create_account_and_login(&client).await;
126 let url = format!(
127 "ws://127.0.0.1:{}/xrpc/com.atproto.sync.subscribeRepos",
128 app_port()
129 );
130 let (mut ws_stream, _) = connect_async(&url).await.expect("Failed to connect");
131 let post_text = "Hello from the firehose test!";
132 let post_payload = json!({
133 "repo": did,
134 "collection": "app.bsky.feed.post",
135 "record": {
136 "$type": "app.bsky.feed.post",
137 "text": post_text,
138 "createdAt": chrono::Utc::now().to_rfc3339(),
139 }
140 });
141 let res = client
142 .post(format!(
143 "{}/xrpc/com.atproto.repo.createRecord",
144 base_url().await
145 ))
146 .bearer_auth(token)
147 .json(&post_payload)
148 .send()
149 .await
150 .expect("Failed to create post");
151 assert_eq!(res.status(), StatusCode::OK);
152 let mut frame_opt: Option<(FrameHeader, CommitFrame)> = None;
153 let timeout = tokio::time::timeout(std::time::Duration::from_secs(5), async {
154 loop {
155 let msg = ws_stream.next().await.unwrap().unwrap();
156 let raw_bytes = match msg {
157 tungstenite::Message::Binary(bin) => bin,
158 _ => continue,
159 };
160 if let Ok((h, f)) = parse_frame(&raw_bytes) {
161 if f.repo == did {
162 frame_opt = Some((h, f));
163 break;
164 }
165 }
166 }
167 })
168 .await;
169 assert!(timeout.is_ok(), "Timed out waiting for event for our DID");
170 let (header, commit) = frame_opt.expect("No matching frame found");
171 assert_eq!(header.op, 1);
172 assert_eq!(header.t, "#commit");
173 assert_eq!(commit.ops.len(), 1);
174 assert!(!commit.blocks.is_empty());
175 let op = &commit.ops[0];
176 let record_cid = op.cid.clone().expect("Op should have CID");
177 let mut car_reader = CarReader::new(Cursor::new(&commit.blocks)).await.unwrap();
178 let mut record_block: Option<Vec<u8>> = None;
179 while let Ok(Some((cid, block))) = car_reader.next_block().await {
180 if cid == record_cid {
181 record_block = Some(block);
182 break;
183 }
184 }
185 let record_block = record_block.expect("Record block not found in CAR");
186 let record: Value = serde_ipld_dagcbor::from_slice(&record_block).unwrap();
187 assert_eq!(record["text"], post_text);
188 ws_stream
189 .send(tungstenite::Message::Close(None))
190 .await
191 .ok();
192}