Monorepo for Tangled
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16 "tangled.org/core/appview/pagination"
17 "tangled.org/core/orm"
18)
19
20func NewPull(tx *sql.Tx, pull *models.Pull) error {
21 var repoDid *string
22 if pull.RepoDid != "" {
23 repoDid = &pull.RepoDid
24 }
25
26 _, err := tx.Exec(`
27 insert into repo_pull_seqs (repo_at, repo_did, next_pull_id)
28 values (?, ?, 1)
29 on conflict(repo_at) do update set repo_did = coalesce(excluded.repo_did, repo_did)
30 `, pull.RepoAt, repoDid)
31 if err != nil {
32 return err
33 }
34
35 var nextId int
36 err = tx.QueryRow(`
37 update repo_pull_seqs
38 set next_pull_id = next_pull_id + 1
39 where repo_at = ?
40 returning next_pull_id - 1
41 `, pull.RepoAt).Scan(&nextId)
42 if err != nil {
43 return err
44 }
45
46 pull.PullId = nextId
47 pull.State = models.PullOpen
48
49 var sourceBranch, sourceRepoAt, sourceRepoDid *string
50 if pull.PullSource != nil {
51 sourceBranch = &pull.PullSource.Branch
52 if pull.PullSource.RepoAt != nil {
53 x := pull.PullSource.RepoAt.String()
54 sourceRepoAt = &x
55 }
56 if pull.PullSource.RepoDid != "" {
57 sourceRepoDid = &pull.PullSource.RepoDid
58 }
59 }
60
61 var stackId, changeId, parentChangeId *string
62 if pull.StackId != "" {
63 stackId = &pull.StackId
64 }
65 if pull.ChangeId != "" {
66 changeId = &pull.ChangeId
67 }
68 if pull.ParentChangeId != "" {
69 parentChangeId = &pull.ParentChangeId
70 }
71
72 result, err := tx.Exec(
73 `
74 insert into pulls (
75 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id, repo_did, source_repo_did
76 )
77 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
78 pull.RepoAt,
79 pull.OwnerDid,
80 pull.PullId,
81 pull.Title,
82 pull.TargetBranch,
83 pull.Body,
84 pull.Rkey,
85 pull.State,
86 sourceBranch,
87 sourceRepoAt,
88 stackId,
89 changeId,
90 parentChangeId,
91 repoDid,
92 sourceRepoDid,
93 )
94 if err != nil {
95 return err
96 }
97
98 // Set the database primary key ID
99 id, err := result.LastInsertId()
100 if err != nil {
101 return err
102 }
103 pull.ID = int(id)
104
105 _, err = tx.Exec(`
106 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
107 values (?, ?, ?, ?, ?)
108 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
109 if err != nil {
110 return err
111 }
112
113 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
114 return fmt.Errorf("put reference_links: %w", err)
115 }
116
117 return nil
118}
119
120func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
121 pull, err := GetPull(e, repoAt, pullId)
122 if err != nil {
123 return "", err
124 }
125 return pull.AtUri(), err
126}
127
128func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
129 var pullId int
130 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
131 return pullId - 1, err
132}
133
134func GetPullsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Pull, error) {
135 pulls := make(map[syntax.ATURI]*models.Pull)
136
137 var conditions []string
138 var args []any
139 for _, filter := range filters {
140 conditions = append(conditions, filter.Condition())
141 args = append(args, filter.Arg()...)
142 }
143
144 whereClause := ""
145 if conditions != nil {
146 whereClause = " where " + strings.Join(conditions, " and ")
147 }
148 pageClause := ""
149 if page.Limit != 0 {
150 pageClause = fmt.Sprintf(
151 " limit %d offset %d ",
152 page.Limit,
153 page.Offset,
154 )
155 }
156
157 query := fmt.Sprintf(`
158 select
159 id,
160 owner_did,
161 repo_at,
162 pull_id,
163 created,
164 title,
165 state,
166 target_branch,
167 body,
168 rkey,
169 source_branch,
170 source_repo_at,
171 stack_id,
172 change_id,
173 parent_change_id,
174 repo_did,
175 source_repo_did
176 from
177 pulls
178 %s
179 order by
180 created desc
181 %s
182 `, whereClause, pageClause)
183
184 rows, err := e.Query(query, args...)
185 if err != nil {
186 return nil, err
187 }
188 defer rows.Close()
189
190 for rows.Next() {
191 var pull models.Pull
192 var createdAt string
193 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId, repoDid, sourceRepoDid sql.NullString
194 err := rows.Scan(
195 &pull.ID,
196 &pull.OwnerDid,
197 &pull.RepoAt,
198 &pull.PullId,
199 &createdAt,
200 &pull.Title,
201 &pull.State,
202 &pull.TargetBranch,
203 &pull.Body,
204 &pull.Rkey,
205 &sourceBranch,
206 &sourceRepoAt,
207 &stackId,
208 &changeId,
209 &parentChangeId,
210 &repoDid,
211 &sourceRepoDid,
212 )
213 if err != nil {
214 return nil, err
215 }
216
217 createdTime, err := time.Parse(time.RFC3339, createdAt)
218 if err != nil {
219 return nil, err
220 }
221 pull.Created = createdTime
222
223 if sourceBranch.Valid {
224 pull.PullSource = &models.PullSource{
225 Branch: sourceBranch.String,
226 }
227 if sourceRepoAt.Valid {
228 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
229 if err != nil {
230 return nil, err
231 }
232 pull.PullSource.RepoAt = &sourceRepoAtParsed
233 }
234 if sourceRepoDid.Valid {
235 pull.PullSource.RepoDid = sourceRepoDid.String
236 }
237 }
238
239 if stackId.Valid {
240 pull.StackId = stackId.String
241 }
242 if changeId.Valid {
243 pull.ChangeId = changeId.String
244 }
245 if parentChangeId.Valid {
246 pull.ParentChangeId = parentChangeId.String
247 }
248 if repoDid.Valid {
249 pull.RepoDid = repoDid.String
250 }
251
252 pulls[pull.AtUri()] = &pull
253 }
254
255 var pullAts []syntax.ATURI
256 for _, p := range pulls {
257 pullAts = append(pullAts, p.AtUri())
258 }
259 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts))
260 if err != nil {
261 return nil, fmt.Errorf("failed to get submissions: %w", err)
262 }
263
264 for pullAt, submissions := range submissionsMap {
265 if p, ok := pulls[pullAt]; ok {
266 p.Submissions = submissions
267 }
268 }
269
270 // collect allLabels for each issue
271 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts))
272 if err != nil {
273 return nil, fmt.Errorf("failed to query labels: %w", err)
274 }
275 for pullAt, labels := range allLabels {
276 if p, ok := pulls[pullAt]; ok {
277 p.Labels = labels
278 }
279 }
280
281 // collect pull source for all pulls that need it
282 var sourceAts []syntax.ATURI
283 for _, p := range pulls {
284 if p.PullSource != nil && p.PullSource.RepoAt != nil {
285 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
286 }
287 }
288 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts))
289 if err != nil && !errors.Is(err, sql.ErrNoRows) {
290 return nil, fmt.Errorf("failed to get source repos: %w", err)
291 }
292 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
293 for _, r := range sourceRepos {
294 sourceRepoMap[r.RepoAt()] = &r
295 }
296 for _, p := range pulls {
297 if p.PullSource != nil && p.PullSource.RepoAt != nil {
298 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
299 p.PullSource.Repo = sourceRepo
300 }
301 }
302 }
303
304 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts))
305 if err != nil {
306 return nil, fmt.Errorf("failed to query reference_links: %w", err)
307 }
308 for pullAt, references := range allReferences {
309 if pull, ok := pulls[pullAt]; ok {
310 pull.References = references
311 }
312 }
313
314 orderedByPullId := []*models.Pull{}
315 for _, p := range pulls {
316 orderedByPullId = append(orderedByPullId, p)
317 }
318 sort.Slice(orderedByPullId, func(i, j int) bool {
319 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
320 })
321
322 return orderedByPullId, nil
323}
324
325func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) {
326 return GetPullsPaginated(e, pagination.Page{}, filters...)
327}
328
329func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
330 pulls, err := GetPullsPaginated(e, pagination.Page{Limit: 1}, orm.FilterEq("repo_at", repoAt), orm.FilterEq("pull_id", pullId))
331 if err != nil {
332 return nil, err
333 }
334 if len(pulls) == 0 {
335 return nil, sql.ErrNoRows
336 }
337
338 return pulls[0], nil
339}
340
341// mapping from pull -> pull submissions
342func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
343 var conditions []string
344 var args []any
345 for _, filter := range filters {
346 conditions = append(conditions, filter.Condition())
347 args = append(args, filter.Arg()...)
348 }
349
350 whereClause := ""
351 if conditions != nil {
352 whereClause = " where " + strings.Join(conditions, " and ")
353 }
354
355 query := fmt.Sprintf(`
356 select
357 id,
358 pull_at,
359 round_number,
360 patch,
361 combined,
362 created,
363 source_rev
364 from
365 pull_submissions
366 %s
367 order by
368 round_number asc
369 `, whereClause)
370
371 rows, err := e.Query(query, args...)
372 if err != nil {
373 return nil, err
374 }
375 defer rows.Close()
376
377 submissionMap := make(map[int]*models.PullSubmission)
378
379 for rows.Next() {
380 var submission models.PullSubmission
381 var submissionCreatedStr string
382 var submissionSourceRev, submissionCombined sql.NullString
383 err := rows.Scan(
384 &submission.ID,
385 &submission.PullAt,
386 &submission.RoundNumber,
387 &submission.Patch,
388 &submissionCombined,
389 &submissionCreatedStr,
390 &submissionSourceRev,
391 )
392 if err != nil {
393 return nil, err
394 }
395
396 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
397 submission.Created = t
398 }
399
400 if submissionSourceRev.Valid {
401 submission.SourceRev = submissionSourceRev.String
402 }
403
404 if submissionCombined.Valid {
405 submission.Combined = submissionCombined.String
406 }
407
408 submissionMap[submission.ID] = &submission
409 }
410
411 if err := rows.Err(); err != nil {
412 return nil, err
413 }
414
415 // Get comments for all submissions using GetPullComments
416 submissionIds := slices.Collect(maps.Keys(submissionMap))
417 comments, err := GetPullComments(e, orm.FilterIn("submission_id", submissionIds))
418 if err != nil {
419 return nil, fmt.Errorf("failed to get pull comments: %w", err)
420 }
421 for _, comment := range comments {
422 if submission, ok := submissionMap[comment.SubmissionId]; ok {
423 submission.Comments = append(submission.Comments, comment)
424 }
425 }
426
427 // group the submissions by pull_at
428 m := make(map[syntax.ATURI][]*models.PullSubmission)
429 for _, s := range submissionMap {
430 m[s.PullAt] = append(m[s.PullAt], s)
431 }
432
433 // sort each one by round number
434 for _, s := range m {
435 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
436 return cmp.Compare(a.RoundNumber, b.RoundNumber)
437 })
438 }
439
440 return m, nil
441}
442
443func GetPullComments(e Execer, filters ...orm.Filter) ([]models.PullComment, error) {
444 var conditions []string
445 var args []any
446 for _, filter := range filters {
447 conditions = append(conditions, filter.Condition())
448 args = append(args, filter.Arg()...)
449 }
450
451 whereClause := ""
452 if conditions != nil {
453 whereClause = " where " + strings.Join(conditions, " and ")
454 }
455
456 query := fmt.Sprintf(`
457 select
458 id,
459 pull_id,
460 submission_id,
461 repo_at,
462 owner_did,
463 comment_at,
464 body,
465 created,
466 repo_did
467 from
468 pull_comments
469 %s
470 order by
471 created asc
472 `, whereClause)
473
474 rows, err := e.Query(query, args...)
475 if err != nil {
476 return nil, err
477 }
478 defer rows.Close()
479
480 commentMap := make(map[string]*models.PullComment)
481 for rows.Next() {
482 var comment models.PullComment
483 var createdAt string
484 var commentRepoDid sql.NullString
485 err := rows.Scan(
486 &comment.ID,
487 &comment.PullId,
488 &comment.SubmissionId,
489 &comment.RepoAt,
490 &comment.OwnerDid,
491 &comment.CommentAt,
492 &comment.Body,
493 &createdAt,
494 &commentRepoDid,
495 )
496 if err != nil {
497 return nil, err
498 }
499
500 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
501 comment.Created = t
502 }
503 if commentRepoDid.Valid {
504 comment.RepoDid = commentRepoDid.String
505 }
506
507 atUri := comment.AtUri().String()
508 commentMap[atUri] = &comment
509 }
510
511 if err := rows.Err(); err != nil {
512 return nil, err
513 }
514
515 // collect references for each comments
516 commentAts := slices.Collect(maps.Keys(commentMap))
517 allReferencs, err := GetReferencesAll(e, orm.FilterIn("from_at", commentAts))
518 if err != nil {
519 return nil, fmt.Errorf("failed to query reference_links: %w", err)
520 }
521 for commentAt, references := range allReferencs {
522 if comment, ok := commentMap[commentAt.String()]; ok {
523 comment.References = references
524 }
525 }
526
527 var comments []models.PullComment
528 for _, c := range commentMap {
529 comments = append(comments, *c)
530 }
531
532 sort.Slice(comments, func(i, j int) bool {
533 return comments[i].Created.Before(comments[j].Created)
534 })
535
536 return comments, nil
537}
538
539// timeframe here is directly passed into the sql query filter, and any
540// timeframe in the past should be negative; e.g.: "-3 months"
541func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
542 var pulls []models.Pull
543
544 rows, err := e.Query(`
545 select
546 p.owner_did,
547 p.repo_at,
548 p.pull_id,
549 p.created,
550 p.title,
551 p.state,
552 p.repo_did,
553 r.did,
554 r.name,
555 r.knot,
556 r.rkey,
557 r.created,
558 r.repo_did
559 from
560 pulls p
561 join
562 repos r on p.repo_at = r.at_uri
563 where
564 p.owner_did = ? and p.created >= date ('now', ?)
565 order by
566 p.created desc`, did, timeframe)
567 if err != nil {
568 return nil, err
569 }
570 defer rows.Close()
571
572 for rows.Next() {
573 var pull models.Pull
574 var repo models.Repo
575 var pullCreatedAt, repoCreatedAt string
576 var pullRepoDid, repoRepoDid sql.NullString
577 err := rows.Scan(
578 &pull.OwnerDid,
579 &pull.RepoAt,
580 &pull.PullId,
581 &pullCreatedAt,
582 &pull.Title,
583 &pull.State,
584 &pullRepoDid,
585 &repo.Did,
586 &repo.Name,
587 &repo.Knot,
588 &repo.Rkey,
589 &repoCreatedAt,
590 &repoRepoDid,
591 )
592 if err != nil {
593 return nil, err
594 }
595
596 if pullRepoDid.Valid {
597 pull.RepoDid = pullRepoDid.String
598 }
599
600 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
601 if err != nil {
602 return nil, err
603 }
604 pull.Created = pullCreatedTime
605
606 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
607 if err != nil {
608 return nil, err
609 }
610 repo.Created = repoCreatedTime
611 if repoRepoDid.Valid {
612 repo.RepoDid = repoRepoDid.String
613 }
614
615 pull.Repo = &repo
616
617 pulls = append(pulls, pull)
618 }
619
620 if err := rows.Err(); err != nil {
621 return nil, err
622 }
623
624 return pulls, nil
625}
626
627func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) {
628 var repoDid *string
629 if comment.RepoDid != "" {
630 repoDid = &comment.RepoDid
631 }
632
633 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body, repo_did) values (?, ?, ?, ?, ?, ?, ?)`
634 res, err := tx.Exec(
635 query,
636 comment.OwnerDid,
637 comment.RepoAt,
638 comment.SubmissionId,
639 comment.CommentAt,
640 comment.PullId,
641 comment.Body,
642 repoDid,
643 )
644 if err != nil {
645 return 0, err
646 }
647
648 i, err := res.LastInsertId()
649 if err != nil {
650 return 0, err
651 }
652
653 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil {
654 return 0, fmt.Errorf("put reference_links: %w", err)
655 }
656
657 return i, nil
658}
659
660func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
661 _, err := e.Exec(
662 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? and state <> ?)`,
663 pullState,
664 repoAt,
665 pullId,
666 models.PullDeleted, // only update state of non-deleted pulls
667 models.PullMerged, // only update state of non-merged pulls
668 )
669 return err
670}
671
672func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
673 err := SetPullState(e, repoAt, pullId, models.PullClosed)
674 return err
675}
676
677func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
678 err := SetPullState(e, repoAt, pullId, models.PullOpen)
679 return err
680}
681
682func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
683 err := SetPullState(e, repoAt, pullId, models.PullMerged)
684 return err
685}
686
687func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
688 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
689 return err
690}
691
692func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
693 _, err := e.Exec(`
694 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
695 values (?, ?, ?, ?, ?)
696 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
697
698 return err
699}
700
701func SetPullParentChangeId(e Execer, parentChangeId string, filters ...orm.Filter) error {
702 var conditions []string
703 var args []any
704
705 args = append(args, parentChangeId)
706
707 for _, filter := range filters {
708 conditions = append(conditions, filter.Condition())
709 args = append(args, filter.Arg()...)
710 }
711
712 whereClause := ""
713 if conditions != nil {
714 whereClause = " where " + strings.Join(conditions, " and ")
715 }
716
717 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
718 _, err := e.Exec(query, args...)
719
720 return err
721}
722
723// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
724// otherwise submissions are immutable
725func UpdatePull(e Execer, newPatch, sourceRev string, filters ...orm.Filter) error {
726 var conditions []string
727 var args []any
728
729 args = append(args, sourceRev)
730 args = append(args, newPatch)
731
732 for _, filter := range filters {
733 conditions = append(conditions, filter.Condition())
734 args = append(args, filter.Arg()...)
735 }
736
737 whereClause := ""
738 if conditions != nil {
739 whereClause = " where " + strings.Join(conditions, " and ")
740 }
741
742 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
743 _, err := e.Exec(query, args...)
744
745 return err
746}
747
748func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
749 row := e.QueryRow(`
750 select
751 count(case when state = ? then 1 end) as open_count,
752 count(case when state = ? then 1 end) as merged_count,
753 count(case when state = ? then 1 end) as closed_count,
754 count(case when state = ? then 1 end) as deleted_count
755 from pulls
756 where repo_at = ?`,
757 models.PullOpen,
758 models.PullMerged,
759 models.PullClosed,
760 models.PullDeleted,
761 repoAt,
762 )
763
764 var count models.PullCount
765 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
766 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
767 }
768
769 return count, nil
770}
771
772// change-id parent-change-id
773//
774// 4 w ,-------- z (TOP)
775// 3 z <----',------- y
776// 2 y <-----',------ x
777// 1 x <------' nil (BOT)
778//
779// `w` is parent of none, so it is the top of the stack
780func GetStack(e Execer, stackId string) (models.Stack, error) {
781 unorderedPulls, err := GetPulls(
782 e,
783 orm.FilterEq("stack_id", stackId),
784 orm.FilterNotEq("state", models.PullDeleted),
785 )
786 if err != nil {
787 return nil, err
788 }
789 // map of parent-change-id to pull
790 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
791 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
792 for _, p := range unorderedPulls {
793 changeIdMap[p.ChangeId] = p
794 if p.ParentChangeId != "" {
795 parentMap[p.ParentChangeId] = p
796 }
797 }
798
799 // the top of the stack is the pull that is not a parent of any pull
800 var topPull *models.Pull
801 for _, maybeTop := range unorderedPulls {
802 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
803 topPull = maybeTop
804 break
805 }
806 }
807
808 pulls := []*models.Pull{}
809 for {
810 pulls = append(pulls, topPull)
811 if topPull.ParentChangeId != "" {
812 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
813 topPull = next
814 } else {
815 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
816 }
817 } else {
818 break
819 }
820 }
821
822 return pulls, nil
823}
824
825func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
826 pulls, err := GetPulls(
827 e,
828 orm.FilterEq("stack_id", stackId),
829 orm.FilterEq("state", models.PullDeleted),
830 )
831 if err != nil {
832 return nil, err
833 }
834
835 return pulls, nil
836}