this repo has no description
1use super::super::{OAuthError, TokenData}; 2use super::helpers::{from_json, to_json}; 3use chrono::{DateTime, Utc}; 4use sqlx::PgPool; 5 6pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> { 7 let client_auth_json = to_json(&data.client_auth)?; 8 let parameters_json = to_json(&data.parameters)?; 9 let row = sqlx::query!( 10 r#" 11 INSERT INTO oauth_token 12 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 13 device_id, parameters, details, code, current_refresh_token, scope) 14 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) 15 RETURNING id 16 "#, 17 data.did, 18 data.token_id, 19 data.created_at, 20 data.updated_at, 21 data.expires_at, 22 data.client_id, 23 client_auth_json, 24 data.device_id, 25 parameters_json, 26 data.details, 27 data.code, 28 data.current_refresh_token, 29 data.scope, 30 ) 31 .fetch_one(pool) 32 .await?; 33 Ok(row.id) 34} 35 36pub async fn get_token_by_id( 37 pool: &PgPool, 38 token_id: &str, 39) -> Result<Option<TokenData>, OAuthError> { 40 let row = sqlx::query!( 41 r#" 42 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 43 device_id, parameters, details, code, current_refresh_token, scope 44 FROM oauth_token 45 WHERE token_id = $1 46 "#, 47 token_id 48 ) 49 .fetch_optional(pool) 50 .await?; 51 match row { 52 Some(r) => Ok(Some(TokenData { 53 did: r.did, 54 token_id: r.token_id, 55 created_at: r.created_at, 56 updated_at: r.updated_at, 57 expires_at: r.expires_at, 58 client_id: r.client_id, 59 client_auth: from_json(r.client_auth)?, 60 device_id: r.device_id, 61 parameters: from_json(r.parameters)?, 62 details: r.details, 63 code: r.code, 64 current_refresh_token: r.current_refresh_token, 65 scope: r.scope, 66 })), 67 None => Ok(None), 68 } 69} 70 71pub async fn get_token_by_refresh_token( 72 pool: &PgPool, 73 refresh_token: &str, 74) -> Result<Option<(i32, TokenData)>, OAuthError> { 75 let row = sqlx::query!( 76 r#" 77 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 78 device_id, parameters, details, code, current_refresh_token, scope 79 FROM oauth_token 80 WHERE current_refresh_token = $1 81 "#, 82 refresh_token 83 ) 84 .fetch_optional(pool) 85 .await?; 86 match row { 87 Some(r) => Ok(Some(( 88 r.id, 89 TokenData { 90 did: r.did, 91 token_id: r.token_id, 92 created_at: r.created_at, 93 updated_at: r.updated_at, 94 expires_at: r.expires_at, 95 client_id: r.client_id, 96 client_auth: from_json(r.client_auth)?, 97 device_id: r.device_id, 98 parameters: from_json(r.parameters)?, 99 details: r.details, 100 code: r.code, 101 current_refresh_token: r.current_refresh_token, 102 scope: r.scope, 103 }, 104 ))), 105 None => Ok(None), 106 } 107} 108 109pub async fn rotate_token( 110 pool: &PgPool, 111 old_db_id: i32, 112 new_token_id: &str, 113 new_refresh_token: &str, 114 new_expires_at: DateTime<Utc>, 115) -> Result<(), OAuthError> { 116 let mut tx = pool.begin().await?; 117 let old_refresh = sqlx::query_scalar!( 118 r#" 119 SELECT current_refresh_token FROM oauth_token WHERE id = $1 120 "#, 121 old_db_id 122 ) 123 .fetch_one(&mut *tx) 124 .await?; 125 if let Some(ref old_rt) = old_refresh { 126 sqlx::query!( 127 r#" 128 INSERT INTO oauth_used_refresh_token (refresh_token, token_id) 129 VALUES ($1, $2) 130 "#, 131 old_rt, 132 old_db_id 133 ) 134 .execute(&mut *tx) 135 .await?; 136 } 137 sqlx::query!( 138 r#" 139 UPDATE oauth_token 140 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW(), 141 previous_refresh_token = $5, rotated_at = NOW() 142 WHERE id = $1 143 "#, 144 old_db_id, 145 new_token_id, 146 new_refresh_token, 147 new_expires_at, 148 old_refresh 149 ) 150 .execute(&mut *tx) 151 .await?; 152 tx.commit().await?; 153 Ok(()) 154} 155 156pub async fn check_refresh_token_used( 157 pool: &PgPool, 158 refresh_token: &str, 159) -> Result<Option<i32>, OAuthError> { 160 let row = sqlx::query_scalar!( 161 r#" 162 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 163 "#, 164 refresh_token 165 ) 166 .fetch_optional(pool) 167 .await?; 168 Ok(row) 169} 170 171const REFRESH_GRACE_PERIOD_SECS: i64 = 60; 172 173pub async fn get_token_by_previous_refresh_token( 174 pool: &PgPool, 175 refresh_token: &str, 176) -> Result<Option<(i32, TokenData)>, OAuthError> { 177 let grace_cutoff = Utc::now() - chrono::Duration::seconds(REFRESH_GRACE_PERIOD_SECS); 178 let row = sqlx::query!( 179 r#" 180 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 181 device_id, parameters, details, code, current_refresh_token, scope 182 FROM oauth_token 183 WHERE previous_refresh_token = $1 AND rotated_at > $2 184 "#, 185 refresh_token, 186 grace_cutoff 187 ) 188 .fetch_optional(pool) 189 .await?; 190 match row { 191 Some(r) => Ok(Some(( 192 r.id, 193 TokenData { 194 did: r.did, 195 token_id: r.token_id, 196 created_at: r.created_at, 197 updated_at: r.updated_at, 198 expires_at: r.expires_at, 199 client_id: r.client_id, 200 client_auth: from_json(r.client_auth)?, 201 device_id: r.device_id, 202 parameters: from_json(r.parameters)?, 203 details: r.details, 204 code: r.code, 205 current_refresh_token: r.current_refresh_token, 206 scope: r.scope, 207 }, 208 ))), 209 None => Ok(None), 210 } 211} 212 213pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 214 sqlx::query!( 215 r#" 216 DELETE FROM oauth_token WHERE token_id = $1 217 "#, 218 token_id 219 ) 220 .execute(pool) 221 .await?; 222 Ok(()) 223} 224 225pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 226 sqlx::query!( 227 r#" 228 DELETE FROM oauth_token WHERE id = $1 229 "#, 230 db_id 231 ) 232 .execute(pool) 233 .await?; 234 Ok(()) 235} 236 237pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> { 238 let rows = sqlx::query!( 239 r#" 240 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 241 device_id, parameters, details, code, current_refresh_token, scope 242 FROM oauth_token 243 WHERE did = $1 244 "#, 245 did 246 ) 247 .fetch_all(pool) 248 .await?; 249 let mut tokens = Vec::with_capacity(rows.len()); 250 for r in rows { 251 tokens.push(TokenData { 252 did: r.did, 253 token_id: r.token_id, 254 created_at: r.created_at, 255 updated_at: r.updated_at, 256 expires_at: r.expires_at, 257 client_id: r.client_id, 258 client_auth: from_json(r.client_auth)?, 259 device_id: r.device_id, 260 parameters: from_json(r.parameters)?, 261 details: r.details, 262 code: r.code, 263 current_refresh_token: r.current_refresh_token, 264 scope: r.scope, 265 }); 266 } 267 Ok(tokens) 268} 269 270pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 271 let count = sqlx::query_scalar!( 272 r#" 273 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1 274 "#, 275 did 276 ) 277 .fetch_one(pool) 278 .await?; 279 Ok(count) 280} 281 282pub async fn delete_oldest_tokens_for_user( 283 pool: &PgPool, 284 did: &str, 285 keep_count: i64, 286) -> Result<u64, OAuthError> { 287 let result = sqlx::query!( 288 r#" 289 DELETE FROM oauth_token 290 WHERE id IN ( 291 SELECT id FROM oauth_token 292 WHERE did = $1 293 ORDER BY updated_at ASC 294 OFFSET $2 295 ) 296 "#, 297 did, 298 keep_count 299 ) 300 .execute(pool) 301 .await?; 302 Ok(result.rows_affected()) 303} 304 305const MAX_TOKENS_PER_USER: i64 = 100; 306 307pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 308 let count = count_tokens_for_user(pool, did).await?; 309 if count > MAX_TOKENS_PER_USER { 310 let to_keep = MAX_TOKENS_PER_USER - 1; 311 delete_oldest_tokens_for_user(pool, did, to_keep).await?; 312 } 313 Ok(()) 314} 315 316pub async fn revoke_tokens_for_client( 317 pool: &PgPool, 318 did: &str, 319 client_id: &str, 320) -> Result<u64, OAuthError> { 321 let result = sqlx::query!( 322 "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2", 323 did, 324 client_id 325 ) 326 .execute(pool) 327 .await?; 328 Ok(result.rows_affected()) 329}