this repo has no description
1use super::validation::validate_record;
2use super::write::has_verified_notification_channel;
3use crate::api::repo::record::utils::{commit_and_log, RecordOp};
4use crate::repo::tracking::TrackingBlockStore;
5use crate::state::AppState;
6use axum::{
7 extract::State,
8 http::StatusCode,
9 response::{IntoResponse, Response},
10 Json,
11};
12use cid::Cid;
13use jacquard::types::{integer::LimitedU32, string::{Nsid, Tid}};
14use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore};
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17use std::str::FromStr;
18use std::sync::Arc;
19use tracing::error;
20
21const MAX_BATCH_WRITES: usize = 200;
22
23#[derive(Deserialize)]
24#[serde(tag = "$type")]
25pub enum WriteOp {
26 #[serde(rename = "com.atproto.repo.applyWrites#create")]
27 Create {
28 collection: String,
29 rkey: Option<String>,
30 value: serde_json::Value,
31 },
32 #[serde(rename = "com.atproto.repo.applyWrites#update")]
33 Update {
34 collection: String,
35 rkey: String,
36 value: serde_json::Value,
37 },
38 #[serde(rename = "com.atproto.repo.applyWrites#delete")]
39 Delete { collection: String, rkey: String },
40}
41
42#[derive(Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct ApplyWritesInput {
45 pub repo: String,
46 pub validate: Option<bool>,
47 pub writes: Vec<WriteOp>,
48 pub swap_commit: Option<String>,
49}
50
51#[derive(Serialize)]
52#[serde(tag = "$type")]
53pub enum WriteResult {
54 #[serde(rename = "com.atproto.repo.applyWrites#createResult")]
55 CreateResult { uri: String, cid: String },
56 #[serde(rename = "com.atproto.repo.applyWrites#updateResult")]
57 UpdateResult { uri: String, cid: String },
58 #[serde(rename = "com.atproto.repo.applyWrites#deleteResult")]
59 DeleteResult {},
60}
61
62#[derive(Serialize)]
63pub struct ApplyWritesOutput {
64 pub commit: CommitInfo,
65 pub results: Vec<WriteResult>,
66}
67
68#[derive(Serialize)]
69pub struct CommitInfo {
70 pub cid: String,
71 pub rev: String,
72}
73
74pub async fn apply_writes(
75 State(state): State<AppState>,
76 headers: axum::http::HeaderMap,
77 Json(input): Json<ApplyWritesInput>,
78) -> Response {
79 let token = match crate::auth::extract_bearer_token_from_header(
80 headers.get("Authorization").and_then(|h| h.to_str().ok())
81 ) {
82 Some(t) => t,
83 None => {
84 return (
85 StatusCode::UNAUTHORIZED,
86 Json(json!({"error": "AuthenticationRequired"})),
87 )
88 .into_response();
89 }
90 };
91 let auth_user = match crate::auth::validate_bearer_token(&state.db, &token).await {
92 Ok(user) => user,
93 Err(_) => {
94 return (
95 StatusCode::UNAUTHORIZED,
96 Json(json!({"error": "AuthenticationFailed"})),
97 )
98 .into_response();
99 }
100 };
101 let did = auth_user.did;
102 if input.repo != did {
103 return (
104 StatusCode::FORBIDDEN,
105 Json(json!({"error": "InvalidRepo", "message": "Repo does not match authenticated user"})),
106 )
107 .into_response();
108 }
109 match has_verified_notification_channel(&state.db, &did).await {
110 Ok(true) => {}
111 Ok(false) => {
112 return (
113 StatusCode::FORBIDDEN,
114 Json(json!({
115 "error": "AccountNotVerified",
116 "message": "You must verify at least one notification channel (email, Discord, Telegram, or Signal) before creating records"
117 })),
118 )
119 .into_response();
120 }
121 Err(e) => {
122 error!("DB error checking notification channels: {}", e);
123 return (
124 StatusCode::INTERNAL_SERVER_ERROR,
125 Json(json!({"error": "InternalError"})),
126 )
127 .into_response();
128 }
129 }
130 if input.writes.is_empty() {
131 return (
132 StatusCode::BAD_REQUEST,
133 Json(json!({"error": "InvalidRequest", "message": "writes array is empty"})),
134 )
135 .into_response();
136 }
137 if input.writes.len() > MAX_BATCH_WRITES {
138 return (
139 StatusCode::BAD_REQUEST,
140 Json(json!({"error": "InvalidRequest", "message": format!("Too many writes (max {})", MAX_BATCH_WRITES)})),
141 )
142 .into_response();
143 }
144 let user_id: uuid::Uuid = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did)
145 .fetch_optional(&state.db)
146 .await
147 {
148 Ok(Some(id)) => id,
149 _ => {
150 return (
151 StatusCode::INTERNAL_SERVER_ERROR,
152 Json(json!({"error": "InternalError", "message": "User not found"})),
153 )
154 .into_response();
155 }
156 };
157 let root_cid_str: String =
158 match sqlx::query_scalar!("SELECT repo_root_cid FROM repos WHERE user_id = $1", user_id)
159 .fetch_optional(&state.db)
160 .await
161 {
162 Ok(Some(cid_str)) => cid_str,
163 _ => {
164 return (
165 StatusCode::INTERNAL_SERVER_ERROR,
166 Json(json!({"error": "InternalError", "message": "Repo root not found"})),
167 )
168 .into_response();
169 }
170 };
171 let current_root_cid = match Cid::from_str(&root_cid_str) {
172 Ok(c) => c,
173 Err(_) => {
174 return (
175 StatusCode::INTERNAL_SERVER_ERROR,
176 Json(json!({"error": "InternalError", "message": "Invalid repo root CID"})),
177 )
178 .into_response();
179 }
180 };
181 if let Some(swap_commit) = &input.swap_commit {
182 if Cid::from_str(swap_commit).ok() != Some(current_root_cid) {
183 return (
184 StatusCode::CONFLICT,
185 Json(json!({"error": "InvalidSwap", "message": "Repo has been modified"})),
186 )
187 .into_response();
188 }
189 }
190 let tracking_store = TrackingBlockStore::new(state.block_store.clone());
191 let commit_bytes = match tracking_store.get(¤t_root_cid).await {
192 Ok(Some(b)) => b,
193 _ => {
194 return (
195 StatusCode::INTERNAL_SERVER_ERROR,
196 Json(json!({"error": "InternalError", "message": "Commit block not found"})),
197 )
198 .into_response()
199 }
200 };
201 let commit = match Commit::from_cbor(&commit_bytes) {
202 Ok(c) => c,
203 _ => {
204 return (
205 StatusCode::INTERNAL_SERVER_ERROR,
206 Json(json!({"error": "InternalError", "message": "Failed to parse commit"})),
207 )
208 .into_response()
209 }
210 };
211 let original_mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
212 let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None);
213 let mut results: Vec<WriteResult> = Vec::new();
214 let mut ops: Vec<RecordOp> = Vec::new();
215 let mut modified_keys: Vec<String> = Vec::new();
216 for write in &input.writes {
217 match write {
218 WriteOp::Create {
219 collection,
220 rkey,
221 value,
222 } => {
223 if input.validate.unwrap_or(true) {
224 if let Err(err_response) = validate_record(value, collection) {
225 return err_response;
226 }
227 }
228 let rkey = rkey
229 .clone()
230 .unwrap_or_else(|| Tid::now(LimitedU32::MIN).to_string());
231 let mut record_bytes = Vec::new();
232 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
233 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
234 }
235 let record_cid = match tracking_store.put(&record_bytes).await {
236 Ok(c) => c,
237 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
238 };
239 let collection_nsid = match collection.parse::<Nsid>() {
240 Ok(n) => n,
241 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
242 };
243 let key = format!("{}/{}", collection_nsid, rkey);
244 modified_keys.push(key.clone());
245 mst = match mst.add(&key, record_cid).await {
246 Ok(m) => m,
247 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to add to MST"}))).into_response(),
248 };
249 let uri = format!("at://{}/{}/{}", did, collection, rkey);
250 results.push(WriteResult::CreateResult {
251 uri,
252 cid: record_cid.to_string(),
253 });
254 ops.push(RecordOp::Create {
255 collection: collection.clone(),
256 rkey,
257 cid: record_cid,
258 });
259 }
260 WriteOp::Update {
261 collection,
262 rkey,
263 value,
264 } => {
265 if input.validate.unwrap_or(true) {
266 if let Err(err_response) = validate_record(value, collection) {
267 return err_response;
268 }
269 }
270 let mut record_bytes = Vec::new();
271 if serde_ipld_dagcbor::to_writer(&mut record_bytes, value).is_err() {
272 return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidRecord", "message": "Failed to serialize record"}))).into_response();
273 }
274 let record_cid = match tracking_store.put(&record_bytes).await {
275 Ok(c) => c,
276 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to store record"}))).into_response(),
277 };
278 let collection_nsid = match collection.parse::<Nsid>() {
279 Ok(n) => n,
280 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
281 };
282 let key = format!("{}/{}", collection_nsid, rkey);
283 modified_keys.push(key.clone());
284 let prev_record_cid = mst.get(&key).await.ok().flatten();
285 mst = match mst.update(&key, record_cid).await {
286 Ok(m) => m,
287 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to update MST"}))).into_response(),
288 };
289 let uri = format!("at://{}/{}/{}", did, collection, rkey);
290 results.push(WriteResult::UpdateResult {
291 uri,
292 cid: record_cid.to_string(),
293 });
294 ops.push(RecordOp::Update {
295 collection: collection.clone(),
296 rkey: rkey.clone(),
297 cid: record_cid,
298 prev: prev_record_cid,
299 });
300 }
301 WriteOp::Delete { collection, rkey } => {
302 let collection_nsid = match collection.parse::<Nsid>() {
303 Ok(n) => n,
304 Err(_) => return (StatusCode::BAD_REQUEST, Json(json!({"error": "InvalidCollection", "message": "Invalid collection NSID"}))).into_response(),
305 };
306 let key = format!("{}/{}", collection_nsid, rkey);
307 modified_keys.push(key.clone());
308 let prev_record_cid = mst.get(&key).await.ok().flatten();
309 mst = match mst.delete(&key).await {
310 Ok(m) => m,
311 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to delete from MST"}))).into_response(),
312 };
313 results.push(WriteResult::DeleteResult {});
314 ops.push(RecordOp::Delete {
315 collection: collection.clone(),
316 rkey: rkey.clone(),
317 prev: prev_record_cid,
318 });
319 }
320 }
321 }
322 let new_mst_root = match mst.persist().await {
323 Ok(c) => c,
324 Err(_) => return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to persist MST"}))).into_response(),
325 };
326 let mut relevant_blocks = std::collections::BTreeMap::new();
327 for key in &modified_keys {
328 if let Err(_) = mst.blocks_for_path(key, &mut relevant_blocks).await {
329 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get new MST blocks for path"}))).into_response();
330 }
331 if let Err(_) = original_mst.blocks_for_path(key, &mut relevant_blocks).await {
332 return (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "InternalError", "message": "Failed to get old MST blocks for path"}))).into_response();
333 }
334 }
335 let mut written_cids = tracking_store.get_all_relevant_cids();
336 for cid in relevant_blocks.keys() {
337 if !written_cids.contains(cid) {
338 written_cids.push(*cid);
339 }
340 }
341 let written_cids_str = written_cids
342 .iter()
343 .map(|c| c.to_string())
344 .collect::<Vec<_>>();
345 let commit_res = match commit_and_log(
346 &state,
347 &did,
348 user_id,
349 Some(current_root_cid),
350 Some(commit.data),
351 new_mst_root,
352 ops,
353 &written_cids_str,
354 )
355 .await
356 {
357 Ok(res) => res,
358 Err(e) => {
359 error!("Commit failed: {}", e);
360 return (
361 StatusCode::INTERNAL_SERVER_ERROR,
362 Json(json!({"error": "InternalError", "message": "Failed to commit changes"})),
363 )
364 .into_response();
365 }
366 };
367 (
368 StatusCode::OK,
369 Json(ApplyWritesOutput {
370 commit: CommitInfo {
371 cid: commit_res.commit_cid.to_string(),
372 rev: commit_res.rev,
373 },
374 results,
375 }),
376 )
377 .into_response()
378}