A fork of https://github.com/teal-fm/piper
at fly 607 lines 17 kB view raw
1package db 2 3import ( 4 "database/sql" 5 "encoding/json" 6 "fmt" 7 "log" 8 "os" 9 "path/filepath" 10 "time" 11 12 _ "github.com/mattn/go-sqlite3" 13 "github.com/teal-fm/piper/models" 14) 15 16type DB struct { 17 *sql.DB 18 logger *log.Logger 19} 20 21func New(dbPath string) (*DB, error) { 22 dir := filepath.Dir(dbPath) 23 if dir != "." && dir != "/" { 24 os.MkdirAll(dir, 0755) 25 } 26 27 db, err := sql.Open("sqlite3", dbPath) 28 if err != nil { 29 return nil, err 30 } 31 32 // Test the connection 33 if err = db.Ping(); err != nil { 34 return nil, err 35 } 36 logger := log.New(os.Stdout, "db: ", log.LstdFlags|log.Lmsgprefix) 37 38 return &DB{db, logger}, nil 39} 40 41func (db *DB) Initialize() error { 42 _, err := db.Exec(` 43 CREATE TABLE IF NOT EXISTS users ( 44 id INTEGER PRIMARY KEY AUTOINCREMENT, 45 username TEXT, -- Made nullable, might not have username initially 46 email TEXT UNIQUE, -- Made nullable 47 atproto_did TEXT UNIQUE, -- Atproto DID (identifier) 48 most_recent_at_session_id TEXT, -- Most recent oAuth session id 49 spotify_id TEXT UNIQUE, -- Spotify specific ID 50 access_token TEXT, -- Spotify access token 51 refresh_token TEXT, -- Spotify refresh token 52 token_expiry TIMESTAMP, -- Spotify token expiry 53 lastfm_username TEXT, -- Last.fm username 54 applemusic_user_token TEXT, -- Apple Music MusicKit user token 55 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Use default 56 updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -- Use default 57 )`) 58 if err != nil { 59 return err 60 } 61 62 // Add missing columns to users table if they don't exist 63 _, err = db.Exec(`ALTER TABLE users ADD COLUMN applemusic_user_token TEXT`) 64 if err != nil && err.Error() != "duplicate column name: applemusic_user_token" { 65 return err 66 } 67 68 _, err = db.Exec(` 69 CREATE TABLE IF NOT EXISTS tracks ( 70 id INTEGER PRIMARY KEY AUTOINCREMENT, 71 user_id INTEGER NOT NULL, 72 name TEXT NOT NULL, 73 recording_mbid TEXT, -- Added 74 artist TEXT NOT NULL, -- should be JSONB in PostgreSQL if we ever switch 75 album TEXT NOT NULL, 76 release_mbid TEXT, -- Added 77 url TEXT NOT NULL, 78 timestamp TIMESTAMP, 79 duration_ms INTEGER, 80 progress_ms INTEGER, 81 service_base_url TEXT, 82 isrc TEXT, 83 has_stamped BOOLEAN, 84 FOREIGN KEY (user_id) REFERENCES users(id) 85 )`) 86 if err != nil { 87 return err 88 } 89 90 _, err = db.Exec(` 91 CREATE TABLE IF NOT EXISTS atproto_state ( 92 id INTEGER PRIMARY KEY AUTOINCREMENT, 93 state TEXT NOT NULL, 94 authserver_url TEXT, 95 account_did TEXT, 96 scopes TEXT, 97 request_uri TEXT, 98 authserver_token_endpoint TEXT, 99 authserver_revocation_endpoint TEXT, 100 pkce_verifier TEXT, 101 dpop_authserver_nonce TEXT, 102 dpop_privatekey_multibase TEXT, 103 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP 104 ); 105 CREATE INDEX IF NOT EXISTS atproto_state_state ON atproto_state(state); 106 107`) 108 if err != nil { 109 return err 110 } 111 112 _, err = db.Exec(` 113 CREATE TABLE IF NOT EXISTS atproto_sessions ( 114 id INTEGER PRIMARY KEY AUTOINCREMENT, 115 look_up_key TEXT NOT NULL, 116 account_did TEXT, 117 session_id TEXT, 118 host_url TEXT, 119 authserver_url TEXT, 120 authserver_token_endpoint TEXT, 121 authserver_revocation_endpoint TEXT, 122 scopes TEXT, 123 access_token TEXT, 124 refresh_token TEXT, 125 dpop_authserver_nonce TEXT, 126 dpop_host_nonce TEXT, 127 dpop_privatekey_multibase TEXT, 128 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP 129 ); 130 CREATE INDEX IF NOT EXISTS idx_atproto_sessions_look_up_key ON atproto_sessions(look_up_key); 131`) 132 if err != nil { 133 return err 134 } 135 136 // Add columns recording_mbid and release_mbid to tracks table if they don't exist 137 _, err = db.Exec(`ALTER TABLE tracks ADD COLUMN recording_mbid TEXT`) 138 if err != nil && err.Error() != "duplicate column name: recording_mbid" { 139 // Handle errors other than 'duplicate column' 140 return err 141 } 142 _, err = db.Exec(`ALTER TABLE tracks ADD COLUMN release_mbid TEXT`) 143 if err != nil && err.Error() != "duplicate column name: release_mbid" { 144 // Handle errors other than 'duplicate column' 145 return err 146 } 147 148 return nil 149} 150 151// Apple Music developer token persistence 152func (db *DB) ensureAppleMusicTokenTable() error { 153 _, err := db.Exec(` 154 CREATE TABLE IF NOT EXISTS applemusic_token ( 155 token TEXT, 156 expires_at TIMESTAMP 157 )`) 158 return err 159} 160 161func (db *DB) GetAppleMusicDeveloperToken() (string, time.Time, bool, error) { 162 if err := db.ensureAppleMusicTokenTable(); err != nil { 163 return "", time.Time{}, false, err 164 } 165 var token string 166 var exp time.Time 167 err := db.QueryRow(`SELECT token, expires_at FROM applemusic_token LIMIT 1`).Scan(&token, &exp) 168 if err == sql.ErrNoRows { 169 return "", time.Time{}, false, nil 170 } 171 if err != nil { 172 return "", time.Time{}, false, err 173 } 174 return token, exp, true, nil 175} 176 177func (db *DB) SaveAppleMusicDeveloperToken(token string, exp time.Time) error { 178 if err := db.ensureAppleMusicTokenTable(); err != nil { 179 return err 180 } 181 // Replace existing single row 182 _, err := db.Exec(`DELETE FROM applemusic_token`) 183 if err != nil { 184 return err 185 } 186 _, err = db.Exec(`INSERT INTO applemusic_token (token, expires_at) VALUES (?, ?)`, token, exp) 187 return err 188} 189 190// create user without spotify id 191func (db *DB) CreateUser(user *models.User) (int64, error) { 192 now := time.Now().UTC() 193 194 result, err := db.Exec(` 195 INSERT INTO users (username, email, created_at, updated_at) 196 VALUES (?, ?, ?, ?)`, 197 user.Username, user.Email, now, now) 198 199 if err != nil { 200 return 0, err 201 } 202 203 return result.LastInsertId() 204} 205 206// add spotify session to user, returning the updated user 207func (db *DB) AddSpotifySession(userID int64, username, email, spotifyId, accessToken, refreshToken string, tokenExpiry time.Time) (*models.User, error) { 208 now := time.Now().UTC() 209 210 _, err := db.Exec(` 211 UPDATE users SET username = ?, email = ?, spotify_id = ?, access_token = ?, refresh_token = ?, token_expiry = ?, created_at = ?, updated_at = ? 212 WHERE id == ? 213 `, 214 username, email, spotifyId, accessToken, refreshToken, tokenExpiry, now, now, userID) 215 if err != nil { 216 return nil, err 217 } 218 219 user, err := db.GetUserByID(userID) 220 if err != nil { 221 return nil, err 222 } 223 224 return user, err 225} 226 227func (db *DB) GetUserByID(ID int64) (*models.User, error) { 228 user := &models.User{} 229 230 err := db.QueryRow(` 231 SELECT id, 232 username, 233 email, 234 atproto_did, 235 most_recent_at_session_id, 236 spotify_id, 237 access_token, 238 refresh_token, 239 token_expiry, 240 lastfm_username, 241 applemusic_user_token, 242 created_at, 243 updated_at 244 FROM users WHERE id = ?`, ID).Scan( 245 &user.ID, &user.Username, &user.Email, &user.ATProtoDID, &user.MostRecentAtProtoSessionID, &user.SpotifyID, 246 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry, 247 &user.LastFMUsername, &user.AppleMusicUserToken, 248 &user.CreatedAt, &user.UpdatedAt) 249 250 if err == sql.ErrNoRows { 251 return nil, nil 252 } 253 254 if err != nil { 255 return nil, err 256 } 257 258 return user, nil 259} 260 261func (db *DB) GetUserBySpotifyID(spotifyID string) (*models.User, error) { 262 user := &models.User{} 263 264 err := db.QueryRow(` 265 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, lastfm_username, applemusic_user_token, created_at, updated_at 266 FROM users WHERE spotify_id = ?`, spotifyID).Scan( 267 &user.ID, &user.Username, &user.Email, &user.SpotifyID, 268 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry, 269 &user.LastFMUsername, &user.AppleMusicUserToken, 270 &user.CreatedAt, &user.UpdatedAt) 271 272 if err == sql.ErrNoRows { 273 return nil, nil 274 } 275 276 if err != nil { 277 return nil, err 278 } 279 280 return user, nil 281} 282 283func (db *DB) UpdateUserToken(userID int64, accessToken, refreshToken string, expiry time.Time) error { 284 now := time.Now().UTC() 285 286 _, err := db.Exec(` 287 UPDATE users 288 SET access_token = ?, refresh_token = ?, token_expiry = ?, updated_at = ? 289 WHERE id = ?`, 290 accessToken, refreshToken, expiry, now, userID) 291 292 return err 293} 294 295func (db *DB) UpdateAppleMusicUserToken(userID int64, userToken string) error { 296 now := time.Now().UTC() 297 _, err := db.Exec(` 298 UPDATE users 299 SET applemusic_user_token = ?, updated_at = ? 300 WHERE id = ?`, 301 userToken, now, userID) 302 return err 303} 304 305// ClearAppleMusicUserToken removes the stored Apple Music user token for a user 306func (db *DB) ClearAppleMusicUserToken(userID int64) error { 307 now := time.Now().UTC() 308 _, err := db.Exec(` 309 UPDATE users 310 SET applemusic_user_token = NULL, updated_at = ? 311 WHERE id = ?`, 312 now, userID) 313 return err 314} 315 316// GetAllAppleMusicLinkedUsers returns users who have an Apple Music user token set 317func (db *DB) GetAllAppleMusicLinkedUsers() ([]*models.User, error) { 318 rows, err := db.Query(` 319 SELECT id, username, email, atproto_did, most_recent_at_session_id, 320 spotify_id, access_token, refresh_token, token_expiry, 321 lastfm_username, applemusic_user_token, created_at, updated_at 322 FROM users 323 WHERE applemusic_user_token IS NOT NULL AND applemusic_user_token != '' 324 ORDER BY id`) 325 if err != nil { 326 return nil, err 327 } 328 defer rows.Close() 329 330 var users []*models.User 331 for rows.Next() { 332 u := &models.User{} 333 if err := rows.Scan( 334 &u.ID, &u.Username, &u.Email, &u.ATProtoDID, &u.MostRecentAtProtoSessionID, 335 &u.SpotifyID, &u.AccessToken, &u.RefreshToken, &u.TokenExpiry, 336 &u.LastFMUsername, &u.AppleMusicUserToken, &u.CreatedAt, &u.UpdatedAt, 337 ); err != nil { 338 return nil, err 339 } 340 users = append(users, u) 341 } 342 if err := rows.Err(); err != nil { 343 return nil, err 344 } 345 return users, nil 346} 347 348func (db *DB) SaveTrack(userID int64, track *models.Track) (int64, error) { 349 // marshal artist json 350 artistString := "" 351 if len(track.Artist) > 0 { 352 bytes, err := json.Marshal(track.Artist) 353 if err != nil { 354 return 0, err 355 } else { 356 artistString = string(bytes) 357 } 358 } 359 360 var trackID int64 361 362 err := db.QueryRow(` 363 INSERT INTO tracks (user_id, name, recording_mbid, artist, album, release_mbid, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped) 364 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 365 RETURNING id`, 366 userID, track.Name, track.RecordingMBID, artistString, track.Album, track.ReleaseMBID, track.URL, track.Timestamp, 367 track.DurationMs, track.ProgressMs, track.ServiceBaseUrl, track.ISRC, track.HasStamped).Scan(&trackID) 368 369 return trackID, err 370} 371 372func (db *DB) UpdateTrack(trackID int64, track *models.Track) error { 373 // marshal artist json 374 artistString := "" 375 if len(track.Artist) > 0 { 376 bytes, err := json.Marshal(track.Artist) 377 if err != nil { 378 return err 379 } else { 380 artistString = string(bytes) 381 } 382 } 383 384 _, err := db.Exec(` 385 UPDATE tracks 386 SET name = ?, 387 recording_mbid = ?, 388 artist = ?, 389 album = ?, 390 release_mbid = ?, 391 url = ?, 392 timestamp = ?, 393 duration_ms = ?, 394 progress_ms = ?, 395 service_base_url = ?, 396 isrc = ?, 397 has_stamped = ? 398 WHERE id = ?`, 399 track.Name, track.RecordingMBID, artistString, track.Album, track.ReleaseMBID, track.URL, track.Timestamp, 400 track.DurationMs, track.ProgressMs, track.ServiceBaseUrl, track.ISRC, track.HasStamped, 401 trackID) 402 403 return err 404} 405 406func (db *DB) GetRecentTracks(userID int64, limit int) ([]*models.Track, error) { 407 rows, err := db.Query(` 408 SELECT id, name, recording_mbid, artist, album, release_mbid, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped 409 FROM tracks 410 WHERE user_id = ? 411 ORDER BY timestamp DESC 412 LIMIT ?`, userID, limit) 413 414 if err != nil { 415 return nil, err 416 } 417 defer rows.Close() 418 419 var tracks []*models.Track 420 421 for rows.Next() { 422 var artistString string 423 track := &models.Track{} 424 err := rows.Scan( 425 &track.PlayID, 426 &track.Name, 427 &track.RecordingMBID, // Scan new field 428 &artistString, // scan to be unmarshaled later 429 &track.Album, 430 &track.ReleaseMBID, // Scan new field 431 &track.URL, 432 &track.Timestamp, 433 &track.DurationMs, 434 &track.ProgressMs, 435 &track.ServiceBaseUrl, 436 &track.ISRC, 437 &track.HasStamped, 438 ) 439 440 if err != nil { 441 return nil, err 442 } 443 444 // unmarshal artist json 445 var artists []models.Artist 446 err = json.Unmarshal([]byte(artistString), &artists) 447 if err != nil { 448 // fallback to previous format 449 artists = []models.Artist{{Name: artistString}} 450 } 451 track.Artist = artists 452 tracks = append(tracks, track) 453 } 454 455 return tracks, nil 456} 457 458// SpotifyQueryMapping maps Spotify sql query results to user structs 459func SpotifyQueryMapping(rows *sql.Rows) ([]*models.User, error) { 460 461 var users []*models.User 462 463 for rows.Next() { 464 user := &models.User{} 465 err := rows.Scan( 466 &user.ID, &user.Username, &user.Email, &user.SpotifyID, 467 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry, 468 &user.CreatedAt, &user.UpdatedAt) 469 if err != nil { 470 return nil, err 471 } 472 users = append(users, user) 473 } 474 475 return users, nil 476} 477 478func (db *DB) GetUsersWithExpiredTokens() ([]*models.User, error) { 479 rows, err := db.Query(` 480 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at 481 FROM users 482 WHERE refresh_token IS NOT NULL AND token_expiry < ? 483 ORDER BY id`, time.Now().UTC()) 484 485 if err != nil { 486 return nil, err 487 } 488 defer rows.Close() 489 490 return SpotifyQueryMapping(rows) 491 492} 493 494func (db *DB) GetAllActiveUsers() ([]*models.User, error) { 495 rows, err := db.Query(` 496 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at 497 FROM users 498 WHERE access_token IS NOT NULL 499 ORDER BY id`) 500 501 if err != nil { 502 return nil, err 503 } 504 defer rows.Close() 505 506 return SpotifyQueryMapping(rows) 507} 508 509func (db *DB) GetAllActiveUsersWithUnExpiredTokens() ([]*models.User, error) { 510 rows, err := db.Query(` 511 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at 512 FROM users 513 WHERE access_token IS NOT NULL AND token_expiry > ? 514 ORDER BY id`, time.Now().UTC()) 515 516 if err != nil { 517 return nil, err 518 } 519 defer rows.Close() 520 521 return SpotifyQueryMapping(rows) 522} 523 524// debug to view current user's information 525// put everything in an 'any' type 526func (db *DB) DebugViewUserInformation(userID int64) (map[string]any, error) { 527 // Use Query instead of QueryRow to get access to column names and ensure only one row is processed 528 rows, err := db.Query(` 529 SELECT * 530 FROM users 531 WHERE id = ? LIMIT 1`, userID) 532 if err != nil { 533 return nil, fmt.Errorf("query failed: %w", err) 534 } 535 defer rows.Close() 536 537 // Get column names 538 cols, err := rows.Columns() 539 if err != nil { 540 return nil, fmt.Errorf("failed to get columns: %w", err) 541 } 542 543 // Check if there's a row to process 544 if !rows.Next() { 545 if err := rows.Err(); err != nil { 546 // Error during rows.Next() or preparing the result set 547 return nil, fmt.Errorf("error checking for row: %w", err) 548 } 549 // No rows found, which is a valid outcome but might be considered an error in some contexts. 550 // Returning sql.ErrNoRows is conventional. 551 return nil, sql.ErrNoRows 552 } 553 554 // Prepare scan arguments: pointers to interface{} slices 555 values := make([]any, len(cols)) 556 scanArgs := make([]any, len(cols)) 557 for i := range values { 558 scanArgs[i] = &values[i] 559 } 560 561 // Scan the row values 562 err = rows.Scan(scanArgs...) 563 if err != nil { 564 return nil, fmt.Errorf("failed to scan row: %w", err) 565 } 566 567 // Check for errors that might have occurred during iteration (after Scan) 568 if err := rows.Err(); err != nil { 569 return nil, fmt.Errorf("error after scanning row: %w", err) 570 } 571 572 // Create the result map 573 resultMap := make(map[string]any, len(cols)) 574 for i, colName := range cols { 575 val := values[i] 576 // SQLite often returns []byte for TEXT columns, convert to string for usability. 577 // Also handle potential nil values appropriately. 578 if b, ok := val.([]byte); ok { 579 resultMap[colName] = string(b) 580 } else { 581 resultMap[colName] = val // Keep nil as nil, numbers as numbers, etc. 582 } 583 } 584 585 return resultMap, nil 586} 587 588func (db *DB) GetLastKnownTimestamp(userID int64) (*time.Time, error) { 589 var lastTimestamp time.Time 590 err := db.QueryRow(` 591 SELECT timestamp 592 FROM tracks 593 WHERE user_id = ? 594 ORDER BY timestamp DESC 595 LIMIT 1`, userID).Scan(&lastTimestamp) 596 597 if err != nil { 598 if err == sql.ErrNoRows { 599 return nil, nil 600 } 601 return nil, fmt.Errorf("failed to query last scrobble timestamp for user %d: %w", userID, err) 602 } 603 604 return &lastTimestamp, nil 605} 606 607//