this repo has no description
1use crate::api::repo::record::utils::{commit_and_log, RecordOp};
2use crate::repo::tracking::TrackingBlockStore;
3use crate::state::AppState;
4use axum::{
5 extract::State,
6 http::StatusCode,
7 response::{IntoResponse, Response},
8 Json,
9};
10use chrono::Utc;
11use cid::Cid;
12use jacquard::types::string::Nsid;
13use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
14use serde::{Deserialize, Serialize};
15use serde_json::json;
16use std::str::FromStr;
17use std::sync::Arc;
18use tracing::error;
19
20#[derive(Deserialize)]
21#[serde(tag = "$type")]
22pub enum WriteOp {
23 #[serde(rename = "com.atproto.repo.applyWrites#create")]
24 Create {
25 collection: String,
26 rkey: Option<String>,
27 value: serde_json::Value,
28 },
29 #[serde(rename = "com.atproto.repo.applyWrites#update")]
30 Update {
31 collection: String,
32 rkey: String,
33 value: serde_json::Value,
34 },
35 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
36 Delete { collection: String, rkey: String },
37}
38
39#[derive(Deserialize)]
40#[serde(rename_all = "camelCase")]
41pub struct ApplyWritesInput {
42 pub repo: String,
43 pub validate: Option<bool>,
44 pub writes: Vec<WriteOp>,
45 pub swap_commit: Option<String>,
46}
47
48#[derive(Serialize)]
49#[serde(tag = "$type")]
50pub enum WriteResult {
51 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
52 CreateResult { uri: String, cid: String },
53 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
54 UpdateResult { uri: String, cid: String },
55 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
56 DeleteResult {},
57}
58
59#[derive(Serialize)]
60pub struct ApplyWritesOutput {
61 pub commit: CommitInfo,
62 pub results: Vec<WriteResult>,
63}
64
65#[derive(Serialize)]
66pub struct CommitInfo {
67 pub cid: String,
68 pub rev: String,
69}
70
71pub async fn apply_writes(
72 State(state): State<AppState>,
73 headers: axum::http::HeaderMap,
74 Json(input): Json<ApplyWritesInput>,
75) -> Response {
76 let auth_header = headers.get("Authorization");
77 if auth_header.is_none() {
78 return (
79 StatusCode::UNAUTHORIZED,
80 Json(json!({"error": "AuthenticationRequired"})),
81 )
82 .into_response();
83 }
84 let token = auth_header
85 .unwrap()
86 .to_str()
87 .unwrap_or("")
88 .replace("Bearer ", "");
89
90 let session = sqlx::query!(
91 "SELECT s.did, k.key_bytes FROM sessions s JOIN users u ON s.did = u.did JOIN user_keys k ON u.id = k.user_id WHERE s.access_jwt = $1",
92 token
93 )
94 .fetch_optional(&state.db)
95 .await
96 .unwrap_or(None);
97
98 let (did, key_bytes) = match session {
99 Some(row) => (row.did, row.key_bytes),
100 None => {
101 return (
102 StatusCode::UNAUTHORIZED,
103 Json(json!({"error": "AuthenticationFailed"})),
104 )
105 .into_response();
106 }
107 };
108
109 if let Err(_) = crate::auth::verify_token(&token, &key_bytes) {
110 return (
111 StatusCode::UNAUTHORIZED,
112 Json(json!({"error": "AuthenticationFailed", "message": "Invalid token signature"})),
113 )
114 .into_response();
115 }
116
117 if input.repo != did {
118 return (
119 StatusCode::FORBIDDEN,
120 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})),
121 )
122 .into_response();
123 }
124
125 if input.writes.is_empty() {
126 return (
127 StatusCode::BAD_REQUEST,
128 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})),
129 )
130 .into_response();
131 }
132
133 if input.writes.len() > 200 {
134 return (
135 StatusCode::BAD_REQUEST,
136 Json(json!({"error": "InvalidRequest", "message": "Too many writes (max 200)"})),
137 )
138 .into_response();
139 }
140
141 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
142 .fetch_optional(&state.db)
143 .await
144 {
145 Ok(Some(id)) => id,
146 _ => {
147 return (
148 StatusCode::INTERNAL_SERVER_ERROR,
149 Json(json!({"error": "InternalError", "message": "User not found"})),
150 )
151 .into_response();
152 }
153 };
154
155 let root_cid_str: String =
156 match sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id)
157 .fetch_optional(&state.db)
158 .await
159 {
160 Ok(Some(cid_str)) => cid_str,
161 _ => {
162 return (
163 StatusCode::INTERNAL_SERVER_ERROR,
164 Json(json!({"error": "InternalError", "message": "Repo root not found"})),
165 )
166 .into_response();
167 }
168 };
169
170 let current_root_cid = match Cid::from_str(&root_cid_str) {
171 Ok(c) => c,
172 Err(_) => {
173 return (
174 StatusCode::INTERNAL_SERVER_ERROR,
175 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})),
176 )
177 .into_response();
178 }
179 };
180
181 if let Some(swap_commit) = &input.swap_commit {
182 if Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
183 return (
184 StatusCode::CONFLICT,
185 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
186 )
187 .into_response();
188 }
189 }
190
191 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
192
193 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
194 Ok(Some(b)) => b,
195 _ => {
196 return (
197 StatusCode::INTERNAL_SERVER_ERROR,
198 Json(json!({"error": "InternalError", "message": "Commit block not found"})),
199 )
200 .into_response()
201 }
202 };
203
204 let commit = match Commit::from_cbor(&commit_bytes) {
205 Ok(c) => c,
206 _ => {
207 return (
208 StatusCode::INTERNAL_SERVER_ERROR,
209 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
210 )
211 .into_response()
212 }
213 };
214
215 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
216
217 let mut results: Vec<WriteResult> = Vec::new();
218 let mut ops: Vec<RecordOp> = Vec::new();
219
220 for write in &input.writes {
221 match write {
222 WriteOp::Create {
223 collection,
224 rkey,
225 value,
226 } => {
227 let rkey = rkey
228 .clone()
229 .unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string());
230 let mut record_bytes = Vec::new();
231 serde_ipld_dagcbor::to_writer(&mut record_bytes, value).unwrap();
232 let record_cid = tracking_store.put(&record_bytes).await.unwrap();
233
234 let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey);
235 mst = mst.add(&key, record_cid).await.unwrap();
236
237 let uri = format!("at://{}/{}/{}", did, collection, rkey);
238 results.push(WriteResult::CreateResult {
239 uri,
240 cid: record_cid.to_string(),
241 });
242 ops.push(RecordOp::Create {
243 collection: collection.clone(),
244 rkey,
245 cid: record_cid,
246 });
247 }
248 WriteOp::Update {
249 collection,
250 rkey,
251 value,
252 } => {
253 let mut record_bytes = Vec::new();
254 serde_ipld_dagcbor::to_writer(&mut record_bytes, value).unwrap();
255 let record_cid = tracking_store.put(&record_bytes).await.unwrap();
256
257 let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey);
258 mst = mst.update(&key, record_cid).await.unwrap();
259
260 let uri = format!("at://{}/{}/{}", did, collection, rkey);
261 results.push(WriteResult::UpdateResult {
262 uri,
263 cid: record_cid.to_string(),
264 });
265 ops.push(RecordOp::Update {
266 collection: collection.clone(),
267 rkey: rkey.clone(),
268 cid: record_cid,
269 });
270 }
271 WriteOp::Delete { collection, rkey } => {
272 let key = format!("{}/{}", collection.parse::<Nsid>().unwrap(), rkey);
273 mst = mst.delete(&key).await.unwrap();
274
275 results.push(WriteResult::DeleteResult {});
276 ops.push(RecordOp::Delete {
277 collection: collection.clone(),
278 rkey: rkey.clone(),
279 });
280 }
281 }
282 }
283
284 let new_mst_root = mst.persist().await.unwrap();
285 let written_cids = tracking_store.get_written_cids();
286 let written_cids_str = written_cids
287 .iter()
288 .map(|c| c.to_string())
289 .collect::<Vec<_>>();
290
291 let commit_res = match commit_and_log(
292 &state,
293 &did,
294 user_id,
295 Some(current_root_cid),
296 new_mst_root,
297 ops,
298 &written_cids_str,
299 )
300 .await
301 {
302 Ok(res) => res,
303 Err(e) => {
304 error!("Commit failed: {}", e);
305 return (
306 StatusCode::INTERNAL_SERVER_ERROR,
307 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})),
308 )
309 .into_response();
310 }
311 };
312
313 (
314 StatusCode::OK,
315 Json(ApplyWritesOutput {
316 commit: CommitInfo {
317 cid: commit_res.commit_cid.to_string(),
318 rev: commit_res.rev,
319 },
320 results,
321 }),
322 )
323 .into_response()
324}