this repo has no description
1use super::super::{OAuthError, RefreshTokenState, TokenData}; 2use super::helpers::{from_json, to_json}; 3use chrono::{DateTime, Utc}; 4use sqlx::PgPool; 5 6pub enum RefreshTokenLookup { 7 Valid { 8 db_id: i32, 9 token_data: TokenData, 10 }, 11 InGracePeriod { 12 db_id: i32, 13 token_data: TokenData, 14 rotated_at: DateTime<Utc>, 15 }, 16 Used { 17 original_token_id: i32, 18 }, 19 Expired { 20 db_id: i32, 21 }, 22 NotFound, 23} 24 25impl RefreshTokenLookup { 26 pub fn state(&self) -> RefreshTokenState { 27 match self { 28 RefreshTokenLookup::Valid { .. } => RefreshTokenState::Valid, 29 RefreshTokenLookup::InGracePeriod { rotated_at, .. } => { 30 RefreshTokenState::InGracePeriod { 31 rotated_at: *rotated_at, 32 } 33 } 34 RefreshTokenLookup::Used { .. } => RefreshTokenState::Used { at: Utc::now() }, 35 RefreshTokenLookup::Expired { .. } => RefreshTokenState::Expired, 36 RefreshTokenLookup::NotFound => RefreshTokenState::Revoked, 37 } 38 } 39} 40 41pub async fn lookup_refresh_token( 42 pool: &PgPool, 43 refresh_token: &str, 44) -> Result<RefreshTokenLookup, OAuthError> { 45 if let Some(token_id) = check_refresh_token_used(pool, refresh_token).await? { 46 if let Some((db_id, token_data)) = 47 get_token_by_previous_refresh_token(pool, refresh_token).await? 48 { 49 let rotated_at = token_data.updated_at; 50 return Ok(RefreshTokenLookup::InGracePeriod { 51 db_id, 52 token_data, 53 rotated_at, 54 }); 55 } 56 return Ok(RefreshTokenLookup::Used { 57 original_token_id: token_id, 58 }); 59 } 60 61 match get_token_by_refresh_token(pool, refresh_token).await? { 62 Some((db_id, token_data)) => { 63 if token_data.expires_at < Utc::now() { 64 Ok(RefreshTokenLookup::Expired { db_id }) 65 } else { 66 Ok(RefreshTokenLookup::Valid { db_id, token_data }) 67 } 68 } 69 None => Ok(RefreshTokenLookup::NotFound), 70 } 71} 72 73pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> { 74 let client_auth_json = to_json(&data.client_auth)?; 75 let parameters_json = to_json(&data.parameters)?; 76 let row = sqlx::query!( 77 r#" 78 INSERT INTO oauth_token 79 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 80 device_id, parameters, details, code, current_refresh_token, scope, controller_did) 81 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) 82 RETURNING id 83 "#, 84 data.did, 85 data.token_id, 86 data.created_at, 87 data.updated_at, 88 data.expires_at, 89 data.client_id, 90 client_auth_json, 91 data.device_id, 92 parameters_json, 93 data.details, 94 data.code, 95 data.current_refresh_token, 96 data.scope, 97 data.controller_did, 98 ) 99 .fetch_one(pool) 100 .await?; 101 Ok(row.id) 102} 103 104pub async fn get_token_by_id( 105 pool: &PgPool, 106 token_id: &str, 107) -> Result<Option<TokenData>, OAuthError> { 108 let row = sqlx::query!( 109 r#" 110 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 111 device_id, parameters, details, code, current_refresh_token, scope, controller_did 112 FROM oauth_token 113 WHERE token_id = $1 114 "#, 115 token_id 116 ) 117 .fetch_optional(pool) 118 .await?; 119 match row { 120 Some(r) => Ok(Some(TokenData { 121 did: r.did, 122 token_id: r.token_id, 123 created_at: r.created_at, 124 updated_at: r.updated_at, 125 expires_at: r.expires_at, 126 client_id: r.client_id, 127 client_auth: from_json(r.client_auth)?, 128 device_id: r.device_id, 129 parameters: from_json(r.parameters)?, 130 details: r.details, 131 code: r.code, 132 current_refresh_token: r.current_refresh_token, 133 scope: r.scope, 134 controller_did: r.controller_did, 135 })), 136 None => Ok(None), 137 } 138} 139 140pub async fn get_token_by_refresh_token( 141 pool: &PgPool, 142 refresh_token: &str, 143) -> Result<Option<(i32, TokenData)>, OAuthError> { 144 let row = sqlx::query!( 145 r#" 146 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 147 device_id, parameters, details, code, current_refresh_token, scope, controller_did 148 FROM oauth_token 149 WHERE current_refresh_token = $1 150 "#, 151 refresh_token 152 ) 153 .fetch_optional(pool) 154 .await?; 155 match row { 156 Some(r) => Ok(Some(( 157 r.id, 158 TokenData { 159 did: r.did, 160 token_id: r.token_id, 161 created_at: r.created_at, 162 updated_at: r.updated_at, 163 expires_at: r.expires_at, 164 client_id: r.client_id, 165 client_auth: from_json(r.client_auth)?, 166 device_id: r.device_id, 167 parameters: from_json(r.parameters)?, 168 details: r.details, 169 code: r.code, 170 current_refresh_token: r.current_refresh_token, 171 scope: r.scope, 172 controller_did: r.controller_did, 173 }, 174 ))), 175 None => Ok(None), 176 } 177} 178 179pub async fn rotate_token( 180 pool: &PgPool, 181 old_db_id: i32, 182 new_token_id: &str, 183 new_refresh_token: &str, 184 new_expires_at: DateTime<Utc>, 185) -> Result<(), OAuthError> { 186 let mut tx = pool.begin().await?; 187 let old_refresh = sqlx::query_scalar!( 188 r#" 189 SELECT current_refresh_token FROM oauth_token WHERE id = $1 190 "#, 191 old_db_id 192 ) 193 .fetch_one(&mut *tx) 194 .await?; 195 if let Some(ref old_rt) = old_refresh { 196 sqlx::query!( 197 r#" 198 INSERT INTO oauth_used_refresh_token (refresh_token, token_id) 199 VALUES ($1, $2) 200 "#, 201 old_rt, 202 old_db_id 203 ) 204 .execute(&mut *tx) 205 .await?; 206 } 207 sqlx::query!( 208 r#" 209 UPDATE oauth_token 210 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW(), 211 previous_refresh_token = $5, rotated_at = NOW() 212 WHERE id = $1 213 "#, 214 old_db_id, 215 new_token_id, 216 new_refresh_token, 217 new_expires_at, 218 old_refresh 219 ) 220 .execute(&mut *tx) 221 .await?; 222 tx.commit().await?; 223 Ok(()) 224} 225 226pub async fn check_refresh_token_used( 227 pool: &PgPool, 228 refresh_token: &str, 229) -> Result<Option<i32>, OAuthError> { 230 let row = sqlx::query_scalar!( 231 r#" 232 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 233 "#, 234 refresh_token 235 ) 236 .fetch_optional(pool) 237 .await?; 238 Ok(row) 239} 240 241const REFRESH_GRACE_PERIOD_SECS: i64 = 60; 242 243pub async fn get_token_by_previous_refresh_token( 244 pool: &PgPool, 245 refresh_token: &str, 246) -> Result<Option<(i32, TokenData)>, OAuthError> { 247 let grace_cutoff = Utc::now() - chrono::Duration::seconds(REFRESH_GRACE_PERIOD_SECS); 248 let row = sqlx::query!( 249 r#" 250 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 251 device_id, parameters, details, code, current_refresh_token, scope, controller_did 252 FROM oauth_token 253 WHERE previous_refresh_token = $1 AND rotated_at > $2 254 "#, 255 refresh_token, 256 grace_cutoff 257 ) 258 .fetch_optional(pool) 259 .await?; 260 match row { 261 Some(r) => Ok(Some(( 262 r.id, 263 TokenData { 264 did: r.did, 265 token_id: r.token_id, 266 created_at: r.created_at, 267 updated_at: r.updated_at, 268 expires_at: r.expires_at, 269 client_id: r.client_id, 270 client_auth: from_json(r.client_auth)?, 271 device_id: r.device_id, 272 parameters: from_json(r.parameters)?, 273 details: r.details, 274 code: r.code, 275 current_refresh_token: r.current_refresh_token, 276 scope: r.scope, 277 controller_did: r.controller_did, 278 }, 279 ))), 280 None => Ok(None), 281 } 282} 283 284pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 285 sqlx::query!( 286 r#" 287 DELETE FROM oauth_token WHERE token_id = $1 288 "#, 289 token_id 290 ) 291 .execute(pool) 292 .await?; 293 Ok(()) 294} 295 296pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 297 sqlx::query!( 298 r#" 299 DELETE FROM oauth_token WHERE id = $1 300 "#, 301 db_id 302 ) 303 .execute(pool) 304 .await?; 305 Ok(()) 306} 307 308pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> { 309 let rows = sqlx::query!( 310 r#" 311 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 312 device_id, parameters, details, code, current_refresh_token, scope, controller_did 313 FROM oauth_token 314 WHERE did = $1 315 "#, 316 did 317 ) 318 .fetch_all(pool) 319 .await?; 320 let mut tokens = Vec::with_capacity(rows.len()); 321 for r in rows { 322 tokens.push(TokenData { 323 did: r.did, 324 token_id: r.token_id, 325 created_at: r.created_at, 326 updated_at: r.updated_at, 327 expires_at: r.expires_at, 328 client_id: r.client_id, 329 client_auth: from_json(r.client_auth)?, 330 device_id: r.device_id, 331 parameters: from_json(r.parameters)?, 332 details: r.details, 333 code: r.code, 334 current_refresh_token: r.current_refresh_token, 335 scope: r.scope, 336 controller_did: r.controller_did, 337 }); 338 } 339 Ok(tokens) 340} 341 342pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 343 let count = sqlx::query_scalar!( 344 r#" 345 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1 346 "#, 347 did 348 ) 349 .fetch_one(pool) 350 .await?; 351 Ok(count) 352} 353 354pub async fn delete_oldest_tokens_for_user( 355 pool: &PgPool, 356 did: &str, 357 keep_count: i64, 358) -> Result<u64, OAuthError> { 359 let result = sqlx::query!( 360 r#" 361 DELETE FROM oauth_token 362 WHERE id IN ( 363 SELECT id FROM oauth_token 364 WHERE did = $1 365 ORDER BY updated_at ASC 366 OFFSET $2 367 ) 368 "#, 369 did, 370 keep_count 371 ) 372 .execute(pool) 373 .await?; 374 Ok(result.rows_affected()) 375} 376 377const MAX_TOKENS_PER_USER: i64 = 100; 378 379pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 380 let count = count_tokens_for_user(pool, did).await?; 381 if count > MAX_TOKENS_PER_USER { 382 let to_keep = MAX_TOKENS_PER_USER - 1; 383 delete_oldest_tokens_for_user(pool, did, to_keep).await?; 384 } 385 Ok(()) 386} 387 388pub async fn revoke_tokens_for_client( 389 pool: &PgPool, 390 did: &str, 391 client_id: &str, 392) -> Result<u64, OAuthError> { 393 let result = sqlx::query!( 394 "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2", 395 did, 396 client_id 397 ) 398 .execute(pool) 399 .await?; 400 Ok(result.rows_affected()) 401} 402 403pub async fn revoke_tokens_for_controller( 404 pool: &PgPool, 405 delegated_did: &str, 406 controller_did: &str, 407) -> Result<u64, OAuthError> { 408 let result = sqlx::query!( 409 "DELETE FROM oauth_token WHERE did = $1 AND controller_did = $2", 410 delegated_did, 411 controller_did 412 ) 413 .execute(pool) 414 .await?; 415 Ok(result.rows_affected()) 416}