this repo has no description
1use chrono::{DateTime, Utc}; 2use sqlx::PgPool; 3 4use super::super::{OAuthError, TokenData}; 5use super::helpers::{from_json, to_json}; 6 7pub async fn create_token( 8 pool: &PgPool, 9 data: &TokenData, 10) -> Result<i32, OAuthError> { 11 let client_auth_json = to_json(&data.client_auth)?; 12 let parameters_json = to_json(&data.parameters)?; 13 14 let row = sqlx::query!( 15 r#" 16 INSERT INTO oauth_token 17 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 18 device_id, parameters, details, code, current_refresh_token, scope) 19 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) 20 RETURNING id 21 "#, 22 data.did, 23 data.token_id, 24 data.created_at, 25 data.updated_at, 26 data.expires_at, 27 data.client_id, 28 client_auth_json, 29 data.device_id, 30 parameters_json, 31 data.details, 32 data.code, 33 data.current_refresh_token, 34 data.scope, 35 ) 36 .fetch_one(pool) 37 .await?; 38 39 Ok(row.id) 40} 41 42pub async fn get_token_by_id( 43 pool: &PgPool, 44 token_id: &str, 45) -> Result<Option<TokenData>, OAuthError> { 46 let row = sqlx::query!( 47 r#" 48 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 49 device_id, parameters, details, code, current_refresh_token, scope 50 FROM oauth_token 51 WHERE token_id = $1 52 "#, 53 token_id 54 ) 55 .fetch_optional(pool) 56 .await?; 57 58 match row { 59 Some(r) => Ok(Some(TokenData { 60 did: r.did, 61 token_id: r.token_id, 62 created_at: r.created_at, 63 updated_at: r.updated_at, 64 expires_at: r.expires_at, 65 client_id: r.client_id, 66 client_auth: from_json(r.client_auth)?, 67 device_id: r.device_id, 68 parameters: from_json(r.parameters)?, 69 details: r.details, 70 code: r.code, 71 current_refresh_token: r.current_refresh_token, 72 scope: r.scope, 73 })), 74 None => Ok(None), 75 } 76} 77 78pub async fn get_token_by_refresh_token( 79 pool: &PgPool, 80 refresh_token: &str, 81) -> Result<Option<(i32, TokenData)>, OAuthError> { 82 let row = sqlx::query!( 83 r#" 84 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 85 device_id, parameters, details, code, current_refresh_token, scope 86 FROM oauth_token 87 WHERE current_refresh_token = $1 88 "#, 89 refresh_token 90 ) 91 .fetch_optional(pool) 92 .await?; 93 94 match row { 95 Some(r) => Ok(Some(( 96 r.id, 97 TokenData { 98 did: r.did, 99 token_id: r.token_id, 100 created_at: r.created_at, 101 updated_at: r.updated_at, 102 expires_at: r.expires_at, 103 client_id: r.client_id, 104 client_auth: from_json(r.client_auth)?, 105 device_id: r.device_id, 106 parameters: from_json(r.parameters)?, 107 details: r.details, 108 code: r.code, 109 current_refresh_token: r.current_refresh_token, 110 scope: r.scope, 111 }, 112 ))), 113 None => Ok(None), 114 } 115} 116 117pub async fn rotate_token( 118 pool: &PgPool, 119 old_db_id: i32, 120 new_token_id: &str, 121 new_refresh_token: &str, 122 new_expires_at: DateTime<Utc>, 123) -> Result<(), OAuthError> { 124 let mut tx = pool.begin().await?; 125 126 let old_refresh = sqlx::query_scalar!( 127 r#" 128 SELECT current_refresh_token FROM oauth_token WHERE id = $1 129 "#, 130 old_db_id 131 ) 132 .fetch_one(&mut *tx) 133 .await?; 134 135 if let Some(old_rt) = old_refresh { 136 sqlx::query!( 137 r#" 138 INSERT INTO oauth_used_refresh_token (refresh_token, token_id) 139 VALUES ($1, $2) 140 "#, 141 old_rt, 142 old_db_id 143 ) 144 .execute(&mut *tx) 145 .await?; 146 } 147 148 sqlx::query!( 149 r#" 150 UPDATE oauth_token 151 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW() 152 WHERE id = $1 153 "#, 154 old_db_id, 155 new_token_id, 156 new_refresh_token, 157 new_expires_at 158 ) 159 .execute(&mut *tx) 160 .await?; 161 162 tx.commit().await?; 163 Ok(()) 164} 165 166pub async fn check_refresh_token_used( 167 pool: &PgPool, 168 refresh_token: &str, 169) -> Result<Option<i32>, OAuthError> { 170 let row = sqlx::query_scalar!( 171 r#" 172 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 173 "#, 174 refresh_token 175 ) 176 .fetch_optional(pool) 177 .await?; 178 179 Ok(row) 180} 181 182pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 183 sqlx::query!( 184 r#" 185 DELETE FROM oauth_token WHERE token_id = $1 186 "#, 187 token_id 188 ) 189 .execute(pool) 190 .await?; 191 192 Ok(()) 193} 194 195pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 196 sqlx::query!( 197 r#" 198 DELETE FROM oauth_token WHERE id = $1 199 "#, 200 db_id 201 ) 202 .execute(pool) 203 .await?; 204 205 Ok(()) 206} 207 208pub async fn list_tokens_for_user( 209 pool: &PgPool, 210 did: &str, 211) -> Result<Vec<TokenData>, OAuthError> { 212 let rows = sqlx::query!( 213 r#" 214 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 215 device_id, parameters, details, code, current_refresh_token, scope 216 FROM oauth_token 217 WHERE did = $1 218 "#, 219 did 220 ) 221 .fetch_all(pool) 222 .await?; 223 224 let mut tokens = Vec::with_capacity(rows.len()); 225 for r in rows { 226 tokens.push(TokenData { 227 did: r.did, 228 token_id: r.token_id, 229 created_at: r.created_at, 230 updated_at: r.updated_at, 231 expires_at: r.expires_at, 232 client_id: r.client_id, 233 client_auth: from_json(r.client_auth)?, 234 device_id: r.device_id, 235 parameters: from_json(r.parameters)?, 236 details: r.details, 237 code: r.code, 238 current_refresh_token: r.current_refresh_token, 239 scope: r.scope, 240 }); 241 } 242 Ok(tokens) 243} 244 245pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 246 let count = sqlx::query_scalar!( 247 r#" 248 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1 249 "#, 250 did 251 ) 252 .fetch_one(pool) 253 .await?; 254 255 Ok(count) 256} 257 258pub async fn delete_oldest_tokens_for_user( 259 pool: &PgPool, 260 did: &str, 261 keep_count: i64, 262) -> Result<u64, OAuthError> { 263 let result = sqlx::query!( 264 r#" 265 DELETE FROM oauth_token 266 WHERE id IN ( 267 SELECT id FROM oauth_token 268 WHERE did = $1 269 ORDER BY updated_at ASC 270 OFFSET $2 271 ) 272 "#, 273 did, 274 keep_count 275 ) 276 .execute(pool) 277 .await?; 278 279 Ok(result.rows_affected()) 280} 281 282const MAX_TOKENS_PER_USER: i64 = 100; 283 284pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 285 let count = count_tokens_for_user(pool, did).await?; 286 if count > MAX_TOKENS_PER_USER { 287 let to_keep = MAX_TOKENS_PER_USER - 1; 288 delete_oldest_tokens_for_user(pool, did, to_keep).await?; 289 } 290 Ok(()) 291}