this repo has no description
1use chrono::{DateTime, Utc}; 2use sqlx::PgPool; 3use super::super::{OAuthError, TokenData}; 4use super::helpers::{from_json, to_json}; 5 6pub async fn create_token( 7 pool: &PgPool, 8 data: &TokenData, 9) -> Result<i32, OAuthError> { 10 let client_auth_json = to_json(&data.client_auth)?; 11 let parameters_json = to_json(&data.parameters)?; 12 let row = sqlx::query!( 13 r#" 14 INSERT INTO oauth_token 15 (did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 16 device_id, parameters, details, code, current_refresh_token, scope) 17 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) 18 RETURNING id 19 "#, 20 data.did, 21 data.token_id, 22 data.created_at, 23 data.updated_at, 24 data.expires_at, 25 data.client_id, 26 client_auth_json, 27 data.device_id, 28 parameters_json, 29 data.details, 30 data.code, 31 data.current_refresh_token, 32 data.scope, 33 ) 34 .fetch_one(pool) 35 .await?; 36 Ok(row.id) 37} 38 39pub async fn get_token_by_id( 40 pool: &PgPool, 41 token_id: &str, 42) -> Result<Option<TokenData>, OAuthError> { 43 let row = sqlx::query!( 44 r#" 45 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 46 device_id, parameters, details, code, current_refresh_token, scope 47 FROM oauth_token 48 WHERE token_id = $1 49 "#, 50 token_id 51 ) 52 .fetch_optional(pool) 53 .await?; 54 match row { 55 Some(r) => Ok(Some(TokenData { 56 did: r.did, 57 token_id: r.token_id, 58 created_at: r.created_at, 59 updated_at: r.updated_at, 60 expires_at: r.expires_at, 61 client_id: r.client_id, 62 client_auth: from_json(r.client_auth)?, 63 device_id: r.device_id, 64 parameters: from_json(r.parameters)?, 65 details: r.details, 66 code: r.code, 67 current_refresh_token: r.current_refresh_token, 68 scope: r.scope, 69 })), 70 None => Ok(None), 71 } 72} 73 74pub async fn get_token_by_refresh_token( 75 pool: &PgPool, 76 refresh_token: &str, 77) -> Result<Option<(i32, TokenData)>, OAuthError> { 78 let row = sqlx::query!( 79 r#" 80 SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 81 device_id, parameters, details, code, current_refresh_token, scope 82 FROM oauth_token 83 WHERE current_refresh_token = $1 84 "#, 85 refresh_token 86 ) 87 .fetch_optional(pool) 88 .await?; 89 match row { 90 Some(r) => Ok(Some(( 91 r.id, 92 TokenData { 93 did: r.did, 94 token_id: r.token_id, 95 created_at: r.created_at, 96 updated_at: r.updated_at, 97 expires_at: r.expires_at, 98 client_id: r.client_id, 99 client_auth: from_json(r.client_auth)?, 100 device_id: r.device_id, 101 parameters: from_json(r.parameters)?, 102 details: r.details, 103 code: r.code, 104 current_refresh_token: r.current_refresh_token, 105 scope: r.scope, 106 }, 107 ))), 108 None => Ok(None), 109 } 110} 111 112pub async fn rotate_token( 113 pool: &PgPool, 114 old_db_id: i32, 115 new_token_id: &str, 116 new_refresh_token: &str, 117 new_expires_at: DateTime<Utc>, 118) -> Result<(), OAuthError> { 119 let mut tx = pool.begin().await?; 120 let old_refresh = sqlx::query_scalar!( 121 r#" 122 SELECT current_refresh_token FROM oauth_token WHERE id = $1 123 "#, 124 old_db_id 125 ) 126 .fetch_one(&mut *tx) 127 .await?; 128 if let Some(old_rt) = old_refresh { 129 sqlx::query!( 130 r#" 131 INSERT INTO oauth_used_refresh_token (refresh_token, token_id) 132 VALUES ($1, $2) 133 "#, 134 old_rt, 135 old_db_id 136 ) 137 .execute(&mut *tx) 138 .await?; 139 } 140 sqlx::query!( 141 r#" 142 UPDATE oauth_token 143 SET token_id = $2, current_refresh_token = $3, expires_at = $4, updated_at = NOW() 144 WHERE id = $1 145 "#, 146 old_db_id, 147 new_token_id, 148 new_refresh_token, 149 new_expires_at 150 ) 151 .execute(&mut *tx) 152 .await?; 153 tx.commit().await?; 154 Ok(()) 155} 156 157pub async fn check_refresh_token_used( 158 pool: &PgPool, 159 refresh_token: &str, 160) -> Result<Option<i32>, OAuthError> { 161 let row = sqlx::query_scalar!( 162 r#" 163 SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 164 "#, 165 refresh_token 166 ) 167 .fetch_optional(pool) 168 .await?; 169 Ok(row) 170} 171 172pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 173 sqlx::query!( 174 r#" 175 DELETE FROM oauth_token WHERE token_id = $1 176 "#, 177 token_id 178 ) 179 .execute(pool) 180 .await?; 181 Ok(()) 182} 183 184pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 185 sqlx::query!( 186 r#" 187 DELETE FROM oauth_token WHERE id = $1 188 "#, 189 db_id 190 ) 191 .execute(pool) 192 .await?; 193 Ok(()) 194} 195 196pub async fn list_tokens_for_user( 197 pool: &PgPool, 198 did: &str, 199) -> Result<Vec<TokenData>, OAuthError> { 200 let rows = sqlx::query!( 201 r#" 202 SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 203 device_id, parameters, details, code, current_refresh_token, scope 204 FROM oauth_token 205 WHERE did = $1 206 "#, 207 did 208 ) 209 .fetch_all(pool) 210 .await?; 211 let mut tokens = Vec::with_capacity(rows.len()); 212 for r in rows { 213 tokens.push(TokenData { 214 did: r.did, 215 token_id: r.token_id, 216 created_at: r.created_at, 217 updated_at: r.updated_at, 218 expires_at: r.expires_at, 219 client_id: r.client_id, 220 client_auth: from_json(r.client_auth)?, 221 device_id: r.device_id, 222 parameters: from_json(r.parameters)?, 223 details: r.details, 224 code: r.code, 225 current_refresh_token: r.current_refresh_token, 226 scope: r.scope, 227 }); 228 } 229 Ok(tokens) 230} 231 232pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 233 let count = sqlx::query_scalar!( 234 r#" 235 SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1 236 "#, 237 did 238 ) 239 .fetch_one(pool) 240 .await?; 241 Ok(count) 242} 243 244pub async fn delete_oldest_tokens_for_user( 245 pool: &PgPool, 246 did: &str, 247 keep_count: i64, 248) -> Result<u64, OAuthError> { 249 let result = sqlx::query!( 250 r#" 251 DELETE FROM oauth_token 252 WHERE id IN ( 253 SELECT id FROM oauth_token 254 WHERE did = $1 255 ORDER BY updated_at ASC 256 OFFSET $2 257 ) 258 "#, 259 did, 260 keep_count 261 ) 262 .execute(pool) 263 .await?; 264 Ok(result.rows_affected()) 265} 266 267const MAX_TOKENS_PER_USER: i64 = 100; 268 269pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 270 let count = count_tokens_for_user(pool, did).await?; 271 if count > MAX_TOKENS_PER_USER { 272 let to_keep = MAX_TOKENS_PER_USER - 1; 273 delete_oldest_tokens_for_user(pool, did, to_keep).await?; 274 } 275 Ok(()) 276}