this repo has no description
1use crate::api::read_after_write::{ 2 PostRecord, PostView, RecordDescript, extract_repo_rev, format_local_post, 3 format_munged_response, get_local_lag, get_records_since_rev, proxy_to_appview_via_registry, 4}; 5use crate::state::AppState; 6use axum::{ 7 Json, 8 extract::{Query, State}, 9 http::StatusCode, 10 response::{IntoResponse, Response}, 11}; 12use serde::{Deserialize, Serialize}; 13use serde_json::{Value, json}; 14use std::collections::HashMap; 15use tracing::warn; 16 17#[derive(Deserialize)] 18pub struct GetPostThreadParams { 19 pub uri: String, 20 pub depth: Option<u32>, 21 #[serde(rename = "parentHeight")] 22 pub parent_height: Option<u32>, 23} 24 25#[derive(Debug, Clone, Serialize, Deserialize)] 26#[serde(rename_all = "camelCase")] 27pub struct ThreadViewPost { 28 #[serde(rename = "$type")] 29 pub thread_type: Option<String>, 30 pub post: PostView, 31 #[serde(skip_serializing_if = "Option::is_none")] 32 pub parent: Option<Box<ThreadNode>>, 33 #[serde(skip_serializing_if = "Option::is_none")] 34 pub replies: Option<Vec<ThreadNode>>, 35 #[serde(flatten)] 36 pub extra: HashMap<String, Value>, 37} 38 39#[derive(Debug, Clone, Serialize, Deserialize)] 40#[serde(untagged)] 41pub enum ThreadNode { 42 Post(Box<ThreadViewPost>), 43 NotFound(ThreadNotFound), 44 Blocked(ThreadBlocked), 45} 46 47#[derive(Debug, Clone, Serialize, Deserialize)] 48#[serde(rename_all = "camelCase")] 49pub struct ThreadNotFound { 50 #[serde(rename = "$type")] 51 pub thread_type: String, 52 pub uri: String, 53 pub not_found: bool, 54} 55 56#[derive(Debug, Clone, Serialize, Deserialize)] 57#[serde(rename_all = "camelCase")] 58pub struct ThreadBlocked { 59 #[serde(rename = "$type")] 60 pub thread_type: String, 61 pub uri: String, 62 pub blocked: bool, 63 pub author: Value, 64} 65 66#[derive(Debug, Clone, Serialize, Deserialize)] 67pub struct PostThreadOutput { 68 pub thread: ThreadNode, 69 #[serde(skip_serializing_if = "Option::is_none")] 70 pub threadgate: Option<Value>, 71} 72 73const MAX_THREAD_DEPTH: usize = 10; 74 75fn add_replies_to_thread( 76 thread: &mut ThreadViewPost, 77 local_posts: &[RecordDescript<PostRecord>], 78 author_did: &str, 79 author_handle: &str, 80 depth: usize, 81) { 82 if depth >= MAX_THREAD_DEPTH { 83 return; 84 } 85 let thread_uri = &thread.post.uri; 86 let replies: Vec<_> = local_posts 87 .iter() 88 .filter(|p| { 89 p.record 90 .reply 91 .as_ref() 92 .and_then(|r| r.get("parent")) 93 .and_then(|parent| parent.get("uri")) 94 .and_then(|u| u.as_str()) 95 == Some(thread_uri) 96 }) 97 .map(|p| { 98 let post_view = format_local_post(p, author_did, author_handle, None); 99 ThreadNode::Post(Box::new(ThreadViewPost { 100 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), 101 post: post_view, 102 parent: None, 103 replies: None, 104 extra: HashMap::new(), 105 })) 106 }) 107 .collect(); 108 if !replies.is_empty() { 109 match &mut thread.replies { 110 Some(existing) => existing.extend(replies), 111 None => thread.replies = Some(replies), 112 } 113 } 114 if let Some(ref mut existing_replies) = thread.replies { 115 for reply in existing_replies.iter_mut() { 116 if let ThreadNode::Post(reply_thread) = reply { 117 add_replies_to_thread( 118 reply_thread, 119 local_posts, 120 author_did, 121 author_handle, 122 depth + 1, 123 ); 124 } 125 } 126 } 127} 128 129pub async fn get_post_thread( 130 State(state): State<AppState>, 131 headers: axum::http::HeaderMap, 132 Query(params): Query<GetPostThreadParams>, 133) -> Response { 134 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 135 let auth_user = if let Some(h) = auth_header { 136 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 137 crate::auth::validate_bearer_token(&state.db, &token) 138 .await 139 .ok() 140 } else { 141 None 142 } 143 } else { 144 None 145 }; 146 let auth_did = auth_user.as_ref().map(|u| u.did.clone()); 147 let auth_key_bytes = auth_user.as_ref().and_then(|u| u.key_bytes.clone()); 148 let mut query_params = HashMap::new(); 149 query_params.insert("uri".to_string(), params.uri.clone()); 150 if let Some(depth) = params.depth { 151 query_params.insert("depth".to_string(), depth.to_string()); 152 } 153 if let Some(parent_height) = params.parent_height { 154 query_params.insert("parentHeight".to_string(), parent_height.to_string()); 155 } 156 let proxy_result = match proxy_to_appview_via_registry( 157 &state, 158 "app.bsky.feed.getPostThread", 159 &query_params, 160 auth_did.as_deref().unwrap_or(""), 161 auth_key_bytes.as_deref(), 162 ) 163 .await 164 { 165 Ok(r) => r, 166 Err(e) => return e, 167 }; 168 if proxy_result.status == StatusCode::NOT_FOUND { 169 return handle_not_found(&state, &params.uri, auth_did, &proxy_result.headers).await; 170 } 171 if !proxy_result.status.is_success() { 172 return proxy_result.into_response(); 173 } 174 let rev = match extract_repo_rev(&proxy_result.headers) { 175 Some(r) => r, 176 None => return proxy_result.into_response(), 177 }; 178 let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) { 179 Ok(t) => t, 180 Err(e) => { 181 warn!("Failed to parse post thread response: {:?}", e); 182 return proxy_result.into_response(); 183 } 184 }; 185 let requester_did = match auth_did { 186 Some(d) => d, 187 None => return (StatusCode::OK, Json(thread_output)).into_response(), 188 }; 189 let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { 190 Ok(r) => r, 191 Err(e) => { 192 warn!("Failed to get local records: {}", e); 193 return proxy_result.into_response(); 194 } 195 }; 196 if local_records.posts.is_empty() { 197 return (StatusCode::OK, Json(thread_output)).into_response(); 198 } 199 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) 200 .fetch_optional(&state.db) 201 .await 202 { 203 Ok(Some(h)) => h, 204 Ok(None) => requester_did.clone(), 205 Err(e) => { 206 warn!("Database error fetching handle: {:?}", e); 207 requester_did.clone() 208 } 209 }; 210 if let ThreadNode::Post(ref mut thread_post) = thread_output.thread { 211 add_replies_to_thread( 212 thread_post, 213 &local_records.posts, 214 &requester_did, 215 &handle, 216 0, 217 ); 218 } 219 let lag = get_local_lag(&local_records); 220 format_munged_response(thread_output, lag) 221} 222 223async fn handle_not_found( 224 state: &AppState, 225 uri: &str, 226 auth_did: Option<String>, 227 headers: &axum::http::HeaderMap, 228) -> Response { 229 let rev = match extract_repo_rev(headers) { 230 Some(r) => r, 231 None => { 232 return ( 233 StatusCode::NOT_FOUND, 234 Json(json!({"error": "NotFound", "message": "Post not found"})), 235 ) 236 .into_response(); 237 } 238 }; 239 let requester_did = match auth_did { 240 Some(d) => d, 241 None => { 242 return ( 243 StatusCode::NOT_FOUND, 244 Json(json!({"error": "NotFound", "message": "Post not found"})), 245 ) 246 .into_response(); 247 } 248 }; 249 let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect(); 250 if uri_parts.len() != 3 { 251 return ( 252 StatusCode::NOT_FOUND, 253 Json(json!({"error": "NotFound", "message": "Post not found"})), 254 ) 255 .into_response(); 256 } 257 let post_did = uri_parts[0]; 258 if post_did != requester_did { 259 return ( 260 StatusCode::NOT_FOUND, 261 Json(json!({"error": "NotFound", "message": "Post not found"})), 262 ) 263 .into_response(); 264 } 265 let local_records = match get_records_since_rev(state, &requester_did, &rev).await { 266 Ok(r) => r, 267 Err(_) => { 268 return ( 269 StatusCode::NOT_FOUND, 270 Json(json!({"error": "NotFound", "message": "Post not found"})), 271 ) 272 .into_response(); 273 } 274 }; 275 let local_post = local_records.posts.iter().find(|p| p.uri == uri); 276 let local_post = match local_post { 277 Some(p) => p, 278 None => { 279 return ( 280 StatusCode::NOT_FOUND, 281 Json(json!({"error": "NotFound", "message": "Post not found"})), 282 ) 283 .into_response(); 284 } 285 }; 286 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) 287 .fetch_optional(&state.db) 288 .await 289 { 290 Ok(Some(h)) => h, 291 Ok(None) => requester_did.clone(), 292 Err(e) => { 293 warn!("Database error fetching handle: {:?}", e); 294 requester_did.clone() 295 } 296 }; 297 let post_view = format_local_post( 298 local_post, 299 &requester_did, 300 &handle, 301 local_records.profile.as_ref(), 302 ); 303 let thread = PostThreadOutput { 304 thread: ThreadNode::Post(Box::new(ThreadViewPost { 305 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), 306 post: post_view, 307 parent: None, 308 replies: None, 309 extra: HashMap::new(), 310 })), 311 threadgate: None, 312 }; 313 let lag = get_local_lag(&local_records); 314 format_munged_response(thread, lag) 315}