Heavily customized version of smokesignal - https://whtwnd.com/kayrozen.com/3lpwe4ymowg2t
1//! AT Protocol client types and utilities
2//!
3//! This module provides client types for AT Protocol operations using the Atrium API.
4
5use serde::{de::DeserializeOwned, Deserialize, Serialize};
6use crate::atproto::lexicon::com_atproto_repo::StrongRef;
7use crate::atproto::atrium_auth::AtriumOAuthManager;
8use crate::http::middleware_auth::WebSession;
9use anyhow::Result;
10use atrium_api::com::atproto::repo::{
11 create_record, put_record, list_records, delete_record,
12};
13use atrium_api::types::{
14 string::{AtIdentifier, Did, Handle, Nsid, RecordKey, Cid},
15 Unknown, LimitedNonZeroU8,
16 TryIntoUnknown, TryFromUnknown
17};
18use std::str::FromStr;
19
20/// Request structure for creating AT Protocol records
21#[derive(Debug, Serialize, Deserialize, Clone)]
22#[serde(bound = "T: Serialize + DeserializeOwned")]
23pub struct CreateRecordRequest<T: DeserializeOwned> {
24 pub repo: String,
25 pub collection: String,
26 #[serde(skip_serializing_if = "Option::is_none", default, rename = "rkey")]
27 pub record_key: Option<String>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub validate: Option<bool>,
30 pub record: T,
31 #[serde(skip_serializing_if = "Option::is_none", default, rename = "swapCommit")]
32 pub swap_commit: Option<String>,
33}
34
35/// Request structure for updating AT Protocol records
36#[derive(Debug, Serialize, Deserialize, Clone)]
37#[serde(bound = "T: Serialize + DeserializeOwned")]
38pub struct PutRecordRequest<T: DeserializeOwned> {
39 pub repo: String,
40 pub collection: String,
41 #[serde(rename = "rkey")]
42 pub record_key: String,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub validate: Option<bool>,
45 pub record: T,
46 #[serde(skip_serializing_if = "Option::is_none", default, rename = "swapCommit")]
47 pub swap_commit: Option<String>,
48 #[serde(skip_serializing_if = "Option::is_none", default, rename = "swapRecord")]
49 pub swap_record: Option<String>,
50}
51
52/// Request structure for deleting AT Protocol records
53#[derive(Debug, Serialize, Deserialize, Clone)]
54pub struct DeleteRecordRequest {
55 pub repo: String,
56 pub collection: String,
57 #[serde(rename = "rkey")]
58 pub record_key: String,
59 #[serde(skip_serializing_if = "Option::is_none", default, rename = "swapRecord")]
60 pub swap_record: Option<String>,
61 #[serde(skip_serializing_if = "Option::is_none", default, rename = "swapCommit")]
62 pub swap_commit: Option<String>,
63}
64
65/// Parameters for listing records
66#[derive(Debug, Serialize, Deserialize, Clone)]
67pub struct ListRecordsParams {
68 pub repo: String,
69 pub collection: String,
70 #[serde(skip_serializing_if = "Option::is_none")]
71 pub limit: Option<u32>,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub cursor: Option<String>,
74 #[serde(skip_serializing_if = "Option::is_none", rename = "rkeyStart")]
75 pub rkey_start: Option<String>,
76 #[serde(skip_serializing_if = "Option::is_none", rename = "rkeyEnd")]
77 pub rkey_end: Option<String>,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub reverse: Option<bool>,
80}
81
82/// Record in a list response
83#[derive(Debug, Serialize, Deserialize, Clone)]
84pub struct ListRecord<T> {
85 pub uri: String,
86 pub cid: String,
87 pub value: T,
88}
89
90/// Response structure for listing AT Protocol records
91#[derive(Debug, Serialize, Deserialize, Clone)]
92pub struct ListRecordsResponse<T> {
93 pub records: Vec<ListRecord<T>>,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub cursor: Option<String>,
96}
97
98/// Simple OAuth-enabled PDS client using Atrium API
99///
100/// This client bridges the existing OAuth session management with the Atrium AT Protocol API.
101pub struct OAuthPdsClient<'a> {
102 pub atrium_manager: &'a AtriumOAuthManager,
103}
104
105impl OAuthPdsClient<'_> {
106 /// Create a record using Atrium API
107 pub async fn create_record<T: DeserializeOwned + Serialize>(
108 &self,
109 web_session: &WebSession,
110 request: CreateRecordRequest<T>,
111 ) -> Result<StrongRef, anyhow::Error> {
112 // Get the authenticated Atrium client
113 let agent = self.atrium_manager
114 .get_atproto_client(&web_session.session_group)
115 .await
116 .map_err(|e| anyhow::anyhow!("Failed to get AT Protocol client: {}", e))?;
117
118 // Convert repo string to AtIdentifier (parse as DID or Handle)
119 let repo = if request.repo.starts_with("did:") {
120 // Parse as DID
121 let did = Did::new(request.repo.clone())
122 .map_err(|e| anyhow::anyhow!("Invalid DID format: {}", e))?;
123 AtIdentifier::from(did)
124 } else {
125 // Parse as Handle
126 let handle = Handle::new(request.repo.clone())
127 .map_err(|e| anyhow::anyhow!("Invalid handle format: {}", e))?;
128 AtIdentifier::from(handle)
129 };
130
131 let collection = Nsid::new(request.collection.clone())
132 .map_err(|e| anyhow::anyhow!("Invalid collection NSID: {}", e))?;
133
134 // Convert the record to Unknown type for Atrium
135 let record_unknown: Unknown = request.record
136 .try_into_unknown()
137 .map_err(|e| anyhow::anyhow!("Failed to convert record to Unknown: {}", e))?;
138
139 // Convert record key to proper type if provided
140 let rkey = request.record_key
141 .map(|k| RecordKey::new(k))
142 .transpose()
143 .map_err(|e| anyhow::anyhow!("Invalid record key: {}", e))?;
144
145 // Convert swap_commit to Cid if provided
146 let swap_commit = request.swap_commit
147 .map(|c| Cid::from_str(&c))
148 .transpose()
149 .map_err(|e| anyhow::anyhow!("Invalid swap commit CID: {}", e))?;
150
151 // Build the input data
152 let input_data = create_record::InputData {
153 repo,
154 collection,
155 rkey,
156 validate: request.validate, // Use the optional validate field directly
157 record: record_unknown,
158 swap_commit,
159 };
160
161 // Make the API call using the correct method pattern
162 let response = agent
163 .api
164 .com
165 .atproto
166 .repo
167 .create_record(input_data.into())
168 .await
169 .map_err(|e| {
170 tracing::error!("AT Protocol create_record failed: {:?}", e);
171 anyhow::anyhow!("Create record failed: {}", e)
172 })?;
173
174 // Return the strong reference
175 Ok(StrongRef {
176 uri: response.data.uri,
177 cid: response.data.cid.as_ref().to_string(),
178 })
179 }
180
181 /// Update a record using Atrium API
182 pub async fn put_record<T: DeserializeOwned + Serialize>(
183 &self,
184 web_session: &WebSession,
185 request: PutRecordRequest<T>,
186 ) -> Result<StrongRef, anyhow::Error> {
187 // Get the authenticated Atrium client
188 let agent = self.atrium_manager
189 .get_atproto_client(&web_session.session_group)
190 .await
191 .map_err(|e| anyhow::anyhow!("Failed to get AT Protocol client: {}", e))?;
192
193 // Convert our request to Atrium format
194 let repo = if request.repo.starts_with("did:") {
195 let did = Did::new(request.repo.clone())
196 .map_err(|e| anyhow::anyhow!("Invalid DID format: {}", e))?;
197 AtIdentifier::from(did)
198 } else {
199 let handle = Handle::new(request.repo.clone())
200 .map_err(|e| anyhow::anyhow!("Invalid handle format: {}", e))?;
201 AtIdentifier::from(handle)
202 };
203
204 let collection = Nsid::new(request.collection.clone())
205 .map_err(|e| anyhow::anyhow!("Invalid collection NSID: {}", e))?;
206
207 let rkey = RecordKey::new(request.record_key)
208 .map_err(|e| anyhow::anyhow!("Invalid record key: {}", e))?;
209
210 // Convert the record to Unknown type for Atrium
211 let record_unknown: Unknown = request.record
212 .try_into_unknown()
213 .map_err(|e| anyhow::anyhow!("Failed to convert record to Unknown: {}", e))?;
214
215 // Convert optional fields
216 let swap_commit = request.swap_commit
217 .map(|c| Cid::from_str(&c))
218 .transpose()
219 .map_err(|e| anyhow::anyhow!("Invalid swap commit CID: {}", e))?;
220
221 let swap_record = request.swap_record
222 .map(|c| Cid::from_str(&c))
223 .transpose()
224 .map_err(|e| anyhow::anyhow!("Invalid swap record CID: {}", e))?;
225
226 // Build the input data
227 let input_data = put_record::InputData {
228 repo,
229 collection,
230 rkey,
231 validate: request.validate, // Use the optional validate field directly
232 record: record_unknown,
233 swap_commit,
234 swap_record,
235 };
236
237 // Make the API call
238 let response = agent.api.com.atproto.repo.put_record(input_data.into())
239 .await
240 .map_err(|e| {
241 tracing::error!("AT Protocol put_record failed: {:?}", e);
242 anyhow::anyhow!("Put record failed: {}", e)
243 })?;
244
245 // Return the strong reference
246 Ok(StrongRef {
247 uri: response.data.uri,
248 cid: response.data.cid.as_ref().to_string(),
249 })
250 }
251
252 /// List records using Atrium API
253 pub async fn list_records<T: DeserializeOwned>(
254 &self,
255 web_session: &WebSession,
256 params: ListRecordsParams,
257 ) -> Result<ListRecordsResponse<T>, anyhow::Error> {
258 // Get the authenticated Atrium client
259 let agent = self.atrium_manager
260 .get_atproto_client(&web_session.session_group)
261 .await
262 .map_err(|e| anyhow::anyhow!("Failed to get AT Protocol client: {}", e))?;
263
264 // Convert our parameters to Atrium format
265 let repo = if params.repo.starts_with("did:") {
266 let did = Did::new(params.repo.clone())
267 .map_err(|e| anyhow::anyhow!("Invalid DID format: {}", e))?;
268 AtIdentifier::from(did)
269 } else {
270 let handle = Handle::new(params.repo.clone())
271 .map_err(|e| anyhow::anyhow!("Invalid handle format: {}", e))?;
272 AtIdentifier::from(handle)
273 };
274
275 let collection = Nsid::new(params.collection.clone())
276 .map_err(|e| anyhow::anyhow!("Invalid collection NSID: {}", e))?;
277
278 // Convert optional record keys (Note: ParametersData doesn't have rkey_start/rkey_end according to docs)
279
280 // Convert limit to the proper bounded type
281 let limit = params.limit
282 .map(|l| LimitedNonZeroU8::try_from(l as u8))
283 .transpose()
284 .map_err(|e| anyhow::anyhow!("Invalid limit value: {}", e))?;
285
286 // Build the parameters data (only fields that exist according to docs)
287 let parameters_data = list_records::ParametersData {
288 repo,
289 collection,
290 limit,
291 cursor: params.cursor,
292 reverse: params.reverse,
293 };
294
295 // Make the API call
296 let response = agent.api.com.atproto.repo.list_records(parameters_data.into())
297 .await
298 .map_err(|e| anyhow::anyhow!("List records failed: {}", e))?;
299
300 // Convert the response to our format
301 let mut records = Vec::new();
302 for record in response.data.records {
303 // Convert Unknown value back to our type using TryFromUnknown trait
304 let typed_value: T = T::try_from_unknown(record.data.value)
305 .map_err(|e| anyhow::anyhow!("Failed to deserialize record: {}", e))?;
306
307 records.push(ListRecord {
308 uri: record.data.uri,
309 cid: record.data.cid.as_ref().to_string(),
310 value: typed_value,
311 });
312 }
313
314 Ok(ListRecordsResponse {
315 records,
316 cursor: response.data.cursor,
317 })
318 }
319
320 /// Delete a record using Atrium API
321 pub async fn delete_record(
322 &self,
323 web_session: &WebSession,
324 request: DeleteRecordRequest,
325 ) -> Result<(), anyhow::Error> {
326 // Get the authenticated Atrium client
327 let agent = self.atrium_manager
328 .get_atproto_client(&web_session.session_group)
329 .await
330 .map_err(|e| anyhow::anyhow!("Failed to get AT Protocol client: {}", e))?;
331
332 // Convert our request to Atrium format
333 let repo = if request.repo.starts_with("did:") {
334 let did = Did::new(request.repo.clone())
335 .map_err(|e| anyhow::anyhow!("Invalid DID format: {}", e))?;
336 AtIdentifier::from(did)
337 } else {
338 let handle = Handle::new(request.repo.clone())
339 .map_err(|e| anyhow::anyhow!("Invalid handle format: {}", e))?;
340 AtIdentifier::from(handle)
341 };
342
343 let collection = Nsid::new(request.collection.clone())
344 .map_err(|e| anyhow::anyhow!("Invalid collection NSID: {}", e))?;
345
346 let rkey = RecordKey::new(request.record_key)
347 .map_err(|e| anyhow::anyhow!("Invalid record key: {}", e))?;
348
349 // Build the input data
350 let input_data = delete_record::InputData {
351 repo,
352 collection,
353 rkey,
354 swap_commit: None,
355 swap_record: None,
356 };
357
358 // Make the API call
359 agent.api.com.atproto.repo.delete_record(input_data.into())
360 .await
361 .map_err(|e| {
362 tracing::error!("AT Protocol delete_record failed: {:?}", e);
363 anyhow::anyhow!("Delete record failed: {}", e)
364 })?;
365
366 Ok(())
367 }
368}