forked from
tangled.org/core
Monorepo for Tangled
1package oauth
2
3import (
4 "errors"
5 "fmt"
6 "net/http"
7 "time"
8
9 comatproto "github.com/bluesky-social/indigo/api/atproto"
10 "github.com/bluesky-social/indigo/atproto/auth/oauth"
11 atpclient "github.com/bluesky-social/indigo/atproto/client"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 xrpc "github.com/bluesky-social/indigo/xrpc"
14 "github.com/gorilla/sessions"
15 "github.com/lestrrat-go/jwx/v2/jwk"
16 "github.com/posthog/posthog-go"
17 "tangled.org/core/appview/config"
18)
19
20func New(config *config.Config, ph posthog.Client) (*OAuth, error) {
21
22 var oauthConfig oauth.ClientConfig
23 var clientUri string
24
25 if config.Core.Dev {
26 clientUri = "http://127.0.0.1:3000"
27 callbackUri := clientUri + "/oauth/callback"
28 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"})
29 } else {
30 clientUri = config.Core.AppviewHost
31 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
32 callbackUri := clientUri + "/oauth/callback"
33 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"})
34 }
35
36 jwksUri := clientUri + "/oauth/jwks.json"
37
38 authStore, err := NewRedisStore(config.Redis.ToURL())
39 if err != nil {
40 return nil, err
41 }
42
43 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
44
45 return &OAuth{
46 ClientApp: oauth.NewClientApp(&oauthConfig, authStore),
47 Config: config,
48 SessStore: sessStore,
49 JwksUri: jwksUri,
50 Posthog: ph,
51 }, nil
52}
53
54type OAuth struct {
55 ClientApp *oauth.ClientApp
56 SessStore *sessions.CookieStore
57 Config *config.Config
58 JwksUri string
59 Posthog posthog.Client
60}
61
62func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
63 // first we save the did in the user session
64 userSession, err := o.SessStore.Get(r, SessionName)
65 if err != nil {
66 return err
67 }
68
69 userSession.Values[SessionDid] = sessData.AccountDID.String()
70 userSession.Values[SessionPds] = sessData.HostURL
71 userSession.Values[SessionId] = sessData.SessionID
72 userSession.Values[SessionAuthenticated] = true
73 return userSession.Save(r, w)
74}
75
76func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
77 userSession, err := o.SessStore.Get(r, SessionName)
78 if err != nil {
79 return nil, fmt.Errorf("error getting user session: %w", err)
80 }
81 if userSession.IsNew {
82 return nil, fmt.Errorf("no session available for user")
83 }
84
85 d := userSession.Values[SessionDid].(string)
86 sessDid, err := syntax.ParseDID(d)
87 if err != nil {
88 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
89 }
90
91 sessId := userSession.Values[SessionId].(string)
92
93 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
94 if err != nil {
95 return nil, fmt.Errorf("failed to resume session: %w", err)
96 }
97
98 return clientSess, nil
99}
100
101func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
102 userSession, err := o.SessStore.Get(r, SessionName)
103 if err != nil {
104 return fmt.Errorf("error getting user session: %w", err)
105 }
106 if userSession.IsNew {
107 return fmt.Errorf("no session available for user")
108 }
109
110 d := userSession.Values[SessionDid].(string)
111 sessDid, err := syntax.ParseDID(d)
112 if err != nil {
113 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
114 }
115
116 sessId := userSession.Values[SessionId].(string)
117
118 // delete the session
119 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
120
121 // remove the cookie
122 userSession.Options.MaxAge = -1
123 err2 := o.SessStore.Save(r, w, userSession)
124
125 return errors.Join(err1, err2)
126}
127
128func pubKeyFromJwk(jwks string) (jwk.Key, error) {
129 k, err := jwk.ParseKey([]byte(jwks))
130 if err != nil {
131 return nil, err
132 }
133 pubKey, err := k.PublicKey()
134 if err != nil {
135 return nil, err
136 }
137 return pubKey, nil
138}
139
140type User struct {
141 Did string
142 Pds string
143}
144
145func (o *OAuth) GetUser(r *http.Request) *User {
146 sess, err := o.SessStore.Get(r, SessionName)
147
148 if err != nil || sess.IsNew {
149 return nil
150 }
151
152 return &User{
153 Did: sess.Values[SessionDid].(string),
154 Pds: sess.Values[SessionPds].(string),
155 }
156}
157
158func (o *OAuth) GetDid(r *http.Request) string {
159 if u := o.GetUser(r); u != nil {
160 return u.Did
161 }
162
163 return ""
164}
165
166func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
167 session, err := o.ResumeSession(r)
168 if err != nil {
169 return nil, fmt.Errorf("error getting session: %w", err)
170 }
171 return session.APIClient(), nil
172}
173
174// this is a higher level abstraction on ServerGetServiceAuth
175type ServiceClientOpts struct {
176 service string
177 exp int64
178 lxm string
179 dev bool
180}
181
182type ServiceClientOpt func(*ServiceClientOpts)
183
184func WithService(service string) ServiceClientOpt {
185 return func(s *ServiceClientOpts) {
186 s.service = service
187 }
188}
189
190// Specify the Duration in seconds for the expiry of this token
191//
192// The time of expiry is calculated as time.Now().Unix() + exp
193func WithExp(exp int64) ServiceClientOpt {
194 return func(s *ServiceClientOpts) {
195 s.exp = time.Now().Unix() + exp
196 }
197}
198
199func WithLxm(lxm string) ServiceClientOpt {
200 return func(s *ServiceClientOpts) {
201 s.lxm = lxm
202 }
203}
204
205func WithDev(dev bool) ServiceClientOpt {
206 return func(s *ServiceClientOpts) {
207 s.dev = dev
208 }
209}
210
211func (s *ServiceClientOpts) Audience() string {
212 return fmt.Sprintf("did:web:%s", s.service)
213}
214
215func (s *ServiceClientOpts) Host() string {
216 scheme := "https://"
217 if s.dev {
218 scheme = "http://"
219 }
220
221 return scheme + s.service
222}
223
224func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
225 opts := ServiceClientOpts{}
226 for _, o := range os {
227 o(&opts)
228 }
229
230 client, err := o.AuthorizedClient(r)
231 if err != nil {
232 return nil, err
233 }
234
235 // force expiry to atleast 60 seconds in the future
236 sixty := time.Now().Unix() + 60
237 if opts.exp < sixty {
238 opts.exp = sixty
239 }
240
241 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
242 if err != nil {
243 return nil, err
244 }
245
246 return &xrpc.Client{
247 Auth: &xrpc.AuthInfo{
248 AccessJwt: resp.Token,
249 },
250 Host: opts.Host(),
251 Client: &http.Client{
252 Timeout: time.Second * 5,
253 },
254 }, nil
255}