this repo has no description
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.org/core/api/tangled"
14 "tangled.org/core/appview/models"
15 "tangled.org/core/patchutil"
16 "tangled.org/core/types"
17)
18
19type PullState int
20
21const (
22 PullClosed PullState = iota
23 PullOpen
24 PullMerged
25 PullDeleted
26)
27
28func (p PullState) String() string {
29 switch p {
30 case PullOpen:
31 return "open"
32 case PullMerged:
33 return "merged"
34 case PullClosed:
35 return "closed"
36 case PullDeleted:
37 return "deleted"
38 default:
39 return "closed"
40 }
41}
42
43func (p PullState) IsOpen() bool {
44 return p == PullOpen
45}
46func (p PullState) IsMerged() bool {
47 return p == PullMerged
48}
49func (p PullState) IsClosed() bool {
50 return p == PullClosed
51}
52func (p PullState) IsDeleted() bool {
53 return p == PullDeleted
54}
55
56type Pull struct {
57 // ids
58 ID int
59 PullId int
60
61 // at ids
62 RepoAt syntax.ATURI
63 OwnerDid string
64 Rkey string
65
66 // content
67 Title string
68 Body string
69 TargetBranch string
70 State PullState
71 Submissions []*PullSubmission
72
73 // stacking
74 StackId string // nullable string
75 ChangeId string // nullable string
76 ParentChangeId string // nullable string
77
78 // meta
79 Created time.Time
80 PullSource *PullSource
81
82 // optionally, populate this when querying for reverse mappings
83 Repo *models.Repo
84}
85
86func (p Pull) AsRecord() tangled.RepoPull {
87 var source *tangled.RepoPull_Source
88 if p.PullSource != nil {
89 s := p.PullSource.AsRecord()
90 source = &s
91 source.Sha = p.LatestSha()
92 }
93
94 record := tangled.RepoPull{
95 Title: p.Title,
96 Body: &p.Body,
97 CreatedAt: p.Created.Format(time.RFC3339),
98 Target: &tangled.RepoPull_Target{
99 Repo: p.RepoAt.String(),
100 Branch: p.TargetBranch,
101 },
102 Patch: p.LatestPatch(),
103 Source: source,
104 }
105 return record
106}
107
108type PullSource struct {
109 Branch string
110 RepoAt *syntax.ATURI
111
112 // optionally populate this for reverse mappings
113 Repo *models.Repo
114}
115
116func (p PullSource) AsRecord() tangled.RepoPull_Source {
117 var repoAt *string
118 if p.RepoAt != nil {
119 s := p.RepoAt.String()
120 repoAt = &s
121 }
122 record := tangled.RepoPull_Source{
123 Branch: p.Branch,
124 Repo: repoAt,
125 }
126 return record
127}
128
129type PullSubmission struct {
130 // ids
131 ID int
132 PullId int
133
134 // at ids
135 RepoAt syntax.ATURI
136
137 // content
138 RoundNumber int
139 Patch string
140 Comments []PullComment
141 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs
142
143 // meta
144 Created time.Time
145}
146
147type PullComment struct {
148 // ids
149 ID int
150 PullId int
151 SubmissionId int
152
153 // at ids
154 RepoAt string
155 OwnerDid string
156 CommentAt string
157
158 // content
159 Body string
160
161 // meta
162 Created time.Time
163}
164
165func (p *Pull) LatestPatch() string {
166 latestSubmission := p.Submissions[p.LastRoundNumber()]
167 return latestSubmission.Patch
168}
169
170func (p *Pull) LatestSha() string {
171 latestSubmission := p.Submissions[p.LastRoundNumber()]
172 return latestSubmission.SourceRev
173}
174
175func (p *Pull) PullAt() syntax.ATURI {
176 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
177}
178
179func (p *Pull) LastRoundNumber() int {
180 return len(p.Submissions) - 1
181}
182
183func (p *Pull) IsPatchBased() bool {
184 return p.PullSource == nil
185}
186
187func (p *Pull) IsBranchBased() bool {
188 if p.PullSource != nil {
189 if p.PullSource.RepoAt != nil {
190 return p.PullSource.RepoAt == &p.RepoAt
191 } else {
192 // no repo specified
193 return true
194 }
195 }
196 return false
197}
198
199func (p *Pull) IsForkBased() bool {
200 if p.PullSource != nil {
201 if p.PullSource.RepoAt != nil {
202 // make sure repos are different
203 return p.PullSource.RepoAt != &p.RepoAt
204 }
205 }
206 return false
207}
208
209func (p *Pull) IsStacked() bool {
210 return p.StackId != ""
211}
212
213func (s PullSubmission) IsFormatPatch() bool {
214 return patchutil.IsFormatPatch(s.Patch)
215}
216
217func (s PullSubmission) AsFormatPatch() []types.FormatPatch {
218 patches, err := patchutil.ExtractPatches(s.Patch)
219 if err != nil {
220 log.Println("error extracting patches from submission:", err)
221 return []types.FormatPatch{}
222 }
223
224 return patches
225}
226
227func NewPull(tx *sql.Tx, pull *Pull) error {
228 _, err := tx.Exec(`
229 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
230 values (?, 1)
231 `, pull.RepoAt)
232 if err != nil {
233 return err
234 }
235
236 var nextId int
237 err = tx.QueryRow(`
238 update repo_pull_seqs
239 set next_pull_id = next_pull_id + 1
240 where repo_at = ?
241 returning next_pull_id - 1
242 `, pull.RepoAt).Scan(&nextId)
243 if err != nil {
244 return err
245 }
246
247 pull.PullId = nextId
248 pull.State = PullOpen
249
250 var sourceBranch, sourceRepoAt *string
251 if pull.PullSource != nil {
252 sourceBranch = &pull.PullSource.Branch
253 if pull.PullSource.RepoAt != nil {
254 x := pull.PullSource.RepoAt.String()
255 sourceRepoAt = &x
256 }
257 }
258
259 var stackId, changeId, parentChangeId *string
260 if pull.StackId != "" {
261 stackId = &pull.StackId
262 }
263 if pull.ChangeId != "" {
264 changeId = &pull.ChangeId
265 }
266 if pull.ParentChangeId != "" {
267 parentChangeId = &pull.ParentChangeId
268 }
269
270 _, err = tx.Exec(
271 `
272 insert into pulls (
273 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
274 )
275 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
276 pull.RepoAt,
277 pull.OwnerDid,
278 pull.PullId,
279 pull.Title,
280 pull.TargetBranch,
281 pull.Body,
282 pull.Rkey,
283 pull.State,
284 sourceBranch,
285 sourceRepoAt,
286 stackId,
287 changeId,
288 parentChangeId,
289 )
290 if err != nil {
291 return err
292 }
293
294 _, err = tx.Exec(`
295 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
296 values (?, ?, ?, ?, ?)
297 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
298 return err
299}
300
301func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
302 pull, err := GetPull(e, repoAt, pullId)
303 if err != nil {
304 return "", err
305 }
306 return pull.PullAt(), err
307}
308
309func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
310 var pullId int
311 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
312 return pullId - 1, err
313}
314
315func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*Pull, error) {
316 pulls := make(map[int]*Pull)
317
318 var conditions []string
319 var args []any
320 for _, filter := range filters {
321 conditions = append(conditions, filter.Condition())
322 args = append(args, filter.Arg()...)
323 }
324
325 whereClause := ""
326 if conditions != nil {
327 whereClause = " where " + strings.Join(conditions, " and ")
328 }
329 limitClause := ""
330 if limit != 0 {
331 limitClause = fmt.Sprintf(" limit %d ", limit)
332 }
333
334 query := fmt.Sprintf(`
335 select
336 owner_did,
337 repo_at,
338 pull_id,
339 created,
340 title,
341 state,
342 target_branch,
343 body,
344 rkey,
345 source_branch,
346 source_repo_at,
347 stack_id,
348 change_id,
349 parent_change_id
350 from
351 pulls
352 %s
353 order by
354 created desc
355 %s
356 `, whereClause, limitClause)
357
358 rows, err := e.Query(query, args...)
359 if err != nil {
360 return nil, err
361 }
362 defer rows.Close()
363
364 for rows.Next() {
365 var pull Pull
366 var createdAt string
367 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
368 err := rows.Scan(
369 &pull.OwnerDid,
370 &pull.RepoAt,
371 &pull.PullId,
372 &createdAt,
373 &pull.Title,
374 &pull.State,
375 &pull.TargetBranch,
376 &pull.Body,
377 &pull.Rkey,
378 &sourceBranch,
379 &sourceRepoAt,
380 &stackId,
381 &changeId,
382 &parentChangeId,
383 )
384 if err != nil {
385 return nil, err
386 }
387
388 createdTime, err := time.Parse(time.RFC3339, createdAt)
389 if err != nil {
390 return nil, err
391 }
392 pull.Created = createdTime
393
394 if sourceBranch.Valid {
395 pull.PullSource = &PullSource{
396 Branch: sourceBranch.String,
397 }
398 if sourceRepoAt.Valid {
399 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
400 if err != nil {
401 return nil, err
402 }
403 pull.PullSource.RepoAt = &sourceRepoAtParsed
404 }
405 }
406
407 if stackId.Valid {
408 pull.StackId = stackId.String
409 }
410 if changeId.Valid {
411 pull.ChangeId = changeId.String
412 }
413 if parentChangeId.Valid {
414 pull.ParentChangeId = parentChangeId.String
415 }
416
417 pulls[pull.PullId] = &pull
418 }
419
420 // get latest round no. for each pull
421 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
422 submissionsQuery := fmt.Sprintf(`
423 select
424 id, pull_id, round_number, patch, created, source_rev
425 from
426 pull_submissions
427 where
428 repo_at in (%s) and pull_id in (%s)
429 `, inClause, inClause)
430
431 args = make([]any, len(pulls)*2)
432 idx := 0
433 for _, p := range pulls {
434 args[idx] = p.RepoAt
435 idx += 1
436 }
437 for _, p := range pulls {
438 args[idx] = p.PullId
439 idx += 1
440 }
441 submissionsRows, err := e.Query(submissionsQuery, args...)
442 if err != nil {
443 return nil, err
444 }
445 defer submissionsRows.Close()
446
447 for submissionsRows.Next() {
448 var s PullSubmission
449 var sourceRev sql.NullString
450 var createdAt string
451 err := submissionsRows.Scan(
452 &s.ID,
453 &s.PullId,
454 &s.RoundNumber,
455 &s.Patch,
456 &createdAt,
457 &sourceRev,
458 )
459 if err != nil {
460 return nil, err
461 }
462
463 createdTime, err := time.Parse(time.RFC3339, createdAt)
464 if err != nil {
465 return nil, err
466 }
467 s.Created = createdTime
468
469 if sourceRev.Valid {
470 s.SourceRev = sourceRev.String
471 }
472
473 if p, ok := pulls[s.PullId]; ok {
474 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
475 p.Submissions[s.RoundNumber] = &s
476 }
477 }
478 if err := rows.Err(); err != nil {
479 return nil, err
480 }
481
482 // get comment count on latest submission on each pull
483 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
484 commentsQuery := fmt.Sprintf(`
485 select
486 count(id), pull_id
487 from
488 pull_comments
489 where
490 submission_id in (%s)
491 group by
492 submission_id
493 `, inClause)
494
495 args = []any{}
496 for _, p := range pulls {
497 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
498 }
499 commentsRows, err := e.Query(commentsQuery, args...)
500 if err != nil {
501 return nil, err
502 }
503 defer commentsRows.Close()
504
505 for commentsRows.Next() {
506 var commentCount, pullId int
507 err := commentsRows.Scan(
508 &commentCount,
509 &pullId,
510 )
511 if err != nil {
512 return nil, err
513 }
514 if p, ok := pulls[pullId]; ok {
515 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
516 }
517 }
518 if err := rows.Err(); err != nil {
519 return nil, err
520 }
521
522 orderedByPullId := []*Pull{}
523 for _, p := range pulls {
524 orderedByPullId = append(orderedByPullId, p)
525 }
526 sort.Slice(orderedByPullId, func(i, j int) bool {
527 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
528 })
529
530 return orderedByPullId, nil
531}
532
533func GetPulls(e Execer, filters ...filter) ([]*Pull, error) {
534 return GetPullsWithLimit(e, 0, filters...)
535}
536
537func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
538 query := `
539 select
540 owner_did,
541 pull_id,
542 created,
543 title,
544 state,
545 target_branch,
546 repo_at,
547 body,
548 rkey,
549 source_branch,
550 source_repo_at,
551 stack_id,
552 change_id,
553 parent_change_id
554 from
555 pulls
556 where
557 repo_at = ? and pull_id = ?
558 `
559 row := e.QueryRow(query, repoAt, pullId)
560
561 var pull Pull
562 var createdAt string
563 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
564 err := row.Scan(
565 &pull.OwnerDid,
566 &pull.PullId,
567 &createdAt,
568 &pull.Title,
569 &pull.State,
570 &pull.TargetBranch,
571 &pull.RepoAt,
572 &pull.Body,
573 &pull.Rkey,
574 &sourceBranch,
575 &sourceRepoAt,
576 &stackId,
577 &changeId,
578 &parentChangeId,
579 )
580 if err != nil {
581 return nil, err
582 }
583
584 createdTime, err := time.Parse(time.RFC3339, createdAt)
585 if err != nil {
586 return nil, err
587 }
588 pull.Created = createdTime
589
590 // populate source
591 if sourceBranch.Valid {
592 pull.PullSource = &PullSource{
593 Branch: sourceBranch.String,
594 }
595 if sourceRepoAt.Valid {
596 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
597 if err != nil {
598 return nil, err
599 }
600 pull.PullSource.RepoAt = &sourceRepoAtParsed
601 }
602 }
603
604 if stackId.Valid {
605 pull.StackId = stackId.String
606 }
607 if changeId.Valid {
608 pull.ChangeId = changeId.String
609 }
610 if parentChangeId.Valid {
611 pull.ParentChangeId = parentChangeId.String
612 }
613
614 submissionsQuery := `
615 select
616 id, pull_id, repo_at, round_number, patch, created, source_rev
617 from
618 pull_submissions
619 where
620 repo_at = ? and pull_id = ?
621 `
622 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
623 if err != nil {
624 return nil, err
625 }
626 defer submissionsRows.Close()
627
628 submissionsMap := make(map[int]*PullSubmission)
629
630 for submissionsRows.Next() {
631 var submission PullSubmission
632 var submissionCreatedStr string
633 var submissionSourceRev sql.NullString
634 err := submissionsRows.Scan(
635 &submission.ID,
636 &submission.PullId,
637 &submission.RepoAt,
638 &submission.RoundNumber,
639 &submission.Patch,
640 &submissionCreatedStr,
641 &submissionSourceRev,
642 )
643 if err != nil {
644 return nil, err
645 }
646
647 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
648 if err != nil {
649 return nil, err
650 }
651 submission.Created = submissionCreatedTime
652
653 if submissionSourceRev.Valid {
654 submission.SourceRev = submissionSourceRev.String
655 }
656
657 submissionsMap[submission.ID] = &submission
658 }
659 if err = submissionsRows.Close(); err != nil {
660 return nil, err
661 }
662 if len(submissionsMap) == 0 {
663 return &pull, nil
664 }
665
666 var args []any
667 for k := range submissionsMap {
668 args = append(args, k)
669 }
670 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
671 commentsQuery := fmt.Sprintf(`
672 select
673 id,
674 pull_id,
675 submission_id,
676 repo_at,
677 owner_did,
678 comment_at,
679 body,
680 created
681 from
682 pull_comments
683 where
684 submission_id IN (%s)
685 order by
686 created asc
687 `, inClause)
688 commentsRows, err := e.Query(commentsQuery, args...)
689 if err != nil {
690 return nil, err
691 }
692 defer commentsRows.Close()
693
694 for commentsRows.Next() {
695 var comment PullComment
696 var commentCreatedStr string
697 err := commentsRows.Scan(
698 &comment.ID,
699 &comment.PullId,
700 &comment.SubmissionId,
701 &comment.RepoAt,
702 &comment.OwnerDid,
703 &comment.CommentAt,
704 &comment.Body,
705 &commentCreatedStr,
706 )
707 if err != nil {
708 return nil, err
709 }
710
711 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
712 if err != nil {
713 return nil, err
714 }
715 comment.Created = commentCreatedTime
716
717 // Add the comment to its submission
718 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
719 submission.Comments = append(submission.Comments, comment)
720 }
721
722 }
723 if err = commentsRows.Err(); err != nil {
724 return nil, err
725 }
726
727 var pullSourceRepo *models.Repo
728 if pull.PullSource != nil {
729 if pull.PullSource.RepoAt != nil {
730 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
731 if err != nil {
732 log.Printf("failed to get repo by at uri: %v", err)
733 } else {
734 pull.PullSource.Repo = pullSourceRepo
735 }
736 }
737 }
738
739 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
740 for _, submission := range submissionsMap {
741 pull.Submissions[submission.RoundNumber] = submission
742 }
743
744 return &pull, nil
745}
746
747// timeframe here is directly passed into the sql query filter, and any
748// timeframe in the past should be negative; e.g.: "-3 months"
749func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
750 var pulls []Pull
751
752 rows, err := e.Query(`
753 select
754 p.owner_did,
755 p.repo_at,
756 p.pull_id,
757 p.created,
758 p.title,
759 p.state,
760 r.did,
761 r.name,
762 r.knot,
763 r.rkey,
764 r.created
765 from
766 pulls p
767 join
768 repos r on p.repo_at = r.at_uri
769 where
770 p.owner_did = ? and p.created >= date ('now', ?)
771 order by
772 p.created desc`, did, timeframe)
773 if err != nil {
774 return nil, err
775 }
776 defer rows.Close()
777
778 for rows.Next() {
779 var pull Pull
780 var repo models.Repo
781 var pullCreatedAt, repoCreatedAt string
782 err := rows.Scan(
783 &pull.OwnerDid,
784 &pull.RepoAt,
785 &pull.PullId,
786 &pullCreatedAt,
787 &pull.Title,
788 &pull.State,
789 &repo.Did,
790 &repo.Name,
791 &repo.Knot,
792 &repo.Rkey,
793 &repoCreatedAt,
794 )
795 if err != nil {
796 return nil, err
797 }
798
799 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
800 if err != nil {
801 return nil, err
802 }
803 pull.Created = pullCreatedTime
804
805 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
806 if err != nil {
807 return nil, err
808 }
809 repo.Created = repoCreatedTime
810
811 pull.Repo = &repo
812
813 pulls = append(pulls, pull)
814 }
815
816 if err := rows.Err(); err != nil {
817 return nil, err
818 }
819
820 return pulls, nil
821}
822
823func NewPullComment(e Execer, comment *PullComment) (int64, error) {
824 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
825 res, err := e.Exec(
826 query,
827 comment.OwnerDid,
828 comment.RepoAt,
829 comment.SubmissionId,
830 comment.CommentAt,
831 comment.PullId,
832 comment.Body,
833 )
834 if err != nil {
835 return 0, err
836 }
837
838 i, err := res.LastInsertId()
839 if err != nil {
840 return 0, err
841 }
842
843 return i, nil
844}
845
846func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
847 _, err := e.Exec(
848 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
849 pullState,
850 repoAt,
851 pullId,
852 PullDeleted, // only update state of non-deleted pulls
853 PullMerged, // only update state of non-merged pulls
854 )
855 return err
856}
857
858func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
859 err := SetPullState(e, repoAt, pullId, PullClosed)
860 return err
861}
862
863func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
864 err := SetPullState(e, repoAt, pullId, PullOpen)
865 return err
866}
867
868func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
869 err := SetPullState(e, repoAt, pullId, PullMerged)
870 return err
871}
872
873func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
874 err := SetPullState(e, repoAt, pullId, PullDeleted)
875 return err
876}
877
878func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
879 newRoundNumber := len(pull.Submissions)
880 _, err := e.Exec(`
881 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
882 values (?, ?, ?, ?, ?)
883 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
884
885 return err
886}
887
888func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
889 var conditions []string
890 var args []any
891
892 args = append(args, parentChangeId)
893
894 for _, filter := range filters {
895 conditions = append(conditions, filter.Condition())
896 args = append(args, filter.Arg()...)
897 }
898
899 whereClause := ""
900 if conditions != nil {
901 whereClause = " where " + strings.Join(conditions, " and ")
902 }
903
904 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
905 _, err := e.Exec(query, args...)
906
907 return err
908}
909
910// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
911// otherwise submissions are immutable
912func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
913 var conditions []string
914 var args []any
915
916 args = append(args, sourceRev)
917 args = append(args, newPatch)
918
919 for _, filter := range filters {
920 conditions = append(conditions, filter.Condition())
921 args = append(args, filter.Arg()...)
922 }
923
924 whereClause := ""
925 if conditions != nil {
926 whereClause = " where " + strings.Join(conditions, " and ")
927 }
928
929 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
930 _, err := e.Exec(query, args...)
931
932 return err
933}
934
935func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
936 row := e.QueryRow(`
937 select
938 count(case when state = ? then 1 end) as open_count,
939 count(case when state = ? then 1 end) as merged_count,
940 count(case when state = ? then 1 end) as closed_count,
941 count(case when state = ? then 1 end) as deleted_count
942 from pulls
943 where repo_at = ?`,
944 PullOpen,
945 PullMerged,
946 PullClosed,
947 PullDeleted,
948 repoAt,
949 )
950
951 var count models.PullCount
952 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
953 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
954 }
955
956 return count, nil
957}
958
959type Stack []*Pull
960
961// change-id parent-change-id
962//
963// 4 w ,-------- z (TOP)
964// 3 z <----',------- y
965// 2 y <-----',------ x
966// 1 x <------' nil (BOT)
967//
968// `w` is parent of none, so it is the top of the stack
969func GetStack(e Execer, stackId string) (Stack, error) {
970 unorderedPulls, err := GetPulls(
971 e,
972 FilterEq("stack_id", stackId),
973 FilterNotEq("state", PullDeleted),
974 )
975 if err != nil {
976 return nil, err
977 }
978 // map of parent-change-id to pull
979 changeIdMap := make(map[string]*Pull, len(unorderedPulls))
980 parentMap := make(map[string]*Pull, len(unorderedPulls))
981 for _, p := range unorderedPulls {
982 changeIdMap[p.ChangeId] = p
983 if p.ParentChangeId != "" {
984 parentMap[p.ParentChangeId] = p
985 }
986 }
987
988 // the top of the stack is the pull that is not a parent of any pull
989 var topPull *Pull
990 for _, maybeTop := range unorderedPulls {
991 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
992 topPull = maybeTop
993 break
994 }
995 }
996
997 pulls := []*Pull{}
998 for {
999 pulls = append(pulls, topPull)
1000 if topPull.ParentChangeId != "" {
1001 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
1002 topPull = next
1003 } else {
1004 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
1005 }
1006 } else {
1007 break
1008 }
1009 }
1010
1011 return pulls, nil
1012}
1013
1014func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) {
1015 pulls, err := GetPulls(
1016 e,
1017 FilterEq("stack_id", stackId),
1018 FilterEq("state", PullDeleted),
1019 )
1020 if err != nil {
1021 return nil, err
1022 }
1023
1024 return pulls, nil
1025}
1026
1027// position of this pull in the stack
1028func (stack Stack) Position(pull *Pull) int {
1029 return slices.IndexFunc(stack, func(p *Pull) bool {
1030 return p.ChangeId == pull.ChangeId
1031 })
1032}
1033
1034// all pulls below this pull (including self) in this stack
1035//
1036// nil if this pull does not belong to this stack
1037func (stack Stack) Below(pull *Pull) Stack {
1038 position := stack.Position(pull)
1039
1040 if position < 0 {
1041 return nil
1042 }
1043
1044 return stack[position:]
1045}
1046
1047// all pulls below this pull (excluding self) in this stack
1048func (stack Stack) StrictlyBelow(pull *Pull) Stack {
1049 below := stack.Below(pull)
1050
1051 if len(below) > 0 {
1052 return below[1:]
1053 }
1054
1055 return nil
1056}
1057
1058// all pulls above this pull (including self) in this stack
1059func (stack Stack) Above(pull *Pull) Stack {
1060 position := stack.Position(pull)
1061
1062 if position < 0 {
1063 return nil
1064 }
1065
1066 return stack[:position+1]
1067}
1068
1069// all pulls below this pull (excluding self) in this stack
1070func (stack Stack) StrictlyAbove(pull *Pull) Stack {
1071 above := stack.Above(pull)
1072
1073 if len(above) > 0 {
1074 return above[:len(above)-1]
1075 }
1076
1077 return nil
1078}
1079
1080// the combined format-patches of all the newest submissions in this stack
1081func (stack Stack) CombinedPatch() string {
1082 // go in reverse order because the bottom of the stack is the last element in the slice
1083 var combined strings.Builder
1084 for idx := range stack {
1085 pull := stack[len(stack)-1-idx]
1086 combined.WriteString(pull.LatestPatch())
1087 combined.WriteString("\n")
1088 }
1089 return combined.String()
1090}
1091
1092// filter out PRs that are "active"
1093//
1094// PRs that are still open are active
1095func (stack Stack) Mergeable() Stack {
1096 var mergeable Stack
1097
1098 for _, p := range stack {
1099 // stop at the first merged PR
1100 if p.State == PullMerged || p.State == PullClosed {
1101 break
1102 }
1103
1104 // skip over deleted PRs
1105 if p.State != PullDeleted {
1106 mergeable = append(mergeable, p)
1107 }
1108 }
1109
1110 return mergeable
1111}