forked from
smokesignal.events/smokesignal
i18n+filtering fork - fluent-templates v2
1use std::borrow::Cow;
2
3use chrono::{DateTime, Utc};
4use serde_json::json;
5
6use crate::{
7 jose::jwk::WrappedJsonWebKey,
8 storage::{errors::StorageError, handle::model::Handle, StoragePool},
9};
10use model::{OAuthRequest, OAuthSession};
11
12pub struct OAuthRequestParams {
13 pub oauth_state: Cow<'static, str>,
14 pub issuer: Cow<'static, str>,
15 pub did: Cow<'static, str>,
16 pub nonce: Cow<'static, str>,
17 pub pkce_verifier: Cow<'static, str>,
18 pub secret_jwk_id: Cow<'static, str>,
19 pub dpop_jwk: Option<WrappedJsonWebKey>,
20 pub destination: Option<Cow<'static, str>>,
21 pub created_at: DateTime<Utc>,
22 pub expires_at: DateTime<Utc>,
23}
24
25pub async fn oauth_request_insert(
26 pool: &StoragePool,
27 params: OAuthRequestParams,
28) -> Result<(), StorageError> {
29 // Validate required input parameters
30 if params.oauth_state.trim().is_empty() {
31 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
32 "OAuth state cannot be empty".into(),
33 )));
34 }
35
36 if params.issuer.trim().is_empty() {
37 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
38 "Issuer cannot be empty".into(),
39 )));
40 }
41
42 if params.did.trim().is_empty() {
43 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
44 "DID cannot be empty".into(),
45 )));
46 }
47
48 if params.nonce.trim().is_empty() {
49 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
50 "Nonce cannot be empty".into(),
51 )));
52 }
53
54 if params.pkce_verifier.trim().is_empty() {
55 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
56 "PKCE verifier cannot be empty".into(),
57 )));
58 }
59
60 if params.secret_jwk_id.trim().is_empty() {
61 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
62 "Secret JWK ID cannot be empty".into(),
63 )));
64 }
65
66 let mut tx = pool
67 .begin()
68 .await
69 .map_err(StorageError::CannotBeginDatabaseTransaction)?;
70
71 let dpop_jwk_value = params
72 .dpop_jwk
73 .map(|jwk| json!(jwk))
74 .unwrap_or_else(|| json!({}));
75
76 sqlx::query("INSERT INTO oauth_requests (oauth_state, issuer, did, nonce, pkce_verifier, secret_jwk_id, dpop_jwk, destination, created_at, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)")
77 .bind(¶ms.oauth_state)
78 .bind(¶ms.issuer)
79 .bind(¶ms.did)
80 .bind(¶ms.nonce)
81 .bind(¶ms.pkce_verifier)
82 .bind(¶ms.secret_jwk_id)
83 .bind(dpop_jwk_value)
84 .bind(params.destination)
85 .bind(params.created_at)
86 .bind(params.expires_at)
87 .execute(tx.as_mut())
88 .await
89 .map_err(StorageError::UnableToExecuteQuery)?;
90
91 tx.commit()
92 .await
93 .map_err(StorageError::CannotCommitDatabaseTransaction)
94}
95
96pub async fn oauth_request_get(
97 pool: &StoragePool,
98 oauth_state: &str,
99) -> Result<OAuthRequest, StorageError> {
100 // Validate oauth_state is not empty
101 if oauth_state.trim().is_empty() {
102 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
103 "OAuth state cannot be empty".into(),
104 )));
105 }
106
107 let mut tx = pool
108 .begin()
109 .await
110 .map_err(StorageError::CannotBeginDatabaseTransaction)?;
111
112 let record =
113 sqlx::query_as::<_, OAuthRequest>("SELECT * FROM oauth_requests WHERE oauth_state = $1")
114 .bind(oauth_state)
115 .fetch_one(tx.as_mut())
116 .await
117 .map_err(|err| match err {
118 sqlx::Error::RowNotFound => StorageError::OAuthRequestNotFound,
119 other => StorageError::UnableToExecuteQuery(other),
120 })?;
121
122 tx.commit()
123 .await
124 .map_err(StorageError::CannotCommitDatabaseTransaction)?;
125
126 Ok(record)
127}
128
129pub async fn oauth_request_remove(
130 pool: &StoragePool,
131 oauth_state: &str,
132) -> Result<(), StorageError> {
133 // Validate oauth_state is not empty
134 if oauth_state.trim().is_empty() {
135 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
136 "OAuth state cannot be empty".into(),
137 )));
138 }
139
140 let mut tx = pool
141 .begin()
142 .await
143 .map_err(StorageError::CannotBeginDatabaseTransaction)?;
144
145 sqlx::query("DELETE FROM oauth_requests WHERE oauth_state = $1")
146 .bind(oauth_state)
147 .execute(tx.as_mut())
148 .await
149 .map_err(StorageError::UnableToExecuteQuery)?;
150
151 tx.commit()
152 .await
153 .map_err(StorageError::CannotCommitDatabaseTransaction)
154}
155
156pub struct OAuthSessionParams {
157 pub session_group: Cow<'static, str>,
158 pub access_token: Cow<'static, str>,
159 pub did: Cow<'static, str>,
160 pub issuer: Cow<'static, str>,
161 pub refresh_token: Cow<'static, str>,
162 pub secret_jwk_id: Cow<'static, str>,
163 pub dpop_jwk: WrappedJsonWebKey,
164 pub created_at: DateTime<Utc>,
165 pub access_token_expires_at: DateTime<Utc>,
166}
167
168pub async fn oauth_session_insert(
169 pool: &StoragePool,
170 params: OAuthSessionParams,
171) -> Result<(), StorageError> {
172 // Validate required input parameters
173 if params.session_group.trim().is_empty() {
174 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
175 "Session group cannot be empty".into(),
176 )));
177 }
178
179 if params.access_token.trim().is_empty() {
180 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
181 "Access token cannot be empty".into(),
182 )));
183 }
184
185 if params.did.trim().is_empty() {
186 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
187 "DID cannot be empty".into(),
188 )));
189 }
190
191 if params.issuer.trim().is_empty() {
192 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
193 "Issuer cannot be empty".into(),
194 )));
195 }
196
197 if params.refresh_token.trim().is_empty() {
198 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
199 "Refresh token cannot be empty".into(),
200 )));
201 }
202
203 if params.secret_jwk_id.trim().is_empty() {
204 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
205 "Secret JWK ID cannot be empty".into(),
206 )));
207 }
208
209 let mut tx = pool
210 .begin()
211 .await
212 .map_err(StorageError::CannotBeginDatabaseTransaction)?;
213
214 sqlx::query("INSERT INTO oauth_sessions (session_group, access_token, did, issuer, refresh_token, secret_jwk_id, dpop_jwk, created_at, access_token_expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)")
215 .bind(¶ms.session_group)
216 .bind(¶ms.access_token)
217 .bind(¶ms.did)
218 .bind(¶ms.issuer)
219 .bind(¶ms.refresh_token)
220 .bind(¶ms.secret_jwk_id)
221 .bind(json!(params.dpop_jwk))
222 .bind(params.created_at)
223 .bind(params.access_token_expires_at)
224 .execute(tx.as_mut())
225 .await
226 .map_err(StorageError::UnableToExecuteQuery)?;
227
228 tx.commit()
229 .await
230 .map_err(StorageError::CannotCommitDatabaseTransaction)
231}
232
233pub async fn oauth_session_update(
234 pool: &StoragePool,
235 session_group: Cow<'_, str>,
236 access_token: Cow<'_, str>,
237 refresh_token: Cow<'_, str>,
238 access_token_expires_at: DateTime<Utc>,
239) -> Result<(), StorageError> {
240 // Validate input parameters
241 if session_group.trim().is_empty() {
242 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
243 "Session group cannot be empty".into(),
244 )));
245 }
246
247 if access_token.trim().is_empty() {
248 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
249 "Access token cannot be empty".into(),
250 )));
251 }
252
253 if refresh_token.trim().is_empty() {
254 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
255 "Refresh token cannot be empty".into(),
256 )));
257 }
258
259 let mut tx = pool
260 .begin()
261 .await
262 .map_err(StorageError::CannotBeginDatabaseTransaction)?;
263
264 sqlx::query("UPDATE oauth_sessions SET access_token = $1, refresh_token = $2, access_token_expires_at = $3 WHERE session_group = $4")
265 .bind(access_token)
266 .bind(refresh_token)
267 .bind(access_token_expires_at)
268 .bind(session_group)
269 .execute(tx.as_mut())
270 .await
271 .map_err(StorageError::UnableToExecuteQuery)?;
272
273 tx.commit()
274 .await
275 .map_err(StorageError::CannotCommitDatabaseTransaction)
276}
277
278/// Delete an OAuth session by its session group.
279pub async fn oauth_session_delete(
280 pool: &StoragePool,
281 session_group: &str,
282) -> Result<(), StorageError> {
283 // Validate session_group is not empty
284 if session_group.trim().is_empty() {
285 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
286 "Session group cannot be empty".into(),
287 )));
288 }
289
290 let mut tx = pool
291 .begin()
292 .await
293 .map_err(StorageError::CannotBeginDatabaseTransaction)?;
294
295 sqlx::query("DELETE FROM oauth_sessions WHERE session_group = $1")
296 .bind(session_group)
297 .execute(tx.as_mut())
298 .await
299 .map_err(StorageError::UnableToExecuteQuery)?;
300
301 tx.commit()
302 .await
303 .map_err(StorageError::CannotCommitDatabaseTransaction)
304}
305
306/// Look up a web session by session group and optionally filter by DID.
307pub async fn web_session_lookup(
308 pool: &StoragePool,
309 session_group: &str,
310 did: Option<&str>,
311) -> Result<(Handle, OAuthSession), StorageError> {
312 // Validate session_group is not empty
313 if session_group.trim().is_empty() {
314 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
315 "Session group cannot be empty".into(),
316 )));
317 }
318
319 // If did is provided, validate it's not empty
320 if let Some(did_value) = did {
321 if did_value.trim().is_empty() {
322 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
323 "DID cannot be empty".into(),
324 )));
325 }
326 }
327
328 let mut tx = pool
329 .begin()
330 .await
331 .map_err(StorageError::CannotBeginDatabaseTransaction)?;
332
333 let oauth_session = match did {
334 Some(did_value) => {
335 sqlx::query_as::<_, OAuthSession>(
336 "SELECT * FROM oauth_sessions WHERE session_group = $1 AND did = $2 ORDER BY created_at DESC LIMIT 1",
337 )
338 .bind(session_group)
339 .bind(did_value)
340 .fetch_one(tx.as_mut())
341 .await
342 },
343 None => {
344 sqlx::query_as::<_, OAuthSession>(
345 "SELECT * FROM oauth_sessions WHERE session_group = $1 ORDER BY created_at DESC LIMIT 1",
346 )
347 .bind(session_group)
348 .fetch_one(tx.as_mut())
349 .await
350 }
351 }
352 .map_err(|err| match err {
353 sqlx::Error::RowNotFound => StorageError::WebSessionNotFound,
354 other => StorageError::UnableToExecuteQuery(other),
355 })?;
356
357 let did_for_handle = did.unwrap_or(&oauth_session.did);
358
359 let handle = sqlx::query_as::<_, Handle>("SELECT * FROM handles WHERE did = $1")
360 .bind(did_for_handle)
361 .fetch_one(tx.as_mut())
362 .await
363 .map_err(|err| match err {
364 sqlx::Error::RowNotFound => StorageError::HandleNotFound,
365 other => StorageError::UnableToExecuteQuery(other),
366 })?;
367
368 tx.commit()
369 .await
370 .map_err(StorageError::CannotCommitDatabaseTransaction)?;
371
372 Ok((handle, oauth_session))
373}
374
375pub mod model {
376 use anyhow::Error;
377 use chrono::{DateTime, Utc};
378 use p256::SecretKey;
379 use serde::Deserialize;
380 use sqlx::FromRow;
381
382 use crate::{
383 atproto::auth::SimpleOAuthSessionProvider, jose::jwk::WrappedJsonWebKey,
384 storage::errors::OAuthModelError,
385 };
386
387 #[derive(Clone, FromRow, Deserialize)]
388 pub struct OAuthRequest {
389 pub oauth_state: String,
390 pub issuer: String,
391 pub did: String,
392 pub nonce: String,
393 pub pkce_verifier: String,
394 pub secret_jwk_id: String,
395 pub destination: Option<String>,
396 pub dpop_jwk: sqlx::types::Json<WrappedJsonWebKey>,
397 pub created_at: DateTime<Utc>,
398 pub expires_at: DateTime<Utc>,
399 }
400
401 pub struct OAuthRequestState {
402 pub state: String,
403 pub nonce: String,
404 pub code_challenge: String,
405 }
406
407 #[derive(Clone, FromRow, Deserialize)]
408 pub struct OAuthSession {
409 pub session_group: String,
410 pub access_token: String,
411 pub did: String,
412 pub issuer: String,
413 pub refresh_token: String,
414 pub secret_jwk_id: String,
415 pub dpop_jwk: sqlx::types::Json<WrappedJsonWebKey>,
416 pub created_at: DateTime<Utc>,
417 pub access_token_expires_at: DateTime<Utc>,
418 }
419
420 impl TryFrom<OAuthSession> for SimpleOAuthSessionProvider {
421 type Error = Error;
422
423 fn try_from(value: OAuthSession) -> Result<Self, Self::Error> {
424 let dpop_secret = SecretKey::from_jwk(&value.dpop_jwk.jwk)
425 .map_err(OAuthModelError::DpopSecretFromJwkFailed)?;
426
427 Ok(SimpleOAuthSessionProvider {
428 access_token: value.access_token,
429 issuer: value.issuer,
430 dpop_secret,
431 })
432 }
433 }
434}
435
436#[cfg(test)]
437pub mod test {
438 use sqlx::PgPool;
439
440 use crate::{
441 jose,
442 storage::oauth::{
443 oauth_request_get, oauth_request_insert, oauth_request_remove, oauth_session_insert,
444 web_session_lookup, OAuthRequestParams, OAuthSessionParams,
445 },
446 };
447
448 #[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))]
449 async fn test_oauth_request(pool: PgPool) -> anyhow::Result<()> {
450 let dpop_jwk = jose::jwk::generate();
451 let created_at = chrono::Utc::now();
452 let expires_at = created_at + chrono::Duration::seconds(60 as i64);
453
454 let res = oauth_request_insert(
455 &pool,
456 OAuthRequestParams {
457 oauth_state: "oauth_state".to_string().into(),
458 issuer: "pds.examplepds.com".to_string().into(),
459 did: "did:plc:d5c1ed6d01421a67b96f68fa".to_string().into(),
460 nonce: "nonce".to_string().into(),
461 pkce_verifier: "pkce_verifier".to_string().into(),
462 secret_jwk_id: "secret_jwk_id".to_string().into(),
463 dpop_jwk: Some(dpop_jwk.clone()),
464 destination: None,
465 created_at,
466 expires_at,
467 },
468 )
469 .await;
470
471 assert!(!res.is_err());
472
473 let oauth_request = oauth_request_get(&pool, "oauth_state").await;
474 assert!(!oauth_request.is_err());
475 let oauth_request = oauth_request.unwrap();
476
477 assert_eq!(oauth_request.did, "did:plc:d5c1ed6d01421a67b96f68fa");
478 assert_eq!(oauth_request.dpop_jwk.as_ref(), &dpop_jwk);
479
480 let res = oauth_request_remove(&pool, "oauth_state").await;
481 assert!(!res.is_err());
482
483 {
484 let oauth_request = oauth_request_get(&pool, "oauth_state").await;
485 assert!(oauth_request.is_err());
486 }
487
488 Ok(())
489 }
490
491 #[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))]
492 async fn test_oauth_session(pool: PgPool) -> anyhow::Result<()> {
493 let dpop_jwk = jose::jwk::generate();
494
495 let session_group = ulid::Ulid::new().to_string();
496 let now = chrono::Utc::now();
497
498 let insert_session_res = oauth_session_insert(
499 &pool,
500 OAuthSessionParams {
501 session_group: session_group.clone().into(),
502 access_token: "access_token".to_string().into(),
503 did: "did:plc:d5c1ed6d01421a67b96f68fa".to_string().into(),
504 issuer: "pds.examplepds.com".to_string().into(),
505 refresh_token: "refresh_token".to_string().into(),
506 secret_jwk_id: "secret_jwk_id".to_string().into(),
507 dpop_jwk: dpop_jwk.clone(),
508 created_at: now,
509 access_token_expires_at: now + chrono::Duration::seconds(60 as i64),
510 },
511 )
512 .await;
513
514 assert!(!insert_session_res.is_err());
515
516 let web_session = web_session_lookup(
517 &pool,
518 &session_group,
519 Some("did:plc:d5c1ed6d01421a67b96f68fa"),
520 )
521 .await;
522 assert!(!web_session.is_err());
523
524 Ok(())
525 }
526}