···11+package atproto
22+33+// Stolen from https://github.com/haileyok/atproto-oauth-golang/blob/f780d3716e2b8a06c87271a2930894319526550e/cmd/web_server_demo/resolution.go
44+55+import (
66+ "context"
77+ "encoding/json"
88+ "fmt"
99+ "io"
1010+ "net"
1111+ "net/http"
1212+ "strings"
1313+1414+ "github.com/bluesky-social/indigo/atproto/syntax"
1515+ oauth "github.com/haileyok/atproto-oauth-golang"
1616+)
1717+1818+// user information struct
1919+type UserInformation struct {
2020+ AuthService string `json:"authService"`
2121+ AuthServer string `json:"authServer"`
2222+ AuthMeta *oauth.OauthAuthorizationMetadata `json:"authMeta"`
2323+ // do NOT save the current handle permanently!
2424+ Handle string `json:"handle"`
2525+ DID string `json:"did"`
2626+}
2727+2828+type Identity struct {
2929+ AlsoKnownAs []string `json:"alsoKnownAs"`
3030+ Service []struct {
3131+ ID string `json:"id"`
3232+ Type string `json:"type"`
3333+ ServiceEndpoint string `json:"serviceEndpoint"`
3434+ } `json:"service"`
3535+}
3636+3737+func (a *ATprotoAuthService) getUserInformation(ctx context.Context, handleOrDid string) (*UserInformation, error) {
3838+ cli := a.client
3939+4040+ // if we have a did skip this
4141+ did := handleOrDid
4242+ err := error(nil)
4343+ // technically checking SHOULD be more rigorous.
4444+ if !strings.HasPrefix(handleOrDid, "did:") {
4545+ did, err = resolveHandle(ctx, did)
4646+ if err != nil {
4747+ return nil, err
4848+ }
4949+ } else {
5050+ did = handleOrDid
5151+ }
5252+5353+ doc, err := getIdentityDocument(ctx, did)
5454+ if err != nil {
5555+ return nil, err
5656+ }
5757+5858+ service, err := getAtprotoPdsService(doc)
5959+ if err != nil {
6060+ return nil, err
6161+ }
6262+6363+ authserver, err := cli.ResolvePdsAuthServer(ctx, service)
6464+ if err != nil {
6565+ return nil, err
6666+ }
6767+6868+ authmeta, err := cli.FetchAuthServerMetadata(ctx, authserver)
6969+ if err != nil {
7070+ return nil, err
7171+ }
7272+7373+ if len(doc.AlsoKnownAs) == 0 {
7474+ return nil, fmt.Errorf("alsoKnownAs is empty, couldn't acquire handle: %w", err)
7575+7676+ }
7777+ handle := strings.Replace(doc.AlsoKnownAs[0], "at://", "", 1)
7878+7979+ return &UserInformation{
8080+ AuthService: service,
8181+ AuthServer: authserver,
8282+ AuthMeta: authmeta,
8383+ Handle: handle,
8484+ DID: did,
8585+ }, nil
8686+}
8787+8888+func resolveHandle(ctx context.Context, handle string) (string, error) {
8989+ var did string
9090+9191+ _, err := syntax.ParseHandle(handle)
9292+ if err != nil {
9393+ return "", err
9494+ }
9595+9696+ recs, err := net.LookupTXT(fmt.Sprintf("_atproto.%s", handle))
9797+ if err == nil {
9898+ for _, rec := range recs {
9999+ if strings.HasPrefix(rec, "did=") {
100100+ did = strings.Split(rec, "did=")[1]
101101+ break
102102+ }
103103+ }
104104+ }
105105+106106+ if did == "" {
107107+ req, err := http.NewRequestWithContext(
108108+ ctx,
109109+ "GET",
110110+ fmt.Sprintf("https://%s/.well-known/atproto-did", handle),
111111+ nil,
112112+ )
113113+ if err != nil {
114114+ return "", err
115115+ }
116116+117117+ resp, err := http.DefaultClient.Do(req)
118118+ if err != nil {
119119+ return "", err
120120+ }
121121+ defer resp.Body.Close()
122122+123123+ if resp.StatusCode != http.StatusOK {
124124+ io.Copy(io.Discard, resp.Body)
125125+ return "", fmt.Errorf("unable to resolve handle")
126126+ }
127127+128128+ b, err := io.ReadAll(resp.Body)
129129+ if err != nil {
130130+ return "", err
131131+ }
132132+133133+ maybeDid := string(b)
134134+135135+ if _, err := syntax.ParseDID(maybeDid); err != nil {
136136+ return "", fmt.Errorf("unable to resolve handle")
137137+ }
138138+139139+ did = maybeDid
140140+ }
141141+142142+ return did, nil
143143+}
144144+145145+// Get the Identity document for a given DID
146146+func getIdentityDocument(ctx context.Context, did string) (*Identity, error) {
147147+ var ustr string
148148+ if strings.HasPrefix(did, "did:plc:") {
149149+ ustr = fmt.Sprintf("https://plc.directory/%s", did)
150150+ } else if strings.HasPrefix(did, "did:web:") {
151151+ ustr = fmt.Sprintf("https://%s/.well-known/did.json", strings.TrimPrefix(did, "did:web:"))
152152+ } else {
153153+ return nil, fmt.Errorf("did was not a supported did type")
154154+ }
155155+156156+ req, err := http.NewRequestWithContext(ctx, "GET", ustr, nil)
157157+ if err != nil {
158158+ return nil, err
159159+ }
160160+161161+ resp, err := http.DefaultClient.Do(req)
162162+ if err != nil {
163163+ return nil, err
164164+ }
165165+ defer resp.Body.Close()
166166+167167+ if resp.StatusCode != http.StatusOK {
168168+ io.Copy(io.Discard, resp.Body)
169169+ return nil, fmt.Errorf("could not find identity in plc registry")
170170+ }
171171+172172+ var identity Identity
173173+ if err := json.NewDecoder(resp.Body).Decode(&identity); err != nil {
174174+ return nil, err
175175+ }
176176+177177+ return &identity, nil
178178+}
179179+180180+// Get the atproto PDS service endpoint from an Identity document
181181+func getAtprotoPdsService(identity *Identity) (string, error) {
182182+ var service string
183183+ for _, svc := range identity.Service {
184184+ if svc.ID == "#atproto_pds" {
185185+ service = svc.ServiceEndpoint
186186+ break
187187+ }
188188+ }
189189+190190+ if service == "" {
191191+ return "", fmt.Errorf("could not find atproto_pds service in identity services")
192192+ }
193193+194194+ return service, nil
195195+}
196196+197197+func resolveServiceFromDoc(identity *Identity) (string, error) {
198198+ service, err := getAtprotoPdsService(identity)
199199+ if err != nil {
200200+ return "", err
201201+ }
202202+203203+ return service, nil
204204+}
+58-25
oauth/oauth2.go
···11+// Modify piper/oauth/oauth2.go
12package oauth
2334import (
···56 "crypto/rand"
67 "crypto/sha256"
78 "encoding/base64"
99+ "errors"
810 "fmt"
1111+ "log"
912 "net/http"
1013 "strings"
11141515+ "github.com/teal-fm/piper/session"
1216 "golang.org/x/oauth2"
1317 "golang.org/x/oauth2/spotify"
1418)
···1822 state string
1923 codeVerifier string
2024 codeChallenge string
2525+ // Added TokenReceiver field to handle user lookup/creation based on token
2626+ tokenReceiver TokenReceiver
2127}
22282323-func generateRandomState() string {
2929+func GenerateRandomState() string {
2430 b := make([]byte, 16)
2531 rand.Read(b)
2632 return base64.URLEncoding.EncodeToString(b)
2733}
28342929-func NewOAuth2Service(clientID, clientSecret, redirectURI string, scopes []string, provider string) *OAuth2Service {
3535+func NewOAuth2Service(clientID, clientSecret, redirectURI string, scopes []string, provider string, tokenReceiver TokenReceiver) *OAuth2Service {
3036 var endpoint oauth2.Endpoint
31373238 switch strings.ToLower(provider) {
3339 case "spotify":
3440 endpoint = spotify.Endpoint
4141+ // Add other providers like Last.fm here
3542 default:
3636- // TODO: support custom endpoints plus lastfm
4343+ // Placeholder for unconfigured providers
4444+ log.Printf("Warning: OAuth2 provider '%s' not explicitly configured. Using placeholder endpoints.", provider)
3745 endpoint = oauth2.Endpoint{
3838- AuthURL: "https://example.com/auth",
4646+ AuthURL: "https://example.com/auth", // Replace with actual endpoints if needed
3947 TokenURL: "https://example.com/token",
4048 }
4149 }
42504343- codeVerifier := generateCodeVerifier()
4444- codeChallenge := generateCodeChallenge(codeVerifier)
5151+ codeVerifier := GenerateCodeVerifier()
5252+ codeChallenge := GenerateCodeChallenge(codeVerifier)
45534654 return &OAuth2Service{
4755 config: oauth2.Config{
···5159 Scopes: scopes,
5260 Endpoint: endpoint,
5361 },
5454- state: generateRandomState(),
6262+ state: GenerateRandomState(),
5563 codeVerifier: codeVerifier,
5664 codeChallenge: codeChallenge,
6565+ tokenReceiver: tokenReceiver, // Store the token receiver
5766 }
5867}
59686069// generateCodeVerifier creates a random code verifier for PKCE
6161-func generateCodeVerifier() string {
6262- // Generate a random string of 32-96 bytes as per RFC 7636
6363- b := make([]byte, 64) // Using 64 bytes (512 bits)
7070+func GenerateCodeVerifier() string {
7171+ b := make([]byte, 64)
6472 rand.Read(b)
6573 return base64.RawURLEncoding.EncodeToString(b)
6674}
67756876// generateCodeChallenge creates a code challenge from the code verifier using S256 method
6969-func generateCodeChallenge(verifier string) string {
7070- // S256 method: SHA256 hash of the code verifier
7777+func GenerateCodeChallenge(verifier string) string {
7178 h := sha256.New()
7279 h.Write([]byte(verifier))
7380 return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
7481}
75827676-// redirect to auth page
8383+// HandleLogin implements the AuthService interface method.
7784func (o *OAuth2Service) HandleLogin(w http.ResponseWriter, r *http.Request) {
7878- // use pkce here
7985 opts := []oauth2.AuthCodeOption{
8086 oauth2.SetAuthURLParam("code_challenge", o.codeChallenge),
8187 oauth2.SetAuthURLParam("code_challenge_method", "S256"),
···8490 http.Redirect(w, r, authURL, http.StatusSeeOther)
8591}
86928787-func (o *OAuth2Service) HandleCallback(w http.ResponseWriter, r *http.Request, tokenReceiver TokenReceiver) int64 {
8888- // Verify state
9393+func (o *OAuth2Service) HandleCallback(w http.ResponseWriter, r *http.Request) (int64, error) {
8994 state := r.URL.Query().Get("state")
9095 if state != o.state {
9696+ log.Printf("OAuth2 Callback Error: State mismatch. Expected '%s', got '%s'", o.state, state)
9197 http.Error(w, "State mismatch", http.StatusBadRequest)
9292- return 0
9898+ return 0, errors.New("state mismatch")
9399 }
9410095101 code := r.URL.Query().Get("code")
96102 if code == "" {
9797- http.Error(w, "No code provided", http.StatusBadRequest)
9898- return 0
103103+ errMsg := r.URL.Query().Get("error")
104104+ errDesc := r.URL.Query().Get("error_description")
105105+ log.Printf("OAuth2 Callback Error: No code provided. Error: '%s', Description: '%s'", errMsg, errDesc)
106106+ http.Error(w, fmt.Sprintf("Authorization failed: %s (%s)", errMsg, errDesc), http.StatusBadRequest)
107107+ return 0, errors.New("no code provided")
108108+ }
109109+110110+ if o.tokenReceiver == nil {
111111+ log.Printf("OAuth2 Callback Error: TokenReceiver is not configured for this service.")
112112+ http.Error(w, "Internal server configuration error", http.StatusInternalServerError)
113113+ return 0, errors.New("token receiver not configured")
99114 }
100115101116 opts := []oauth2.AuthCodeOption{
102117 oauth2.SetAuthURLParam("code_verifier", o.codeVerifier),
103118 }
104119120120+ log.Println(code)
121121+105122 token, err := o.config.Exchange(context.Background(), code, opts...)
106123 if err != nil {
124124+ log.Printf("OAuth2 Callback Error: Failed to exchange code for token: %v", err)
107125 http.Error(w, fmt.Sprintf("Error exchanging code for token: %v", err), http.StatusInternalServerError)
108108- return 0
126126+ return 0, errors.New("failed to exchange code for token")
109127 }
110128111111- // Store access token
112112- userID := tokenReceiver.SetAccessToken(token.AccessToken)
129129+ userId, hasSession := session.GetUserID(r.Context())
130130+131131+ // Use the token receiver to store the token and get the user ID
132132+ userID, err := o.tokenReceiver.SetAccessToken(token.AccessToken, userId, hasSession)
133133+ if err != nil {
134134+ log.Printf("OAuth2 Callback Info: TokenReceiver did not return a valid user ID for token: %s...", token.AccessToken[:min(10, len(token.AccessToken))])
135135+ }
113136114114- return userID
137137+ log.Printf("OAuth2 Callback Success: Exchanged code for token, UserID: %d", userID)
138138+ return userID, nil
115139}
116140117117-// GetToken returns the OAuth2 token using the authorization code
141141+// GetToken remains unchanged
118142func (o *OAuth2Service) GetToken(code string) (*oauth2.Token, error) {
119143 opts := []oauth2.AuthCodeOption{
120144 oauth2.SetAuthURLParam("code_verifier", o.codeVerifier),
121145 }
122122-123146 return o.config.Exchange(context.Background(), code, opts...)
124147}
125148149149+// GetClient remains unchanged
126150func (o *OAuth2Service) GetClient(token *oauth2.Token) *http.Client {
127151 return o.config.Client(context.Background(), token)
128152}
129153154154+// RefreshToken remains unchanged
130155func (o *OAuth2Service) RefreshToken(token *oauth2.Token) (*oauth2.Token, error) {
131156 source := o.config.TokenSource(context.Background(), token)
132157 return oauth2.ReuseTokenSource(token, source).Token()
133158}
159159+160160+// Helper function
161161+func min(a, b int) int {
162162+ if a < b {
163163+ return a
164164+ }
165165+ return b
166166+}
+40-27
oauth/oauth_manager.go
···11+// Modify piper/oauth/oauth_manager.go
12package oauth
2334import (
···910 "github.com/teal-fm/piper/session"
1011)
11121212-// TokenReceiver interface for services that can receive OAuth tokens
1313-type TokenReceiver interface {
1414- SetAccessToken(token string) int64
1515-}
1616-1717-// manages multiple oauth2 client services
1313+// manages multiple oauth client services
1814type OAuthServiceManager struct {
1919- oauth2Services map[string]*OAuth2Service
1515+ services map[string]AuthService // Changed from *OAuth2Service to AuthService interface
2016 sessionManager *session.SessionManager
2117 mu sync.RWMutex
2218}
23192420func NewOAuthServiceManager() *OAuthServiceManager {
2521 return &OAuthServiceManager{
2626- oauth2Services: make(map[string]*OAuth2Service),
2222+ services: make(map[string]AuthService), // Initialize the new map
2723 sessionManager: session.NewSessionManager(),
2824 }
2925}
30263131-func (m *OAuthServiceManager) RegisterOAuth2Service(name string, service *OAuth2Service) {
2727+// RegisterService registers any service that implements the AuthService interface.
2828+func (m *OAuthServiceManager) RegisterService(name string, service AuthService) {
3229 m.mu.Lock()
3330 defer m.mu.Unlock()
3434- m.oauth2Services[name] = service
3131+ m.services[name] = service
3232+ log.Printf("Registered auth service: %s", name)
3533}
36343737-func (m *OAuthServiceManager) GetOAuth2Service(name string) (*OAuth2Service, bool) {
3535+// GetService retrieves a registered AuthService by name.
3636+func (m *OAuthServiceManager) GetService(name string) (AuthService, bool) {
3837 m.mu.RLock()
3938 defer m.mu.RUnlock()
4040- service, exists := m.oauth2Services[name]
3939+ service, exists := m.services[name]
4140 return service, exists
4241}
43424443func (m *OAuthServiceManager) HandleLogin(serviceName string) http.HandlerFunc {
4544 return func(w http.ResponseWriter, r *http.Request) {
4645 m.mu.RLock()
4747- oauth2Service, oauth2Exists := m.oauth2Services[serviceName]
4646+ service, exists := m.services[serviceName]
4847 m.mu.RUnlock()
49485050- if oauth2Exists {
5151- oauth2Service.HandleLogin(w, r)
4949+ if exists {
5050+ service.HandleLogin(w, r) // Call interface method
5251 return
5352 }
54535555- http.Error(w, fmt.Sprintf("OAuth service '%s' not found", serviceName), http.StatusNotFound)
5454+ log.Printf("Auth service '%s' not found for login request", serviceName)
5555+ http.Error(w, fmt.Sprintf("Auth service '%s' not found", serviceName), http.StatusNotFound)
5656 }
5757}
58585959-func (m *OAuthServiceManager) HandleCallback(serviceName string, tokenReceiver TokenReceiver) http.HandlerFunc {
5959+func (m *OAuthServiceManager) HandleCallback(serviceName string) http.HandlerFunc {
6060 return func(w http.ResponseWriter, r *http.Request) {
6161 m.mu.RLock()
6262- oauth2Service, oauth2Exists := m.oauth2Services[serviceName]
6262+ service, exists := m.services[serviceName]
6363 m.mu.RUnlock()
64646565- var userID int64
6565+ log.Printf("Logging in with service %s", serviceName)
66666767- if oauth2Exists {
6868- // Handle OAuth2 with PKCE callback
6969- userID = oauth2Service.HandleCallback(w, r, tokenReceiver)
7070- } else {
6767+ if !exists {
6868+ log.Printf("Auth service '%s' not found for callback request", serviceName)
7169 http.Error(w, fmt.Sprintf("OAuth service '%s' not found", serviceName), http.StatusNotFound)
7270 return
7371 }
74727373+ // Call the service's HandleCallback, which now returns the user ID
7474+ userID, err := service.HandleCallback(w, r) // Call interface method
7575+7676+ if err != nil {
7777+ log.Printf("Error handling callback for service '%s': %v", serviceName, err)
7878+ http.Error(w, fmt.Sprintf("Error handling callback for service '%s'", serviceName), http.StatusInternalServerError)
7979+ return
8080+ }
8181+7582 if userID > 0 {
7683 // Create session for the user
7784 session := m.sessionManager.CreateSession(userID)
···7986 // Set session cookie
8087 m.sessionManager.SetSessionCookie(w, session)
81888282- log.Printf("Created session for user %d", userID)
8383- }
8989+ log.Printf("Created session for user %d via service %s", userID, serviceName)
84908585- // Redirect to homepage
8686- http.Redirect(w, r, "/", http.StatusSeeOther)
9191+ // Redirect to homepage after successful login and session creation
9292+ http.Redirect(w, r, "/", http.StatusSeeOther)
9393+ } else {
9494+ log.Printf("Callback for service '%s' did not result in a valid user ID.", serviceName)
9595+ // Optionally redirect to an error page or show an error message
9696+ // For now, just redirecting home, but this might hide errors.
9797+ // Consider adding error handling based on why userID might be 0.
9898+ http.Redirect(w, r, "/", http.StatusSeeOther) // Or redirect to a login/error page
9999+ }
87100 }
88101}
+24
oauth/service.go
···11+// Create piper/oauth/auth_service.go
22+package oauth
33+44+import (
55+ "net/http"
66+)
77+88+// AuthService defines the interface for different authentication services
99+// that can be managed by the OAuthServiceManager.
1010+type AuthService interface {
1111+ // HandleLogin initiates the login flow for the specific service.
1212+ HandleLogin(w http.ResponseWriter, r *http.Request)
1313+ // HandleCallback handles the callback from the authentication provider,
1414+ // processes the response (e.g., exchanges code for token), finds or creates
1515+ // the user in the local system, and returns the user ID.
1616+ // Returns 0 if authentication failed or user could not be determined.
1717+ HandleCallback(w http.ResponseWriter, r *http.Request) (int64, error)
1818+}
1919+2020+type TokenReceiver interface {
2121+ // SetAccessToken stores the access token for the user and returns the user ID.
2222+ // If the user is already logged in, the current ID is provided.
2323+ SetAccessToken(token string, currentId int64, hasSession bool) (int64, error)
2424+}
+44-33
service/spotify/spotify.go
···3030 }
3131}
32323333-// SetAccessToken is called from OAuth callback and now identifies the user
3434-// SetAccessToken is called from OAuth callback and now identifies the user
3535-func (s *SpotifyService) SetAccessToken(token string) int64 {
3333+func (s *SpotifyService) SetAccessToken(token string, userId int64, hasSession bool) (int64, error) {
3634 // Identify the user synchronously instead of in a goroutine
3737- userID := s.identifyAndStoreUser(token)
3838- return userID
3535+ userID, err := s.identifyAndStoreUser(token, userId, hasSession)
3636+ if err != nil {
3737+ log.Printf("Error identifying and storing user: %v", err)
3838+ return 0, err
3939+ }
4040+ return userID, nil
3941}
40424141-func (s *SpotifyService) identifyAndStoreUser(token string) int64 {
4343+func (s *SpotifyService) identifyAndStoreUser(token string, userId int64, hasSession bool) (int64, error) {
4244 // Get Spotify user profile
4345 userProfile, err := s.fetchSpotifyProfile(token)
4446 if err != nil {
4547 log.Printf("Error fetching Spotify profile: %v", err)
4646- return 0
4848+ return 0, err
4749 }
5050+5151+ fmt.Printf("uid: %d hasSession: %t", userId, hasSession)
48524953 // Check if user exists
5054 user, err := s.DB.GetUserBySpotifyID(userProfile.ID)
5155 if err != nil {
5252- log.Printf("Error checking for user: %v", err)
5353- return 0
5656+ // This error might mean DB connection issue, not just user not found.
5757+ log.Printf("Error checking for user by Spotify ID %s: %v", userProfile.ID, err)
5858+ return 0, err
5459 }
55605656- // If user doesn't exist, create them
6161+ tokenExpiryTime := time.Now().Add(1 * time.Hour) // Spotify tokens last ~1 hour
6262+6363+ // We don't intend users to log in via spotify!
5764 if user == nil {
5858- user = &models.User{
5959- Username: userProfile.DisplayName,
6060- Email: userProfile.Email,
6161- SpotifyID: userProfile.ID,
6262- AccessToken: token,
6363- TokenExpiry: time.Now().Add(1 * time.Hour), // Spotify tokens last ~1 hour
6464- }
6565-6666- userID, err := s.DB.CreateUser(user)
6767- if err != nil {
6868- log.Printf("Error creating user: %v", err)
6969- return 0
6565+ if !hasSession {
6666+ log.Printf("User does not seem to exist")
6767+ return 0, fmt.Errorf("user does not seem to exist")
6868+ } else {
6969+ // overwrite prev user
7070+ user, err = s.DB.AddSpotifySession(userId, userProfile.DisplayName, userProfile.Email, userProfile.ID, token, "", tokenExpiryTime)
7171+ if err != nil {
7272+ log.Printf("Error adding Spotify session for user ID %d: %v", userId, err)
7373+ return 0, err
7474+ }
7075 }
7171- user.ID = userID
7276 } else {
7373- // Update token
7474- err = s.DB.UpdateUserToken(user.ID, token, "", time.Now().Add(1*time.Hour))
7777+ // Update existing user's token and expiry
7878+ err = s.DB.UpdateUserToken(user.ID, token, "", tokenExpiryTime)
7579 if err != nil {
7676- log.Printf("Error updating user token: %v", err)
8080+ log.Printf("Error updating user token for user ID %d: %v", user.ID, err)
8181+ // Consider if we should return 0 or the user ID even if update fails
8282+ // Sticking to original behavior: log and continue
8383+ } else {
8484+ log.Printf("Updated token for existing user: %s (ID: %d)", user.Username, user.ID)
7785 }
7886 }
8787+ // Keep the local 'user' object consistent (optional but good practice)
8888+ user.AccessToken = &token
8989+ user.TokenExpiry = &tokenExpiryTime
79908080- // Store token in memory
9191+ // Store token in memory cache regardless of new/existing user
8192 s.mu.Lock()
8293 s.userTokens[user.ID] = token
8394 s.mu.Unlock()
84958585- log.Printf("User authenticated: %s (ID: %d)", user.Username, user.ID)
8686- return user.ID
9696+ log.Printf("User authenticated via Spotify: %s (ID: %d)", user.Username, user.ID)
9797+ return user.ID, nil
8798}
889989100type spotifyProfile struct {
···105116 count := 0
106117 for _, user := range users {
107118 // Only load users with valid tokens
108108- if user.AccessToken != "" && user.TokenExpiry.After(time.Now()) {
109109- s.userTokens[user.ID] = user.AccessToken
119119+ if user.AccessToken != nil && user.TokenExpiry.After(time.Now()) {
120120+ s.userTokens[user.ID] = *user.AccessToken
110121 count++
111122 }
112123 }
···124135 return fmt.Errorf("error loading user: %v", err)
125136 }
126137127127- if user.RefreshToken == "" {
138138+ if user.RefreshToken == nil {
128139 return fmt.Errorf("no refresh token for user %s", userID)
129140 }
130141···150161 refreshed := 0
151162 for _, user := range users {
152163 // Skip users without refresh tokens
153153- if user.RefreshToken == "" {
164164+ if user.RefreshToken == nil {
154165 continue
155166 }
156167
+56-4
session/session.go
···1313 "github.com/teal-fm/piper/db/apikey"
1414)
15151616+// session/session.go
1617type Session struct {
1717- ID string
1818- UserID int64
1919- CreatedAt time.Time
2020- ExpiresAt time.Time
1818+ ID string
1919+ UserID int64
2020+ ATprotoDID string
2121+ ATprotoAccessToken string
2222+ ATprotoRefreshToken string
2323+ CreatedAt time.Time
2424+ ExpiresAt time.Time
2125}
22262327type SessionManager struct {
···251255 }
252256}
253257258258+func WithPossibleAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
259259+ return func(w http.ResponseWriter, r *http.Request) {
260260+ ctx := r.Context()
261261+ authenticated := false // Default to not authenticated
262262+263263+ // 1. Try API key authentication
264264+ apiKeyStr, apiKeyErr := apikey.ExtractApiKey(r)
265265+ if apiKeyErr == nil && apiKeyStr != "" {
266266+ apiKey, valid := sm.apiKeyMgr.GetApiKey(apiKeyStr)
267267+ if valid {
268268+ // API Key valid: Add UserID, API flag, and set auth status
269269+ ctx = WithUserID(ctx, apiKey.UserID)
270270+ ctx = WithAPIRequest(ctx, true)
271271+ authenticated = true
272272+ // Update request context and call handler
273273+ r = r.WithContext(WithAuthStatus(ctx, authenticated))
274274+ handler(w, r)
275275+ return
276276+ }
277277+ // If API key was provided but invalid, we still proceed without auth
278278+ }
279279+280280+ // 2. If no valid API key, try cookie authentication
281281+ if !authenticated { // Only check cookies if API key didn't authenticate
282282+ cookie, err := r.Cookie("session")
283283+ if err == nil { // Cookie exists
284284+ session, exists := sm.GetSession(cookie.Value)
285285+ if exists {
286286+ // Session valid: Add UserID and set auth status
287287+ ctx = WithUserID(ctx, session.UserID)
288288+ // ctx = WithAPIRequest(ctx, false) // Not strictly needed, default is false
289289+ authenticated = true
290290+ }
291291+ // If session cookie exists but is invalid/expired, we proceed without auth
292292+ }
293293+ }
294294+295295+ // 3. Set final auth status (could be true or false) and call handler
296296+ r = r.WithContext(WithAuthStatus(ctx, authenticated))
297297+ handler(w, r)
298298+ }
299299+}
300300+254301// WithAPIAuth is a middleware specifically for API-only endpoints (no cookie fallback, returns 401 instead of redirect)
255302func WithAPIAuth(handler http.HandlerFunc, sm *SessionManager) http.HandlerFunc {
256303 return func(w http.ResponseWriter, r *http.Request) {
···288335const (
289336 userIDKey contextKey = iota
290337 apiRequestKey
338338+ authStatusKey
291339)
292340293341func WithUserID(ctx context.Context, userID int64) context.Context {
···297345func GetUserID(ctx context.Context) (int64, bool) {
298346 userID, ok := ctx.Value(userIDKey).(int64)
299347 return userID, ok
348348+}
349349+350350+func WithAuthStatus(ctx context.Context, isAuthed bool) context.Context {
351351+ return context.WithValue(ctx, authStatusKey, isAuthed)
300352}
301353302354func WithAPIRequest(ctx context.Context, isAPI bool) context.Context {