this repo has no description
1use super::super::{OAuthError, RefreshTokenState, TokenData}; 2use super::helpers::{from_json, to_json}; 3use chrono::{DateTime, Utc}; 4use sqlx::PgPool; 5 6pub enum RefreshTokenLookup { 7 Valid { db_id: i32, token_data: TokenData }, 8 InGracePeriod { db_id: i32, token_data: TokenData, rotated_at: DateTime<Utc> }, 9 Used { original_token_id: i32 }, 10 Expired { db_id: i32 }, 11 NotFound, 12} 13 14impl RefreshTokenLookup { 15 pub fn state(&self) -> RefreshTokenState { 16 match self { 17 RefreshTokenLookup::Valid { .. } => RefreshTokenState::Valid, 18 RefreshTokenLookup::InGracePeriod { rotated_at, .. } => { 19 RefreshTokenState::InGracePeriod { rotated_at: *rotated_at } 20 } 21 RefreshTokenLookup::Used { .. } => RefreshTokenState::Used { at: Utc::now() }, 22 RefreshTokenLookup::Expired { .. } => RefreshTokenState::Expired, 23 RefreshTokenLookup::NotFound => RefreshTokenState::Revoked, 24 } 25 } 26} 27 28pub async fn lookup_refresh_token( 29 pool: &PgPool, 30 refresh_token: &str, 31) -> Result<RefreshTokenLookup, OAuthError> { 32 if let Some(token_id) = check_refresh_token_used(pool, refresh_token).await? { 33 if let Some((db_id, token_data)) = get_token_by_previous_refresh_token(pool, refresh_token).await? { 34 let rotated_at = token_data.updated_at; 35 return Ok(RefreshTokenLookup::InGracePeriod { db_id, token_data, rotated_at }); 36 } 37 return Ok(RefreshTokenLookup::Used { original_token_id: token_id }); 38 } 39 40 match get_token_by_refresh_token(pool, refresh_token).await? { 41 Some((db_id, token_data)) => { 42 if token_data.expires_at < Utc::now() { 43 Ok(RefreshTokenLookup::Expired { db_id }) 44 } else { 45 Ok(RefreshTokenLookup::Valid { db_id, token_data }) 46 } 47 } 48 None => Ok(RefreshTokenLookup::NotFound), 49 } 50} 51 52pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> { 53 let client_auth_json = to_json(&data.client_auth)?; 54 let parameters_json = to_json(&data.parameters)?; 55 let row = sqlx::query!( 56 r#" 57 INSERT INTO oauth_token 58 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 59 device_id, parameters, details, code, current_refresh_token, scope, controller_did) 60 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) 61 RETURNING id 62 "#, 63 data.did, 64 data.token_id, 65 data.created_at, 66 data.updated_at, 67 data.expires_at, 68 data.client_id, 69 client_auth_json, 70 data.device_id, 71 parameters_json, 72 data.details, 73 data.code, 74 data.current_refresh_token, 75 data.scope, 76 data.controller_did, 77 ) 78 .fetch_one(pool) 79 .await?; 80 Ok(row.id) 81} 82 83pub async fn get_token_by_id( 84 pool: &PgPool, 85 token_id: &str, 86) -> Result<Option<TokenData>, OAuthError> { 87 let row = sqlx::query!( 88 r#" 89 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 90 device_id, parameters, details, code, current_refresh_token, scope, controller_did 91 FROM oauth_token 92 WHERE token_id = $1 93 "#, 94 token_id 95 ) 96 .fetch_optional(pool) 97 .await?; 98 match row { 99 Some(r) => Ok(Some(TokenData { 100 did: r.did, 101 token_id: r.token_id, 102 created_at: r.created_at, 103 updated_at: r.updated_at, 104 expires_at: r.expires_at, 105 client_id: r.client_id, 106 client_auth: from_json(r.client_auth)?, 107 device_id: r.device_id, 108 parameters: from_json(r.parameters)?, 109 details: r.details, 110 code: r.code, 111 current_refresh_token: r.current_refresh_token, 112 scope: r.scope, 113 controller_did: r.controller_did, 114 })), 115 None => Ok(None), 116 } 117} 118 119pub async fn get_token_by_refresh_token( 120 pool: &PgPool, 121 refresh_token: &str, 122) -> Result<Option<(i32, TokenData)>, OAuthError> { 123 let row = sqlx::query!( 124 r#" 125 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 126 device_id, parameters, details, code, current_refresh_token, scope, controller_did 127 FROM oauth_token 128 WHERE current_refresh_token = $1 129 "#, 130 refresh_token 131 ) 132 .fetch_optional(pool) 133 .await?; 134 match row { 135 Some(r) => Ok(Some(( 136 r.id, 137 TokenData { 138 did: r.did, 139 token_id: r.token_id, 140 created_at: r.created_at, 141 updated_at: r.updated_at, 142 expires_at: r.expires_at, 143 client_id: r.client_id, 144 client_auth: from_json(r.client_auth)?, 145 device_id: r.device_id, 146 parameters: from_json(r.parameters)?, 147 details: r.details, 148 code: r.code, 149 current_refresh_token: r.current_refresh_token, 150 scope: r.scope, 151 controller_did: r.controller_did, 152 }, 153 ))), 154 None => Ok(None), 155 } 156} 157 158pub async fn rotate_token( 159 pool: &PgPool, 160 old_db_id: i32, 161 new_token_id: &str, 162 new_refresh_token: &str, 163 new_expires_at: DateTime<Utc>, 164) -> Result<(), OAuthError> { 165 let mut tx = pool.begin().await?; 166 let old_refresh = sqlx::query_scalar!( 167 r#" 168 SELECT current_refresh_token FROM oauth_token WHERE id = $1 169 "#, 170 old_db_id 171 ) 172 .fetch_one(&mut *tx) 173 .await?; 174 if let Some(ref old_rt) = old_refresh { 175 sqlx::query!( 176 r#" 177 INSERT INTO oauth_used_refresh_token (refresh_token, token_id) 178 VALUES ($1, $2) 179 "#, 180 old_rt, 181 old_db_id 182 ) 183 .execute(&mut *tx) 184 .await?; 185 } 186 sqlx::query!( 187 r#" 188 UPDATE oauth_token 189 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW(), 190 previous_refresh_token = $5, rotated_at = NOW() 191 WHERE id = $1 192 "#, 193 old_db_id, 194 new_token_id, 195 new_refresh_token, 196 new_expires_at, 197 old_refresh 198 ) 199 .execute(&mut *tx) 200 .await?; 201 tx.commit().await?; 202 Ok(()) 203} 204 205pub async fn check_refresh_token_used( 206 pool: &PgPool, 207 refresh_token: &str, 208) -> Result<Option<i32>, OAuthError> { 209 let row = sqlx::query_scalar!( 210 r#" 211 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 212 "#, 213 refresh_token 214 ) 215 .fetch_optional(pool) 216 .await?; 217 Ok(row) 218} 219 220const REFRESH_GRACE_PERIOD_SECS: i64 = 60; 221 222pub async fn get_token_by_previous_refresh_token( 223 pool: &PgPool, 224 refresh_token: &str, 225) -> Result<Option<(i32, TokenData)>, OAuthError> { 226 let grace_cutoff = Utc::now() - chrono::Duration::seconds(REFRESH_GRACE_PERIOD_SECS); 227 let row = sqlx::query!( 228 r#" 229 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 230 device_id, parameters, details, code, current_refresh_token, scope, controller_did 231 FROM oauth_token 232 WHERE previous_refresh_token = $1 AND rotated_at > $2 233 "#, 234 refresh_token, 235 grace_cutoff 236 ) 237 .fetch_optional(pool) 238 .await?; 239 match row { 240 Some(r) => Ok(Some(( 241 r.id, 242 TokenData { 243 did: r.did, 244 token_id: r.token_id, 245 created_at: r.created_at, 246 updated_at: r.updated_at, 247 expires_at: r.expires_at, 248 client_id: r.client_id, 249 client_auth: from_json(r.client_auth)?, 250 device_id: r.device_id, 251 parameters: from_json(r.parameters)?, 252 details: r.details, 253 code: r.code, 254 current_refresh_token: r.current_refresh_token, 255 scope: r.scope, 256 controller_did: r.controller_did, 257 }, 258 ))), 259 None => Ok(None), 260 } 261} 262 263pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 264 sqlx::query!( 265 r#" 266 DELETE FROM oauth_token WHERE token_id = $1 267 "#, 268 token_id 269 ) 270 .execute(pool) 271 .await?; 272 Ok(()) 273} 274 275pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 276 sqlx::query!( 277 r#" 278 DELETE FROM oauth_token WHERE id = $1 279 "#, 280 db_id 281 ) 282 .execute(pool) 283 .await?; 284 Ok(()) 285} 286 287pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> { 288 let rows = sqlx::query!( 289 r#" 290 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 291 device_id, parameters, details, code, current_refresh_token, scope, controller_did 292 FROM oauth_token 293 WHERE did = $1 294 "#, 295 did 296 ) 297 .fetch_all(pool) 298 .await?; 299 let mut tokens = Vec::with_capacity(rows.len()); 300 for r in rows { 301 tokens.push(TokenData { 302 did: r.did, 303 token_id: r.token_id, 304 created_at: r.created_at, 305 updated_at: r.updated_at, 306 expires_at: r.expires_at, 307 client_id: r.client_id, 308 client_auth: from_json(r.client_auth)?, 309 device_id: r.device_id, 310 parameters: from_json(r.parameters)?, 311 details: r.details, 312 code: r.code, 313 current_refresh_token: r.current_refresh_token, 314 scope: r.scope, 315 controller_did: r.controller_did, 316 }); 317 } 318 Ok(tokens) 319} 320 321pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 322 let count = sqlx::query_scalar!( 323 r#" 324 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1 325 "#, 326 did 327 ) 328 .fetch_one(pool) 329 .await?; 330 Ok(count) 331} 332 333pub async fn delete_oldest_tokens_for_user( 334 pool: &PgPool, 335 did: &str, 336 keep_count: i64, 337) -> Result<u64, OAuthError> { 338 let result = sqlx::query!( 339 r#" 340 DELETE FROM oauth_token 341 WHERE id IN ( 342 SELECT id FROM oauth_token 343 WHERE did = $1 344 ORDER BY updated_at ASC 345 OFFSET $2 346 ) 347 "#, 348 did, 349 keep_count 350 ) 351 .execute(pool) 352 .await?; 353 Ok(result.rows_affected()) 354} 355 356const MAX_TOKENS_PER_USER: i64 = 100; 357 358pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 359 let count = count_tokens_for_user(pool, did).await?; 360 if count > MAX_TOKENS_PER_USER { 361 let to_keep = MAX_TOKENS_PER_USER - 1; 362 delete_oldest_tokens_for_user(pool, did, to_keep).await?; 363 } 364 Ok(()) 365} 366 367pub async fn revoke_tokens_for_client( 368 pool: &PgPool, 369 did: &str, 370 client_id: &str, 371) -> Result<u64, OAuthError> { 372 let result = sqlx::query!( 373 "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2", 374 did, 375 client_id 376 ) 377 .execute(pool) 378 .await?; 379 Ok(result.rows_affected()) 380} 381 382pub async fn revoke_tokens_for_controller( 383 pool: &PgPool, 384 delegated_did: &str, 385 controller_did: &str, 386) -> Result<u64, OAuthError> { 387 let result = sqlx::query!( 388 "DELETE FROM oauth_token WHERE did = $1 AND controller_did = $2", 389 delegated_did, 390 controller_did 391 ) 392 .execute(pool) 393 .await?; 394 Ok(result.rows_affected()) 395}