this repo has no description
1use crate::api::read_after_write::{
2 PostRecord, PostView, RecordDescript, extract_repo_rev, format_local_post,
3 format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview,
4};
5use crate::state::AppState;
6use axum::{
7 Json,
8 extract::{Query, State},
9 http::StatusCode,
10 response::{IntoResponse, Response},
11};
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use std::collections::HashMap;
15use tracing::warn;
16
17#[derive(Deserialize)]
18pub struct GetPostThreadParams {
19 pub uri: String,
20 pub depth: Option<u32>,
21 #[serde(rename = "parentHeight")]
22 pub parent_height: Option<u32>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct ThreadViewPost {
28 #[serde(rename = "$type")]
29 pub thread_type: Option<String>,
30 pub post: PostView,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub parent: Option<Box<ThreadNode>>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub replies: Option<Vec<ThreadNode>>,
35 #[serde(flatten)]
36 pub extra: HashMap<String, Value>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[serde(untagged)]
41pub enum ThreadNode {
42 Post(Box<ThreadViewPost>),
43 NotFound(ThreadNotFound),
44 Blocked(ThreadBlocked),
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(rename_all = "camelCase")]
49pub struct ThreadNotFound {
50 #[serde(rename = "$type")]
51 pub thread_type: String,
52 pub uri: String,
53 pub not_found: bool,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct ThreadBlocked {
59 #[serde(rename = "$type")]
60 pub thread_type: String,
61 pub uri: String,
62 pub blocked: bool,
63 pub author: Value,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PostThreadOutput {
68 pub thread: ThreadNode,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub threadgate: Option<Value>,
71}
72
73const MAX_THREAD_DEPTH: usize = 10;
74
75fn add_replies_to_thread(
76 thread: &mut ThreadViewPost,
77 local_posts: &[RecordDescript<PostRecord>],
78 author_did: &str,
79 author_handle: &str,
80 depth: usize,
81) {
82 if depth >= MAX_THREAD_DEPTH {
83 return;
84 }
85 let thread_uri = &thread.post.uri;
86 let replies: Vec<_> = local_posts
87 .iter()
88 .filter(|p| {
89 p.record
90 .reply
91 .as_ref()
92 .and_then(|r| r.get("parent"))
93 .and_then(|parent| parent.get("uri"))
94 .and_then(|u| u.as_str())
95 == Some(thread_uri)
96 })
97 .map(|p| {
98 let post_view = format_local_post(p, author_did, author_handle, None);
99 ThreadNode::Post(Box::new(ThreadViewPost {
100 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
101 post: post_view,
102 parent: None,
103 replies: None,
104 extra: HashMap::new(),
105 }))
106 })
107 .collect();
108 if !replies.is_empty() {
109 match &mut thread.replies {
110 Some(existing) => existing.extend(replies),
111 None => thread.replies = Some(replies),
112 }
113 }
114 if let Some(ref mut existing_replies) = thread.replies {
115 for reply in existing_replies.iter_mut() {
116 if let ThreadNode::Post(reply_thread) = reply {
117 add_replies_to_thread(
118 reply_thread,
119 local_posts,
120 author_did,
121 author_handle,
122 depth + 1,
123 );
124 }
125 }
126 }
127}
128
129pub async fn get_post_thread(
130 State(state): State<AppState>,
131 headers: axum::http::HeaderMap,
132 Query(params): Query<GetPostThreadParams>,
133) -> Response {
134 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
135 let auth_user = if let Some(h) = auth_header {
136 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
137 crate::auth::validate_bearer_token(&state.db, &token)
138 .await
139 .ok()
140 } else {
141 None
142 }
143 } else {
144 None
145 };
146 let auth_did = auth_user.as_ref().map(|u| u.did.clone());
147 let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone());
148 let mut query_params = HashMap::new();
149 query_params.insert("uri".to_string(), params.uri.clone());
150 if let Some(depth) = params.depth {
151 query_params.insert("depth".to_string(), depth.to_string());
152 }
153 if let Some(parent_height) = params.parent_height {
154 query_params.insert("parentHeight".to_string(), parent_height.to_string());
155 }
156 let proxy_result = match proxy_to_appview(
157 "app.bsky.feed.getPostThread",
158 &query_params,
159 auth_did.as_deref().unwrap_or(""),
160 auth_key_bytes.as_deref(),
161 )
162 .await
163 {
164 Ok(r) => r,
165 Err(e) => return e,
166 };
167 if proxy_result.status == StatusCode::NOT_FOUND {
168 return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await;
169 }
170 if !proxy_result.status.is_success() {
171 return proxy_result.into_response();
172 }
173 let rev = match extract_repo_rev(&proxy_result.headers) {
174 Some(r) => r,
175 None => return proxy_result.into_response(),
176 };
177 let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
178 Ok(t) => t,
179 Err(e) => {
180 warn!("Failed to parse post thread response: {:?}", e);
181 return proxy_result.into_response();
182 }
183 };
184 let requester_did = match auth_did {
185 Some(d) => d,
186 None => return (StatusCode::OK, Json(thread_output)).into_response(),
187 };
188 let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
189 Ok(r) => r,
190 Err(e) => {
191 warn!("Failed to get local records: {}", e);
192 return proxy_result.into_response();
193 }
194 };
195 if local_records.posts.is_empty() {
196 return (StatusCode::OK, Json(thread_output)).into_response();
197 }
198 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
199 .fetch_optional(&state.db)
200 .await
201 {
202 Ok(Some(h)) => h,
203 Ok(None) => requester_did.clone(),
204 Err(e) => {
205 warn!("Database error fetching handle: {:?}", e);
206 requester_did.clone()
207 }
208 };
209 if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
210 add_replies_to_thread(
211 thread_post,
212 &local_records.posts,
213 &requester_did,
214 &handle,
215 0,
216 );
217 }
218 let lag = get_local_lag(&local_records);
219 format_munged_response(thread_output, lag)
220}
221
222async fn handle_not_found(
223 state: &AppState,
224 uri: &str,
225 auth_did: Option<String>,
226 headers: &axum::http::HeaderMap,
227) -> Response {
228 let rev = match extract_repo_rev(headers) {
229 Some(r) => r,
230 None => {
231 return (
232 StatusCode::NOT_FOUND,
233 Json(json!({"error": "NotFound", "message": "Post not found"})),
234 )
235 .into_response();
236 }
237 };
238 let requester_did = match auth_did {
239 Some(d) => d,
240 None => {
241 return (
242 StatusCode::NOT_FOUND,
243 Json(json!({"error": "NotFound", "message": "Post not found"})),
244 )
245 .into_response();
246 }
247 };
248 let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
249 if uri_parts.len() != 3 {
250 return (
251 StatusCode::NOT_FOUND,
252 Json(json!({"error": "NotFound", "message": "Post not found"})),
253 )
254 .into_response();
255 }
256 let post_did = uri_parts[0];
257 if post_did != requester_did {
258 return (
259 StatusCode::NOT_FOUND,
260 Json(json!({"error": "NotFound", "message": "Post not found"})),
261 )
262 .into_response();
263 }
264 let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
265 Ok(r) => r,
266 Err(_) => {
267 return (
268 StatusCode::NOT_FOUND,
269 Json(json!({"error": "NotFound", "message": "Post not found"})),
270 )
271 .into_response();
272 }
273 };
274 let local_post = local_records.posts.iter().find(|p| p.uri == uri);
275 let local_post = match local_post {
276 Some(p) => p,
277 None => {
278 return (
279 StatusCode::NOT_FOUND,
280 Json(json!({"error": "NotFound", "message": "Post not found"})),
281 )
282 .into_response();
283 }
284 };
285 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
286 .fetch_optional(&state.db)
287 .await
288 {
289 Ok(Some(h)) => h,
290 Ok(None) => requester_did.clone(),
291 Err(e) => {
292 warn!("Database error fetching handle: {:?}", e);
293 requester_did.clone()
294 }
295 };
296 let post_view = format_local_post(
297 local_post,
298 &requester_did,
299 &handle,
300 local_records.profile.as_ref(),
301 );
302 let thread = PostThreadOutput {
303 thread: ThreadNode::Post(Box::new(ThreadViewPost {
304 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
305 post: post_view,
306 parent: None,
307 replies: None,
308 extra: HashMap::new(),
309 })),
310 threadgate: None,
311 };
312 let lag = get_local_lag(&local_records);
313 format_munged_response(thread, lag)
314}