Heavily customized version of smokesignal - https://whtwnd.com/kayrozen.com/3lpwe4ymowg2t
at main 368 lines 14 kB view raw
1//! AT Protocol client types and utilities 2//! 3//! This module provides client types for AT Protocol operations using the Atrium API. 4 5use serde::{de::DeserializeOwned, Deserialize, Serialize}; 6use crate::atproto::lexicon::com_atproto_repo::StrongRef; 7use crate::atproto::atrium_auth::AtriumOAuthManager; 8use crate::http::middleware_auth::WebSession; 9use anyhow::Result; 10use atrium_api::com::atproto::repo::{ 11 create_record, put_record, list_records, delete_record, 12}; 13use atrium_api::types::{ 14 string::{AtIdentifier, Did, Handle, Nsid, RecordKey, Cid}, 15 Unknown, LimitedNonZeroU8, 16 TryIntoUnknown, TryFromUnknown 17}; 18use std::str::FromStr; 19 20/// Request structure for creating AT Protocol records 21#[derive(Debug, Serialize, Deserialize, Clone)] 22#[serde(bound = "T: Serialize + DeserializeOwned")] 23pub struct CreateRecordRequest<T: DeserializeOwned> { 24 pub repo: String, 25 pub collection: String, 26 #[serde(skip_serializing_if = "Option::is_none", default, rename = "rkey")] 27 pub record_key: Option<String>, 28 #[serde(skip_serializing_if = "Option::is_none")] 29 pub validate: Option<bool>, 30 pub record: T, 31 #[serde(skip_serializing_if = "Option::is_none", default, rename = "swapCommit")] 32 pub swap_commit: Option<String>, 33} 34 35/// Request structure for updating AT Protocol records 36#[derive(Debug, Serialize, Deserialize, Clone)] 37#[serde(bound = "T: Serialize + DeserializeOwned")] 38pub struct PutRecordRequest<T: DeserializeOwned> { 39 pub repo: String, 40 pub collection: String, 41 #[serde(rename = "rkey")] 42 pub record_key: String, 43 #[serde(skip_serializing_if = "Option::is_none")] 44 pub validate: Option<bool>, 45 pub record: T, 46 #[serde(skip_serializing_if = "Option::is_none", default, rename = "swapCommit")] 47 pub swap_commit: Option<String>, 48 #[serde(skip_serializing_if = "Option::is_none", default, rename = "swapRecord")] 49 pub swap_record: Option<String>, 50} 51 52/// Request structure for deleting AT Protocol records 53#[derive(Debug, Serialize, Deserialize, Clone)] 54pub struct DeleteRecordRequest { 55 pub repo: String, 56 pub collection: String, 57 #[serde(rename = "rkey")] 58 pub record_key: String, 59 #[serde(skip_serializing_if = "Option::is_none", default, rename = "swapRecord")] 60 pub swap_record: Option<String>, 61 #[serde(skip_serializing_if = "Option::is_none", default, rename = "swapCommit")] 62 pub swap_commit: Option<String>, 63} 64 65/// Parameters for listing records 66#[derive(Debug, Serialize, Deserialize, Clone)] 67pub struct ListRecordsParams { 68 pub repo: String, 69 pub collection: String, 70 #[serde(skip_serializing_if = "Option::is_none")] 71 pub limit: Option<u32>, 72 #[serde(skip_serializing_if = "Option::is_none")] 73 pub cursor: Option<String>, 74 #[serde(skip_serializing_if = "Option::is_none", rename = "rkeyStart")] 75 pub rkey_start: Option<String>, 76 #[serde(skip_serializing_if = "Option::is_none", rename = "rkeyEnd")] 77 pub rkey_end: Option<String>, 78 #[serde(skip_serializing_if = "Option::is_none")] 79 pub reverse: Option<bool>, 80} 81 82/// Record in a list response 83#[derive(Debug, Serialize, Deserialize, Clone)] 84pub struct ListRecord<T> { 85 pub uri: String, 86 pub cid: String, 87 pub value: T, 88} 89 90/// Response structure for listing AT Protocol records 91#[derive(Debug, Serialize, Deserialize, Clone)] 92pub struct ListRecordsResponse<T> { 93 pub records: Vec<ListRecord<T>>, 94 #[serde(skip_serializing_if = "Option::is_none")] 95 pub cursor: Option<String>, 96} 97 98/// Simple OAuth-enabled PDS client using Atrium API 99/// 100/// This client bridges the existing OAuth session management with the Atrium AT Protocol API. 101pub struct OAuthPdsClient<'a> { 102 pub atrium_manager: &'a AtriumOAuthManager, 103} 104 105impl OAuthPdsClient<'_> { 106 /// Create a record using Atrium API 107 pub async fn create_record<T: DeserializeOwned + Serialize>( 108 &self, 109 web_session: &WebSession, 110 request: CreateRecordRequest<T>, 111 ) -> Result<StrongRef, anyhow::Error> { 112 // Get the authenticated Atrium client 113 let agent = self.atrium_manager 114 .get_atproto_client(&web_session.session_group) 115 .await 116 .map_err(|e| anyhow::anyhow!("Failed to get AT Protocol client: {}", e))?; 117 118 // Convert repo string to AtIdentifier (parse as DID or Handle) 119 let repo = if request.repo.starts_with("did:") { 120 // Parse as DID 121 let did = Did::new(request.repo.clone()) 122 .map_err(|e| anyhow::anyhow!("Invalid DID format: {}", e))?; 123 AtIdentifier::from(did) 124 } else { 125 // Parse as Handle 126 let handle = Handle::new(request.repo.clone()) 127 .map_err(|e| anyhow::anyhow!("Invalid handle format: {}", e))?; 128 AtIdentifier::from(handle) 129 }; 130 131 let collection = Nsid::new(request.collection.clone()) 132 .map_err(|e| anyhow::anyhow!("Invalid collection NSID: {}", e))?; 133 134 // Convert the record to Unknown type for Atrium 135 let record_unknown: Unknown = request.record 136 .try_into_unknown() 137 .map_err(|e| anyhow::anyhow!("Failed to convert record to Unknown: {}", e))?; 138 139 // Convert record key to proper type if provided 140 let rkey = request.record_key 141 .map(|k| RecordKey::new(k)) 142 .transpose() 143 .map_err(|e| anyhow::anyhow!("Invalid record key: {}", e))?; 144 145 // Convert swap_commit to Cid if provided 146 let swap_commit = request.swap_commit 147 .map(|c| Cid::from_str(&c)) 148 .transpose() 149 .map_err(|e| anyhow::anyhow!("Invalid swap commit CID: {}", e))?; 150 151 // Build the input data 152 let input_data = create_record::InputData { 153 repo, 154 collection, 155 rkey, 156 validate: request.validate, // Use the optional validate field directly 157 record: record_unknown, 158 swap_commit, 159 }; 160 161 // Make the API call using the correct method pattern 162 let response = agent 163 .api 164 .com 165 .atproto 166 .repo 167 .create_record(input_data.into()) 168 .await 169 .map_err(|e| { 170 tracing::error!("AT Protocol create_record failed: {:?}", e); 171 anyhow::anyhow!("Create record failed: {}", e) 172 })?; 173 174 // Return the strong reference 175 Ok(StrongRef { 176 uri: response.data.uri, 177 cid: response.data.cid.as_ref().to_string(), 178 }) 179 } 180 181 /// Update a record using Atrium API 182 pub async fn put_record<T: DeserializeOwned + Serialize>( 183 &self, 184 web_session: &WebSession, 185 request: PutRecordRequest<T>, 186 ) -> Result<StrongRef, anyhow::Error> { 187 // Get the authenticated Atrium client 188 let agent = self.atrium_manager 189 .get_atproto_client(&web_session.session_group) 190 .await 191 .map_err(|e| anyhow::anyhow!("Failed to get AT Protocol client: {}", e))?; 192 193 // Convert our request to Atrium format 194 let repo = if request.repo.starts_with("did:") { 195 let did = Did::new(request.repo.clone()) 196 .map_err(|e| anyhow::anyhow!("Invalid DID format: {}", e))?; 197 AtIdentifier::from(did) 198 } else { 199 let handle = Handle::new(request.repo.clone()) 200 .map_err(|e| anyhow::anyhow!("Invalid handle format: {}", e))?; 201 AtIdentifier::from(handle) 202 }; 203 204 let collection = Nsid::new(request.collection.clone()) 205 .map_err(|e| anyhow::anyhow!("Invalid collection NSID: {}", e))?; 206 207 let rkey = RecordKey::new(request.record_key) 208 .map_err(|e| anyhow::anyhow!("Invalid record key: {}", e))?; 209 210 // Convert the record to Unknown type for Atrium 211 let record_unknown: Unknown = request.record 212 .try_into_unknown() 213 .map_err(|e| anyhow::anyhow!("Failed to convert record to Unknown: {}", e))?; 214 215 // Convert optional fields 216 let swap_commit = request.swap_commit 217 .map(|c| Cid::from_str(&c)) 218 .transpose() 219 .map_err(|e| anyhow::anyhow!("Invalid swap commit CID: {}", e))?; 220 221 let swap_record = request.swap_record 222 .map(|c| Cid::from_str(&c)) 223 .transpose() 224 .map_err(|e| anyhow::anyhow!("Invalid swap record CID: {}", e))?; 225 226 // Build the input data 227 let input_data = put_record::InputData { 228 repo, 229 collection, 230 rkey, 231 validate: request.validate, // Use the optional validate field directly 232 record: record_unknown, 233 swap_commit, 234 swap_record, 235 }; 236 237 // Make the API call 238 let response = agent.api.com.atproto.repo.put_record(input_data.into()) 239 .await 240 .map_err(|e| { 241 tracing::error!("AT Protocol put_record failed: {:?}", e); 242 anyhow::anyhow!("Put record failed: {}", e) 243 })?; 244 245 // Return the strong reference 246 Ok(StrongRef { 247 uri: response.data.uri, 248 cid: response.data.cid.as_ref().to_string(), 249 }) 250 } 251 252 /// List records using Atrium API 253 pub async fn list_records<T: DeserializeOwned>( 254 &self, 255 web_session: &WebSession, 256 params: ListRecordsParams, 257 ) -> Result<ListRecordsResponse<T>, anyhow::Error> { 258 // Get the authenticated Atrium client 259 let agent = self.atrium_manager 260 .get_atproto_client(&web_session.session_group) 261 .await 262 .map_err(|e| anyhow::anyhow!("Failed to get AT Protocol client: {}", e))?; 263 264 // Convert our parameters to Atrium format 265 let repo = if params.repo.starts_with("did:") { 266 let did = Did::new(params.repo.clone()) 267 .map_err(|e| anyhow::anyhow!("Invalid DID format: {}", e))?; 268 AtIdentifier::from(did) 269 } else { 270 let handle = Handle::new(params.repo.clone()) 271 .map_err(|e| anyhow::anyhow!("Invalid handle format: {}", e))?; 272 AtIdentifier::from(handle) 273 }; 274 275 let collection = Nsid::new(params.collection.clone()) 276 .map_err(|e| anyhow::anyhow!("Invalid collection NSID: {}", e))?; 277 278 // Convert optional record keys (Note: ParametersData doesn't have rkey_start/rkey_end according to docs) 279 280 // Convert limit to the proper bounded type 281 let limit = params.limit 282 .map(|l| LimitedNonZeroU8::try_from(l as u8)) 283 .transpose() 284 .map_err(|e| anyhow::anyhow!("Invalid limit value: {}", e))?; 285 286 // Build the parameters data (only fields that exist according to docs) 287 let parameters_data = list_records::ParametersData { 288 repo, 289 collection, 290 limit, 291 cursor: params.cursor, 292 reverse: params.reverse, 293 }; 294 295 // Make the API call 296 let response = agent.api.com.atproto.repo.list_records(parameters_data.into()) 297 .await 298 .map_err(|e| anyhow::anyhow!("List records failed: {}", e))?; 299 300 // Convert the response to our format 301 let mut records = Vec::new(); 302 for record in response.data.records { 303 // Convert Unknown value back to our type using TryFromUnknown trait 304 let typed_value: T = T::try_from_unknown(record.data.value) 305 .map_err(|e| anyhow::anyhow!("Failed to deserialize record: {}", e))?; 306 307 records.push(ListRecord { 308 uri: record.data.uri, 309 cid: record.data.cid.as_ref().to_string(), 310 value: typed_value, 311 }); 312 } 313 314 Ok(ListRecordsResponse { 315 records, 316 cursor: response.data.cursor, 317 }) 318 } 319 320 /// Delete a record using Atrium API 321 pub async fn delete_record( 322 &self, 323 web_session: &WebSession, 324 request: DeleteRecordRequest, 325 ) -> Result<(), anyhow::Error> { 326 // Get the authenticated Atrium client 327 let agent = self.atrium_manager 328 .get_atproto_client(&web_session.session_group) 329 .await 330 .map_err(|e| anyhow::anyhow!("Failed to get AT Protocol client: {}", e))?; 331 332 // Convert our request to Atrium format 333 let repo = if request.repo.starts_with("did:") { 334 let did = Did::new(request.repo.clone()) 335 .map_err(|e| anyhow::anyhow!("Invalid DID format: {}", e))?; 336 AtIdentifier::from(did) 337 } else { 338 let handle = Handle::new(request.repo.clone()) 339 .map_err(|e| anyhow::anyhow!("Invalid handle format: {}", e))?; 340 AtIdentifier::from(handle) 341 }; 342 343 let collection = Nsid::new(request.collection.clone()) 344 .map_err(|e| anyhow::anyhow!("Invalid collection NSID: {}", e))?; 345 346 let rkey = RecordKey::new(request.record_key) 347 .map_err(|e| anyhow::anyhow!("Invalid record key: {}", e))?; 348 349 // Build the input data 350 let input_data = delete_record::InputData { 351 repo, 352 collection, 353 rkey, 354 swap_commit: None, 355 swap_record: None, 356 }; 357 358 // Make the API call 359 agent.api.com.atproto.repo.delete_record(input_data.into()) 360 .await 361 .map_err(|e| { 362 tracing::error!("AT Protocol delete_record failed: {:?}", e); 363 anyhow::anyhow!("Delete record failed: {}", e) 364 })?; 365 366 Ok(()) 367 } 368}