this repo has no description
1package oauth
2
3import (
4 "fmt"
5 "log"
6 "net/http"
7 "time"
8
9 "github.com/gorilla/sessions"
10 oauth "github.com/haileyok/atproto-oauth-golang"
11 "github.com/haileyok/atproto-oauth-golang/helpers"
12 "tangled.sh/tangled.sh/core/appview"
13 "tangled.sh/tangled.sh/core/appview/db"
14 "tangled.sh/tangled.sh/core/appview/oauth/client"
15 xrpc "tangled.sh/tangled.sh/core/appview/xrpcclient"
16)
17
18type OAuthRequest struct {
19 ID uint
20 AuthserverIss string
21 State string
22 Did string
23 PdsUrl string
24 PkceVerifier string
25 DpopAuthserverNonce string
26 DpopPrivateJwk string
27}
28
29type OAuth struct {
30 Store *sessions.CookieStore
31 Db *db.DB
32 Config *appview.Config
33}
34
35func NewOAuth(db *db.DB, config *appview.Config) *OAuth {
36 return &OAuth{
37 Store: sessions.NewCookieStore([]byte(config.Core.CookieSecret)),
38 Db: db,
39 Config: config,
40 }
41}
42
43func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, oreq db.OAuthRequest, oresp *oauth.TokenResponse) error {
44 // first we save the did in the user session
45 userSession, err := o.Store.Get(r, appview.SessionName)
46 if err != nil {
47 return err
48 }
49
50 userSession.Values[appview.SessionDid] = oreq.Did
51 userSession.Values[appview.SessionAuthenticated] = true
52 err = userSession.Save(r, w)
53 if err != nil {
54 return fmt.Errorf("error saving user session: %w", err)
55 }
56
57 // then save the whole thing in the db
58 session := db.OAuthSession{
59 Did: oreq.Did,
60 Handle: oreq.Handle,
61 PdsUrl: oreq.PdsUrl,
62 DpopAuthserverNonce: oreq.DpopAuthserverNonce,
63 AuthServerIss: oreq.AuthserverIss,
64 DpopPrivateJwk: oreq.DpopPrivateJwk,
65 AccessJwt: oresp.AccessToken,
66 RefreshJwt: oresp.RefreshToken,
67 Expiry: time.Now().Add(time.Duration(oresp.ExpiresIn) * time.Second).Format(time.RFC3339),
68 }
69
70 return db.SaveOAuthSession(o.Db, session)
71}
72
73func (o *OAuth) ClearSession(r *http.Request, w http.ResponseWriter) error {
74 userSession, err := o.Store.Get(r, appview.SessionName)
75 if err != nil || userSession.IsNew {
76 return fmt.Errorf("error getting user session (or new session?): %w", err)
77 }
78
79 did := userSession.Values[appview.SessionDid].(string)
80
81 err = db.DeleteOAuthSessionByDid(o.Db, did)
82 if err != nil {
83 return fmt.Errorf("error deleting oauth session: %w", err)
84 }
85
86 userSession.Options.MaxAge = -1
87
88 return userSession.Save(r, w)
89}
90
91func (o *OAuth) GetSession(r *http.Request) (*db.OAuthSession, bool, error) {
92 userSession, err := o.Store.Get(r, appview.SessionName)
93 if err != nil || userSession.IsNew {
94 return nil, false, fmt.Errorf("error getting user session (or new session?): %w", err)
95 }
96
97 did := userSession.Values[appview.SessionDid].(string)
98 auth := userSession.Values[appview.SessionAuthenticated].(bool)
99
100 session, err := db.GetOAuthSessionByDid(o.Db, did)
101 if err != nil {
102 return nil, false, fmt.Errorf("error getting oauth session: %w", err)
103 }
104
105 expiry, err := time.Parse(time.RFC3339, session.Expiry)
106 if err != nil {
107 return nil, false, fmt.Errorf("error parsing expiry time: %w", err)
108 }
109 if expiry.Sub(time.Now()) <= 5*time.Minute {
110 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
111 if err != nil {
112 return nil, false, err
113 }
114 oauthClient, err := client.NewClient(o.Config.OAuth.ServerMetadataUrl,
115 o.Config.OAuth.Jwks,
116 fmt.Sprintf("%s/oauth/callback", o.Config.Core.AppviewHost))
117
118 if err != nil {
119 return nil, false, err
120 }
121
122 resp, err := oauthClient.RefreshTokenRequest(r.Context(), session.RefreshJwt, session.AuthServerIss, session.DpopAuthserverNonce, privateJwk)
123 if err != nil {
124 return nil, false, err
125 }
126
127 newExpiry := time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second).Format(time.RFC3339)
128 err = db.RefreshOAuthSession(o.Db, did, resp.AccessToken, resp.RefreshToken, newExpiry)
129 if err != nil {
130 return nil, false, fmt.Errorf("error refreshing oauth session: %w", err)
131 }
132
133 // update the current session
134 session.AccessJwt = resp.AccessToken
135 session.RefreshJwt = resp.RefreshToken
136 session.DpopAuthserverNonce = resp.DpopAuthserverNonce
137 session.Expiry = newExpiry
138 }
139
140 return session, auth, nil
141}
142
143type User struct {
144 Handle string
145 Did string
146 Pds string
147}
148
149func (a *OAuth) GetUser(r *http.Request) *User {
150 clientSession, err := a.Store.Get(r, appview.SessionName)
151
152 if err != nil || clientSession.IsNew {
153 return nil
154 }
155
156 return &User{
157 Handle: clientSession.Values[appview.SessionHandle].(string),
158 Did: clientSession.Values[appview.SessionDid].(string),
159 Pds: clientSession.Values[appview.SessionPds].(string),
160 }
161}
162
163func (a *OAuth) GetDid(r *http.Request) string {
164 clientSession, err := a.Store.Get(r, appview.SessionName)
165
166 if err != nil || clientSession.IsNew {
167 return ""
168 }
169
170 return clientSession.Values[appview.SessionDid].(string)
171}
172
173func (o *OAuth) AuthorizedClient(r *http.Request) (*xrpc.Client, error) {
174 session, auth, err := o.GetSession(r)
175 if err != nil {
176 return nil, fmt.Errorf("error getting session: %w", err)
177 }
178 if !auth {
179 return nil, fmt.Errorf("not authorized")
180 }
181
182 client := &oauth.XrpcClient{
183 OnDpopPdsNonceChanged: func(did, newNonce string) {
184 err := db.UpdateDpopPdsNonce(o.Db, did, newNonce)
185 if err != nil {
186 log.Printf("error updating dpop pds nonce: %v", err)
187 }
188 },
189 }
190
191 privateJwk, err := helpers.ParseJWKFromBytes([]byte(session.DpopPrivateJwk))
192 if err != nil {
193 return nil, fmt.Errorf("error parsing private jwk: %w", err)
194 }
195
196 xrpcClient := xrpc.NewClient(client, &oauth.XrpcAuthedRequestArgs{
197 Did: session.Did,
198 PdsUrl: session.PdsUrl,
199 DpopPdsNonce: session.PdsUrl,
200 AccessToken: session.AccessJwt,
201 Issuer: session.AuthServerIss,
202 DpopPrivateJwk: privateJwk,
203 })
204
205 return xrpcClient, nil
206}