this repo has no description
1use super::validation::validate_record;
2use crate::api::repo::record::utils::{commit_and_log, RecordOp};
3use crate::repo::tracking::TrackingBlockStore;
4use crate::state::AppState;
5use axum::{
6 extract::State,
7 http::StatusCode,
8 response::{IntoResponse, Response},
9 Json,
10};
11use chrono::Utc;
12use cid::Cid;
13use jacquard::types::string::Nsid;
14use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17use std::str::FromStr;
18use std::sync::Arc;
19use tracing::error;
20
21const MAX_BATCH_WRITES: usize = 200;
22
23#[derive(Deserialize)]
24#[serde(tag = "$type")]
25pub enum WriteOp {
26 #[serde(rename = "com.atproto.repo.applyWrites#create")]
27 Create {
28 collection: String,
29 rkey: Option<String>,
30 value: serde_json::Value,
31 },
32 #[serde(rename = "com.atproto.repo.applyWrites#update")]
33 Update {
34 collection: String,
35 rkey: String,
36 value: serde_json::Value,
37 },
38 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
39 Delete { collection: String, rkey: String },
40}
41
42#[derive(Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct ApplyWritesInput {
45 pub repo: String,
46 pub validate: Option<bool>,
47 pub writes: Vec<WriteOp>,
48 pub swap_commit: Option<String>,
49}
50
51#[derive(Serialize)]
52#[serde(tag = "$type")]
53pub enum WriteResult {
54 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
55 CreateResult { uri: String, cid: String },
56 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
57 UpdateResult { uri: String, cid: String },
58 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
59 DeleteResult {},
60}
61
62#[derive(Serialize)]
63pub struct ApplyWritesOutput {
64 pub commit: CommitInfo,
65 pub results: Vec<WriteResult>,
66}
67
68#[derive(Serialize)]
69pub struct CommitInfo {
70 pub cid: String,
71 pub rev: String,
72}
73
74pub async fn apply_writes(
75 State(state): State<AppState>,
76 headers: axum::http::HeaderMap,
77 Json(input): Json<ApplyWritesInput>,
78) -> Response {
79 let token = match crate::auth::extract_bearer_token_from_header(
80 headers.get("Authorization").and_then(|h| h.to_str().ok())
81 ) {
82 Some(t) => t,
83 None => {
84 return (
85 StatusCode::UNAUTHORIZED,
86 Json(json!({"error": "AuthenticationRequired"})),
87 )
88 .into_response();
89 }
90 };
91
92 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
93 Ok(user) => user,
94 Err(_) => {
95 return (
96 StatusCode::UNAUTHORIZED,
97 Json(json!({"error": "AuthenticationFailed"})),
98 )
99 .into_response();
100 }
101 };
102
103 let did = auth_user.did;
104
105 if input.repo != did {
106 return (
107 StatusCode::FORBIDDEN,
108 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})),
109 )
110 .into_response();
111 }
112
113 if input.writes.is_empty() {
114 return (
115 StatusCode::BAD_REQUEST,
116 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})),
117 )
118 .into_response();
119 }
120
121 if input.writes.len() > MAX_BATCH_WRITES {
122 return (
123 StatusCode::BAD_REQUEST,
124 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})),
125 )
126 .into_response();
127 }
128
129 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
130 .fetch_optional(&state.db)
131 .await
132 {
133 Ok(Some(id)) => id,
134 _ => {
135 return (
136 StatusCode::INTERNAL_SERVER_ERROR,
137 Json(json!({"error": "InternalError", "message": "User not found"})),
138 )
139 .into_response();
140 }
141 };
142
143 let root_cid_str: String =
144 match sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id)
145 .fetch_optional(&state.db)
146 .await
147 {
148 Ok(Some(cid_str)) => cid_str,
149 _ => {
150 return (
151 StatusCode::INTERNAL_SERVER_ERROR,
152 Json(json!({"error": "InternalError", "message": "Repo root not found"})),
153 )
154 .into_response();
155 }
156 };
157
158 let current_root_cid = match Cid::from_str(&root_cid_str) {
159 Ok(c) => c,
160 Err(_) => {
161 return (
162 StatusCode::INTERNAL_SERVER_ERROR,
163 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})),
164 )
165 .into_response();
166 }
167 };
168
169 if let Some(swap_commit) = &input.swap_commit {
170 if Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
171 return (
172 StatusCode::CONFLICT,
173 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
174 )
175 .into_response();
176 }
177 }
178
179 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
180
181 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
182 Ok(Some(b)) => b,
183 _ => {
184 return (
185 StatusCode::INTERNAL_SERVER_ERROR,
186 Json(json!({"error": "InternalError", "message": "Commit block not found"})),
187 )
188 .into_response()
189 }
190 };
191
192 let commit = match Commit::from_cbor(&commit_bytes) {
193 Ok(c) => c,
194 _ => {
195 return (
196 StatusCode::INTERNAL_SERVER_ERROR,
197 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
198 )
199 .into_response()
200 }
201 };
202
203 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
204
205 let mut results: Vec<WriteResult> = Vec::new();
206 let mut ops: Vec<RecordOp> = Vec::new();
207
208 for write in &input.writes {
209 match write {
210 WriteOp::Create {
211 collection,
212 rkey,
213 value,
214 } => {
215 if input.validate.unwrap_or(true) {
216 if let Err(err_response) = validate_record(value, collection) {
217 return err_response;
218 }
219 }
220 let rkey = rkey
221 .clone()
222 .unwrap_or_else(|| Utc::now().format("%Y%m%d%H%M%S%f").to_string());
223 let mut record_bytes = Vec::new();
224 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
225 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
226 }
227 let record_cid = match tracking_store.put(&record_bytes).await {
228 Ok(c) => c,
229 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
230 };
231
232 let collection_nsid = match collection.parse::<Nsid>() {
233 Ok(n) => n,
234 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
235 };
236 let key = format!("{}/{}", collection_nsid, rkey);
237 mst = match mst.add(&key, record_cid).await {
238 Ok(m) => m,
239 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
240 };
241
242 let uri = format!("at://{}/{}/{}", did, collection, rkey);
243 results.push(WriteResult::CreateResult {
244 uri,
245 cid: record_cid.to_string(),
246 });
247 ops.push(RecordOp::Create {
248 collection: collection.clone(),
249 rkey,
250 cid: record_cid,
251 });
252 }
253 WriteOp::Update {
254 collection,
255 rkey,
256 value,
257 } => {
258 if input.validate.unwrap_or(true) {
259 if let Err(err_response) = validate_record(value, collection) {
260 return err_response;
261 }
262 }
263 let mut record_bytes = Vec::new();
264 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
265 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
266 }
267 let record_cid = match tracking_store.put(&record_bytes).await {
268 Ok(c) => c,
269 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
270 };
271
272 let collection_nsid = match collection.parse::<Nsid>() {
273 Ok(n) => n,
274 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
275 };
276 let key = format!("{}/{}", collection_nsid, rkey);
277 mst = match mst.update(&key, record_cid).await {
278 Ok(m) => m,
279 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(),
280 };
281
282 let uri = format!("at://{}/{}/{}", did, collection, rkey);
283 results.push(WriteResult::UpdateResult {
284 uri,
285 cid: record_cid.to_string(),
286 });
287 ops.push(RecordOp::Update {
288 collection: collection.clone(),
289 rkey: rkey.clone(),
290 cid: record_cid,
291 });
292 }
293 WriteOp::Delete { collection, rkey } => {
294 let collection_nsid = match collection.parse::<Nsid>() {
295 Ok(n) => n,
296 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
297 };
298 let key = format!("{}/{}", collection_nsid, rkey);
299 mst = match mst.delete(&key).await {
300 Ok(m) => m,
301 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(),
302 };
303
304 results.push(WriteResult::DeleteResult {});
305 ops.push(RecordOp::Delete {
306 collection: collection.clone(),
307 rkey: rkey.clone(),
308 });
309 }
310 }
311 }
312
313 let new_mst_root = match mst.persist().await {
314 Ok(c) => c,
315 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
316 };
317 let written_cids = tracking_store.get_written_cids();
318 let written_cids_str = written_cids
319 .iter()
320 .map(|c| c.to_string())
321 .collect::<Vec<_>>();
322
323 let commit_res = match commit_and_log(
324 &state,
325 &did,
326 user_id,
327 Some(current_root_cid),
328 new_mst_root,
329 ops,
330 &written_cids_str,
331 )
332 .await
333 {
334 Ok(res) => res,
335 Err(e) => {
336 error!("Commit failed: {}", e);
337 return (
338 StatusCode::INTERNAL_SERVER_ERROR,
339 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})),
340 )
341 .into_response();
342 }
343 };
344
345 (
346 StatusCode::OK,
347 Json(ApplyWritesOutput {
348 commit: CommitInfo {
349 cid: commit_res.commit_cid.to_string(),
350 rev: commit_res.rev,
351 },
352 results,
353 }),
354 )
355 .into_response()
356}