A fork of https://github.com/teal-fm/piper
at main 507 lines 13 kB view raw
1package db 2 3import ( 4 "database/sql" 5 "encoding/json" 6 "fmt" 7 "log" 8 "os" 9 "path/filepath" 10 "time" 11 12 _ "github.com/mattn/go-sqlite3" 13 "github.com/teal-fm/piper/models" 14) 15 16type DB struct { 17 *sql.DB 18 logger *log.Logger 19} 20 21func New(dbPath string) (*DB, error) { 22 dir := filepath.Dir(dbPath) 23 if dir != "." && dir != "/" { 24 os.MkdirAll(dir, 755) 25 } 26 27 db, err := sql.Open("sqlite3", dbPath) 28 if err != nil { 29 return nil, err 30 } 31 32 // Test the connection 33 if err = db.Ping(); err != nil { 34 return nil, err 35 } 36 logger := log.New(os.Stdout, "db: ", log.LstdFlags|log.Lmsgprefix) 37 38 return &DB{db, logger}, nil 39} 40 41func (db *DB) Initialize() error { 42 _, err := db.Exec(` 43 CREATE TABLE IF NOT EXISTS users ( 44 id INTEGER PRIMARY KEY AUTOINCREMENT, 45 username TEXT, -- Made nullable, might not have username initially 46 email TEXT UNIQUE, -- Made nullable 47 atproto_did TEXT UNIQUE, -- Atproto DID (identifier) 48 most_recent_at_session_id TEXT, -- Most recent oAuth session id 49 spotify_id TEXT UNIQUE, -- Spotify specific ID 50 access_token TEXT, -- Spotify access token 51 refresh_token TEXT, -- Spotify refresh token 52 token_expiry TIMESTAMP, -- Spotify token expiry 53 lastfm_username TEXT, -- Last.fm username 54 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Use default 55 updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -- Use default 56 )`) 57 if err != nil { 58 return err 59 } 60 61 _, err = db.Exec(` 62 CREATE TABLE IF NOT EXISTS tracks ( 63 id INTEGER PRIMARY KEY AUTOINCREMENT, 64 user_id INTEGER NOT NULL, 65 name TEXT NOT NULL, 66 recording_mbid TEXT, -- Added 67 artist TEXT NOT NULL, -- should be JSONB in PostgreSQL if we ever switch 68 album TEXT NOT NULL, 69 release_mbid TEXT, -- Added 70 url TEXT NOT NULL, 71 timestamp TIMESTAMP, 72 duration_ms INTEGER, 73 progress_ms INTEGER, 74 service_base_url TEXT, 75 isrc TEXT, 76 has_stamped BOOLEAN, 77 FOREIGN KEY (user_id) REFERENCES users(id) 78 )`) 79 if err != nil { 80 return err 81 } 82 83 _, err = db.Exec(` 84 CREATE TABLE IF NOT EXISTS atproto_state ( 85 id INTEGER PRIMARY KEY AUTOINCREMENT, 86 state TEXT NOT NULL, 87 authserver_url TEXT, 88 account_did TEXT, 89 scopes TEXT, 90 request_uri TEXT, 91 authserver_token_endpoint TEXT, 92 authserver_revocation_endpoint TEXT, 93 pkce_verifier TEXT, 94 dpop_authserver_nonce TEXT, 95 dpop_privatekey_multibase TEXT, 96 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP 97 ); 98 CREATE INDEX IF NOT EXISTS atproto_state_state ON atproto_state(state); 99 100`) 101 if err != nil { 102 return err 103 } 104 105 _, err = db.Exec(` 106 CREATE TABLE IF NOT EXISTS atproto_sessions ( 107 id INTEGER PRIMARY KEY AUTOINCREMENT, 108 look_up_key TEXT NOT NULL, 109 account_did TEXT, 110 session_id TEXT, 111 host_url TEXT, 112 authserver_url TEXT, 113 authserver_token_endpoint TEXT, 114 authserver_revocation_endpoint TEXT, 115 scopes TEXT, 116 access_token TEXT, 117 refresh_token TEXT, 118 dpop_authserver_nonce TEXT, 119 dpop_host_nonce TEXT, 120 dpop_privatekey_multibase TEXT, 121 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP 122 ); 123 CREATE INDEX IF NOT EXISTS idx_atproto_sessions_look_up_key ON atproto_sessions(look_up_key); 124`) 125 if err != nil { 126 return err 127 } 128 129 // Add columns recording_mbid and release_mbid to tracks table if they don't exist 130 _, err = db.Exec(`ALTER TABLE tracks ADD COLUMN recording_mbid TEXT`) 131 if err != nil && err.Error() != "duplicate column name: recording_mbid" { 132 // Handle errors other than 'duplicate column' 133 return err 134 } 135 _, err = db.Exec(`ALTER TABLE tracks ADD COLUMN release_mbid TEXT`) 136 if err != nil && err.Error() != "duplicate column name: release_mbid" { 137 // Handle errors other than 'duplicate column' 138 return err 139 } 140 141 return nil 142} 143 144// create user without spotify id 145func (db *DB) CreateUser(user *models.User) (int64, error) { 146 now := time.Now().UTC() 147 148 result, err := db.Exec(` 149 INSERT INTO users (username, email, created_at, updated_at) 150 VALUES (?, ?, ?, ?)`, 151 user.Username, user.Email, now, now) 152 153 if err != nil { 154 return 0, err 155 } 156 157 return result.LastInsertId() 158} 159 160// add spotify session to user, returning the updated user 161func (db *DB) AddSpotifySession(userID int64, username, email, spotifyId, accessToken, refreshToken string, tokenExpiry time.Time) (*models.User, error) { 162 now := time.Now().UTC() 163 164 _, err := db.Exec(` 165 UPDATE users SET username = ?, email = ?, spotify_id = ?, access_token = ?, refresh_token = ?, token_expiry = ?, created_at = ?, updated_at = ? 166 WHERE id == ? 167 `, 168 username, email, spotifyId, accessToken, refreshToken, tokenExpiry, now, now, userID) 169 if err != nil { 170 return nil, err 171 } 172 173 user, err := db.GetUserByID(userID) 174 if err != nil { 175 return nil, err 176 } 177 178 return user, err 179} 180 181func (db *DB) GetUserByID(ID int64) (*models.User, error) { 182 user := &models.User{} 183 184 err := db.QueryRow(` 185 SELECT id, 186 username, 187 email, 188 atproto_did, 189 most_recent_at_session_id, 190 spotify_id, 191 access_token, 192 refresh_token, 193 token_expiry, 194 lastfm_username, 195 created_at, 196 updated_at 197 FROM users WHERE id = ?`, ID).Scan( 198 &user.ID, &user.Username, &user.Email, &user.ATProtoDID, &user.MostRecentAtProtoSessionID, &user.SpotifyID, 199 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry, 200 &user.LastFMUsername, 201 &user.CreatedAt, &user.UpdatedAt) 202 203 if err == sql.ErrNoRows { 204 return nil, nil 205 } 206 207 if err != nil { 208 return nil, err 209 } 210 211 return user, nil 212} 213 214func (db *DB) GetUserBySpotifyID(spotifyID string) (*models.User, error) { 215 user := &models.User{} 216 217 err := db.QueryRow(` 218 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, lastfm_username, created_at, updated_at 219 FROM users WHERE spotify_id = ?`, spotifyID).Scan( 220 &user.ID, &user.Username, &user.Email, &user.SpotifyID, 221 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry, 222 &user.LastFMUsername, 223 &user.CreatedAt, &user.UpdatedAt) 224 225 if err == sql.ErrNoRows { 226 return nil, nil 227 } 228 229 if err != nil { 230 return nil, err 231 } 232 233 return user, nil 234} 235 236func (db *DB) UpdateUserToken(userID int64, accessToken, refreshToken string, expiry time.Time) error { 237 now := time.Now().UTC() 238 239 _, err := db.Exec(` 240 UPDATE users 241 SET access_token = ?, refresh_token = ?, token_expiry = ?, updated_at = ? 242 WHERE id = ?`, 243 accessToken, refreshToken, expiry, now, userID) 244 245 return err 246} 247 248func (db *DB) SaveTrack(userID int64, track *models.Track) (int64, error) { 249 // marshal artist json 250 artistString := "" 251 if len(track.Artist) > 0 { 252 bytes, err := json.Marshal(track.Artist) 253 if err != nil { 254 return 0, err 255 } else { 256 artistString = string(bytes) 257 } 258 } 259 260 var trackID int64 261 262 err := db.QueryRow(` 263 INSERT INTO tracks (user_id, name, recording_mbid, artist, album, release_mbid, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped) 264 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 265 RETURNING id`, 266 userID, track.Name, track.RecordingMBID, artistString, track.Album, track.ReleaseMBID, track.URL, track.Timestamp, 267 track.DurationMs, track.ProgressMs, track.ServiceBaseUrl, track.ISRC, track.HasStamped).Scan(&trackID) 268 269 return trackID, err 270} 271 272func (db *DB) UpdateTrack(trackID int64, track *models.Track) error { 273 // marshal artist json 274 artistString := "" 275 if len(track.Artist) > 0 { 276 bytes, err := json.Marshal(track.Artist) 277 if err != nil { 278 return err 279 } else { 280 artistString = string(bytes) 281 } 282 } 283 284 _, err := db.Exec(` 285 UPDATE tracks 286 SET name = ?, 287 recording_mbid = ?, 288 artist = ?, 289 album = ?, 290 release_mbid = ?, 291 url = ?, 292 timestamp = ?, 293 duration_ms = ?, 294 progress_ms = ?, 295 service_base_url = ?, 296 isrc = ?, 297 has_stamped = ? 298 WHERE id = ?`, 299 track.Name, track.RecordingMBID, artistString, track.Album, track.ReleaseMBID, track.URL, track.Timestamp, 300 track.DurationMs, track.ProgressMs, track.ServiceBaseUrl, track.ISRC, track.HasStamped, 301 trackID) 302 303 return err 304} 305 306func (db *DB) GetRecentTracks(userID int64, limit int) ([]*models.Track, error) { 307 rows, err := db.Query(` 308 SELECT id, name, recording_mbid, artist, album, release_mbid, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped 309 FROM tracks 310 WHERE user_id = ? 311 ORDER BY timestamp DESC 312 LIMIT ?`, userID, limit) 313 314 if err != nil { 315 return nil, err 316 } 317 defer rows.Close() 318 319 var tracks []*models.Track 320 321 for rows.Next() { 322 var artistString string 323 track := &models.Track{} 324 err := rows.Scan( 325 &track.PlayID, 326 &track.Name, 327 &track.RecordingMBID, // Scan new field 328 &artistString, // scan to be unmarshaled later 329 &track.Album, 330 &track.ReleaseMBID, // Scan new field 331 &track.URL, 332 &track.Timestamp, 333 &track.DurationMs, 334 &track.ProgressMs, 335 &track.ServiceBaseUrl, 336 &track.ISRC, 337 &track.HasStamped, 338 ) 339 340 if err != nil { 341 return nil, err 342 } 343 344 // unmarshal artist json 345 var artists []models.Artist 346 err = json.Unmarshal([]byte(artistString), &artists) 347 if err != nil { 348 // fallback to previous format 349 artists = []models.Artist{{Name: artistString}} 350 } 351 track.Artist = artists 352 tracks = append(tracks, track) 353 } 354 355 return tracks, nil 356} 357 358// SpotifyQueryMapping maps Spotify sql query results to user structs 359func SpotifyQueryMapping(rows *sql.Rows) ([]*models.User, error) { 360 361 var users []*models.User 362 363 for rows.Next() { 364 user := &models.User{} 365 err := rows.Scan( 366 &user.ID, &user.Username, &user.Email, &user.SpotifyID, 367 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry, 368 &user.CreatedAt, &user.UpdatedAt) 369 if err != nil { 370 return nil, err 371 } 372 users = append(users, user) 373 } 374 375 return users, nil 376} 377 378func (db *DB) GetUsersWithExpiredTokens() ([]*models.User, error) { 379 rows, err := db.Query(` 380 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at 381 FROM users 382 WHERE refresh_token IS NOT NULL AND token_expiry < ? 383 ORDER BY id`, time.Now().UTC()) 384 385 if err != nil { 386 return nil, err 387 } 388 defer rows.Close() 389 390 return SpotifyQueryMapping(rows) 391 392} 393 394func (db *DB) GetAllActiveUsers() ([]*models.User, error) { 395 rows, err := db.Query(` 396 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at 397 FROM users 398 WHERE access_token IS NOT NULL 399 ORDER BY id`) 400 401 if err != nil { 402 return nil, err 403 } 404 defer rows.Close() 405 406 return SpotifyQueryMapping(rows) 407} 408 409func (db *DB) GetAllActiveUsersWithUnExpiredTokens() ([]*models.User, error) { 410 rows, err := db.Query(` 411 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at 412 FROM users 413 WHERE access_token IS NOT NULL AND token_expiry > ? 414 ORDER BY id`, time.Now().UTC()) 415 416 if err != nil { 417 return nil, err 418 } 419 defer rows.Close() 420 421 return SpotifyQueryMapping(rows) 422} 423 424// debug to view current user's information 425// put everything in an 'any' type 426func (db *DB) DebugViewUserInformation(userID int64) (map[string]any, error) { 427 // Use Query instead of QueryRow to get access to column names and ensure only one row is processed 428 rows, err := db.Query(` 429 SELECT * 430 FROM users 431 WHERE id = ? LIMIT 1`, userID) 432 if err != nil { 433 return nil, fmt.Errorf("query failed: %w", err) 434 } 435 defer rows.Close() 436 437 // Get column names 438 cols, err := rows.Columns() 439 if err != nil { 440 return nil, fmt.Errorf("failed to get columns: %w", err) 441 } 442 443 // Check if there's a row to process 444 if !rows.Next() { 445 if err := rows.Err(); err != nil { 446 // Error during rows.Next() or preparing the result set 447 return nil, fmt.Errorf("error checking for row: %w", err) 448 } 449 // No rows found, which is a valid outcome but might be considered an error in some contexts. 450 // Returning sql.ErrNoRows is conventional. 451 return nil, sql.ErrNoRows 452 } 453 454 // Prepare scan arguments: pointers to interface{} slices 455 values := make([]any, len(cols)) 456 scanArgs := make([]any, len(cols)) 457 for i := range values { 458 scanArgs[i] = &values[i] 459 } 460 461 // Scan the row values 462 err = rows.Scan(scanArgs...) 463 if err != nil { 464 return nil, fmt.Errorf("failed to scan row: %w", err) 465 } 466 467 // Check for errors that might have occurred during iteration (after Scan) 468 if err := rows.Err(); err != nil { 469 return nil, fmt.Errorf("error after scanning row: %w", err) 470 } 471 472 // Create the result map 473 resultMap := make(map[string]any, len(cols)) 474 for i, colName := range cols { 475 val := values[i] 476 // SQLite often returns []byte for TEXT columns, convert to string for usability. 477 // Also handle potential nil values appropriately. 478 if b, ok := val.([]byte); ok { 479 resultMap[colName] = string(b) 480 } else { 481 resultMap[colName] = val // Keep nil as nil, numbers as numbers, etc. 482 } 483 } 484 485 return resultMap, nil 486} 487 488func (db *DB) GetLastKnownTimestamp(userID int64) (*time.Time, error) { 489 var lastTimestamp time.Time 490 err := db.QueryRow(` 491 SELECT timestamp 492 FROM tracks 493 WHERE user_id = ? 494 ORDER BY timestamp DESC 495 LIMIT 1`, userID).Scan(&lastTimestamp) 496 497 if err != nil { 498 if err == sql.ErrNoRows { 499 return nil, nil 500 } 501 return nil, fmt.Errorf("failed to query last scrobble timestamp for user %d: %w", userID, err) 502 } 503 504 return &lastTimestamp, nil 505} 506 507//