this repo has no description
1package oauth
2
3import (
4 "errors"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/auth/oauth"
12 atpclient "github.com/bluesky-social/indigo/atproto/client"
13 atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 xrpc "github.com/bluesky-social/indigo/xrpc"
16 "github.com/gorilla/sessions"
17 "github.com/posthog/posthog-go"
18 "tangled.org/core/appview/config"
19 "tangled.org/core/appview/db"
20 "tangled.org/core/idresolver"
21 "tangled.org/core/rbac"
22)
23
24type OAuth struct {
25 ClientApp *oauth.ClientApp
26 SessStore *sessions.CookieStore
27 Config *config.Config
28 JwksUri string
29 ClientName string
30 ClientUri string
31 Posthog posthog.Client
32 Db *db.DB
33 Enforcer *rbac.Enforcer
34 IdResolver *idresolver.Resolver
35 Logger *slog.Logger
36}
37
38func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
39 var oauthConfig oauth.ClientConfig
40 var clientUri string
41 if config.Core.Dev {
42 clientUri = "http://127.0.0.1:3000"
43 callbackUri := clientUri + "/oauth/callback"
44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"})
45 } else {
46 clientUri = config.Core.AppviewHost
47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
48 callbackUri := clientUri + "/oauth/callback"
49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"})
50 }
51
52 // configure client secret
53 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
54 if err != nil {
55 return nil, err
56 }
57 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
58 return nil, err
59 }
60
61 jwksUri := clientUri + "/oauth/jwks.json"
62
63 authStore, err := NewRedisStore(config.Redis.ToURL())
64 if err != nil {
65 return nil, err
66 }
67
68 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
69
70 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
71 clientApp.Dir = res.Directory()
72
73 clientName := config.Core.AppviewName
74
75 return &OAuth{
76 ClientApp: clientApp,
77 Config: config,
78 SessStore: sessStore,
79 JwksUri: jwksUri,
80 ClientName: clientName,
81 ClientUri: clientUri,
82 Posthog: ph,
83 Db: db,
84 Enforcer: enforcer,
85 IdResolver: res,
86 Logger: logger,
87 }, nil
88}
89
90func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
91 // first we save the did in the user session
92 userSession, err := o.SessStore.Get(r, SessionName)
93 if err != nil {
94 return err
95 }
96
97 userSession.Values[SessionDid] = sessData.AccountDID.String()
98 userSession.Values[SessionPds] = sessData.HostURL
99 userSession.Values[SessionId] = sessData.SessionID
100 userSession.Values[SessionAuthenticated] = true
101 return userSession.Save(r, w)
102}
103
104func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
105 userSession, err := o.SessStore.Get(r, SessionName)
106 if err != nil {
107 return nil, fmt.Errorf("error getting user session: %w", err)
108 }
109 if userSession.IsNew {
110 return nil, fmt.Errorf("no session available for user")
111 }
112
113 d := userSession.Values[SessionDid].(string)
114 sessDid, err := syntax.ParseDID(d)
115 if err != nil {
116 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
117 }
118
119 sessId := userSession.Values[SessionId].(string)
120
121 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
122 if err != nil {
123 return nil, fmt.Errorf("failed to resume session: %w", err)
124 }
125
126 return clientSess, nil
127}
128
129func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
130 userSession, err := o.SessStore.Get(r, SessionName)
131 if err != nil {
132 return fmt.Errorf("error getting user session: %w", err)
133 }
134 if userSession.IsNew {
135 return fmt.Errorf("no session available for user")
136 }
137
138 d := userSession.Values[SessionDid].(string)
139 sessDid, err := syntax.ParseDID(d)
140 if err != nil {
141 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
142 }
143
144 sessId := userSession.Values[SessionId].(string)
145
146 // delete the session
147 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
148
149 // remove the cookie
150 userSession.Options.MaxAge = -1
151 err2 := o.SessStore.Save(r, w, userSession)
152
153 return errors.Join(err1, err2)
154}
155
156type User struct {
157 Did string
158 Pds string
159}
160
161func (o *OAuth) GetUser(r *http.Request) *User {
162 sess, err := o.SessStore.Get(r, SessionName)
163
164 if err != nil || sess.IsNew {
165 return nil
166 }
167
168 return &User{
169 Did: sess.Values[SessionDid].(string),
170 Pds: sess.Values[SessionPds].(string),
171 }
172}
173
174func (o *OAuth) GetDid(r *http.Request) string {
175 if u := o.GetUser(r); u != nil {
176 return u.Did
177 }
178
179 return ""
180}
181
182func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
183 session, err := o.ResumeSession(r)
184 if err != nil {
185 return nil, fmt.Errorf("error getting session: %w", err)
186 }
187 return session.APIClient(), nil
188}
189
190// this is a higher level abstraction on ServerGetServiceAuth
191type ServiceClientOpts struct {
192 service string
193 exp int64
194 lxm string
195 dev bool
196}
197
198type ServiceClientOpt func(*ServiceClientOpts)
199
200func WithService(service string) ServiceClientOpt {
201 return func(s *ServiceClientOpts) {
202 s.service = service
203 }
204}
205
206// Specify the Duration in seconds for the expiry of this token
207//
208// The time of expiry is calculated as time.Now().Unix() + exp
209func WithExp(exp int64) ServiceClientOpt {
210 return func(s *ServiceClientOpts) {
211 s.exp = time.Now().Unix() + exp
212 }
213}
214
215func WithLxm(lxm string) ServiceClientOpt {
216 return func(s *ServiceClientOpts) {
217 s.lxm = lxm
218 }
219}
220
221func WithDev(dev bool) ServiceClientOpt {
222 return func(s *ServiceClientOpts) {
223 s.dev = dev
224 }
225}
226
227func (s *ServiceClientOpts) Audience() string {
228 return fmt.Sprintf("did:web:%s", s.service)
229}
230
231func (s *ServiceClientOpts) Host() string {
232 scheme := "https://"
233 if s.dev {
234 scheme = "http://"
235 }
236
237 return scheme + s.service
238}
239
240func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
241 opts := ServiceClientOpts{}
242 for _, o := range os {
243 o(&opts)
244 }
245
246 client, err := o.AuthorizedClient(r)
247 if err != nil {
248 return nil, err
249 }
250
251 // force expiry to atleast 60 seconds in the future
252 sixty := time.Now().Unix() + 60
253 if opts.exp < sixty {
254 opts.exp = sixty
255 }
256
257 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
258 if err != nil {
259 return nil, err
260 }
261
262 return &xrpc.Client{
263 Auth: &xrpc.AuthInfo{
264 AccessJwt: resp.Token,
265 },
266 Host: opts.Host(),
267 Client: &http.Client{
268 Timeout: time.Second * 5,
269 },
270 }, nil
271}