this repo has no description
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24 PullDeleted
25)
26
27func (p PullState) String() string {
28 switch p {
29 case PullOpen:
30 return "open"
31 case PullMerged:
32 return "merged"
33 case PullClosed:
34 return "closed"
35 case PullDeleted:
36 return "deleted"
37 default:
38 return "closed"
39 }
40}
41
42func (p PullState) IsOpen() bool {
43 return p == PullOpen
44}
45func (p PullState) IsMerged() bool {
46 return p == PullMerged
47}
48func (p PullState) IsClosed() bool {
49 return p == PullClosed
50}
51func (p PullState) IsDeleted() bool {
52 return p == PullDeleted
53}
54
55type Pull struct {
56 // ids
57 ID int
58 PullId int
59
60 // at ids
61 RepoAt syntax.ATURI
62 OwnerDid string
63 Rkey string
64
65 // content
66 Title string
67 Body string
68 TargetBranch string
69 State PullState
70 Submissions []*PullSubmission
71
72 // stacking
73 StackId string // nullable string
74 ChangeId string // nullable string
75 ParentChangeId string // nullable string
76
77 // meta
78 Created time.Time
79 PullSource *PullSource
80
81 // optionally, populate this when querying for reverse mappings
82 Repo *Repo
83}
84
85func (p Pull) AsRecord() tangled.RepoPull {
86 var source *tangled.RepoPull_Source
87 if p.PullSource != nil {
88 s := p.PullSource.AsRecord()
89 source = &s
90 source.Sha = p.LatestSha()
91 }
92
93 record := tangled.RepoPull{
94 Title: p.Title,
95 Body: &p.Body,
96 CreatedAt: p.Created.Format(time.RFC3339),
97 TargetRepo: p.RepoAt.String(),
98 TargetBranch: p.TargetBranch,
99 Patch: p.LatestPatch(),
100 Source: source,
101 }
102 return record
103}
104
105type PullSource struct {
106 Branch string
107 RepoAt *syntax.ATURI
108
109 // optionally populate this for reverse mappings
110 Repo *Repo
111}
112
113func (p PullSource) AsRecord() tangled.RepoPull_Source {
114 var repoAt *string
115 if p.RepoAt != nil {
116 s := p.RepoAt.String()
117 repoAt = &s
118 }
119 record := tangled.RepoPull_Source{
120 Branch: p.Branch,
121 Repo: repoAt,
122 }
123 return record
124}
125
126type PullSubmission struct {
127 // ids
128 ID int
129 PullId int
130
131 // at ids
132 RepoAt syntax.ATURI
133
134 // content
135 RoundNumber int
136 Patch string
137 Comments []PullComment
138 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs
139
140 // meta
141 Created time.Time
142}
143
144type PullComment struct {
145 // ids
146 ID int
147 PullId int
148 SubmissionId int
149
150 // at ids
151 RepoAt string
152 OwnerDid string
153 CommentAt string
154
155 // content
156 Body string
157
158 // meta
159 Created time.Time
160}
161
162func (p *Pull) LatestPatch() string {
163 latestSubmission := p.Submissions[p.LastRoundNumber()]
164 return latestSubmission.Patch
165}
166
167func (p *Pull) LatestSha() string {
168 latestSubmission := p.Submissions[p.LastRoundNumber()]
169 return latestSubmission.SourceRev
170}
171
172func (p *Pull) PullAt() syntax.ATURI {
173 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
174}
175
176func (p *Pull) LastRoundNumber() int {
177 return len(p.Submissions) - 1
178}
179
180func (p *Pull) IsPatchBased() bool {
181 return p.PullSource == nil
182}
183
184func (p *Pull) IsBranchBased() bool {
185 if p.PullSource != nil {
186 if p.PullSource.RepoAt != nil {
187 return p.PullSource.RepoAt == &p.RepoAt
188 } else {
189 // no repo specified
190 return true
191 }
192 }
193 return false
194}
195
196func (p *Pull) IsForkBased() bool {
197 if p.PullSource != nil {
198 if p.PullSource.RepoAt != nil {
199 // make sure repos are different
200 return p.PullSource.RepoAt != &p.RepoAt
201 }
202 }
203 return false
204}
205
206func (p *Pull) IsStacked() bool {
207 return p.StackId != ""
208}
209
210func (s PullSubmission) IsFormatPatch() bool {
211 return patchutil.IsFormatPatch(s.Patch)
212}
213
214func (s PullSubmission) AsFormatPatch() []types.FormatPatch {
215 patches, err := patchutil.ExtractPatches(s.Patch)
216 if err != nil {
217 log.Println("error extracting patches from submission:", err)
218 return []types.FormatPatch{}
219 }
220
221 return patches
222}
223
224func NewPull(tx *sql.Tx, pull *Pull) error {
225 _, err := tx.Exec(`
226 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
227 values (?, 1)
228 `, pull.RepoAt)
229 if err != nil {
230 return err
231 }
232
233 var nextId int
234 err = tx.QueryRow(`
235 update repo_pull_seqs
236 set next_pull_id = next_pull_id + 1
237 where repo_at = ?
238 returning next_pull_id - 1
239 `, pull.RepoAt).Scan(&nextId)
240 if err != nil {
241 return err
242 }
243
244 pull.PullId = nextId
245 pull.State = PullOpen
246
247 var sourceBranch, sourceRepoAt *string
248 if pull.PullSource != nil {
249 sourceBranch = &pull.PullSource.Branch
250 if pull.PullSource.RepoAt != nil {
251 x := pull.PullSource.RepoAt.String()
252 sourceRepoAt = &x
253 }
254 }
255
256 var stackId, changeId, parentChangeId *string
257 if pull.StackId != "" {
258 stackId = &pull.StackId
259 }
260 if pull.ChangeId != "" {
261 changeId = &pull.ChangeId
262 }
263 if pull.ParentChangeId != "" {
264 parentChangeId = &pull.ParentChangeId
265 }
266
267 _, err = tx.Exec(
268 `
269 insert into pulls (
270 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
271 )
272 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
273 pull.RepoAt,
274 pull.OwnerDid,
275 pull.PullId,
276 pull.Title,
277 pull.TargetBranch,
278 pull.Body,
279 pull.Rkey,
280 pull.State,
281 sourceBranch,
282 sourceRepoAt,
283 stackId,
284 changeId,
285 parentChangeId,
286 )
287 if err != nil {
288 return err
289 }
290
291 _, err = tx.Exec(`
292 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
293 values (?, ?, ?, ?, ?)
294 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
295 return err
296}
297
298func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
299 pull, err := GetPull(e, repoAt, pullId)
300 if err != nil {
301 return "", err
302 }
303 return pull.PullAt(), err
304}
305
306func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
307 var pullId int
308 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
309 return pullId - 1, err
310}
311
312func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*Pull, error) {
313 pulls := make(map[int]*Pull)
314
315 var conditions []string
316 var args []any
317 for _, filter := range filters {
318 conditions = append(conditions, filter.Condition())
319 args = append(args, filter.Arg()...)
320 }
321
322 whereClause := ""
323 if conditions != nil {
324 whereClause = " where " + strings.Join(conditions, " and ")
325 }
326 limitClause := ""
327 if limit != 0 {
328 limitClause = fmt.Sprintf(" limit %d ", limit)
329 }
330
331 query := fmt.Sprintf(`
332 select
333 owner_did,
334 repo_at,
335 pull_id,
336 created,
337 title,
338 state,
339 target_branch,
340 body,
341 rkey,
342 source_branch,
343 source_repo_at,
344 stack_id,
345 change_id,
346 parent_change_id
347 from
348 pulls
349 %s
350 order by
351 created desc
352 %s
353 `, whereClause, limitClause)
354
355 rows, err := e.Query(query, args...)
356 if err != nil {
357 return nil, err
358 }
359 defer rows.Close()
360
361 for rows.Next() {
362 var pull Pull
363 var createdAt string
364 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
365 err := rows.Scan(
366 &pull.OwnerDid,
367 &pull.RepoAt,
368 &pull.PullId,
369 &createdAt,
370 &pull.Title,
371 &pull.State,
372 &pull.TargetBranch,
373 &pull.Body,
374 &pull.Rkey,
375 &sourceBranch,
376 &sourceRepoAt,
377 &stackId,
378 &changeId,
379 &parentChangeId,
380 )
381 if err != nil {
382 return nil, err
383 }
384
385 createdTime, err := time.Parse(time.RFC3339, createdAt)
386 if err != nil {
387 return nil, err
388 }
389 pull.Created = createdTime
390
391 if sourceBranch.Valid {
392 pull.PullSource = &PullSource{
393 Branch: sourceBranch.String,
394 }
395 if sourceRepoAt.Valid {
396 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
397 if err != nil {
398 return nil, err
399 }
400 pull.PullSource.RepoAt = &sourceRepoAtParsed
401 }
402 }
403
404 if stackId.Valid {
405 pull.StackId = stackId.String
406 }
407 if changeId.Valid {
408 pull.ChangeId = changeId.String
409 }
410 if parentChangeId.Valid {
411 pull.ParentChangeId = parentChangeId.String
412 }
413
414 pulls[pull.PullId] = &pull
415 }
416
417 // get latest round no. for each pull
418 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
419 submissionsQuery := fmt.Sprintf(`
420 select
421 id, pull_id, round_number, patch, created, source_rev
422 from
423 pull_submissions
424 where
425 repo_at in (%s) and pull_id in (%s)
426 `, inClause, inClause)
427
428 args = make([]any, len(pulls)*2)
429 idx := 0
430 for _, p := range pulls {
431 args[idx] = p.RepoAt
432 idx += 1
433 }
434 for _, p := range pulls {
435 args[idx] = p.PullId
436 idx += 1
437 }
438 submissionsRows, err := e.Query(submissionsQuery, args...)
439 if err != nil {
440 return nil, err
441 }
442 defer submissionsRows.Close()
443
444 for submissionsRows.Next() {
445 var s PullSubmission
446 var sourceRev sql.NullString
447 var createdAt string
448 err := submissionsRows.Scan(
449 &s.ID,
450 &s.PullId,
451 &s.RoundNumber,
452 &s.Patch,
453 &createdAt,
454 &sourceRev,
455 )
456 if err != nil {
457 return nil, err
458 }
459
460 createdTime, err := time.Parse(time.RFC3339, createdAt)
461 if err != nil {
462 return nil, err
463 }
464 s.Created = createdTime
465
466 if sourceRev.Valid {
467 s.SourceRev = sourceRev.String
468 }
469
470 if p, ok := pulls[s.PullId]; ok {
471 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
472 p.Submissions[s.RoundNumber] = &s
473 }
474 }
475 if err := rows.Err(); err != nil {
476 return nil, err
477 }
478
479 // get comment count on latest submission on each pull
480 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
481 commentsQuery := fmt.Sprintf(`
482 select
483 count(id), pull_id
484 from
485 pull_comments
486 where
487 submission_id in (%s)
488 group by
489 submission_id
490 `, inClause)
491
492 args = []any{}
493 for _, p := range pulls {
494 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
495 }
496 commentsRows, err := e.Query(commentsQuery, args...)
497 if err != nil {
498 return nil, err
499 }
500 defer commentsRows.Close()
501
502 for commentsRows.Next() {
503 var commentCount, pullId int
504 err := commentsRows.Scan(
505 &commentCount,
506 &pullId,
507 )
508 if err != nil {
509 return nil, err
510 }
511 if p, ok := pulls[pullId]; ok {
512 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
513 }
514 }
515 if err := rows.Err(); err != nil {
516 return nil, err
517 }
518
519 orderedByPullId := []*Pull{}
520 for _, p := range pulls {
521 orderedByPullId = append(orderedByPullId, p)
522 }
523 sort.Slice(orderedByPullId, func(i, j int) bool {
524 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
525 })
526
527 return orderedByPullId, nil
528}
529
530func GetPulls(e Execer, filters ...filter) ([]*Pull, error) {
531 return GetPullsWithLimit(e, 0, filters...)
532}
533
534func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
535 query := `
536 select
537 owner_did,
538 pull_id,
539 created,
540 title,
541 state,
542 target_branch,
543 repo_at,
544 body,
545 rkey,
546 source_branch,
547 source_repo_at,
548 stack_id,
549 change_id,
550 parent_change_id
551 from
552 pulls
553 where
554 repo_at = ? and pull_id = ?
555 `
556 row := e.QueryRow(query, repoAt, pullId)
557
558 var pull Pull
559 var createdAt string
560 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
561 err := row.Scan(
562 &pull.OwnerDid,
563 &pull.PullId,
564 &createdAt,
565 &pull.Title,
566 &pull.State,
567 &pull.TargetBranch,
568 &pull.RepoAt,
569 &pull.Body,
570 &pull.Rkey,
571 &sourceBranch,
572 &sourceRepoAt,
573 &stackId,
574 &changeId,
575 &parentChangeId,
576 )
577 if err != nil {
578 return nil, err
579 }
580
581 createdTime, err := time.Parse(time.RFC3339, createdAt)
582 if err != nil {
583 return nil, err
584 }
585 pull.Created = createdTime
586
587 // populate source
588 if sourceBranch.Valid {
589 pull.PullSource = &PullSource{
590 Branch: sourceBranch.String,
591 }
592 if sourceRepoAt.Valid {
593 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
594 if err != nil {
595 return nil, err
596 }
597 pull.PullSource.RepoAt = &sourceRepoAtParsed
598 }
599 }
600
601 if stackId.Valid {
602 pull.StackId = stackId.String
603 }
604 if changeId.Valid {
605 pull.ChangeId = changeId.String
606 }
607 if parentChangeId.Valid {
608 pull.ParentChangeId = parentChangeId.String
609 }
610
611 submissionsQuery := `
612 select
613 id, pull_id, repo_at, round_number, patch, created, source_rev
614 from
615 pull_submissions
616 where
617 repo_at = ? and pull_id = ?
618 `
619 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
620 if err != nil {
621 return nil, err
622 }
623 defer submissionsRows.Close()
624
625 submissionsMap := make(map[int]*PullSubmission)
626
627 for submissionsRows.Next() {
628 var submission PullSubmission
629 var submissionCreatedStr string
630 var submissionSourceRev sql.NullString
631 err := submissionsRows.Scan(
632 &submission.ID,
633 &submission.PullId,
634 &submission.RepoAt,
635 &submission.RoundNumber,
636 &submission.Patch,
637 &submissionCreatedStr,
638 &submissionSourceRev,
639 )
640 if err != nil {
641 return nil, err
642 }
643
644 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
645 if err != nil {
646 return nil, err
647 }
648 submission.Created = submissionCreatedTime
649
650 if submissionSourceRev.Valid {
651 submission.SourceRev = submissionSourceRev.String
652 }
653
654 submissionsMap[submission.ID] = &submission
655 }
656 if err = submissionsRows.Close(); err != nil {
657 return nil, err
658 }
659 if len(submissionsMap) == 0 {
660 return &pull, nil
661 }
662
663 var args []any
664 for k := range submissionsMap {
665 args = append(args, k)
666 }
667 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
668 commentsQuery := fmt.Sprintf(`
669 select
670 id,
671 pull_id,
672 submission_id,
673 repo_at,
674 owner_did,
675 comment_at,
676 body,
677 created
678 from
679 pull_comments
680 where
681 submission_id IN (%s)
682 order by
683 created asc
684 `, inClause)
685 commentsRows, err := e.Query(commentsQuery, args...)
686 if err != nil {
687 return nil, err
688 }
689 defer commentsRows.Close()
690
691 for commentsRows.Next() {
692 var comment PullComment
693 var commentCreatedStr string
694 err := commentsRows.Scan(
695 &comment.ID,
696 &comment.PullId,
697 &comment.SubmissionId,
698 &comment.RepoAt,
699 &comment.OwnerDid,
700 &comment.CommentAt,
701 &comment.Body,
702 &commentCreatedStr,
703 )
704 if err != nil {
705 return nil, err
706 }
707
708 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
709 if err != nil {
710 return nil, err
711 }
712 comment.Created = commentCreatedTime
713
714 // Add the comment to its submission
715 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
716 submission.Comments = append(submission.Comments, comment)
717 }
718
719 }
720 if err = commentsRows.Err(); err != nil {
721 return nil, err
722 }
723
724 var pullSourceRepo *Repo
725 if pull.PullSource != nil {
726 if pull.PullSource.RepoAt != nil {
727 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
728 if err != nil {
729 log.Printf("failed to get repo by at uri: %v", err)
730 } else {
731 pull.PullSource.Repo = pullSourceRepo
732 }
733 }
734 }
735
736 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
737 for _, submission := range submissionsMap {
738 pull.Submissions[submission.RoundNumber] = submission
739 }
740
741 return &pull, nil
742}
743
744// timeframe here is directly passed into the sql query filter, and any
745// timeframe in the past should be negative; e.g.: "-3 months"
746func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
747 var pulls []Pull
748
749 rows, err := e.Query(`
750 select
751 p.owner_did,
752 p.repo_at,
753 p.pull_id,
754 p.created,
755 p.title,
756 p.state,
757 r.did,
758 r.name,
759 r.knot,
760 r.rkey,
761 r.created
762 from
763 pulls p
764 join
765 repos r on p.repo_at = r.at_uri
766 where
767 p.owner_did = ? and p.created >= date ('now', ?)
768 order by
769 p.created desc`, did, timeframe)
770 if err != nil {
771 return nil, err
772 }
773 defer rows.Close()
774
775 for rows.Next() {
776 var pull Pull
777 var repo Repo
778 var pullCreatedAt, repoCreatedAt string
779 err := rows.Scan(
780 &pull.OwnerDid,
781 &pull.RepoAt,
782 &pull.PullId,
783 &pullCreatedAt,
784 &pull.Title,
785 &pull.State,
786 &repo.Did,
787 &repo.Name,
788 &repo.Knot,
789 &repo.Rkey,
790 &repoCreatedAt,
791 )
792 if err != nil {
793 return nil, err
794 }
795
796 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
797 if err != nil {
798 return nil, err
799 }
800 pull.Created = pullCreatedTime
801
802 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
803 if err != nil {
804 return nil, err
805 }
806 repo.Created = repoCreatedTime
807
808 pull.Repo = &repo
809
810 pulls = append(pulls, pull)
811 }
812
813 if err := rows.Err(); err != nil {
814 return nil, err
815 }
816
817 return pulls, nil
818}
819
820func NewPullComment(e Execer, comment *PullComment) (int64, error) {
821 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
822 res, err := e.Exec(
823 query,
824 comment.OwnerDid,
825 comment.RepoAt,
826 comment.SubmissionId,
827 comment.CommentAt,
828 comment.PullId,
829 comment.Body,
830 )
831 if err != nil {
832 return 0, err
833 }
834
835 i, err := res.LastInsertId()
836 if err != nil {
837 return 0, err
838 }
839
840 return i, nil
841}
842
843func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
844 _, err := e.Exec(
845 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
846 pullState,
847 repoAt,
848 pullId,
849 PullDeleted, // only update state of non-deleted pulls
850 PullMerged, // only update state of non-merged pulls
851 )
852 return err
853}
854
855func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
856 err := SetPullState(e, repoAt, pullId, PullClosed)
857 return err
858}
859
860func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
861 err := SetPullState(e, repoAt, pullId, PullOpen)
862 return err
863}
864
865func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
866 err := SetPullState(e, repoAt, pullId, PullMerged)
867 return err
868}
869
870func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
871 err := SetPullState(e, repoAt, pullId, PullDeleted)
872 return err
873}
874
875func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
876 newRoundNumber := len(pull.Submissions)
877 _, err := e.Exec(`
878 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
879 values (?, ?, ?, ?, ?)
880 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
881
882 return err
883}
884
885func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
886 var conditions []string
887 var args []any
888
889 args = append(args, parentChangeId)
890
891 for _, filter := range filters {
892 conditions = append(conditions, filter.Condition())
893 args = append(args, filter.Arg()...)
894 }
895
896 whereClause := ""
897 if conditions != nil {
898 whereClause = " where " + strings.Join(conditions, " and ")
899 }
900
901 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
902 _, err := e.Exec(query, args...)
903
904 return err
905}
906
907// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
908// otherwise submissions are immutable
909func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
910 var conditions []string
911 var args []any
912
913 args = append(args, sourceRev)
914 args = append(args, newPatch)
915
916 for _, filter := range filters {
917 conditions = append(conditions, filter.Condition())
918 args = append(args, filter.Arg()...)
919 }
920
921 whereClause := ""
922 if conditions != nil {
923 whereClause = " where " + strings.Join(conditions, " and ")
924 }
925
926 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
927 _, err := e.Exec(query, args...)
928
929 return err
930}
931
932type PullCount struct {
933 Open int
934 Merged int
935 Closed int
936 Deleted int
937}
938
939func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
940 row := e.QueryRow(`
941 select
942 count(case when state = ? then 1 end) as open_count,
943 count(case when state = ? then 1 end) as merged_count,
944 count(case when state = ? then 1 end) as closed_count,
945 count(case when state = ? then 1 end) as deleted_count
946 from pulls
947 where repo_at = ?`,
948 PullOpen,
949 PullMerged,
950 PullClosed,
951 PullDeleted,
952 repoAt,
953 )
954
955 var count PullCount
956 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
957 return PullCount{0, 0, 0, 0}, err
958 }
959
960 return count, nil
961}
962
963type Stack []*Pull
964
965// change-id parent-change-id
966//
967// 4 w ,-------- z (TOP)
968// 3 z <----',------- y
969// 2 y <-----',------ x
970// 1 x <------' nil (BOT)
971//
972// `w` is parent of none, so it is the top of the stack
973func GetStack(e Execer, stackId string) (Stack, error) {
974 unorderedPulls, err := GetPulls(
975 e,
976 FilterEq("stack_id", stackId),
977 FilterNotEq("state", PullDeleted),
978 )
979 if err != nil {
980 return nil, err
981 }
982 // map of parent-change-id to pull
983 changeIdMap := make(map[string]*Pull, len(unorderedPulls))
984 parentMap := make(map[string]*Pull, len(unorderedPulls))
985 for _, p := range unorderedPulls {
986 changeIdMap[p.ChangeId] = p
987 if p.ParentChangeId != "" {
988 parentMap[p.ParentChangeId] = p
989 }
990 }
991
992 // the top of the stack is the pull that is not a parent of any pull
993 var topPull *Pull
994 for _, maybeTop := range unorderedPulls {
995 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
996 topPull = maybeTop
997 break
998 }
999 }
1000
1001 pulls := []*Pull{}
1002 for {
1003 pulls = append(pulls, topPull)
1004 if topPull.ParentChangeId != "" {
1005 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
1006 topPull = next
1007 } else {
1008 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
1009 }
1010 } else {
1011 break
1012 }
1013 }
1014
1015 return pulls, nil
1016}
1017
1018func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) {
1019 pulls, err := GetPulls(
1020 e,
1021 FilterEq("stack_id", stackId),
1022 FilterEq("state", PullDeleted),
1023 )
1024 if err != nil {
1025 return nil, err
1026 }
1027
1028 return pulls, nil
1029}
1030
1031// position of this pull in the stack
1032func (stack Stack) Position(pull *Pull) int {
1033 return slices.IndexFunc(stack, func(p *Pull) bool {
1034 return p.ChangeId == pull.ChangeId
1035 })
1036}
1037
1038// all pulls below this pull (including self) in this stack
1039//
1040// nil if this pull does not belong to this stack
1041func (stack Stack) Below(pull *Pull) Stack {
1042 position := stack.Position(pull)
1043
1044 if position < 0 {
1045 return nil
1046 }
1047
1048 return stack[position:]
1049}
1050
1051// all pulls below this pull (excluding self) in this stack
1052func (stack Stack) StrictlyBelow(pull *Pull) Stack {
1053 below := stack.Below(pull)
1054
1055 if len(below) > 0 {
1056 return below[1:]
1057 }
1058
1059 return nil
1060}
1061
1062// all pulls above this pull (including self) in this stack
1063func (stack Stack) Above(pull *Pull) Stack {
1064 position := stack.Position(pull)
1065
1066 if position < 0 {
1067 return nil
1068 }
1069
1070 return stack[:position+1]
1071}
1072
1073// all pulls below this pull (excluding self) in this stack
1074func (stack Stack) StrictlyAbove(pull *Pull) Stack {
1075 above := stack.Above(pull)
1076
1077 if len(above) > 0 {
1078 return above[:len(above)-1]
1079 }
1080
1081 return nil
1082}
1083
1084// the combined format-patches of all the newest submissions in this stack
1085func (stack Stack) CombinedPatch() string {
1086 // go in reverse order because the bottom of the stack is the last element in the slice
1087 var combined strings.Builder
1088 for idx := range stack {
1089 pull := stack[len(stack)-1-idx]
1090 combined.WriteString(pull.LatestPatch())
1091 combined.WriteString("\n")
1092 }
1093 return combined.String()
1094}
1095
1096// filter out PRs that are "active"
1097//
1098// PRs that are still open are active
1099func (stack Stack) Mergeable() Stack {
1100 var mergeable Stack
1101
1102 for _, p := range stack {
1103 // stop at the first merged PR
1104 if p.State == PullMerged || p.State == PullClosed {
1105 break
1106 }
1107
1108 // skip over deleted PRs
1109 if p.State != PullDeleted {
1110 mergeable = append(mergeable, p)
1111 }
1112 }
1113
1114 return mergeable
1115}