this repo has no description
1package oauth
2
3import (
4 "encoding/json"
5 "fmt"
6 "log"
7 "net/http"
8 "net/url"
9 "strings"
10
11 "github.com/go-chi/chi/v5"
12 "github.com/gorilla/sessions"
13 "github.com/haileyok/atproto-oauth-golang/helpers"
14 "github.com/lestrrat-go/jwx/v2/jwk"
15 "tangled.sh/tangled.sh/core/appview"
16 "tangled.sh/tangled.sh/core/appview/db"
17 "tangled.sh/tangled.sh/core/appview/middleware"
18 "tangled.sh/tangled.sh/core/appview/oauth"
19 "tangled.sh/tangled.sh/core/appview/oauth/client"
20 "tangled.sh/tangled.sh/core/appview/pages"
21)
22
23const (
24 oauthScope = "atproto transition:generic"
25)
26
27type OAuthHandler struct {
28 Config *appview.Config
29 Pages *pages.Pages
30 Resolver *appview.Resolver
31 Db *db.DB
32 Store *sessions.CookieStore
33 OAuth *oauth.OAuth
34}
35
36func (o *OAuthHandler) Router() http.Handler {
37 r := chi.NewRouter()
38
39 r.Get("/login", o.login)
40 r.Post("/login", o.login)
41
42 r.With(middleware.AuthMiddleware(o.OAuth)).Post("/logout", o.logout)
43
44 r.Get("/oauth/client-metadata.json", o.clientMetadata)
45 r.Get("/oauth/jwks.json", o.jwks)
46 r.Get("/oauth/callback", o.callback)
47 return r
48}
49
50func (o *OAuthHandler) clientMetadata(w http.ResponseWriter, r *http.Request) {
51 metadata := map[string]any{
52 "client_id": o.Config.OAuth.ServerMetadataUrl,
53 "client_name": "Tangled",
54 "subject_type": "public",
55 "client_uri": o.Config.Core.AppviewHost,
56 "redirect_uris": []string{fmt.Sprintf("%s/oauth/callback", o.Config.Core.AppviewHost)},
57 "grant_types": []string{"authorization_code", "refresh_token"},
58 "response_types": []string{"code"},
59 "application_type": "web",
60 "dpop_bound_access_tokens": true,
61 "jwks_uri": fmt.Sprintf("%s/oauth/jwks.json", o.Config.Core.AppviewHost),
62 "scope": "atproto transition:generic",
63 "token_endpoint_auth_method": "private_key_jwt",
64 "token_endpoint_auth_signing_alg": "ES256",
65 }
66
67 w.Header().Set("Content-Type", "application/json")
68 w.WriteHeader(http.StatusOK)
69 json.NewEncoder(w).Encode(metadata)
70}
71
72func (o *OAuthHandler) jwks(w http.ResponseWriter, r *http.Request) {
73 jwks := o.Config.OAuth.Jwks
74 pubKey, err := pubKeyFromJwk(jwks)
75 if err != nil {
76 log.Printf("error parsing public key: %v", err)
77 http.Error(w, err.Error(), http.StatusInternalServerError)
78 return
79 }
80
81 response := helpers.CreateJwksResponseObject(pubKey)
82
83 w.Header().Set("Content-Type", "application/json")
84 w.WriteHeader(http.StatusOK)
85 json.NewEncoder(w).Encode(response)
86}
87
88func (o *OAuthHandler) login(w http.ResponseWriter, r *http.Request) {
89 switch r.Method {
90 case http.MethodGet:
91 o.Pages.Login(w, pages.LoginParams{})
92 case http.MethodPost:
93 handle := strings.TrimPrefix(r.FormValue("handle"), "@")
94
95 resolved, err := o.Resolver.ResolveIdent(r.Context(), handle)
96 if err != nil {
97 log.Println("failed to resolve handle:", err)
98 o.Pages.Notice(w, "login-msg", fmt.Sprintf("\"%s\" is an invalid handle.", handle))
99 return
100 }
101 oauthClient, err := client.NewClient(
102 o.Config.OAuth.ServerMetadataUrl,
103 o.Config.OAuth.Jwks,
104 fmt.Sprintf("%s/oauth/callback", o.Config.Core.AppviewHost))
105
106 if err != nil {
107 log.Println("failed to create oauth client:", err)
108 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
109 return
110 }
111
112 authServer, err := oauthClient.ResolvePdsAuthServer(r.Context(), resolved.PDSEndpoint())
113 if err != nil {
114 log.Println("failed to resolve auth server:", err)
115 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
116 return
117 }
118
119 authMeta, err := oauthClient.FetchAuthServerMetadata(r.Context(), authServer)
120 if err != nil {
121 log.Println("failed to fetch auth server metadata:", err)
122 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
123 return
124 }
125
126 dpopKey, err := helpers.GenerateKey(nil)
127 if err != nil {
128 log.Println("failed to generate dpop key:", err)
129 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
130 return
131 }
132
133 dpopKeyJson, err := json.Marshal(dpopKey)
134 if err != nil {
135 log.Println("failed to marshal dpop key:", err)
136 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
137 return
138 }
139
140 parResp, err := oauthClient.SendParAuthRequest(r.Context(), authServer, authMeta, handle, oauthScope, dpopKey)
141 if err != nil {
142 log.Println("failed to send par auth request:", err)
143 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
144 return
145 }
146
147 err = db.SaveOAuthRequest(o.Db, db.OAuthRequest{
148 Did: resolved.DID.String(),
149 PdsUrl: resolved.PDSEndpoint(),
150 Handle: handle,
151 AuthserverIss: authMeta.Issuer,
152 PkceVerifier: parResp.PkceVerifier,
153 DpopAuthserverNonce: parResp.DpopAuthserverNonce,
154 DpopPrivateJwk: string(dpopKeyJson),
155 State: parResp.State,
156 })
157 if err != nil {
158 log.Println("failed to save oauth request:", err)
159 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
160 return
161 }
162
163 u, _ := url.Parse(authMeta.AuthorizationEndpoint)
164 u.RawQuery = fmt.Sprintf("client_id=%s&request_uri=%s", url.QueryEscape(o.Config.OAuth.ServerMetadataUrl), parResp.RequestUri)
165 o.Pages.HxRedirect(w, u.String())
166 }
167}
168
169func (o *OAuthHandler) callback(w http.ResponseWriter, r *http.Request) {
170 state := r.FormValue("state")
171
172 oauthRequest, err := db.GetOAuthRequestByState(o.Db, state)
173 if err != nil {
174 log.Println("failed to get oauth request:", err)
175 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
176 return
177 }
178
179 defer func() {
180 err := db.DeleteOAuthRequestByState(o.Db, state)
181 if err != nil {
182 log.Println("failed to delete oauth request for state:", state, err)
183 }
184 }()
185
186 code := r.FormValue("code")
187 if code == "" {
188 log.Println("missing code for state: ", state)
189 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
190 return
191 }
192
193 iss := r.FormValue("iss")
194 if iss == "" {
195 log.Println("missing iss for state: ", state)
196 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
197 return
198 }
199
200 oauthClient, err := client.NewClient(
201 o.Config.OAuth.ServerMetadataUrl,
202 o.Config.OAuth.Jwks,
203 fmt.Sprintf("%s/oauth/callback", o.Config.Core.AppviewHost))
204
205 if err != nil {
206 log.Println("failed to create oauth client:", err)
207 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
208 return
209 }
210
211 jwk, err := helpers.ParseJWKFromBytes([]byte(oauthRequest.DpopPrivateJwk))
212 if err != nil {
213 log.Println("failed to parse jwk:", err)
214 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
215 return
216 }
217
218 tokenResp, err := oauthClient.InitialTokenRequest(
219 r.Context(),
220 code,
221 oauthRequest.AuthserverIss,
222 oauthRequest.PkceVerifier,
223 oauthRequest.DpopAuthserverNonce,
224 jwk,
225 )
226 if err != nil {
227 log.Println("failed to get token:", err)
228 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
229 return
230 }
231
232 if tokenResp.Scope != oauthScope {
233 log.Println("scope doesn't match:", tokenResp.Scope)
234 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
235 return
236 }
237
238 err = o.OAuth.SaveSession(w, r, oauthRequest, tokenResp)
239 if err != nil {
240 log.Println("failed to save session:", err)
241 o.Pages.Notice(w, "login-msg", "Failed to authenticate. Try again later.")
242 return
243 }
244
245 log.Println("session saved successfully")
246
247 http.Redirect(w, r, "/", http.StatusFound)
248}
249
250func (o *OAuthHandler) logout(w http.ResponseWriter, r *http.Request) {
251 err := o.OAuth.ClearSession(r, w)
252 if err != nil {
253 log.Println("failed to clear session:", err)
254 http.Redirect(w, r, "/", http.StatusFound)
255 return
256 }
257
258 log.Println("session cleared successfully")
259 http.Redirect(w, r, "/", http.StatusFound)
260}
261
262func pubKeyFromJwk(jwks string) (jwk.Key, error) {
263 k, err := helpers.ParseJWKFromBytes([]byte(jwks))
264 if err != nil {
265 return nil, err
266 }
267 pubKey, err := k.PublicKey()
268 if err != nil {
269 return nil, err
270 }
271 return pubKey, nil
272}