this repo has no description
1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strings"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/identity"
12 "github.com/bluesky-social/indigo/xrpc"
13 "github.com/go-chi/chi/v5"
14 "github.com/sotangled/tangled/appview"
15 "github.com/sotangled/tangled/appview/auth"
16 "github.com/sotangled/tangled/appview/db"
17)
18
19type Middleware func(http.Handler) http.Handler
20
21func AuthMiddleware(s *State) Middleware {
22 return func(next http.Handler) http.Handler {
23 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
25 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
26 }
27 if r.Header.Get("HX-Request") == "true" {
28 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
29 w.Header().Set("HX-Redirect", "/login")
30 w.WriteHeader(http.StatusOK)
31 }
32 }
33
34 session, err := s.auth.GetSession(r)
35 if session.IsNew || err != nil {
36 log.Printf("not logged in, redirecting")
37 redirectFunc(w, r)
38 return
39 }
40
41 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
42 if !ok || !authorized {
43 log.Printf("not logged in, redirecting")
44 redirectFunc(w, r)
45 return
46 }
47
48 // refresh if nearing expiry
49 // TODO: dedup with /login
50 expiryStr := session.Values[appview.SessionExpiry].(string)
51 expiry, err := time.Parse(time.RFC3339, expiryStr)
52 if err != nil {
53 log.Println("invalid expiry time", err)
54 redirectFunc(w, r)
55 return
56 }
57 pdsUrl, ok1 := session.Values[appview.SessionPds].(string)
58 did, ok2 := session.Values[appview.SessionDid].(string)
59 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string)
60
61 if !ok1 || !ok2 || !ok3 {
62 log.Println("invalid expiry time", err)
63 redirectFunc(w, r)
64 return
65 }
66
67 if time.Now().After(expiry) {
68 log.Println("token expired, refreshing ...")
69
70 client := xrpc.Client{
71 Host: pdsUrl,
72 Auth: &xrpc.AuthInfo{
73 Did: did,
74 AccessJwt: refreshJwt,
75 RefreshJwt: refreshJwt,
76 },
77 }
78 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
79 if err != nil {
80 log.Println("failed to refresh session", err)
81 redirectFunc(w, r)
82 return
83 }
84
85 sessionish := auth.RefreshSessionWrapper{atSession}
86
87 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
88 if err != nil {
89 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
90 return
91 }
92
93 log.Println("successfully refreshed token")
94 }
95
96 next.ServeHTTP(w, r)
97 })
98 }
99}
100
101func RoleMiddleware(s *State, group string) Middleware {
102 return func(next http.Handler) http.Handler {
103 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
104 // requires auth also
105 actor := s.auth.GetUser(r)
106 if actor == nil {
107 // we need a logged in user
108 log.Printf("not logged in, redirecting")
109 http.Error(w, "Forbiden", http.StatusUnauthorized)
110 return
111 }
112 domain := chi.URLParam(r, "domain")
113 if domain == "" {
114 http.Error(w, "malformed url", http.StatusBadRequest)
115 return
116 }
117
118 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
119 if err != nil || !ok {
120 // we need a logged in user
121 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
122 http.Error(w, "Forbiden", http.StatusUnauthorized)
123 return
124 }
125
126 next.ServeHTTP(w, r)
127 })
128 }
129}
130
131func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
132 return func(next http.Handler) http.Handler {
133 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
134 // requires auth also
135 actor := s.auth.GetUser(r)
136 if actor == nil {
137 // we need a logged in user
138 log.Printf("not logged in, redirecting")
139 http.Error(w, "Forbiden", http.StatusUnauthorized)
140 return
141 }
142 f, err := fullyResolvedRepo(r)
143 if err != nil {
144 http.Error(w, "malformed url", http.StatusBadRequest)
145 return
146 }
147
148 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
149 if err != nil || !ok {
150 // we need a logged in user
151 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
152 http.Error(w, "Forbiden", http.StatusUnauthorized)
153 return
154 }
155
156 next.ServeHTTP(w, r)
157 })
158 }
159}
160
161func StripLeadingAt(next http.Handler) http.Handler {
162 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
163 path := req.URL.Path
164 if strings.HasPrefix(path, "/@") {
165 req.URL.Path = "/" + strings.TrimPrefix(path, "/@")
166 }
167 next.ServeHTTP(w, req)
168 })
169}
170
171func ResolveIdent(s *State) Middleware {
172 return func(next http.Handler) http.Handler {
173 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
174 didOrHandle := chi.URLParam(req, "user")
175
176 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
177 if err != nil {
178 // invalid did or handle
179 log.Println("failed to resolve did/handle:", err)
180 w.WriteHeader(http.StatusNotFound)
181 return
182 }
183
184 ctx := context.WithValue(req.Context(), "resolvedId", *id)
185
186 next.ServeHTTP(w, req.WithContext(ctx))
187 })
188 }
189}
190
191func ResolveRepoKnot(s *State) Middleware {
192 return func(next http.Handler) http.Handler {
193 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
194 repoName := chi.URLParam(req, "repo")
195 id, ok := req.Context().Value("resolvedId").(identity.Identity)
196 if !ok {
197 log.Println("malformed middleware")
198 w.WriteHeader(http.StatusInternalServerError)
199 return
200 }
201
202 repo, err := db.GetRepo(s.db, id.DID.String(), repoName)
203 if err != nil {
204 // invalid did or handle
205 log.Println("failed to resolve repo")
206 w.WriteHeader(http.StatusNotFound)
207 return
208 }
209
210 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
211 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
212 ctx = context.WithValue(ctx, "repoDescription", repo.Description)
213 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339))
214 next.ServeHTTP(w, req.WithContext(ctx))
215 })
216 }
217}