this repo has no description
1use crate::api::read_after_write::{ 2 extract_repo_rev, format_local_post, format_munged_response, get_local_lag, 3 get_records_since_rev, proxy_to_appview, PostRecord, PostView, RecordDescript, 4}; 5use crate::state::AppState; 6use axum::{ 7 extract::{Query, State}, 8 http::StatusCode, 9 response::{IntoResponse, Response}, 10 Json, 11}; 12use serde::{Deserialize, Serialize}; 13use serde_json::{json, Value}; 14use std::collections::HashMap; 15use tracing::warn; 16#[derive(Deserialize)] 17pub struct GetPostThreadParams { 18 pub uri: String, 19 pub depth: Option<u32>, 20 #[serde(rename = "parentHeight")] 21 pub parent_height: Option<u32>, 22} 23#[derive(Debug, Clone, Serialize, Deserialize)] 24#[serde(rename_all = "camelCase")] 25pub struct ThreadViewPost { 26 #[serde(rename = "$type")] 27 pub thread_type: Option<String>, 28 pub post: PostView, 29 #[serde(skip_serializing_if = "Option::is_none")] 30 pub parent: Option<Box<ThreadNode>>, 31 #[serde(skip_serializing_if = "Option::is_none")] 32 pub replies: Option<Vec<ThreadNode>>, 33 #[serde(flatten)] 34 pub extra: HashMap<String, Value>, 35} 36#[derive(Debug, Clone, Serialize, Deserialize)] 37#[serde(untagged)] 38pub enum ThreadNode { 39 Post(ThreadViewPost), 40 NotFound(ThreadNotFound), 41 Blocked(ThreadBlocked), 42} 43#[derive(Debug, Clone, Serialize, Deserialize)] 44#[serde(rename_all = "camelCase")] 45pub struct ThreadNotFound { 46 #[serde(rename = "$type")] 47 pub thread_type: String, 48 pub uri: String, 49 pub not_found: bool, 50} 51#[derive(Debug, Clone, Serialize, Deserialize)] 52#[serde(rename_all = "camelCase")] 53pub struct ThreadBlocked { 54 #[serde(rename = "$type")] 55 pub thread_type: String, 56 pub uri: String, 57 pub blocked: bool, 58 pub author: Value, 59} 60#[derive(Debug, Clone, Serialize, Deserialize)] 61pub struct PostThreadOutput { 62 pub thread: ThreadNode, 63 #[serde(skip_serializing_if = "Option::is_none")] 64 pub threadgate: Option<Value>, 65} 66const MAX_THREAD_DEPTH: usize = 10; 67fn add_replies_to_thread( 68 thread: &mut ThreadViewPost, 69 local_posts: &[RecordDescript<PostRecord>], 70 author_did: &str, 71 author_handle: &str, 72 depth: usize, 73) { 74 if depth >= MAX_THREAD_DEPTH { 75 return; 76 } 77 let thread_uri = &thread.post.uri; 78 let replies: Vec<_> = local_posts 79 .iter() 80 .filter(|p| { 81 p.record 82 .reply 83 .as_ref() 84 .and_then(|r| r.get("parent")) 85 .and_then(|parent| parent.get("uri")) 86 .and_then(|u| u.as_str()) 87 == Some(thread_uri) 88 }) 89 .map(|p| { 90 let post_view = format_local_post(p, author_did, author_handle, None); 91 ThreadNode::Post(ThreadViewPost { 92 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), 93 post: post_view, 94 parent: None, 95 replies: None, 96 extra: HashMap::new(), 97 }) 98 }) 99 .collect(); 100 if !replies.is_empty() { 101 match &mut thread.replies { 102 Some(existing) => existing.extend(replies), 103 None => thread.replies = Some(replies), 104 } 105 } 106 if let Some(ref mut existing_replies) = thread.replies { 107 for reply in existing_replies.iter_mut() { 108 if let ThreadNode::Post(reply_thread) = reply { 109 add_replies_to_thread(reply_thread, local_posts, author_did, author_handle, depth + 1); 110 } 111 } 112 } 113} 114pub async fn get_post_thread( 115 State(state): State<AppState>, 116 headers: axum::http::HeaderMap, 117 Query(params): Query<GetPostThreadParams>, 118) -> Response { 119 let auth_header = headers.get("Authorization").and_then(|h| h.to_str().ok()); 120 let auth_did = if let Some(h) = auth_header { 121 if let Some(token) = crate::auth::extract_bearer_token_from_header(Some(h)) { 122 match crate::auth::validate_bearer_token(&state.db, &token).await { 123 Ok(user) => Some(user.did), 124 Err(_) => None, 125 } 126 } else { 127 None 128 } 129 } else { 130 None 131 }; 132 let mut query_params = HashMap::new(); 133 query_params.insert("uri".to_string(), params.uri.clone()); 134 if let Some(depth) = params.depth { 135 query_params.insert("depth".to_string(), depth.to_string()); 136 } 137 if let Some(parent_height) = params.parent_height { 138 query_params.insert("parentHeight".to_string(), parent_height.to_string()); 139 } 140 let proxy_result = 141 match proxy_to_appview("app.bsky.feed.getPostThread", &query_params, auth_header).await { 142 Ok(r) => r, 143 Err(e) => return e, 144 }; 145 if proxy_result.status == StatusCode::NOT_FOUND { 146 return handle_not_found(&state, &params.uri, auth_did, &proxy_result.headers).await; 147 } 148 if !proxy_result.status.is_success() { 149 return (proxy_result.status, proxy_result.body).into_response(); 150 } 151 let rev = match extract_repo_rev(&proxy_result.headers) { 152 Some(r) => r, 153 None => return (proxy_result.status, proxy_result.body).into_response(), 154 }; 155 let mut thread_output: PostThreadOutput = match serde_json::from_slice(&proxy_result.body) { 156 Ok(t) => t, 157 Err(e) => { 158 warn!("Failed to parse post thread response: {:?}", e); 159 return (proxy_result.status, proxy_result.body).into_response(); 160 } 161 }; 162 let requester_did = match auth_did { 163 Some(d) => d, 164 None => return (StatusCode::OK, Json(thread_output)).into_response(), 165 }; 166 let local_records = match get_records_since_rev(&state, &requester_did, &rev).await { 167 Ok(r) => r, 168 Err(e) => { 169 warn!("Failed to get local records: {}", e); 170 return (proxy_result.status, proxy_result.body).into_response(); 171 } 172 }; 173 if local_records.posts.is_empty() { 174 return (StatusCode::OK, Json(thread_output)).into_response(); 175 } 176 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) 177 .fetch_optional(&state.db) 178 .await 179 { 180 Ok(Some(h)) => h, 181 Ok(None) => requester_did.clone(), 182 Err(e) => { 183 warn!("Database error fetching handle: {:?}", e); 184 requester_did.clone() 185 } 186 }; 187 if let ThreadNode::Post(ref mut thread_post) = thread_output.thread { 188 add_replies_to_thread(thread_post, &local_records.posts, &requester_did, &handle, 0); 189 } 190 let lag = get_local_lag(&local_records); 191 format_munged_response(thread_output, lag) 192} 193async fn handle_not_found( 194 state: &AppState, 195 uri: &str, 196 auth_did: Option<String>, 197 headers: &axum::http::HeaderMap, 198) -> Response { 199 let rev = match extract_repo_rev(headers) { 200 Some(r) => r, 201 None => { 202 return ( 203 StatusCode::NOT_FOUND, 204 Json(json!({"error": "NotFound", "message": "Post not found"})), 205 ) 206 .into_response() 207 } 208 }; 209 let requester_did = match auth_did { 210 Some(d) => d, 211 None => { 212 return ( 213 StatusCode::NOT_FOUND, 214 Json(json!({"error": "NotFound", "message": "Post not found"})), 215 ) 216 .into_response() 217 } 218 }; 219 let uri_parts: Vec<&str> = uri.trim_start_matches("at://").split('/').collect(); 220 if uri_parts.len() != 3 { 221 return ( 222 StatusCode::NOT_FOUND, 223 Json(json!({"error": "NotFound", "message": "Post not found"})), 224 ) 225 .into_response(); 226 } 227 let post_did = uri_parts[0]; 228 if post_did != requester_did { 229 return ( 230 StatusCode::NOT_FOUND, 231 Json(json!({"error": "NotFound", "message": "Post not found"})), 232 ) 233 .into_response(); 234 } 235 let local_records = match get_records_since_rev(state, &requester_did, &rev).await { 236 Ok(r) => r, 237 Err(_) => { 238 return ( 239 StatusCode::NOT_FOUND, 240 Json(json!({"error": "NotFound", "message": "Post not found"})), 241 ) 242 .into_response() 243 } 244 }; 245 let local_post = local_records.posts.iter().find(|p| p.uri == uri); 246 let local_post = match local_post { 247 Some(p) => p, 248 None => { 249 return ( 250 StatusCode::NOT_FOUND, 251 Json(json!({"error": "NotFound", "message": "Post not found"})), 252 ) 253 .into_response() 254 } 255 }; 256 let handle = match sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", requester_did) 257 .fetch_optional(&state.db) 258 .await 259 { 260 Ok(Some(h)) => h, 261 Ok(None) => requester_did.clone(), 262 Err(e) => { 263 warn!("Database error fetching handle: {:?}", e); 264 requester_did.clone() 265 } 266 }; 267 let post_view = format_local_post( 268 local_post, 269 &requester_did, 270 &handle, 271 local_records.profile.as_ref(), 272 ); 273 let thread = PostThreadOutput { 274 thread: ThreadNode::Post(ThreadViewPost { 275 thread_type: Some("app.bsky.feed.defs#threadViewPost".to_string()), 276 post: post_view, 277 parent: None, 278 replies: None, 279 extra: HashMap::new(), 280 }), 281 threadgate: None, 282 }; 283 let lag = get_local_lag(&local_records); 284 format_munged_response(thread, lag) 285}