this repo has no description
1package state
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log"
9 "log/slog"
10 "net/http"
11 "strings"
12 "time"
13
14 comatproto "github.com/bluesky-social/indigo/api/atproto"
15 "github.com/bluesky-social/indigo/atproto/syntax"
16 lexutil "github.com/bluesky-social/indigo/lex/util"
17 securejoin "github.com/cyphar/filepath-securejoin"
18 "github.com/go-chi/chi/v5"
19 "github.com/posthog/posthog-go"
20 "tangled.org/core/api/tangled"
21 "tangled.org/core/appview"
22 "tangled.org/core/appview/cache"
23 "tangled.org/core/appview/cache/session"
24 "tangled.org/core/appview/config"
25 "tangled.org/core/appview/db"
26 "tangled.org/core/appview/models"
27 "tangled.org/core/appview/notify"
28 dbnotify "tangled.org/core/appview/notify/db"
29 phnotify "tangled.org/core/appview/notify/posthog"
30 "tangled.org/core/appview/oauth"
31 "tangled.org/core/appview/pages"
32 "tangled.org/core/appview/reporesolver"
33 "tangled.org/core/appview/validator"
34 xrpcclient "tangled.org/core/appview/xrpcclient"
35 "tangled.org/core/eventconsumer"
36 "tangled.org/core/idresolver"
37 "tangled.org/core/jetstream"
38 tlog "tangled.org/core/log"
39 "tangled.org/core/rbac"
40 "tangled.org/core/tid"
41)
42
43type State struct {
44 db *db.DB
45 notifier notify.Notifier
46 oauth *oauth.OAuth
47 enforcer *rbac.Enforcer
48 pages *pages.Pages
49 sess *session.SessionStore
50 idResolver *idresolver.Resolver
51 posthog posthog.Client
52 jc *jetstream.JetstreamClient
53 config *config.Config
54 repoResolver *reporesolver.RepoResolver
55 knotstream *eventconsumer.Consumer
56 spindlestream *eventconsumer.Consumer
57 logger *slog.Logger
58 validator *validator.Validator
59}
60
61func Make(ctx context.Context, config *config.Config) (*State, error) {
62 d, err := db.Make(config.Core.DbPath)
63 if err != nil {
64 return nil, fmt.Errorf("failed to create db: %w", err)
65 }
66
67 enforcer, err := rbac.NewEnforcer(config.Core.DbPath)
68 if err != nil {
69 return nil, fmt.Errorf("failed to create enforcer: %w", err)
70 }
71
72 res, err := idresolver.RedisResolver(config.Redis.ToURL())
73 if err != nil {
74 log.Printf("failed to create redis resolver: %v", err)
75 res = idresolver.DefaultResolver()
76 }
77
78 pgs := pages.NewPages(config, res)
79 cache := cache.New(config.Redis.Addr)
80 sess := session.New(cache)
81 oauth := oauth.NewOAuth(config, sess)
82 validator := validator.New(d, res)
83
84 posthog, err := posthog.NewWithConfig(config.Posthog.ApiKey, posthog.Config{Endpoint: config.Posthog.Endpoint})
85 if err != nil {
86 return nil, fmt.Errorf("failed to create posthog client: %w", err)
87 }
88
89 repoResolver := reporesolver.New(config, enforcer, res, d)
90
91 wrapper := db.DbWrapper{Execer: d}
92 jc, err := jetstream.NewJetstreamClient(
93 config.Jetstream.Endpoint,
94 "appview",
95 []string{
96 tangled.GraphFollowNSID,
97 tangled.FeedStarNSID,
98 tangled.PublicKeyNSID,
99 tangled.RepoArtifactNSID,
100 tangled.ActorProfileNSID,
101 tangled.SpindleMemberNSID,
102 tangled.SpindleNSID,
103 tangled.StringNSID,
104 tangled.RepoIssueNSID,
105 tangled.RepoIssueCommentNSID,
106 tangled.LabelDefinitionNSID,
107 tangled.LabelOpNSID,
108 },
109 nil,
110 slog.Default(),
111 wrapper,
112 false,
113
114 // in-memory filter is inapplicalble to appview so
115 // we'll never log dids anyway.
116 false,
117 )
118 if err != nil {
119 return nil, fmt.Errorf("failed to create jetstream client: %w", err)
120 }
121
122 if err := BackfillDefaultDefs(d, res); err != nil {
123 return nil, fmt.Errorf("failed to backfill default label defs: %w", err)
124 }
125
126 ingester := appview.Ingester{
127 Db: wrapper,
128 Enforcer: enforcer,
129 IdResolver: res,
130 Config: config,
131 Logger: tlog.New("ingester"),
132 Validator: validator,
133 }
134 err = jc.StartJetstream(ctx, ingester.Ingest())
135 if err != nil {
136 return nil, fmt.Errorf("failed to start jetstream watcher: %w", err)
137 }
138
139 knotstream, err := Knotstream(ctx, config, d, enforcer, posthog)
140 if err != nil {
141 return nil, fmt.Errorf("failed to start knotstream consumer: %w", err)
142 }
143 knotstream.Start(ctx)
144
145 spindlestream, err := Spindlestream(ctx, config, d, enforcer)
146 if err != nil {
147 return nil, fmt.Errorf("failed to start spindlestream consumer: %w", err)
148 }
149 spindlestream.Start(ctx)
150
151 var notifiers []notify.Notifier
152
153 // Always add the database notifier
154 notifiers = append(notifiers, dbnotify.NewDatabaseNotifier(d, res))
155
156 // Add other notifiers in production only
157 if !config.Core.Dev {
158 notifiers = append(notifiers, phnotify.NewPosthogNotifier(posthog))
159 }
160 notifier := notify.NewMergedNotifier(notifiers...)
161
162 state := &State{
163 d,
164 notifier,
165 oauth,
166 enforcer,
167 pgs,
168 sess,
169 res,
170 posthog,
171 jc,
172 config,
173 repoResolver,
174 knotstream,
175 spindlestream,
176 slog.Default(),
177 validator,
178 }
179
180 return state, nil
181}
182
183func (s *State) Close() error {
184 // other close up logic goes here
185 return s.db.Close()
186}
187
188func (s *State) Favicon(w http.ResponseWriter, r *http.Request) {
189 w.Header().Set("Content-Type", "image/svg+xml")
190 w.Header().Set("Cache-Control", "public, max-age=31536000") // one year
191 w.Header().Set("ETag", `"favicon-svg-v1"`)
192
193 if match := r.Header.Get("If-None-Match"); match == `"favicon-svg-v1"` {
194 w.WriteHeader(http.StatusNotModified)
195 return
196 }
197
198 s.pages.Favicon(w)
199}
200
201func (s *State) TermsOfService(w http.ResponseWriter, r *http.Request) {
202 user := s.oauth.GetUser(r)
203 s.pages.TermsOfService(w, pages.TermsOfServiceParams{
204 LoggedInUser: user,
205 })
206}
207
208func (s *State) PrivacyPolicy(w http.ResponseWriter, r *http.Request) {
209 user := s.oauth.GetUser(r)
210 s.pages.PrivacyPolicy(w, pages.PrivacyPolicyParams{
211 LoggedInUser: user,
212 })
213}
214
215func (s *State) HomeOrTimeline(w http.ResponseWriter, r *http.Request) {
216 if s.oauth.GetUser(r) != nil {
217 s.Timeline(w, r)
218 return
219 }
220 s.Home(w, r)
221}
222
223func (s *State) Timeline(w http.ResponseWriter, r *http.Request) {
224 user := s.oauth.GetUser(r)
225
226 var userDid string
227 if user != nil {
228 userDid = user.Did
229 }
230 timeline, err := db.MakeTimeline(s.db, 50, userDid)
231 if err != nil {
232 log.Println(err)
233 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.")
234 }
235
236 repos, err := db.GetTopStarredReposLastWeek(s.db)
237 if err != nil {
238 log.Println(err)
239 s.pages.Notice(w, "topstarredrepos", "Unable to load.")
240 return
241 }
242
243 s.pages.Timeline(w, pages.TimelineParams{
244 LoggedInUser: user,
245 Timeline: timeline,
246 Repos: repos,
247 })
248}
249
250func (s *State) UpgradeBanner(w http.ResponseWriter, r *http.Request) {
251 user := s.oauth.GetUser(r)
252 l := s.logger.With("handler", "UpgradeBanner")
253 l = l.With("did", user.Did)
254 l = l.With("handle", user.Handle)
255
256 regs, err := db.GetRegistrations(
257 s.db,
258 db.FilterEq("did", user.Did),
259 db.FilterEq("needs_upgrade", 1),
260 )
261 if err != nil {
262 l.Error("non-fatal: failed to get registrations", "err", err)
263 }
264
265 spindles, err := db.GetSpindles(
266 s.db,
267 db.FilterEq("owner", user.Did),
268 db.FilterEq("needs_upgrade", 1),
269 )
270 if err != nil {
271 l.Error("non-fatal: failed to get spindles", "err", err)
272 }
273
274 if regs == nil && spindles == nil {
275 return
276 }
277
278 s.pages.UpgradeBanner(w, pages.UpgradeBannerParams{
279 Registrations: regs,
280 Spindles: spindles,
281 })
282}
283
284func (s *State) Home(w http.ResponseWriter, r *http.Request) {
285 timeline, err := db.MakeTimeline(s.db, 5, "")
286 if err != nil {
287 log.Println(err)
288 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.")
289 return
290 }
291
292 repos, err := db.GetTopStarredReposLastWeek(s.db)
293 if err != nil {
294 log.Println(err)
295 s.pages.Notice(w, "topstarredrepos", "Unable to load.")
296 return
297 }
298
299 s.pages.Home(w, pages.TimelineParams{
300 LoggedInUser: nil,
301 Timeline: timeline,
302 Repos: repos,
303 })
304}
305
306func (s *State) Keys(w http.ResponseWriter, r *http.Request) {
307 user := chi.URLParam(r, "user")
308 user = strings.TrimPrefix(user, "@")
309
310 if user == "" {
311 w.WriteHeader(http.StatusBadRequest)
312 return
313 }
314
315 id, err := s.idResolver.ResolveIdent(r.Context(), user)
316 if err != nil {
317 w.WriteHeader(http.StatusInternalServerError)
318 return
319 }
320
321 pubKeys, err := db.GetPublicKeysForDid(s.db, id.DID.String())
322 if err != nil {
323 w.WriteHeader(http.StatusNotFound)
324 return
325 }
326
327 if len(pubKeys) == 0 {
328 w.WriteHeader(http.StatusNotFound)
329 return
330 }
331
332 for _, k := range pubKeys {
333 key := strings.TrimRight(k.Key, "\n")
334 fmt.Fprintln(w, key)
335 }
336}
337
338func validateRepoName(name string) error {
339 // check for path traversal attempts
340 if name == "." || name == ".." ||
341 strings.Contains(name, "/") || strings.Contains(name, "\\") {
342 return fmt.Errorf("Repository name contains invalid path characters")
343 }
344
345 // check for sequences that could be used for traversal when normalized
346 if strings.Contains(name, "./") || strings.Contains(name, "../") ||
347 strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") {
348 return fmt.Errorf("Repository name contains invalid path sequence")
349 }
350
351 // then continue with character validation
352 for _, char := range name {
353 if !((char >= 'a' && char <= 'z') ||
354 (char >= 'A' && char <= 'Z') ||
355 (char >= '0' && char <= '9') ||
356 char == '-' || char == '_' || char == '.') {
357 return fmt.Errorf("Repository name can only contain alphanumeric characters, periods, hyphens, and underscores")
358 }
359 }
360
361 // additional check to prevent multiple sequential dots
362 if strings.Contains(name, "..") {
363 return fmt.Errorf("Repository name cannot contain sequential dots")
364 }
365
366 // if all checks pass
367 return nil
368}
369
370func stripGitExt(name string) string {
371 return strings.TrimSuffix(name, ".git")
372}
373
374func (s *State) NewRepo(w http.ResponseWriter, r *http.Request) {
375 switch r.Method {
376 case http.MethodGet:
377 user := s.oauth.GetUser(r)
378 knots, err := s.enforcer.GetKnotsForUser(user.Did)
379 if err != nil {
380 s.pages.Notice(w, "repo", "Invalid user account.")
381 return
382 }
383
384 s.pages.NewRepo(w, pages.NewRepoParams{
385 LoggedInUser: user,
386 Knots: knots,
387 })
388
389 case http.MethodPost:
390 l := s.logger.With("handler", "NewRepo")
391
392 user := s.oauth.GetUser(r)
393 l = l.With("did", user.Did)
394 l = l.With("handle", user.Handle)
395
396 // form validation
397 domain := r.FormValue("domain")
398 if domain == "" {
399 s.pages.Notice(w, "repo", "Invalid form submission—missing knot domain.")
400 return
401 }
402 l = l.With("knot", domain)
403
404 repoName := r.FormValue("name")
405 if repoName == "" {
406 s.pages.Notice(w, "repo", "Repository name cannot be empty.")
407 return
408 }
409
410 if err := validateRepoName(repoName); err != nil {
411 s.pages.Notice(w, "repo", err.Error())
412 return
413 }
414 repoName = stripGitExt(repoName)
415 l = l.With("repoName", repoName)
416
417 defaultBranch := r.FormValue("branch")
418 if defaultBranch == "" {
419 defaultBranch = "main"
420 }
421 l = l.With("defaultBranch", defaultBranch)
422
423 description := r.FormValue("description")
424
425 // ACL validation
426 ok, err := s.enforcer.E.Enforce(user.Did, domain, domain, "repo:create")
427 if err != nil || !ok {
428 l.Info("unauthorized")
429 s.pages.Notice(w, "repo", "You do not have permission to create a repo in this knot.")
430 return
431 }
432
433 // Check for existing repos
434 existingRepo, err := db.GetRepo(
435 s.db,
436 db.FilterEq("did", user.Did),
437 db.FilterEq("name", repoName),
438 )
439 if err == nil && existingRepo != nil {
440 l.Info("repo exists")
441 s.pages.Notice(w, "repo", fmt.Sprintf("You already have a repository by this name on %s", existingRepo.Knot))
442 return
443 }
444
445 // create atproto record for this repo
446 rkey := tid.TID()
447 repo := &models.Repo{
448 Did: user.Did,
449 Name: repoName,
450 Knot: domain,
451 Rkey: rkey,
452 Description: description,
453 Created: time.Now(),
454 Labels: models.DefaultLabelDefs(),
455 }
456 record := repo.AsRecord()
457
458 xrpcClient, err := s.oauth.AuthorizedClient(r)
459 if err != nil {
460 l.Info("PDS write failed", "err", err)
461 s.pages.Notice(w, "repo", "Failed to write record to PDS.")
462 return
463 }
464
465 atresp, err := xrpcClient.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{
466 Collection: tangled.RepoNSID,
467 Repo: user.Did,
468 Rkey: rkey,
469 Record: &lexutil.LexiconTypeDecoder{
470 Val: &record,
471 },
472 })
473 if err != nil {
474 l.Info("PDS write failed", "err", err)
475 s.pages.Notice(w, "repo", "Failed to announce repository creation.")
476 return
477 }
478
479 aturi := atresp.Uri
480 l = l.With("aturi", aturi)
481 l.Info("wrote to PDS")
482
483 tx, err := s.db.BeginTx(r.Context(), nil)
484 if err != nil {
485 l.Info("txn failed", "err", err)
486 s.pages.Notice(w, "repo", "Failed to save repository information.")
487 return
488 }
489
490 // The rollback function reverts a few things on failure:
491 // - the pending txn
492 // - the ACLs
493 // - the atproto record created
494 rollback := func() {
495 err1 := tx.Rollback()
496 err2 := s.enforcer.E.LoadPolicy()
497 err3 := rollbackRecord(context.Background(), aturi, xrpcClient)
498
499 // ignore txn complete errors, this is okay
500 if errors.Is(err1, sql.ErrTxDone) {
501 err1 = nil
502 }
503
504 if errs := errors.Join(err1, err2, err3); errs != nil {
505 l.Error("failed to rollback changes", "errs", errs)
506 return
507 }
508 }
509 defer rollback()
510
511 client, err := s.oauth.ServiceClient(
512 r,
513 oauth.WithService(domain),
514 oauth.WithLxm(tangled.RepoCreateNSID),
515 oauth.WithDev(s.config.Core.Dev),
516 )
517 if err != nil {
518 l.Error("service auth failed", "err", err)
519 s.pages.Notice(w, "repo", "Failed to reach PDS.")
520 return
521 }
522
523 xe := tangled.RepoCreate(
524 r.Context(),
525 client,
526 &tangled.RepoCreate_Input{
527 Rkey: rkey,
528 },
529 )
530 if err := xrpcclient.HandleXrpcErr(xe); err != nil {
531 l.Error("xrpc error", "xe", xe)
532 s.pages.Notice(w, "repo", err.Error())
533 return
534 }
535
536 err = db.AddRepo(tx, repo)
537 if err != nil {
538 l.Error("db write failed", "err", err)
539 s.pages.Notice(w, "repo", "Failed to save repository information.")
540 return
541 }
542
543 // acls
544 p, _ := securejoin.SecureJoin(user.Did, repoName)
545 err = s.enforcer.AddRepo(user.Did, domain, p)
546 if err != nil {
547 l.Error("acl setup failed", "err", err)
548 s.pages.Notice(w, "repo", "Failed to set up repository permissions.")
549 return
550 }
551
552 err = tx.Commit()
553 if err != nil {
554 l.Error("txn commit failed", "err", err)
555 http.Error(w, err.Error(), http.StatusInternalServerError)
556 return
557 }
558
559 err = s.enforcer.E.SavePolicy()
560 if err != nil {
561 l.Error("acl save failed", "err", err)
562 http.Error(w, err.Error(), http.StatusInternalServerError)
563 return
564 }
565
566 // reset the ATURI because the transaction completed successfully
567 aturi = ""
568
569 s.notifier.NewRepo(r.Context(), repo)
570 s.pages.HxLocation(w, fmt.Sprintf("/@%s/%s", user.Handle, repoName))
571 }
572}
573
574// this is used to rollback changes made to the PDS
575//
576// it is a no-op if the provided ATURI is empty
577func rollbackRecord(ctx context.Context, aturi string, xrpcc *xrpcclient.Client) error {
578 if aturi == "" {
579 return nil
580 }
581
582 parsed := syntax.ATURI(aturi)
583
584 collection := parsed.Collection().String()
585 repo := parsed.Authority().String()
586 rkey := parsed.RecordKey().String()
587
588 _, err := xrpcc.RepoDeleteRecord(ctx, &comatproto.RepoDeleteRecord_Input{
589 Collection: collection,
590 Repo: repo,
591 Rkey: rkey,
592 })
593 return err
594}
595
596func BackfillDefaultDefs(e db.Execer, r *idresolver.Resolver) error {
597 defaults := models.DefaultLabelDefs()
598 defaultLabels, err := db.GetLabelDefinitions(e, db.FilterIn("at_uri", defaults))
599 if err != nil {
600 return err
601 }
602 // already present
603 if len(defaultLabels) == len(defaults) {
604 return nil
605 }
606
607 labelDefs, err := models.FetchDefaultDefs(r)
608 if err != nil {
609 return err
610 }
611
612 // Insert each label definition to the database
613 for _, labelDef := range labelDefs {
614 _, err = db.AddLabelDefinition(e, &labelDef)
615 if err != nil {
616 return fmt.Errorf("failed to add label definition %s: %v", labelDef.Name, err)
617 }
618 }
619
620 return nil
621}