A fork of https://github.com/teal-fm/piper
1package db
2
3import (
4 "database/sql"
5 "encoding/json"
6 "fmt"
7 "log"
8 "os"
9 "path/filepath"
10 "time"
11
12 _ "github.com/mattn/go-sqlite3"
13 "github.com/teal-fm/piper/models"
14)
15
16type DB struct {
17 *sql.DB
18 logger *log.Logger
19}
20
21func New(dbPath string) (*DB, error) {
22 dir := filepath.Dir(dbPath)
23 if dir != "." && dir != "/" {
24 os.MkdirAll(dir, 0755)
25 }
26
27 db, err := sql.Open("sqlite3", dbPath)
28 if err != nil {
29 return nil, err
30 }
31
32 // Test the connection
33 if err = db.Ping(); err != nil {
34 return nil, err
35 }
36 logger := log.New(os.Stdout, "db: ", log.LstdFlags|log.Lmsgprefix)
37
38 return &DB{db, logger}, nil
39}
40
41func (db *DB) Initialize() error {
42 _, err := db.Exec(`
43 CREATE TABLE IF NOT EXISTS users (
44 id INTEGER PRIMARY KEY AUTOINCREMENT,
45 username TEXT, -- Made nullable, might not have username initially
46 email TEXT UNIQUE, -- Made nullable
47 atproto_did TEXT UNIQUE, -- Atproto DID (identifier)
48 most_recent_at_session_id TEXT, -- Most recent oAuth session id
49 spotify_id TEXT UNIQUE, -- Spotify specific ID
50 access_token TEXT, -- Spotify access token
51 refresh_token TEXT, -- Spotify refresh token
52 token_expiry TIMESTAMP, -- Spotify token expiry
53 lastfm_username TEXT, -- Last.fm username
54 applemusic_user_token TEXT, -- Apple Music MusicKit user token
55 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Use default
56 updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -- Use default
57 )`)
58 if err != nil {
59 return err
60 }
61
62 // Add missing columns to users table if they don't exist
63 _, err = db.Exec(`ALTER TABLE users ADD COLUMN applemusic_user_token TEXT`)
64 if err != nil && err.Error() != "duplicate column name: applemusic_user_token" {
65 return err
66 }
67
68 _, err = db.Exec(`
69 CREATE TABLE IF NOT EXISTS tracks (
70 id INTEGER PRIMARY KEY AUTOINCREMENT,
71 user_id INTEGER NOT NULL,
72 name TEXT NOT NULL,
73 recording_mbid TEXT, -- Added
74 artist TEXT NOT NULL, -- should be JSONB in PostgreSQL if we ever switch
75 album TEXT NOT NULL,
76 release_mbid TEXT, -- Added
77 url TEXT NOT NULL,
78 timestamp TIMESTAMP,
79 duration_ms INTEGER,
80 progress_ms INTEGER,
81 service_base_url TEXT,
82 isrc TEXT,
83 has_stamped BOOLEAN,
84 FOREIGN KEY (user_id) REFERENCES users(id)
85 )`)
86 if err != nil {
87 return err
88 }
89
90 _, err = db.Exec(`
91 CREATE TABLE IF NOT EXISTS atproto_state (
92 id INTEGER PRIMARY KEY AUTOINCREMENT,
93 state TEXT NOT NULL,
94 authserver_url TEXT,
95 account_did TEXT,
96 scopes TEXT,
97 request_uri TEXT,
98 authserver_token_endpoint TEXT,
99 authserver_revocation_endpoint TEXT,
100 pkce_verifier TEXT,
101 dpop_authserver_nonce TEXT,
102 dpop_privatekey_multibase TEXT,
103 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
104 );
105 CREATE INDEX IF NOT EXISTS atproto_state_state ON atproto_state(state);
106
107`)
108 if err != nil {
109 return err
110 }
111
112 _, err = db.Exec(`
113 CREATE TABLE IF NOT EXISTS atproto_sessions (
114 id INTEGER PRIMARY KEY AUTOINCREMENT,
115 look_up_key TEXT NOT NULL,
116 account_did TEXT,
117 session_id TEXT,
118 host_url TEXT,
119 authserver_url TEXT,
120 authserver_token_endpoint TEXT,
121 authserver_revocation_endpoint TEXT,
122 scopes TEXT,
123 access_token TEXT,
124 refresh_token TEXT,
125 dpop_authserver_nonce TEXT,
126 dpop_host_nonce TEXT,
127 dpop_privatekey_multibase TEXT,
128 created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
129 );
130 CREATE INDEX IF NOT EXISTS idx_atproto_sessions_look_up_key ON atproto_sessions(look_up_key);
131`)
132 if err != nil {
133 return err
134 }
135
136 // Add columns recording_mbid and release_mbid to tracks table if they don't exist
137 _, err = db.Exec(`ALTER TABLE tracks ADD COLUMN recording_mbid TEXT`)
138 if err != nil && err.Error() != "duplicate column name: recording_mbid" {
139 // Handle errors other than 'duplicate column'
140 return err
141 }
142 _, err = db.Exec(`ALTER TABLE tracks ADD COLUMN release_mbid TEXT`)
143 if err != nil && err.Error() != "duplicate column name: release_mbid" {
144 // Handle errors other than 'duplicate column'
145 return err
146 }
147
148 return nil
149}
150
151// Apple Music developer token persistence
152func (db *DB) ensureAppleMusicTokenTable() error {
153 _, err := db.Exec(`
154 CREATE TABLE IF NOT EXISTS applemusic_token (
155 token TEXT,
156 expires_at TIMESTAMP
157 )`)
158 return err
159}
160
161func (db *DB) GetAppleMusicDeveloperToken() (string, time.Time, bool, error) {
162 if err := db.ensureAppleMusicTokenTable(); err != nil {
163 return "", time.Time{}, false, err
164 }
165 var token string
166 var exp time.Time
167 err := db.QueryRow(`SELECT token, expires_at FROM applemusic_token LIMIT 1`).Scan(&token, &exp)
168 if err == sql.ErrNoRows {
169 return "", time.Time{}, false, nil
170 }
171 if err != nil {
172 return "", time.Time{}, false, err
173 }
174 return token, exp, true, nil
175}
176
177func (db *DB) SaveAppleMusicDeveloperToken(token string, exp time.Time) error {
178 if err := db.ensureAppleMusicTokenTable(); err != nil {
179 return err
180 }
181 // Replace existing single row
182 _, err := db.Exec(`DELETE FROM applemusic_token`)
183 if err != nil {
184 return err
185 }
186 _, err = db.Exec(`INSERT INTO applemusic_token (token, expires_at) VALUES (?, ?)`, token, exp)
187 return err
188}
189
190// create user without spotify id
191func (db *DB) CreateUser(user *models.User) (int64, error) {
192 now := time.Now().UTC()
193
194 result, err := db.Exec(`
195 INSERT INTO users (username, email, created_at, updated_at)
196 VALUES (?, ?, ?, ?)`,
197 user.Username, user.Email, now, now)
198
199 if err != nil {
200 return 0, err
201 }
202
203 return result.LastInsertId()
204}
205
206// add spotify session to user, returning the updated user
207func (db *DB) AddSpotifySession(userID int64, username, email, spotifyId, accessToken, refreshToken string, tokenExpiry time.Time) (*models.User, error) {
208 now := time.Now().UTC()
209
210 _, err := db.Exec(`
211 UPDATE users SET username = ?, email = ?, spotify_id = ?, access_token = ?, refresh_token = ?, token_expiry = ?, created_at = ?, updated_at = ?
212 WHERE id == ?
213 `,
214 username, email, spotifyId, accessToken, refreshToken, tokenExpiry, now, now, userID)
215 if err != nil {
216 return nil, err
217 }
218
219 user, err := db.GetUserByID(userID)
220 if err != nil {
221 return nil, err
222 }
223
224 return user, err
225}
226
227func (db *DB) GetUserByID(ID int64) (*models.User, error) {
228 user := &models.User{}
229
230 err := db.QueryRow(`
231 SELECT id,
232 username,
233 email,
234 atproto_did,
235 most_recent_at_session_id,
236 spotify_id,
237 access_token,
238 refresh_token,
239 token_expiry,
240 lastfm_username,
241 applemusic_user_token,
242 created_at,
243 updated_at
244 FROM users WHERE id = ?`, ID).Scan(
245 &user.ID, &user.Username, &user.Email, &user.ATProtoDID, &user.MostRecentAtProtoSessionID, &user.SpotifyID,
246 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry,
247 &user.LastFMUsername, &user.AppleMusicUserToken,
248 &user.CreatedAt, &user.UpdatedAt)
249
250 if err == sql.ErrNoRows {
251 return nil, nil
252 }
253
254 if err != nil {
255 return nil, err
256 }
257
258 return user, nil
259}
260
261func (db *DB) GetUserBySpotifyID(spotifyID string) (*models.User, error) {
262 user := &models.User{}
263
264 err := db.QueryRow(`
265 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, lastfm_username, applemusic_user_token, created_at, updated_at
266 FROM users WHERE spotify_id = ?`, spotifyID).Scan(
267 &user.ID, &user.Username, &user.Email, &user.SpotifyID,
268 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry,
269 &user.LastFMUsername, &user.AppleMusicUserToken,
270 &user.CreatedAt, &user.UpdatedAt)
271
272 if err == sql.ErrNoRows {
273 return nil, nil
274 }
275
276 if err != nil {
277 return nil, err
278 }
279
280 return user, nil
281}
282
283func (db *DB) UpdateUserToken(userID int64, accessToken, refreshToken string, expiry time.Time) error {
284 now := time.Now().UTC()
285
286 _, err := db.Exec(`
287 UPDATE users
288 SET access_token = ?, refresh_token = ?, token_expiry = ?, updated_at = ?
289 WHERE id = ?`,
290 accessToken, refreshToken, expiry, now, userID)
291
292 return err
293}
294
295func (db *DB) UpdateAppleMusicUserToken(userID int64, userToken string) error {
296 now := time.Now().UTC()
297 _, err := db.Exec(`
298 UPDATE users
299 SET applemusic_user_token = ?, updated_at = ?
300 WHERE id = ?`,
301 userToken, now, userID)
302 return err
303}
304
305// ClearAppleMusicUserToken removes the stored Apple Music user token for a user
306func (db *DB) ClearAppleMusicUserToken(userID int64) error {
307 now := time.Now().UTC()
308 _, err := db.Exec(`
309 UPDATE users
310 SET applemusic_user_token = NULL, updated_at = ?
311 WHERE id = ?`,
312 now, userID)
313 return err
314}
315
316// GetAllAppleMusicLinkedUsers returns users who have an Apple Music user token set
317func (db *DB) GetAllAppleMusicLinkedUsers() ([]*models.User, error) {
318 rows, err := db.Query(`
319 SELECT id, username, email, atproto_did, most_recent_at_session_id,
320 spotify_id, access_token, refresh_token, token_expiry,
321 lastfm_username, applemusic_user_token, created_at, updated_at
322 FROM users
323 WHERE applemusic_user_token IS NOT NULL AND applemusic_user_token != ''
324 ORDER BY id`)
325 if err != nil {
326 return nil, err
327 }
328 defer rows.Close()
329
330 var users []*models.User
331 for rows.Next() {
332 u := &models.User{}
333 if err := rows.Scan(
334 &u.ID, &u.Username, &u.Email, &u.ATProtoDID, &u.MostRecentAtProtoSessionID,
335 &u.SpotifyID, &u.AccessToken, &u.RefreshToken, &u.TokenExpiry,
336 &u.LastFMUsername, &u.AppleMusicUserToken, &u.CreatedAt, &u.UpdatedAt,
337 ); err != nil {
338 return nil, err
339 }
340 users = append(users, u)
341 }
342 if err := rows.Err(); err != nil {
343 return nil, err
344 }
345 return users, nil
346}
347
348func (db *DB) SaveTrack(userID int64, track *models.Track) (int64, error) {
349 // marshal artist json
350 artistString := ""
351 if len(track.Artist) > 0 {
352 bytes, err := json.Marshal(track.Artist)
353 if err != nil {
354 return 0, err
355 } else {
356 artistString = string(bytes)
357 }
358 }
359
360 var trackID int64
361
362 err := db.QueryRow(`
363 INSERT INTO tracks (user_id, name, recording_mbid, artist, album, release_mbid, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped)
364 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
365 RETURNING id`,
366 userID, track.Name, track.RecordingMBID, artistString, track.Album, track.ReleaseMBID, track.URL, track.Timestamp,
367 track.DurationMs, track.ProgressMs, track.ServiceBaseUrl, track.ISRC, track.HasStamped).Scan(&trackID)
368
369 return trackID, err
370}
371
372func (db *DB) UpdateTrack(trackID int64, track *models.Track) error {
373 // marshal artist json
374 artistString := ""
375 if len(track.Artist) > 0 {
376 bytes, err := json.Marshal(track.Artist)
377 if err != nil {
378 return err
379 } else {
380 artistString = string(bytes)
381 }
382 }
383
384 _, err := db.Exec(`
385 UPDATE tracks
386 SET name = ?,
387 recording_mbid = ?,
388 artist = ?,
389 album = ?,
390 release_mbid = ?,
391 url = ?,
392 timestamp = ?,
393 duration_ms = ?,
394 progress_ms = ?,
395 service_base_url = ?,
396 isrc = ?,
397 has_stamped = ?
398 WHERE id = ?`,
399 track.Name, track.RecordingMBID, artistString, track.Album, track.ReleaseMBID, track.URL, track.Timestamp,
400 track.DurationMs, track.ProgressMs, track.ServiceBaseUrl, track.ISRC, track.HasStamped,
401 trackID)
402
403 return err
404}
405
406func (db *DB) GetRecentTracks(userID int64, limit int) ([]*models.Track, error) {
407 rows, err := db.Query(`
408 SELECT id, name, recording_mbid, artist, album, release_mbid, url, timestamp, duration_ms, progress_ms, service_base_url, isrc, has_stamped
409 FROM tracks
410 WHERE user_id = ?
411 ORDER BY timestamp DESC
412 LIMIT ?`, userID, limit)
413
414 if err != nil {
415 return nil, err
416 }
417 defer rows.Close()
418
419 var tracks []*models.Track
420
421 for rows.Next() {
422 var artistString string
423 track := &models.Track{}
424 err := rows.Scan(
425 &track.PlayID,
426 &track.Name,
427 &track.RecordingMBID, // Scan new field
428 &artistString, // scan to be unmarshaled later
429 &track.Album,
430 &track.ReleaseMBID, // Scan new field
431 &track.URL,
432 &track.Timestamp,
433 &track.DurationMs,
434 &track.ProgressMs,
435 &track.ServiceBaseUrl,
436 &track.ISRC,
437 &track.HasStamped,
438 )
439
440 if err != nil {
441 return nil, err
442 }
443
444 // unmarshal artist json
445 var artists []models.Artist
446 err = json.Unmarshal([]byte(artistString), &artists)
447 if err != nil {
448 // fallback to previous format
449 artists = []models.Artist{{Name: artistString}}
450 }
451 track.Artist = artists
452 tracks = append(tracks, track)
453 }
454
455 return tracks, nil
456}
457
458// SpotifyQueryMapping maps Spotify sql query results to user structs
459func SpotifyQueryMapping(rows *sql.Rows) ([]*models.User, error) {
460
461 var users []*models.User
462
463 for rows.Next() {
464 user := &models.User{}
465 err := rows.Scan(
466 &user.ID, &user.Username, &user.Email, &user.SpotifyID,
467 &user.AccessToken, &user.RefreshToken, &user.TokenExpiry,
468 &user.CreatedAt, &user.UpdatedAt)
469 if err != nil {
470 return nil, err
471 }
472 users = append(users, user)
473 }
474
475 return users, nil
476}
477
478func (db *DB) GetUsersWithExpiredTokens() ([]*models.User, error) {
479 rows, err := db.Query(`
480 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at
481 FROM users
482 WHERE refresh_token IS NOT NULL AND token_expiry < ?
483 ORDER BY id`, time.Now().UTC())
484
485 if err != nil {
486 return nil, err
487 }
488 defer rows.Close()
489
490 return SpotifyQueryMapping(rows)
491
492}
493
494func (db *DB) GetAllActiveUsers() ([]*models.User, error) {
495 rows, err := db.Query(`
496 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at
497 FROM users
498 WHERE access_token IS NOT NULL
499 ORDER BY id`)
500
501 if err != nil {
502 return nil, err
503 }
504 defer rows.Close()
505
506 return SpotifyQueryMapping(rows)
507}
508
509func (db *DB) GetAllActiveUsersWithUnExpiredTokens() ([]*models.User, error) {
510 rows, err := db.Query(`
511 SELECT id, username, email, spotify_id, access_token, refresh_token, token_expiry, created_at, updated_at
512 FROM users
513 WHERE access_token IS NOT NULL AND token_expiry > ?
514 ORDER BY id`, time.Now().UTC())
515
516 if err != nil {
517 return nil, err
518 }
519 defer rows.Close()
520
521 return SpotifyQueryMapping(rows)
522}
523
524// debug to view current user's information
525// put everything in an 'any' type
526func (db *DB) DebugViewUserInformation(userID int64) (map[string]any, error) {
527 // Use Query instead of QueryRow to get access to column names and ensure only one row is processed
528 rows, err := db.Query(`
529 SELECT *
530 FROM users
531 WHERE id = ? LIMIT 1`, userID)
532 if err != nil {
533 return nil, fmt.Errorf("query failed: %w", err)
534 }
535 defer rows.Close()
536
537 // Get column names
538 cols, err := rows.Columns()
539 if err != nil {
540 return nil, fmt.Errorf("failed to get columns: %w", err)
541 }
542
543 // Check if there's a row to process
544 if !rows.Next() {
545 if err := rows.Err(); err != nil {
546 // Error during rows.Next() or preparing the result set
547 return nil, fmt.Errorf("error checking for row: %w", err)
548 }
549 // No rows found, which is a valid outcome but might be considered an error in some contexts.
550 // Returning sql.ErrNoRows is conventional.
551 return nil, sql.ErrNoRows
552 }
553
554 // Prepare scan arguments: pointers to interface{} slices
555 values := make([]any, len(cols))
556 scanArgs := make([]any, len(cols))
557 for i := range values {
558 scanArgs[i] = &values[i]
559 }
560
561 // Scan the row values
562 err = rows.Scan(scanArgs...)
563 if err != nil {
564 return nil, fmt.Errorf("failed to scan row: %w", err)
565 }
566
567 // Check for errors that might have occurred during iteration (after Scan)
568 if err := rows.Err(); err != nil {
569 return nil, fmt.Errorf("error after scanning row: %w", err)
570 }
571
572 // Create the result map
573 resultMap := make(map[string]any, len(cols))
574 for i, colName := range cols {
575 val := values[i]
576 // SQLite often returns []byte for TEXT columns, convert to string for usability.
577 // Also handle potential nil values appropriately.
578 if b, ok := val.([]byte); ok {
579 resultMap[colName] = string(b)
580 } else {
581 resultMap[colName] = val // Keep nil as nil, numbers as numbers, etc.
582 }
583 }
584
585 return resultMap, nil
586}
587
588func (db *DB) GetLastKnownTimestamp(userID int64) (*time.Time, error) {
589 var lastTimestamp time.Time
590 err := db.QueryRow(`
591 SELECT timestamp
592 FROM tracks
593 WHERE user_id = ?
594 ORDER BY timestamp DESC
595 LIMIT 1`, userID).Scan(&lastTimestamp)
596
597 if err != nil {
598 if err == sql.ErrNoRows {
599 return nil, nil
600 }
601 return nil, fmt.Errorf("failed to query last scrobble timestamp for user %d: %w", userID, err)
602 }
603
604 return &lastTimestamp, nil
605}
606
607//