this repo has no description
1package oauth
2
3import (
4 "errors"
5 "fmt"
6 "log/slog"
7 "net/http"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/auth/oauth"
12 atpclient "github.com/bluesky-social/indigo/atproto/client"
13 atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 xrpc "github.com/bluesky-social/indigo/xrpc"
16 "github.com/gorilla/sessions"
17 "github.com/posthog/posthog-go"
18 "tangled.org/core/appview/config"
19 "tangled.org/core/appview/db"
20 "tangled.org/core/idresolver"
21 "tangled.org/core/rbac"
22)
23
24type OAuth struct {
25 ClientApp *oauth.ClientApp
26 SessStore *sessions.CookieStore
27 Config *config.Config
28 JwksUri string
29 Posthog posthog.Client
30 Db *db.DB
31 Enforcer *rbac.Enforcer
32 IdResolver *idresolver.Resolver
33 Logger *slog.Logger
34}
35
36func New(config *config.Config, ph posthog.Client, db *db.DB, enforcer *rbac.Enforcer, res *idresolver.Resolver, logger *slog.Logger) (*OAuth, error) {
37
38 var oauthConfig oauth.ClientConfig
39 var clientUri string
40
41 if config.Core.Dev {
42 clientUri = "http://127.0.0.1:3000"
43 callbackUri := clientUri + "/oauth/callback"
44 oauthConfig = oauth.NewLocalhostConfig(callbackUri, []string{"atproto", "transition:generic"})
45 } else {
46 clientUri = config.Core.AppviewHost
47 clientId := fmt.Sprintf("%s/oauth/client-metadata.json", clientUri)
48 callbackUri := clientUri + "/oauth/callback"
49 oauthConfig = oauth.NewPublicConfig(clientId, callbackUri, []string{"atproto", "transition:generic"})
50 }
51
52 // configure client secret
53 priv, err := atcrypto.ParsePrivateMultibase(config.OAuth.ClientSecret)
54 if err != nil {
55 return nil, err
56 }
57 if err := oauthConfig.SetClientSecret(priv, config.OAuth.ClientKid); err != nil {
58 return nil, err
59 }
60
61 jwksUri := clientUri + "/oauth/jwks.json"
62
63 authStore, err := NewRedisStore(config.Redis.ToURL())
64 if err != nil {
65 return nil, err
66 }
67
68 sessStore := sessions.NewCookieStore([]byte(config.Core.CookieSecret))
69
70 clientApp := oauth.NewClientApp(&oauthConfig, authStore)
71 clientApp.Dir = res.Directory()
72
73 return &OAuth{
74 ClientApp: clientApp,
75 Config: config,
76 SessStore: sessStore,
77 JwksUri: jwksUri,
78 Posthog: ph,
79 Db: db,
80 Enforcer: enforcer,
81 IdResolver: res,
82 Logger: logger,
83 }, nil
84}
85
86func (o *OAuth) SaveSession(w http.ResponseWriter, r *http.Request, sessData *oauth.ClientSessionData) error {
87 // first we save the did in the user session
88 userSession, err := o.SessStore.Get(r, SessionName)
89 if err != nil {
90 return err
91 }
92
93 userSession.Values[SessionDid] = sessData.AccountDID.String()
94 userSession.Values[SessionPds] = sessData.HostURL
95 userSession.Values[SessionId] = sessData.SessionID
96 userSession.Values[SessionAuthenticated] = true
97 return userSession.Save(r, w)
98}
99
100func (o *OAuth) ResumeSession(r *http.Request) (*oauth.ClientSession, error) {
101 userSession, err := o.SessStore.Get(r, SessionName)
102 if err != nil {
103 return nil, fmt.Errorf("error getting user session: %w", err)
104 }
105 if userSession.IsNew {
106 return nil, fmt.Errorf("no session available for user")
107 }
108
109 d := userSession.Values[SessionDid].(string)
110 sessDid, err := syntax.ParseDID(d)
111 if err != nil {
112 return nil, fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
113 }
114
115 sessId := userSession.Values[SessionId].(string)
116
117 clientSess, err := o.ClientApp.ResumeSession(r.Context(), sessDid, sessId)
118 if err != nil {
119 return nil, fmt.Errorf("failed to resume session: %w", err)
120 }
121
122 return clientSess, nil
123}
124
125func (o *OAuth) DeleteSession(w http.ResponseWriter, r *http.Request) error {
126 userSession, err := o.SessStore.Get(r, SessionName)
127 if err != nil {
128 return fmt.Errorf("error getting user session: %w", err)
129 }
130 if userSession.IsNew {
131 return fmt.Errorf("no session available for user")
132 }
133
134 d := userSession.Values[SessionDid].(string)
135 sessDid, err := syntax.ParseDID(d)
136 if err != nil {
137 return fmt.Errorf("malformed DID in session cookie '%s': %w", d, err)
138 }
139
140 sessId := userSession.Values[SessionId].(string)
141
142 // delete the session
143 err1 := o.ClientApp.Logout(r.Context(), sessDid, sessId)
144
145 // remove the cookie
146 userSession.Options.MaxAge = -1
147 err2 := o.SessStore.Save(r, w, userSession)
148
149 return errors.Join(err1, err2)
150}
151
152type User struct {
153 Did string
154 Pds string
155}
156
157func (o *OAuth) GetUser(r *http.Request) *User {
158 sess, err := o.SessStore.Get(r, SessionName)
159
160 if err != nil || sess.IsNew {
161 return nil
162 }
163
164 return &User{
165 Did: sess.Values[SessionDid].(string),
166 Pds: sess.Values[SessionPds].(string),
167 }
168}
169
170func (o *OAuth) GetDid(r *http.Request) string {
171 if u := o.GetUser(r); u != nil {
172 return u.Did
173 }
174
175 return ""
176}
177
178func (o *OAuth) AuthorizedClient(r *http.Request) (*atpclient.APIClient, error) {
179 session, err := o.ResumeSession(r)
180 if err != nil {
181 return nil, fmt.Errorf("error getting session: %w", err)
182 }
183 return session.APIClient(), nil
184}
185
186// this is a higher level abstraction on ServerGetServiceAuth
187type ServiceClientOpts struct {
188 service string
189 exp int64
190 lxm string
191 dev bool
192}
193
194type ServiceClientOpt func(*ServiceClientOpts)
195
196func WithService(service string) ServiceClientOpt {
197 return func(s *ServiceClientOpts) {
198 s.service = service
199 }
200}
201
202// Specify the Duration in seconds for the expiry of this token
203//
204// The time of expiry is calculated as time.Now().Unix() + exp
205func WithExp(exp int64) ServiceClientOpt {
206 return func(s *ServiceClientOpts) {
207 s.exp = time.Now().Unix() + exp
208 }
209}
210
211func WithLxm(lxm string) ServiceClientOpt {
212 return func(s *ServiceClientOpts) {
213 s.lxm = lxm
214 }
215}
216
217func WithDev(dev bool) ServiceClientOpt {
218 return func(s *ServiceClientOpts) {
219 s.dev = dev
220 }
221}
222
223func (s *ServiceClientOpts) Audience() string {
224 return fmt.Sprintf("did:web:%s", s.service)
225}
226
227func (s *ServiceClientOpts) Host() string {
228 scheme := "https://"
229 if s.dev {
230 scheme = "http://"
231 }
232
233 return scheme + s.service
234}
235
236func (o *OAuth) ServiceClient(r *http.Request, os ...ServiceClientOpt) (*xrpc.Client, error) {
237 opts := ServiceClientOpts{}
238 for _, o := range os {
239 o(&opts)
240 }
241
242 client, err := o.AuthorizedClient(r)
243 if err != nil {
244 return nil, err
245 }
246
247 // force expiry to atleast 60 seconds in the future
248 sixty := time.Now().Unix() + 60
249 if opts.exp < sixty {
250 opts.exp = sixty
251 }
252
253 resp, err := comatproto.ServerGetServiceAuth(r.Context(), client, opts.Audience(), opts.exp, opts.lxm)
254 if err != nil {
255 return nil, err
256 }
257
258 return &xrpc.Client{
259 Auth: &xrpc.AuthInfo{
260 AccessJwt: resp.Token,
261 },
262 Host: opts.Host(),
263 Client: &http.Client{
264 Timeout: time.Second * 5,
265 },
266 }, nil
267}