this repo has no description
at main 12 kB view raw
1use super::super::{OAuthError, RefreshTokenState, TokenData}; 2use super::helpers::{from_json, to_json}; 3use chrono::{DateTime, Utc}; 4use sqlx::PgPool; 5 6pub enum RefreshTokenLookup { 7 Valid { 8 db_id: i32, 9 token_data: TokenData, 10 }, 11 InGracePeriod { 12 db_id: i32, 13 token_data: TokenData, 14 rotated_at: DateTime<Utc>, 15 }, 16 Used { 17 original_token_id: i32, 18 }, 19 Expired { 20 db_id: i32, 21 }, 22 NotFound, 23} 24 25impl RefreshTokenLookup { 26 pub fn state(&self) -> RefreshTokenState { 27 match self { 28 RefreshTokenLookup::Valid { .. } => RefreshTokenState::Valid, 29 RefreshTokenLookup::InGracePeriod { rotated_at, .. } => { 30 RefreshTokenState::InGracePeriod { 31 rotated_at: *rotated_at, 32 } 33 } 34 RefreshTokenLookup::Used { .. } => RefreshTokenState::Used { at: Utc::now() }, 35 RefreshTokenLookup::Expired { .. } => RefreshTokenState::Expired, 36 RefreshTokenLookup::NotFound => RefreshTokenState::Revoked, 37 } 38 } 39} 40 41pub async fn lookup_refresh_token( 42 pool: &PgPool, 43 refresh_token: &str, 44) -> Result<RefreshTokenLookup, OAuthError> { 45 if let Some(token_id) = check_refresh_token_used(pool, refresh_token).await? { 46 if let Some((db_id, token_data)) = 47 get_token_by_previous_refresh_token(pool, refresh_token).await? 48 { 49 let rotated_at = token_data.updated_at; 50 return Ok(RefreshTokenLookup::InGracePeriod { 51 db_id, 52 token_data, 53 rotated_at, 54 }); 55 } 56 return Ok(RefreshTokenLookup::Used { 57 original_token_id: token_id, 58 }); 59 } 60 61 match get_token_by_refresh_token(pool, refresh_token).await? { 62 Some((db_id, token_data)) => { 63 if token_data.expires_at < Utc::now() { 64 Ok(RefreshTokenLookup::Expired { db_id }) 65 } else { 66 Ok(RefreshTokenLookup::Valid { db_id, token_data }) 67 } 68 } 69 None => Ok(RefreshTokenLookup::NotFound), 70 } 71} 72 73pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> { 74 let client_auth_json = to_json(&data.client_auth)?; 75 let parameters_json = to_json(&data.parameters)?; 76 let row = sqlx::query!( 77 r#" 78 INSERT INTO oauth_token 79 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 80 device_id, parameters, details, code, current_refresh_token, scope, controller_did) 81 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) 82 RETURNING id 83 "#, 84 data.did, 85 data.token_id, 86 data.created_at, 87 data.updated_at, 88 data.expires_at, 89 data.client_id, 90 client_auth_json, 91 data.device_id, 92 parameters_json, 93 data.details, 94 data.code, 95 data.current_refresh_token, 96 data.scope, 97 data.controller_did, 98 ) 99 .fetch_one(pool) 100 .await?; 101 Ok(row.id) 102} 103 104pub async fn get_token_by_id( 105 pool: &PgPool, 106 token_id: &str, 107) -> Result<Option<TokenData>, OAuthError> { 108 let row = sqlx::query!( 109 r#" 110 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 111 device_id, parameters, details, code, current_refresh_token, scope, controller_did 112 FROM oauth_token 113 WHERE token_id = $1 114 "#, 115 token_id 116 ) 117 .fetch_optional(pool) 118 .await?; 119 match row { 120 Some(r) => Ok(Some(TokenData { 121 did: r.did, 122 token_id: r.token_id, 123 created_at: r.created_at, 124 updated_at: r.updated_at, 125 expires_at: r.expires_at, 126 client_id: r.client_id, 127 client_auth: from_json(r.client_auth)?, 128 device_id: r.device_id, 129 parameters: from_json(r.parameters)?, 130 details: r.details, 131 code: r.code, 132 current_refresh_token: r.current_refresh_token, 133 scope: r.scope, 134 controller_did: r.controller_did, 135 })), 136 None => Ok(None), 137 } 138} 139 140pub async fn get_token_by_refresh_token( 141 pool: &PgPool, 142 refresh_token: &str, 143) -> Result<Option<(i32, TokenData)>, OAuthError> { 144 let row = sqlx::query!( 145 r#" 146 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 147 device_id, parameters, details, code, current_refresh_token, scope, controller_did 148 FROM oauth_token 149 WHERE current_refresh_token = $1 150 "#, 151 refresh_token 152 ) 153 .fetch_optional(pool) 154 .await?; 155 match row { 156 Some(r) => Ok(Some(( 157 r.id, 158 TokenData { 159 did: r.did, 160 token_id: r.token_id, 161 created_at: r.created_at, 162 updated_at: r.updated_at, 163 expires_at: r.expires_at, 164 client_id: r.client_id, 165 client_auth: from_json(r.client_auth)?, 166 device_id: r.device_id, 167 parameters: from_json(r.parameters)?, 168 details: r.details, 169 code: r.code, 170 current_refresh_token: r.current_refresh_token, 171 scope: r.scope, 172 controller_did: r.controller_did, 173 }, 174 ))), 175 None => Ok(None), 176 } 177} 178 179pub async fn rotate_token( 180 pool: &PgPool, 181 old_db_id: i32, 182 new_refresh_token: &str, 183 new_expires_at: DateTime<Utc>, 184) -> Result<(), OAuthError> { 185 let mut tx = pool.begin().await?; 186 let old_refresh = sqlx::query_scalar!( 187 r#" 188 SELECT current_refresh_token FROM oauth_token WHERE id = $1 189 "#, 190 old_db_id 191 ) 192 .fetch_one(&mut *tx) 193 .await?; 194 if let Some(ref old_rt) = old_refresh { 195 sqlx::query!( 196 r#" 197 INSERT INTO oauth_used_refresh_token (refresh_token, token_id) 198 VALUES ($1, $2) 199 "#, 200 old_rt, 201 old_db_id 202 ) 203 .execute(&mut *tx) 204 .await?; 205 } 206 sqlx::query!( 207 r#" 208 UPDATE oauth_token 209 SET current_refresh_token = $2, expires_at = $3, updated_at = NOW(), 210 previous_refresh_token = $4, rotated_at = NOW() 211 WHERE id = $1 212 "#, 213 old_db_id, 214 new_refresh_token, 215 new_expires_at, 216 old_refresh 217 ) 218 .execute(&mut *tx) 219 .await?; 220 tx.commit().await?; 221 Ok(()) 222} 223 224pub async fn check_refresh_token_used( 225 pool: &PgPool, 226 refresh_token: &str, 227) -> Result<Option<i32>, OAuthError> { 228 let row = sqlx::query_scalar!( 229 r#" 230 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 231 "#, 232 refresh_token 233 ) 234 .fetch_optional(pool) 235 .await?; 236 Ok(row) 237} 238 239const REFRESH_GRACE_PERIOD_SECS: i64 = 60; 240 241pub async fn get_token_by_previous_refresh_token( 242 pool: &PgPool, 243 refresh_token: &str, 244) -> Result<Option<(i32, TokenData)>, OAuthError> { 245 let grace_cutoff = Utc::now() - chrono::Duration::seconds(REFRESH_GRACE_PERIOD_SECS); 246 let row = sqlx::query!( 247 r#" 248 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 249 device_id, parameters, details, code, current_refresh_token, scope, controller_did 250 FROM oauth_token 251 WHERE previous_refresh_token = $1 AND rotated_at > $2 252 "#, 253 refresh_token, 254 grace_cutoff 255 ) 256 .fetch_optional(pool) 257 .await?; 258 match row { 259 Some(r) => Ok(Some(( 260 r.id, 261 TokenData { 262 did: r.did, 263 token_id: r.token_id, 264 created_at: r.created_at, 265 updated_at: r.updated_at, 266 expires_at: r.expires_at, 267 client_id: r.client_id, 268 client_auth: from_json(r.client_auth)?, 269 device_id: r.device_id, 270 parameters: from_json(r.parameters)?, 271 details: r.details, 272 code: r.code, 273 current_refresh_token: r.current_refresh_token, 274 scope: r.scope, 275 controller_did: r.controller_did, 276 }, 277 ))), 278 None => Ok(None), 279 } 280} 281 282pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 283 sqlx::query!( 284 r#" 285 DELETE FROM oauth_token WHERE token_id = $1 286 "#, 287 token_id 288 ) 289 .execute(pool) 290 .await?; 291 Ok(()) 292} 293 294pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 295 sqlx::query!( 296 r#" 297 DELETE FROM oauth_token WHERE id = $1 298 "#, 299 db_id 300 ) 301 .execute(pool) 302 .await?; 303 Ok(()) 304} 305 306pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> { 307 let rows = sqlx::query!( 308 r#" 309 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 310 device_id, parameters, details, code, current_refresh_token, scope, controller_did 311 FROM oauth_token 312 WHERE did = $1 313 "#, 314 did 315 ) 316 .fetch_all(pool) 317 .await?; 318 rows.into_iter() 319 .map(|r| { 320 Ok(TokenData { 321 did: r.did, 322 token_id: r.token_id, 323 created_at: r.created_at, 324 updated_at: r.updated_at, 325 expires_at: r.expires_at, 326 client_id: r.client_id, 327 client_auth: from_json(r.client_auth)?, 328 device_id: r.device_id, 329 parameters: from_json(r.parameters)?, 330 details: r.details, 331 code: r.code, 332 current_refresh_token: r.current_refresh_token, 333 scope: r.scope, 334 controller_did: r.controller_did, 335 }) 336 }) 337 .collect() 338} 339 340pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 341 let count = sqlx::query_scalar!( 342 r#" 343 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1 344 "#, 345 did 346 ) 347 .fetch_one(pool) 348 .await?; 349 Ok(count) 350} 351 352pub async fn delete_oldest_tokens_for_user( 353 pool: &PgPool, 354 did: &str, 355 keep_count: i64, 356) -> Result<u64, OAuthError> { 357 let result = sqlx::query!( 358 r#" 359 DELETE FROM oauth_token 360 WHERE id IN ( 361 SELECT id FROM oauth_token 362 WHERE did = $1 363 ORDER BY updated_at ASC 364 OFFSET $2 365 ) 366 "#, 367 did, 368 keep_count 369 ) 370 .execute(pool) 371 .await?; 372 Ok(result.rows_affected()) 373} 374 375const MAX_TOKENS_PER_USER: i64 = 100; 376 377pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 378 let count = count_tokens_for_user(pool, did).await?; 379 if count > MAX_TOKENS_PER_USER { 380 let to_keep = MAX_TOKENS_PER_USER - 1; 381 delete_oldest_tokens_for_user(pool, did, to_keep).await?; 382 } 383 Ok(()) 384} 385 386pub async fn revoke_tokens_for_client( 387 pool: &PgPool, 388 did: &str, 389 client_id: &str, 390) -> Result<u64, OAuthError> { 391 let result = sqlx::query!( 392 "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2", 393 did, 394 client_id 395 ) 396 .execute(pool) 397 .await?; 398 Ok(result.rows_affected()) 399} 400 401pub async fn revoke_tokens_for_controller( 402 pool: &PgPool, 403 delegated_did: &str, 404 controller_did: &str, 405) -> Result<u64, OAuthError> { 406 let result = sqlx::query!( 407 "DELETE FROM oauth_token WHERE did = $1 AND controller_did = $2", 408 delegated_did, 409 controller_did 410 ) 411 .execute(pool) 412 .await?; 413 Ok(result.rows_affected()) 414}