this repo has no description
1use crate::api::read_after_write::{
2 extract_repo_rev, format_local_post, format_munged_response, get_local_lag,
3 get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript,
4};
5use crate::state::AppState;
6use axum::{
7 extract::{Query, State},
8 http::StatusCode,
9 response::{IntoResponse, Response},
10 Json,
11};
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14use std::collections::HashMap;
15use tracing::warn;
16
17#[derive(Deserialize)]
18pub struct GetPostThreadParams {
19 pub uri: String,
20 pub depth: Option<u32>,
21 #[serde(rename = "parentHeight")]
22 pub parent_height: Option<u32>,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct ThreadViewPost {
28 #[serde(rename = "$type")]
29 pub thread_type: Option<String>,
30 pub post: PostView,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub parent: Option<Box<ThreadNode>>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub replies: Option<Vec<ThreadNode>>,
35 #[serde(flatten)]
36 pub extra: HashMap<String, Value>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
40#[serde(untagged)]
41pub enum ThreadNode {
42 Post(ThreadViewPost),
43 NotFound(ThreadNotFound),
44 Blocked(ThreadBlocked),
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48#[serde(rename_all = "camelCase")]
49pub struct ThreadNotFound {
50 #[serde(rename = "$type")]
51 pub thread_type: String,
52 pub uri: String,
53 pub not_found: bool,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct ThreadBlocked {
59 #[serde(rename = "$type")]
60 pub thread_type: String,
61 pub uri: String,
62 pub blocked: bool,
63 pub author: Value,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct PostThreadOutput {
68 pub thread: ThreadNode,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub threadgate: Option<Value>,
71}
72
73const MAX_THREAD_DEPTH: usize = 10;
74
75fn add_replies_to_thread(
76 thread: &mut ThreadViewPost,
77 local_posts: &[RecordDescript<PostRecord>],
78 author_did: &str,
79 author_handle: &str,
80 depth: usize,
81) {
82 if depth >= MAX_THREAD_DEPTH {
83 return;
84 }
85 let thread_uri = &thread.post.uri;
86 let replies: Vec<_> = local_posts
87 .iter()
88 .filter(|p| {
89 p.record
90 .reply
91 .as_ref()
92 .and_then(|r| r.get("parent"))
93 .and_then(|parent| parent.get("uri"))
94 .and_then(|u| u.as_str())
95 == Some(thread_uri)
96 })
97 .map(|p| {
98 let post_view = format_local_post(p, author_did, author_handle, None);
99 ThreadNode::Post(ThreadViewPost {
100 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
101 post: post_view,
102 parent: None,
103 replies: None,
104 extra: HashMap::new(),
105 })
106 })
107 .collect();
108 if !replies.is_empty() {
109 match &mut thread.replies {
110 Some(existing) => existing.extend(replies),
111 None => thread.replies = Some(replies),
112 }
113 }
114 if let Some(ref mut existing_replies) = thread.replies {
115 for reply in existing_replies.iter_mut() {
116 if let ThreadNode::Post(reply_thread) = reply {
117 add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1);
118 }
119 }
120 }
121}
122
123pub async fn get_post_thread(
124 State(state): State<AppState>,
125 headers: axum::http::HeaderMap,
126 Query(params): Query<GetPostThreadParams>,
127) -> Response {
128 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok());
129 let auth_did = if let Some(h) = auth_header {
130 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) {
131 match crate::auth::validate_bearer_token(&state.db, &token).await {
132 Ok(user) => Some(user.did),
133 Err(_) => None,
134 }
135 } else {
136 None
137 }
138 } else {
139 None
140 };
141 let mut query_params = HashMap::new();
142 query_params.insert("uri".to_string(), params.uri.clone());
143 if let Some(depth) = params.depth {
144 query_params.insert("depth".to_string(), depth.to_string());
145 }
146 if let Some(parent_height) = params.parent_height {
147 query_params.insert("parentHeight".to_string(), parent_height.to_string());
148 }
149 let proxy_result =
150 match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_header).await {
151 Ok(r) => r,
152 Err(e) => return e,
153 };
154 if proxy_result.status == StatusCode::NOT_FOUND {
155 return handle_not_found(&state, ¶ms.uri, auth_did, &proxy_result.headers).await;
156 }
157 if !proxy_result.status.is_success() {
158 return (proxy_result.status, proxy_result.body).into_response();
159 }
160 let rev = match extract_repo_rev(&proxy_result.headers) {
161 Some(r) => r,
162 None => return (proxy_result.status, proxy_result.body).into_response(),
163 };
164 let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) {
165 Ok(t) => t,
166 Err(e) => {
167 warn!("Failed to parse post thread response: {:?}", e);
168 return (proxy_result.status, proxy_result.body).into_response();
169 }
170 };
171 let requester_did = match auth_did {
172 Some(d) => d,
173 None => return (StatusCode::OK, Json(thread_output)).into_response(),
174 };
175 let local_records = match get_records_since_rev(&state, &requester_did, &rev).await {
176 Ok(r) => r,
177 Err(e) => {
178 warn!("Failed to get local records: {}", e);
179 return (proxy_result.status, proxy_result.body).into_response();
180 }
181 };
182 if local_records.posts.is_empty() {
183 return (StatusCode::OK, Json(thread_output)).into_response();
184 }
185 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
186 .fetch_optional(&state.db)
187 .await
188 {
189 Ok(Some(h)) => h,
190 Ok(None) => requester_did.clone(),
191 Err(e) => {
192 warn!("Database error fetching handle: {:?}", e);
193 requester_did.clone()
194 }
195 };
196 if let ThreadNode::Post(ref mut thread_post) = thread_output.thread {
197 add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0);
198 }
199 let lag = get_local_lag(&local_records);
200 format_munged_response(thread_output, lag)
201}
202
203async fn handle_not_found(
204 state: &AppState,
205 uri: &str,
206 auth_did: Option<String>,
207 headers: &axum::http::HeaderMap,
208) -> Response {
209 let rev = match extract_repo_rev(headers) {
210 Some(r) => r,
211 None => {
212 return (
213 StatusCode::NOT_FOUND,
214 Json(json!({"error": "NotFound", "message": "Post not found"})),
215 )
216 .into_response()
217 }
218 };
219 let requester_did = match auth_did {
220 Some(d) => d,
221 None => {
222 return (
223 StatusCode::NOT_FOUND,
224 Json(json!({"error": "NotFound", "message": "Post not found"})),
225 )
226 .into_response()
227 }
228 };
229 let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect();
230 if uri_parts.len() != 3 {
231 return (
232 StatusCode::NOT_FOUND,
233 Json(json!({"error": "NotFound", "message": "Post not found"})),
234 )
235 .into_response();
236 }
237 let post_did = uri_parts[0];
238 if post_did != requester_did {
239 return (
240 StatusCode::NOT_FOUND,
241 Json(json!({"error": "NotFound", "message": "Post not found"})),
242 )
243 .into_response();
244 }
245 let local_records = match get_records_since_rev(state, &requester_did, &rev).await {
246 Ok(r) => r,
247 Err(_) => {
248 return (
249 StatusCode::NOT_FOUND,
250 Json(json!({"error": "NotFound", "message": "Post not found"})),
251 )
252 .into_response()
253 }
254 };
255 let local_post = local_records.posts.iter().find(|p| p.uri == uri);
256 let local_post = match local_post {
257 Some(p) => p,
258 None => {
259 return (
260 StatusCode::NOT_FOUND,
261 Json(json!({"error": "NotFound", "message": "Post not found"})),
262 )
263 .into_response()
264 }
265 };
266 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did)
267 .fetch_optional(&state.db)
268 .await
269 {
270 Ok(Some(h)) => h,
271 Ok(None) => requester_did.clone(),
272 Err(e) => {
273 warn!("Database error fetching handle: {:?}", e);
274 requester_did.clone()
275 }
276 };
277 let post_view = format_local_post(
278 local_post,
279 &requester_did,
280 &handle,
281 local_records.profile.as_ref(),
282 );
283 let thread = PostThreadOutput {
284 thread: ThreadNode::Post(ThreadViewPost {
285 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()),
286 post: post_view,
287 parent: None,
288 replies: None,
289 extra: HashMap::new(),
290 }),
291 threadgate: None,
292 };
293 let lag = get_local_lag(&local_records);
294 format_munged_response(thread, lag)
295}