this repo has no description
1use crate::api::read_after_write::{
2 extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
3 get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript,
4};
5use crate::state::AppState;
6use axum::{
7 extract::{Query, State},
8 http::StatusCode,
9 response::{IntoResponse, Response},
10 Json,
11};
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14use std::collections::HashMap;
15use tracing::warn;
16
17#[derive(Deserialize)]
18pub struct GetPostThreadParams {
19 pub uri: String,
20 pub depth: Option<u32>,
21 #[serde(rename = "parentHeight")]
22 pub parent_height: Option<u32>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct ThreadViewPost {
28 #[serde(rename = "$type")]
29 pub thread_type: Option<String>,
30 pub post: PostView,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub parent: Option<Box<ThreadNode>>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub replies: Option<Vec<ThreadNode>>,
35 #[serde(flatten)]
36 pub extra: HashMap<String, Value>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[serde(untagged)]
41pub enum ThreadNode {
42 Post(ThreadViewPost),
43 NotFound(ThreadNotFound),
44 Blocked(ThreadBlocked),
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(rename_all = "camelCase")]
49pub struct ThreadNotFound {
50 #[serde(rename = "$type")]
51 pub thread_type: String,
52 pub uri: String,
53 pub not_found: bool,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct ThreadBlocked {
59 #[serde(rename = "$type")]
60 pub thread_type: String,
61 pub uri: String,
62 pub blocked: bool,
63 pub author: Value,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PostThreadOutput {
68 pub thread: ThreadNode,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub threadgate: Option<Value>,
71}
72
73const MAX_THREAD_DEPTH: usize = 10;
74
75fn add_replies_to_thread(
76 thread: &mut ThreadViewPost,
77 local_posts: &[RecordDescript<PostRecord>],
78 author_did: &str,
79 author_handle: &str,
80 depth: usize,
81) {
82 if depth >= MAX_THREAD_DEPTH {
83 return;
84 }
85
86 let thread_uri = &thread.post.uri;
87
88 let replies: Vec<_> = local_posts
89 .iter()
90 .filter(|p| {
91 p.record
92 .reply
93 .as_ref()
94 .and_then(|r| r.get("parent"))
95 .and_then(|parent| parent.get("uri"))
96 .and_then(|u| u.as_str())
97 == Some(thread_uri)
98 })
99 .map(|p| {
100 let post_view = format_local_post(p, author_did, author_handle, None);
101 ThreadNode::Post(ThreadViewPost {
102 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
103 post: post_view,
104 parent: None,
105 replies: None,
106 extra: HashMap::new(),
107 })
108 })
109 .collect();
110
111 if !replies.is_empty() {
112 match &mut thread.replies {
113 Some(existing) => existing.extend(replies),
114 None => thread.replies = Some(replies),
115 }
116 }
117
118 if let Some(ref mut existing_replies) = thread.replies {
119 for reply in existing_replies.iter_mut() {
120 if let ThreadNode::Post(reply_thread) = reply {
121 add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1);
122 }
123 }
124 }
125}
126
127pub async fn get_post_thread(
128 State(state): State<AppState>,
129 headers: axum::http::HeaderMap,
130 Query(params): Query<GetPostThreadParams>,
131) -> Response {
132 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
133
134 let auth_did = if let Some(h) = auth_header {
135 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
136 match crate::auth::validate_bearer_token(&state.db, &token).await {
137 Ok(user) => Some(user.did),
138 Err(_) => None,
139 }
140 } else {
141 None
142 }
143 } else {
144 None
145 };
146
147 let mut query_params = HashMap::new();
148 query_params.insert("uri".to_string(), params.uri.clone());
149 if let Some(depth) = params.depth {
150 query_params.insert("depth".to_string(), depth.to_string());
151 }
152 if let Some(parent_height) = params.parent_height {
153 query_params.insert("parentHeight".to_string(), parent_height.to_string());
154 }
155
156 let proxy_result =
157 match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_header).await {
158 Ok(r) => r,
159 Err(e) => return e,
160 };
161
162 if proxy_result.status == StatusCode::NOT_FOUND {
163 return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await;
164 }
165
166 if !proxy_result.status.is_success() {
167 return (proxy_result.status, proxy_result.body).into_response();
168 }
169
170 let rev = match extract_repo_rev(&proxy_result.headers) {
171 Some(r) => r,
172 None => return (proxy_result.status, proxy_result.body).into_response(),
173 };
174
175 let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
176 Ok(t) => t,
177 Err(e) => {
178 warn!("Failed to parse post thread response: {:?}", e);
179 return (proxy_result.status, proxy_result.body).into_response();
180 }
181 };
182
183 let requester_did = match auth_did {
184 Some(d) => d,
185 None => return (StatusCode::OK, Json(thread_output)).into_response(),
186 };
187
188 let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
189 Ok(r) => r,
190 Err(e) => {
191 warn!("Failed to get local records: {}", e);
192 return (proxy_result.status, proxy_result.body).into_response();
193 }
194 };
195
196 if local_records.posts.is_empty() {
197 return (StatusCode::OK, Json(thread_output)).into_response();
198 }
199
200 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
201 .fetch_optional(&state.db)
202 .await
203 {
204 Ok(Some(h)) => h,
205 Ok(None) => requester_did.clone(),
206 Err(e) => {
207 warn!("Database error fetching handle: {:?}", e);
208 requester_did.clone()
209 }
210 };
211
212 if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
213 add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0);
214 }
215
216 let lag = get_local_lag(&local_records);
217 format_munged_response(thread_output, lag)
218}
219
220async fn handle_not_found(
221 state: &AppState,
222 uri: &str,
223 auth_did: Option<String>,
224 headers: &axum::http::HeaderMap,
225) -> Response {
226 let rev = match extract_repo_rev(headers) {
227 Some(r) => r,
228 None => {
229 return (
230 StatusCode::NOT_FOUND,
231 Json(json!({"error": "NotFound", "message": "Post not found"})),
232 )
233 .into_response()
234 }
235 };
236
237 let requester_did = match auth_did {
238 Some(d) => d,
239 None => {
240 return (
241 StatusCode::NOT_FOUND,
242 Json(json!({"error": "NotFound", "message": "Post not found"})),
243 )
244 .into_response()
245 }
246 };
247
248 let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
249 if uri_parts.len() != 3 {
250 return (
251 StatusCode::NOT_FOUND,
252 Json(json!({"error": "NotFound", "message": "Post not found"})),
253 )
254 .into_response();
255 }
256
257 let post_did = uri_parts[0];
258 if post_did != requester_did {
259 return (
260 StatusCode::NOT_FOUND,
261 Json(json!({"error": "NotFound", "message": "Post not found"})),
262 )
263 .into_response();
264 }
265
266 let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
267 Ok(r) => r,
268 Err(_) => {
269 return (
270 StatusCode::NOT_FOUND,
271 Json(json!({"error": "NotFound", "message": "Post not found"})),
272 )
273 .into_response()
274 }
275 };
276
277 let local_post = local_records.posts.iter().find(|p| p.uri == uri);
278
279 let local_post = match local_post {
280 Some(p) => p,
281 None => {
282 return (
283 StatusCode::NOT_FOUND,
284 Json(json!({"error": "NotFound", "message": "Post not found"})),
285 )
286 .into_response()
287 }
288 };
289
290 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
291 .fetch_optional(&state.db)
292 .await
293 {
294 Ok(Some(h)) => h,
295 Ok(None) => requester_did.clone(),
296 Err(e) => {
297 warn!("Database error fetching handle: {:?}", e);
298 requester_did.clone()
299 }
300 };
301
302 let post_view = format_local_post(
303 local_post,
304 &requester_did,
305 &handle,
306 local_records.profile.as_ref(),
307 );
308
309 let thread = PostThreadOutput {
310 thread: ThreadNode::Post(ThreadViewPost {
311 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
312 post: post_view,
313 parent: None,
314 replies: None,
315 extra: HashMap::new(),
316 }),
317 threadgate: None,
318 };
319
320 let lag = get_local_lag(&local_records);
321 format_munged_response(thread, lag)
322}