this repo has no description
1package state
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log"
9 "log/slog"
10 "net/http"
11 "strings"
12 "time"
13
14 comatproto "github.com/bluesky-social/indigo/api/atproto"
15 "github.com/bluesky-social/indigo/atproto/syntax"
16 lexutil "github.com/bluesky-social/indigo/lex/util"
17 securejoin "github.com/cyphar/filepath-securejoin"
18 "github.com/go-chi/chi/v5"
19 "github.com/posthog/posthog-go"
20 "tangled.org/core/api/tangled"
21 "tangled.org/core/appview"
22 "tangled.org/core/appview/cache"
23 "tangled.org/core/appview/cache/session"
24 "tangled.org/core/appview/config"
25 "tangled.org/core/appview/db"
26 "tangled.org/core/appview/models"
27 "tangled.org/core/appview/notify"
28 phnotify "tangled.org/core/appview/notify/posthog"
29 "tangled.org/core/appview/oauth"
30 "tangled.org/core/appview/pages"
31 "tangled.org/core/appview/reporesolver"
32 "tangled.org/core/appview/validator"
33 xrpcclient "tangled.org/core/appview/xrpcclient"
34 "tangled.org/core/eventconsumer"
35 "tangled.org/core/idresolver"
36 "tangled.org/core/jetstream"
37 tlog "tangled.org/core/log"
38 "tangled.org/core/rbac"
39 "tangled.org/core/tid"
40)
41
42type State struct {
43 db *db.DB
44 notifier notify.Notifier
45 oauth *oauth.OAuth
46 enforcer *rbac.Enforcer
47 pages *pages.Pages
48 sess *session.SessionStore
49 idResolver *idresolver.Resolver
50 posthog posthog.Client
51 jc *jetstream.JetstreamClient
52 config *config.Config
53 repoResolver *reporesolver.RepoResolver
54 knotstream *eventconsumer.Consumer
55 spindlestream *eventconsumer.Consumer
56 logger *slog.Logger
57 validator *validator.Validator
58}
59
60func Make(ctx context.Context, config *config.Config) (*State, error) {
61 d, err := db.Make(config.Core.DbPath)
62 if err != nil {
63 return nil, fmt.Errorf("failed to create db: %w", err)
64 }
65
66 enforcer, err := rbac.NewEnforcer(config.Core.DbPath)
67 if err != nil {
68 return nil, fmt.Errorf("failed to create enforcer: %w", err)
69 }
70
71 res, err := idresolver.RedisResolver(config.Redis.ToURL())
72 if err != nil {
73 log.Printf("failed to create redis resolver: %v", err)
74 res = idresolver.DefaultResolver()
75 }
76
77 pgs := pages.NewPages(config, res)
78 cache := cache.New(config.Redis.Addr)
79 sess := session.New(cache)
80 oauth := oauth.NewOAuth(config, sess)
81 validator := validator.New(d, res)
82
83 posthog, err := posthog.NewWithConfig(config.Posthog.ApiKey, posthog.Config{Endpoint: config.Posthog.Endpoint})
84 if err != nil {
85 return nil, fmt.Errorf("failed to create posthog client: %w", err)
86 }
87
88 repoResolver := reporesolver.New(config, enforcer, res, d)
89
90 wrapper := db.DbWrapper{Execer: d}
91 jc, err := jetstream.NewJetstreamClient(
92 config.Jetstream.Endpoint,
93 "appview",
94 []string{
95 tangled.GraphFollowNSID,
96 tangled.FeedStarNSID,
97 tangled.PublicKeyNSID,
98 tangled.RepoArtifactNSID,
99 tangled.ActorProfileNSID,
100 tangled.SpindleMemberNSID,
101 tangled.SpindleNSID,
102 tangled.StringNSID,
103 tangled.RepoIssueNSID,
104 tangled.RepoIssueCommentNSID,
105 tangled.LabelDefinitionNSID,
106 tangled.LabelOpNSID,
107 },
108 nil,
109 slog.Default(),
110 wrapper,
111 false,
112
113 // in-memory filter is inapplicalble to appview so
114 // we'll never log dids anyway.
115 false,
116 )
117 if err != nil {
118 return nil, fmt.Errorf("failed to create jetstream client: %w", err)
119 }
120
121 if err := BackfillDefaultDefs(d, res); err != nil {
122 return nil, fmt.Errorf("failed to backfill default label defs: %w", err)
123 }
124
125 ingester := appview.Ingester{
126 Db: wrapper,
127 Enforcer: enforcer,
128 IdResolver: res,
129 Config: config,
130 Logger: tlog.New("ingester"),
131 Validator: validator,
132 }
133 err = jc.StartJetstream(ctx, ingester.Ingest())
134 if err != nil {
135 return nil, fmt.Errorf("failed to start jetstream watcher: %w", err)
136 }
137
138 knotstream, err := Knotstream(ctx, config, d, enforcer, posthog)
139 if err != nil {
140 return nil, fmt.Errorf("failed to start knotstream consumer: %w", err)
141 }
142 knotstream.Start(ctx)
143
144 spindlestream, err := Spindlestream(ctx, config, d, enforcer)
145 if err != nil {
146 return nil, fmt.Errorf("failed to start spindlestream consumer: %w", err)
147 }
148 spindlestream.Start(ctx)
149
150 var notifiers []notify.Notifier
151
152 // Always add the database notifier
153 notifiers = append(notifiers, dbnotify.NewDatabaseNotifier(d, res))
154
155 // Add other notifiers in production only
156 if !config.Core.Dev {
157 notifiers = append(notifiers, phnotify.NewPosthogNotifier(posthog))
158 }
159 notifier := notify.NewMergedNotifier(notifiers...)
160
161 state := &State{
162 d,
163 notifier,
164 oauth,
165 enforcer,
166 pgs,
167 sess,
168 res,
169 posthog,
170 jc,
171 config,
172 repoResolver,
173 knotstream,
174 spindlestream,
175 slog.Default(),
176 validator,
177 }
178
179 return state, nil
180}
181
182func (s *State) Close() error {
183 // other close up logic goes here
184 return s.db.Close()
185}
186
187func (s *State) Favicon(w http.ResponseWriter, r *http.Request) {
188 w.Header().Set("Content-Type", "image/svg+xml")
189 w.Header().Set("Cache-Control", "public, max-age=31536000") // one year
190 w.Header().Set("ETag", `"favicon-svg-v1"`)
191
192 if match := r.Header.Get("If-None-Match"); match == `"favicon-svg-v1"` {
193 w.WriteHeader(http.StatusNotModified)
194 return
195 }
196
197 s.pages.Favicon(w)
198}
199
200func (s *State) TermsOfService(w http.ResponseWriter, r *http.Request) {
201 user := s.oauth.GetUser(r)
202 s.pages.TermsOfService(w, pages.TermsOfServiceParams{
203 LoggedInUser: user,
204 })
205}
206
207func (s *State) PrivacyPolicy(w http.ResponseWriter, r *http.Request) {
208 user := s.oauth.GetUser(r)
209 s.pages.PrivacyPolicy(w, pages.PrivacyPolicyParams{
210 LoggedInUser: user,
211 })
212}
213
214func (s *State) HomeOrTimeline(w http.ResponseWriter, r *http.Request) {
215 if s.oauth.GetUser(r) != nil {
216 s.Timeline(w, r)
217 return
218 }
219 s.Home(w, r)
220}
221
222func (s *State) Timeline(w http.ResponseWriter, r *http.Request) {
223 user := s.oauth.GetUser(r)
224
225 var userDid string
226 if user != nil {
227 userDid = user.Did
228 }
229 timeline, err := db.MakeTimeline(s.db, 50, userDid)
230 if err != nil {
231 log.Println(err)
232 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.")
233 }
234
235 repos, err := db.GetTopStarredReposLastWeek(s.db)
236 if err != nil {
237 log.Println(err)
238 s.pages.Notice(w, "topstarredrepos", "Unable to load.")
239 return
240 }
241
242 s.pages.Timeline(w, pages.TimelineParams{
243 LoggedInUser: user,
244 Timeline: timeline,
245 Repos: repos,
246 })
247}
248
249func (s *State) UpgradeBanner(w http.ResponseWriter, r *http.Request) {
250 user := s.oauth.GetUser(r)
251 l := s.logger.With("handler", "UpgradeBanner")
252 l = l.With("did", user.Did)
253 l = l.With("handle", user.Handle)
254
255 regs, err := db.GetRegistrations(
256 s.db,
257 db.FilterEq("did", user.Did),
258 db.FilterEq("needs_upgrade", 1),
259 )
260 if err != nil {
261 l.Error("non-fatal: failed to get registrations", "err", err)
262 }
263
264 spindles, err := db.GetSpindles(
265 s.db,
266 db.FilterEq("owner", user.Did),
267 db.FilterEq("needs_upgrade", 1),
268 )
269 if err != nil {
270 l.Error("non-fatal: failed to get spindles", "err", err)
271 }
272
273 if regs == nil && spindles == nil {
274 return
275 }
276
277 s.pages.UpgradeBanner(w, pages.UpgradeBannerParams{
278 Registrations: regs,
279 Spindles: spindles,
280 })
281}
282
283func (s *State) Home(w http.ResponseWriter, r *http.Request) {
284 timeline, err := db.MakeTimeline(s.db, 5, "")
285 if err != nil {
286 log.Println(err)
287 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.")
288 return
289 }
290
291 repos, err := db.GetTopStarredReposLastWeek(s.db)
292 if err != nil {
293 log.Println(err)
294 s.pages.Notice(w, "topstarredrepos", "Unable to load.")
295 return
296 }
297
298 s.pages.Home(w, pages.TimelineParams{
299 LoggedInUser: nil,
300 Timeline: timeline,
301 Repos: repos,
302 })
303}
304
305func (s *State) Keys(w http.ResponseWriter, r *http.Request) {
306 user := chi.URLParam(r, "user")
307 user = strings.TrimPrefix(user, "@")
308
309 if user == "" {
310 w.WriteHeader(http.StatusBadRequest)
311 return
312 }
313
314 id, err := s.idResolver.ResolveIdent(r.Context(), user)
315 if err != nil {
316 w.WriteHeader(http.StatusInternalServerError)
317 return
318 }
319
320 pubKeys, err := db.GetPublicKeysForDid(s.db, id.DID.String())
321 if err != nil {
322 w.WriteHeader(http.StatusNotFound)
323 return
324 }
325
326 if len(pubKeys) == 0 {
327 w.WriteHeader(http.StatusNotFound)
328 return
329 }
330
331 for _, k := range pubKeys {
332 key := strings.TrimRight(k.Key, "\n")
333 fmt.Fprintln(w, key)
334 }
335}
336
337func validateRepoName(name string) error {
338 // check for path traversal attempts
339 if name == "." || name == ".." ||
340 strings.Contains(name, "/") || strings.Contains(name, "\\") {
341 return fmt.Errorf("Repository name contains invalid path characters")
342 }
343
344 // check for sequences that could be used for traversal when normalized
345 if strings.Contains(name, "./") || strings.Contains(name, "../") ||
346 strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") {
347 return fmt.Errorf("Repository name contains invalid path sequence")
348 }
349
350 // then continue with character validation
351 for _, char := range name {
352 if !((char >= 'a' && char <= 'z') ||
353 (char >= 'A' && char <= 'Z') ||
354 (char >= '0' && char <= '9') ||
355 char == '-' || char == '_' || char == '.') {
356 return fmt.Errorf("Repository name can only contain alphanumeric characters, periods, hyphens, and underscores")
357 }
358 }
359
360 // additional check to prevent multiple sequential dots
361 if strings.Contains(name, "..") {
362 return fmt.Errorf("Repository name cannot contain sequential dots")
363 }
364
365 // if all checks pass
366 return nil
367}
368
369func stripGitExt(name string) string {
370 return strings.TrimSuffix(name, ".git")
371}
372
373func (s *State) NewRepo(w http.ResponseWriter, r *http.Request) {
374 switch r.Method {
375 case http.MethodGet:
376 user := s.oauth.GetUser(r)
377 knots, err := s.enforcer.GetKnotsForUser(user.Did)
378 if err != nil {
379 s.pages.Notice(w, "repo", "Invalid user account.")
380 return
381 }
382
383 s.pages.NewRepo(w, pages.NewRepoParams{
384 LoggedInUser: user,
385 Knots: knots,
386 })
387
388 case http.MethodPost:
389 l := s.logger.With("handler", "NewRepo")
390
391 user := s.oauth.GetUser(r)
392 l = l.With("did", user.Did)
393 l = l.With("handle", user.Handle)
394
395 // form validation
396 domain := r.FormValue("domain")
397 if domain == "" {
398 s.pages.Notice(w, "repo", "Invalid form submission—missing knot domain.")
399 return
400 }
401 l = l.With("knot", domain)
402
403 repoName := r.FormValue("name")
404 if repoName == "" {
405 s.pages.Notice(w, "repo", "Repository name cannot be empty.")
406 return
407 }
408
409 if err := validateRepoName(repoName); err != nil {
410 s.pages.Notice(w, "repo", err.Error())
411 return
412 }
413 repoName = stripGitExt(repoName)
414 l = l.With("repoName", repoName)
415
416 defaultBranch := r.FormValue("branch")
417 if defaultBranch == "" {
418 defaultBranch = "main"
419 }
420 l = l.With("defaultBranch", defaultBranch)
421
422 description := r.FormValue("description")
423
424 // ACL validation
425 ok, err := s.enforcer.E.Enforce(user.Did, domain, domain, "repo:create")
426 if err != nil || !ok {
427 l.Info("unauthorized")
428 s.pages.Notice(w, "repo", "You do not have permission to create a repo in this knot.")
429 return
430 }
431
432 // Check for existing repos
433 existingRepo, err := db.GetRepo(
434 s.db,
435 db.FilterEq("did", user.Did),
436 db.FilterEq("name", repoName),
437 )
438 if err == nil && existingRepo != nil {
439 l.Info("repo exists")
440 s.pages.Notice(w, "repo", fmt.Sprintf("You already have a repository by this name on %s", existingRepo.Knot))
441 return
442 }
443
444 // create atproto record for this repo
445 rkey := tid.TID()
446 repo := &models.Repo{
447 Did: user.Did,
448 Name: repoName,
449 Knot: domain,
450 Rkey: rkey,
451 Description: description,
452 Created: time.Now(),
453 Labels: models.DefaultLabelDefs(),
454 }
455 record := repo.AsRecord()
456
457 xrpcClient, err := s.oauth.AuthorizedClient(r)
458 if err != nil {
459 l.Info("PDS write failed", "err", err)
460 s.pages.Notice(w, "repo", "Failed to write record to PDS.")
461 return
462 }
463
464 atresp, err := xrpcClient.RepoPutRecord(r.Context(), &comatproto.RepoPutRecord_Input{
465 Collection: tangled.RepoNSID,
466 Repo: user.Did,
467 Rkey: rkey,
468 Record: &lexutil.LexiconTypeDecoder{
469 Val: &record,
470 },
471 })
472 if err != nil {
473 l.Info("PDS write failed", "err", err)
474 s.pages.Notice(w, "repo", "Failed to announce repository creation.")
475 return
476 }
477
478 aturi := atresp.Uri
479 l = l.With("aturi", aturi)
480 l.Info("wrote to PDS")
481
482 tx, err := s.db.BeginTx(r.Context(), nil)
483 if err != nil {
484 l.Info("txn failed", "err", err)
485 s.pages.Notice(w, "repo", "Failed to save repository information.")
486 return
487 }
488
489 // The rollback function reverts a few things on failure:
490 // - the pending txn
491 // - the ACLs
492 // - the atproto record created
493 rollback := func() {
494 err1 := tx.Rollback()
495 err2 := s.enforcer.E.LoadPolicy()
496 err3 := rollbackRecord(context.Background(), aturi, xrpcClient)
497
498 // ignore txn complete errors, this is okay
499 if errors.Is(err1, sql.ErrTxDone) {
500 err1 = nil
501 }
502
503 if errs := errors.Join(err1, err2, err3); errs != nil {
504 l.Error("failed to rollback changes", "errs", errs)
505 return
506 }
507 }
508 defer rollback()
509
510 client, err := s.oauth.ServiceClient(
511 r,
512 oauth.WithService(domain),
513 oauth.WithLxm(tangled.RepoCreateNSID),
514 oauth.WithDev(s.config.Core.Dev),
515 )
516 if err != nil {
517 l.Error("service auth failed", "err", err)
518 s.pages.Notice(w, "repo", "Failed to reach PDS.")
519 return
520 }
521
522 xe := tangled.RepoCreate(
523 r.Context(),
524 client,
525 &tangled.RepoCreate_Input{
526 Rkey: rkey,
527 },
528 )
529 if err := xrpcclient.HandleXrpcErr(xe); err != nil {
530 l.Error("xrpc error", "xe", xe)
531 s.pages.Notice(w, "repo", err.Error())
532 return
533 }
534
535 err = db.AddRepo(tx, repo)
536 if err != nil {
537 l.Error("db write failed", "err", err)
538 s.pages.Notice(w, "repo", "Failed to save repository information.")
539 return
540 }
541
542 // acls
543 p, _ := securejoin.SecureJoin(user.Did, repoName)
544 err = s.enforcer.AddRepo(user.Did, domain, p)
545 if err != nil {
546 l.Error("acl setup failed", "err", err)
547 s.pages.Notice(w, "repo", "Failed to set up repository permissions.")
548 return
549 }
550
551 err = tx.Commit()
552 if err != nil {
553 l.Error("txn commit failed", "err", err)
554 http.Error(w, err.Error(), http.StatusInternalServerError)
555 return
556 }
557
558 err = s.enforcer.E.SavePolicy()
559 if err != nil {
560 l.Error("acl save failed", "err", err)
561 http.Error(w, err.Error(), http.StatusInternalServerError)
562 return
563 }
564
565 // reset the ATURI because the transaction completed successfully
566 aturi = ""
567
568 s.notifier.NewRepo(r.Context(), repo)
569 s.pages.HxLocation(w, fmt.Sprintf("/@%s/%s", user.Handle, repoName))
570 }
571}
572
573// this is used to rollback changes made to the PDS
574//
575// it is a no-op if the provided ATURI is empty
576func rollbackRecord(ctx context.Context, aturi string, xrpcc *xrpcclient.Client) error {
577 if aturi == "" {
578 return nil
579 }
580
581 parsed := syntax.ATURI(aturi)
582
583 collection := parsed.Collection().String()
584 repo := parsed.Authority().String()
585 rkey := parsed.RecordKey().String()
586
587 _, err := xrpcc.RepoDeleteRecord(ctx, &comatproto.RepoDeleteRecord_Input{
588 Collection: collection,
589 Repo: repo,
590 Rkey: rkey,
591 })
592 return err
593}
594
595func BackfillDefaultDefs(e db.Execer, r *idresolver.Resolver) error {
596 defaults := models.DefaultLabelDefs()
597 defaultLabels, err := db.GetLabelDefinitions(e, db.FilterIn("at_uri", defaults))
598 if err != nil {
599 return err
600 }
601 // already present
602 if len(defaultLabels) == len(defaults) {
603 return nil
604 }
605
606 labelDefs, err := models.FetchDefaultDefs(r)
607 if err != nil {
608 return err
609 }
610
611 // Insert each label definition to the database
612 for _, labelDef := range labelDefs {
613 _, err = db.AddLabelDefinition(e, &labelDef)
614 if err != nil {
615 return fmt.Errorf("failed to add label definition %s: %v", labelDef.Name, err)
616 }
617 }
618
619 return nil
620}