this repo has no description
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16)
17
18func NewPull(tx *sql.Tx, pull *models.Pull) error {
19 _, err := tx.Exec(`
20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
21 values (?, 1)
22 `, pull.RepoAt)
23 if err != nil {
24 return err
25 }
26
27 var nextId int
28 err = tx.QueryRow(`
29 update repo_pull_seqs
30 set next_pull_id = next_pull_id + 1
31 where repo_at = ?
32 returning next_pull_id - 1
33 `, pull.RepoAt).Scan(&nextId)
34 if err != nil {
35 return err
36 }
37
38 pull.PullId = nextId
39 pull.State = models.PullOpen
40
41 var sourceBranch, sourceRepoAt *string
42 if pull.PullSource != nil {
43 sourceBranch = &pull.PullSource.Branch
44 if pull.PullSource.RepoAt != nil {
45 x := pull.PullSource.RepoAt.String()
46 sourceRepoAt = &x
47 }
48 }
49
50 var stackId, changeId, parentChangeId *string
51 if pull.StackId != "" {
52 stackId = &pull.StackId
53 }
54 if pull.ChangeId != "" {
55 changeId = &pull.ChangeId
56 }
57 if pull.ParentChangeId != "" {
58 parentChangeId = &pull.ParentChangeId
59 }
60
61 result, err := tx.Exec(
62 `
63 insert into pulls (
64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
65 )
66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
67 pull.RepoAt,
68 pull.OwnerDid,
69 pull.PullId,
70 pull.Title,
71 pull.TargetBranch,
72 pull.Body,
73 pull.Rkey,
74 pull.State,
75 sourceBranch,
76 sourceRepoAt,
77 stackId,
78 changeId,
79 parentChangeId,
80 )
81 if err != nil {
82 return err
83 }
84
85 // Set the database primary key ID
86 id, err := result.LastInsertId()
87 if err != nil {
88 return err
89 }
90 pull.ID = int(id)
91
92 _, err = tx.Exec(`
93 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
94 values (?, ?, ?, ?, ?)
95 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
96 return err
97}
98
99func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
100 pull, err := GetPull(e, repoAt, pullId)
101 if err != nil {
102 return "", err
103 }
104 return pull.AtUri(), err
105}
106
107func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
108 var pullId int
109 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
110 return pullId - 1, err
111}
112
113func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) {
114 pulls := make(map[syntax.ATURI]*models.Pull)
115
116 var conditions []string
117 var args []any
118 for _, filter := range filters {
119 conditions = append(conditions, filter.Condition())
120 args = append(args, filter.Arg()...)
121 }
122
123 whereClause := ""
124 if conditions != nil {
125 whereClause = " where " + strings.Join(conditions, " and ")
126 }
127 limitClause := ""
128 if limit != 0 {
129 limitClause = fmt.Sprintf(" limit %d ", limit)
130 }
131
132 query := fmt.Sprintf(`
133 select
134 id,
135 owner_did,
136 repo_at,
137 pull_id,
138 created,
139 title,
140 state,
141 target_branch,
142 body,
143 rkey,
144 source_branch,
145 source_repo_at,
146 stack_id,
147 change_id,
148 parent_change_id
149 from
150 pulls
151 %s
152 order by
153 created desc
154 %s
155 `, whereClause, limitClause)
156
157 rows, err := e.Query(query, args...)
158 if err != nil {
159 return nil, err
160 }
161 defer rows.Close()
162
163 for rows.Next() {
164 var pull models.Pull
165 var createdAt string
166 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
167 err := rows.Scan(
168 &pull.ID,
169 &pull.OwnerDid,
170 &pull.RepoAt,
171 &pull.PullId,
172 &createdAt,
173 &pull.Title,
174 &pull.State,
175 &pull.TargetBranch,
176 &pull.Body,
177 &pull.Rkey,
178 &sourceBranch,
179 &sourceRepoAt,
180 &stackId,
181 &changeId,
182 &parentChangeId,
183 )
184 if err != nil {
185 return nil, err
186 }
187
188 createdTime, err := time.Parse(time.RFC3339, createdAt)
189 if err != nil {
190 return nil, err
191 }
192 pull.Created = createdTime
193
194 if sourceBranch.Valid {
195 pull.PullSource = &models.PullSource{
196 Branch: sourceBranch.String,
197 }
198 if sourceRepoAt.Valid {
199 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
200 if err != nil {
201 return nil, err
202 }
203 pull.PullSource.RepoAt = &sourceRepoAtParsed
204 }
205 }
206
207 if stackId.Valid {
208 pull.StackId = stackId.String
209 }
210 if changeId.Valid {
211 pull.ChangeId = changeId.String
212 }
213 if parentChangeId.Valid {
214 pull.ParentChangeId = parentChangeId.String
215 }
216
217 pulls[pull.AtUri()] = &pull
218 }
219
220 var pullAts []syntax.ATURI
221 for _, p := range pulls {
222 pullAts = append(pullAts, p.AtUri())
223 }
224 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts))
225 if err != nil {
226 return nil, fmt.Errorf("failed to get submissions: %w", err)
227 }
228
229 for pullAt, submissions := range submissionsMap {
230 if p, ok := pulls[pullAt]; ok {
231 p.Submissions = submissions
232 }
233 }
234
235 // collect allLabels for each issue
236 allLabels, err := GetLabels(e, FilterIn("subject", pullAts))
237 if err != nil {
238 return nil, fmt.Errorf("failed to query labels: %w", err)
239 }
240 for pullAt, labels := range allLabels {
241 if p, ok := pulls[pullAt]; ok {
242 p.Labels = labels
243 }
244 }
245
246 // collect pull source for all pulls that need it
247 var sourceAts []syntax.ATURI
248 for _, p := range pulls {
249 if p.PullSource != nil && p.PullSource.RepoAt != nil {
250 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
251 }
252 }
253 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts))
254 if err != nil && !errors.Is(err, sql.ErrNoRows) {
255 return nil, fmt.Errorf("failed to get source repos: %w", err)
256 }
257 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
258 for _, r := range sourceRepos {
259 sourceRepoMap[r.RepoAt()] = &r
260 }
261 for _, p := range pulls {
262 if p.PullSource != nil && p.PullSource.RepoAt != nil {
263 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
264 p.PullSource.Repo = sourceRepo
265 }
266 }
267 }
268
269 orderedByPullId := []*models.Pull{}
270 for _, p := range pulls {
271 orderedByPullId = append(orderedByPullId, p)
272 }
273 sort.Slice(orderedByPullId, func(i, j int) bool {
274 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
275 })
276
277 return orderedByPullId, nil
278}
279
280func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) {
281 return GetPullsWithLimit(e, 0, filters...)
282}
283
284func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) {
285 var ids []int64
286
287 var filters []filter
288 filters = append(filters, FilterEq("state", opts.State))
289 if opts.RepoAt != "" {
290 filters = append(filters, FilterEq("repo_at", opts.RepoAt))
291 }
292
293 var conditions []string
294 var args []any
295
296 for _, filter := range filters {
297 conditions = append(conditions, filter.Condition())
298 args = append(args, filter.Arg()...)
299 }
300
301 whereClause := ""
302 if conditions != nil {
303 whereClause = " where " + strings.Join(conditions, " and ")
304 }
305 pageClause := ""
306 if opts.Page.Limit != 0 {
307 pageClause = fmt.Sprintf(
308 " limit %d offset %d ",
309 opts.Page.Limit,
310 opts.Page.Offset,
311 )
312 }
313
314 query := fmt.Sprintf(
315 `
316 select
317 id
318 from
319 pulls
320 %s
321 %s`,
322 whereClause,
323 pageClause,
324 )
325 args = append(args, opts.Page.Limit, opts.Page.Offset)
326 rows, err := e.Query(query, args...)
327 if err != nil {
328 return nil, err
329 }
330 defer rows.Close()
331
332 for rows.Next() {
333 var id int64
334 err := rows.Scan(&id)
335 if err != nil {
336 return nil, err
337 }
338
339 ids = append(ids, id)
340 }
341
342 return ids, nil
343}
344
345func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
346 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId))
347 if err != nil {
348 return nil, err
349 }
350 if len(pulls) == 0 {
351 return nil, sql.ErrNoRows
352 }
353
354 return pulls[0], nil
355}
356
357// mapping from pull -> pull submissions
358func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
359 var conditions []string
360 var args []any
361 for _, filter := range filters {
362 conditions = append(conditions, filter.Condition())
363 args = append(args, filter.Arg()...)
364 }
365
366 whereClause := ""
367 if conditions != nil {
368 whereClause = " where " + strings.Join(conditions, " and ")
369 }
370
371 query := fmt.Sprintf(`
372 select
373 id,
374 pull_at,
375 round_number,
376 patch,
377 combined,
378 created,
379 source_rev
380 from
381 pull_submissions
382 %s
383 order by
384 round_number asc
385 `, whereClause)
386
387 rows, err := e.Query(query, args...)
388 if err != nil {
389 return nil, err
390 }
391 defer rows.Close()
392
393 submissionMap := make(map[int]*models.PullSubmission)
394
395 for rows.Next() {
396 var submission models.PullSubmission
397 var submissionCreatedStr string
398 var submissionSourceRev, submissionCombined sql.NullString
399 err := rows.Scan(
400 &submission.ID,
401 &submission.PullAt,
402 &submission.RoundNumber,
403 &submission.Patch,
404 &submissionCombined,
405 &submissionCreatedStr,
406 &submissionSourceRev,
407 )
408 if err != nil {
409 return nil, err
410 }
411
412 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
413 submission.Created = t
414 }
415
416 if submissionSourceRev.Valid {
417 submission.SourceRev = submissionSourceRev.String
418 }
419
420 if submissionCombined.Valid {
421 submission.Combined = submissionCombined.String
422 }
423
424 submissionMap[submission.ID] = &submission
425 }
426
427 if err := rows.Err(); err != nil {
428 return nil, err
429 }
430
431 // Get comments for all submissions using GetPullComments
432 submissionIds := slices.Collect(maps.Keys(submissionMap))
433 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds))
434 if err != nil {
435 return nil, fmt.Errorf("failed to get pull comments: %w", err)
436 }
437 for _, comment := range comments {
438 if submission, ok := submissionMap[comment.SubmissionId]; ok {
439 submission.Comments = append(submission.Comments, comment)
440 }
441 }
442
443 // group the submissions by pull_at
444 m := make(map[syntax.ATURI][]*models.PullSubmission)
445 for _, s := range submissionMap {
446 m[s.PullAt] = append(m[s.PullAt], s)
447 }
448
449 // sort each one by round number
450 for _, s := range m {
451 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
452 return cmp.Compare(a.RoundNumber, b.RoundNumber)
453 })
454 }
455
456 return m, nil
457}
458
459func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) {
460 var conditions []string
461 var args []any
462 for _, filter := range filters {
463 conditions = append(conditions, filter.Condition())
464 args = append(args, filter.Arg()...)
465 }
466
467 whereClause := ""
468 if conditions != nil {
469 whereClause = " where " + strings.Join(conditions, " and ")
470 }
471
472 query := fmt.Sprintf(`
473 select
474 id,
475 pull_id,
476 submission_id,
477 repo_at,
478 owner_did,
479 comment_at,
480 body,
481 created
482 from
483 pull_comments
484 %s
485 order by
486 created asc
487 `, whereClause)
488
489 rows, err := e.Query(query, args...)
490 if err != nil {
491 return nil, err
492 }
493 defer rows.Close()
494
495 commentMap := make(map[string]*models.PullComment)
496 for rows.Next() {
497 var comment models.PullComment
498 var createdAt string
499 err := rows.Scan(
500 &comment.ID,
501 &comment.PullId,
502 &comment.SubmissionId,
503 &comment.RepoAt,
504 &comment.OwnerDid,
505 &comment.CommentAt,
506 &comment.Body,
507 &createdAt,
508 )
509 if err != nil {
510 return nil, err
511 }
512
513 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
514 comment.Created = t
515 }
516
517 atUri := comment.AtUri().String()
518 commentMap[atUri] = &comment
519 }
520
521 if err := rows.Err(); err != nil {
522 return nil, err
523 }
524
525 // collect references for each comments
526 commentAts := slices.Collect(maps.Keys(commentMap))
527 allReferencs, err := GetReferencesAll(e, FilterIn("from_at", commentAts))
528 if err != nil {
529 return nil, fmt.Errorf("failed to query reference_links: %w", err)
530 }
531 for commentAt, references := range allReferencs {
532 if comment, ok := commentMap[commentAt.String()]; ok {
533 comment.References = references
534 }
535 }
536
537 var comments []models.PullComment
538 for _, c := range commentMap {
539 comments = append(comments, *c)
540 }
541
542 sort.Slice(comments, func(i, j int) bool {
543 return comments[i].Created.Before(comments[j].Created)
544 })
545
546 return comments, nil
547}
548
549// timeframe here is directly passed into the sql query filter, and any
550// timeframe in the past should be negative; e.g.: "-3 months"
551func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
552 var pulls []models.Pull
553
554 rows, err := e.Query(`
555 select
556 p.owner_did,
557 p.repo_at,
558 p.pull_id,
559 p.created,
560 p.title,
561 p.state,
562 r.did,
563 r.name,
564 r.knot,
565 r.rkey,
566 r.created
567 from
568 pulls p
569 join
570 repos r on p.repo_at = r.at_uri
571 where
572 p.owner_did = ? and p.created >= date ('now', ?)
573 order by
574 p.created desc`, did, timeframe)
575 if err != nil {
576 return nil, err
577 }
578 defer rows.Close()
579
580 for rows.Next() {
581 var pull models.Pull
582 var repo models.Repo
583 var pullCreatedAt, repoCreatedAt string
584 err := rows.Scan(
585 &pull.OwnerDid,
586 &pull.RepoAt,
587 &pull.PullId,
588 &pullCreatedAt,
589 &pull.Title,
590 &pull.State,
591 &repo.Did,
592 &repo.Name,
593 &repo.Knot,
594 &repo.Rkey,
595 &repoCreatedAt,
596 )
597 if err != nil {
598 return nil, err
599 }
600
601 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
602 if err != nil {
603 return nil, err
604 }
605 pull.Created = pullCreatedTime
606
607 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
608 if err != nil {
609 return nil, err
610 }
611 repo.Created = repoCreatedTime
612
613 pull.Repo = &repo
614
615 pulls = append(pulls, pull)
616 }
617
618 if err := rows.Err(); err != nil {
619 return nil, err
620 }
621
622 return pulls, nil
623}
624
625func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) {
626 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
627 res, err := tx.Exec(
628 query,
629 comment.OwnerDid,
630 comment.RepoAt,
631 comment.SubmissionId,
632 comment.CommentAt,
633 comment.PullId,
634 comment.Body,
635 )
636 if err != nil {
637 return 0, err
638 }
639
640 i, err := res.LastInsertId()
641 if err != nil {
642 return 0, err
643 }
644
645 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil {
646 return 0, fmt.Errorf("put reference_links: %w", err)
647 }
648
649 return i, nil
650}
651
652func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
653 _, err := e.Exec(
654 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
655 pullState,
656 repoAt,
657 pullId,
658 models.PullDeleted, // only update state of non-deleted pulls
659 models.PullMerged, // only update state of non-merged pulls
660 )
661 return err
662}
663
664func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
665 err := SetPullState(e, repoAt, pullId, models.PullClosed)
666 return err
667}
668
669func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
670 err := SetPullState(e, repoAt, pullId, models.PullOpen)
671 return err
672}
673
674func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
675 err := SetPullState(e, repoAt, pullId, models.PullMerged)
676 return err
677}
678
679func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
680 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
681 return err
682}
683
684func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
685 _, err := e.Exec(`
686 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
687 values (?, ?, ?, ?, ?)
688 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
689
690 return err
691}
692
693func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
694 var conditions []string
695 var args []any
696
697 args = append(args, parentChangeId)
698
699 for _, filter := range filters {
700 conditions = append(conditions, filter.Condition())
701 args = append(args, filter.Arg()...)
702 }
703
704 whereClause := ""
705 if conditions != nil {
706 whereClause = " where " + strings.Join(conditions, " and ")
707 }
708
709 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
710 _, err := e.Exec(query, args...)
711
712 return err
713}
714
715// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
716// otherwise submissions are immutable
717func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
718 var conditions []string
719 var args []any
720
721 args = append(args, sourceRev)
722 args = append(args, newPatch)
723
724 for _, filter := range filters {
725 conditions = append(conditions, filter.Condition())
726 args = append(args, filter.Arg()...)
727 }
728
729 whereClause := ""
730 if conditions != nil {
731 whereClause = " where " + strings.Join(conditions, " and ")
732 }
733
734 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
735 _, err := e.Exec(query, args...)
736
737 return err
738}
739
740func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
741 row := e.QueryRow(`
742 select
743 count(case when state = ? then 1 end) as open_count,
744 count(case when state = ? then 1 end) as merged_count,
745 count(case when state = ? then 1 end) as closed_count,
746 count(case when state = ? then 1 end) as deleted_count
747 from pulls
748 where repo_at = ?`,
749 models.PullOpen,
750 models.PullMerged,
751 models.PullClosed,
752 models.PullDeleted,
753 repoAt,
754 )
755
756 var count models.PullCount
757 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
758 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
759 }
760
761 return count, nil
762}
763
764// change-id parent-change-id
765//
766// 4 w ,-------- z (TOP)
767// 3 z <----',------- y
768// 2 y <-----',------ x
769// 1 x <------' nil (BOT)
770//
771// `w` is parent of none, so it is the top of the stack
772func GetStack(e Execer, stackId string) (models.Stack, error) {
773 unorderedPulls, err := GetPulls(
774 e,
775 FilterEq("stack_id", stackId),
776 FilterNotEq("state", models.PullDeleted),
777 )
778 if err != nil {
779 return nil, err
780 }
781 // map of parent-change-id to pull
782 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
783 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
784 for _, p := range unorderedPulls {
785 changeIdMap[p.ChangeId] = p
786 if p.ParentChangeId != "" {
787 parentMap[p.ParentChangeId] = p
788 }
789 }
790
791 // the top of the stack is the pull that is not a parent of any pull
792 var topPull *models.Pull
793 for _, maybeTop := range unorderedPulls {
794 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
795 topPull = maybeTop
796 break
797 }
798 }
799
800 pulls := []*models.Pull{}
801 for {
802 pulls = append(pulls, topPull)
803 if topPull.ParentChangeId != "" {
804 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
805 topPull = next
806 } else {
807 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
808 }
809 } else {
810 break
811 }
812 }
813
814 return pulls, nil
815}
816
817func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
818 pulls, err := GetPulls(
819 e,
820 FilterEq("stack_id", stackId),
821 FilterEq("state", models.PullDeleted),
822 )
823 if err != nil {
824 return nil, err
825 }
826
827 return pulls, nil
828}