this repo has no description
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluekeyes/go-gitdiff/gitdiff"
13 "github.com/bluesky-social/indigo/atproto/syntax"
14 "tangled.sh/tangled.sh/core/api/tangled"
15 "tangled.sh/tangled.sh/core/patchutil"
16 "tangled.sh/tangled.sh/core/types"
17)
18
19type PullState int
20
21const (
22 PullClosed PullState = iota
23 PullOpen
24 PullMerged
25)
26
27func (p PullState) String() string {
28 switch p {
29 case PullOpen:
30 return "open"
31 case PullMerged:
32 return "merged"
33 case PullClosed:
34 return "closed"
35 default:
36 return "closed"
37 }
38}
39
40func (p PullState) IsOpen() bool {
41 return p == PullOpen
42}
43func (p PullState) IsMerged() bool {
44 return p == PullMerged
45}
46func (p PullState) IsClosed() bool {
47 return p == PullClosed
48}
49
50type Pull struct {
51 // ids
52 ID int
53 PullId int
54
55 // at ids
56 RepoAt syntax.ATURI
57 OwnerDid string
58 Rkey string
59
60 // content
61 Title string
62 Body string
63 TargetBranch string
64 State PullState
65 Submissions []*PullSubmission
66
67 // stacking
68 StackId string // nullable string
69 ChangeId string // nullable string
70 ParentChangeId string // nullable string
71
72 // meta
73 Created time.Time
74 PullSource *PullSource
75
76 // optionally, populate this when querying for reverse mappings
77 Repo *Repo
78}
79
80type PullSource struct {
81 Branch string
82 RepoAt *syntax.ATURI
83
84 // optionally populate this for reverse mappings
85 Repo *Repo
86}
87
88type PullSubmission struct {
89 // ids
90 ID int
91 PullId int
92
93 // at ids
94 RepoAt syntax.ATURI
95
96 // content
97 RoundNumber int
98 Patch string
99 Comments []PullComment
100 SourceRev string // include the rev that was used to create this submission: only for branch PRs
101
102 // meta
103 Created time.Time
104}
105
106type PullComment struct {
107 // ids
108 ID int
109 PullId int
110 SubmissionId int
111
112 // at ids
113 RepoAt string
114 OwnerDid string
115 CommentAt string
116
117 // content
118 Body string
119
120 // meta
121 Created time.Time
122}
123
124func (p *Pull) LatestPatch() string {
125 latestSubmission := p.Submissions[p.LastRoundNumber()]
126 return latestSubmission.Patch
127}
128
129func (p *Pull) PullAt() syntax.ATURI {
130 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
131}
132
133func (p *Pull) LastRoundNumber() int {
134 return len(p.Submissions) - 1
135}
136
137func (p *Pull) IsPatchBased() bool {
138 return p.PullSource == nil
139}
140
141func (p *Pull) IsBranchBased() bool {
142 if p.PullSource != nil {
143 if p.PullSource.RepoAt != nil {
144 return p.PullSource.RepoAt == &p.RepoAt
145 } else {
146 // no repo specified
147 return true
148 }
149 }
150 return false
151}
152
153func (p *Pull) IsForkBased() bool {
154 if p.PullSource != nil {
155 if p.PullSource.RepoAt != nil {
156 // make sure repos are different
157 return p.PullSource.RepoAt != &p.RepoAt
158 }
159 }
160 return false
161}
162
163func (p *Pull) IsStacked() bool {
164 return p.StackId != ""
165}
166
167func (s PullSubmission) AsDiff(targetBranch string) ([]*gitdiff.File, error) {
168 patch := s.Patch
169
170 // if format-patch; then extract each patch
171 var diffs []*gitdiff.File
172 if patchutil.IsFormatPatch(patch) {
173 patches, err := patchutil.ExtractPatches(patch)
174 if err != nil {
175 return nil, err
176 }
177 var ps [][]*gitdiff.File
178 for _, p := range patches {
179 ps = append(ps, p.Files)
180 }
181
182 diffs = patchutil.CombineDiff(ps...)
183 } else {
184 d, _, err := gitdiff.Parse(strings.NewReader(patch))
185 if err != nil {
186 return nil, err
187 }
188 diffs = d
189 }
190
191 return diffs, nil
192}
193
194func (s PullSubmission) AsNiceDiff(targetBranch string) types.NiceDiff {
195 diffs, err := s.AsDiff(targetBranch)
196 if err != nil {
197 log.Println(err)
198 }
199
200 nd := types.NiceDiff{}
201 nd.Commit.Parent = targetBranch
202
203 for _, d := range diffs {
204 ndiff := types.Diff{}
205 ndiff.Name.New = d.NewName
206 ndiff.Name.Old = d.OldName
207 ndiff.IsBinary = d.IsBinary
208 ndiff.IsNew = d.IsNew
209 ndiff.IsDelete = d.IsDelete
210 ndiff.IsCopy = d.IsCopy
211 ndiff.IsRename = d.IsRename
212
213 for _, tf := range d.TextFragments {
214 ndiff.TextFragments = append(ndiff.TextFragments, *tf)
215 for _, l := range tf.Lines {
216 switch l.Op {
217 case gitdiff.OpAdd:
218 nd.Stat.Insertions += 1
219 case gitdiff.OpDelete:
220 nd.Stat.Deletions += 1
221 }
222 }
223 }
224
225 nd.Diff = append(nd.Diff, ndiff)
226 }
227
228 nd.Stat.FilesChanged = len(diffs)
229
230 return nd
231}
232
233func (s PullSubmission) IsFormatPatch() bool {
234 return patchutil.IsFormatPatch(s.Patch)
235}
236
237func (s PullSubmission) AsFormatPatch() []patchutil.FormatPatch {
238 patches, err := patchutil.ExtractPatches(s.Patch)
239 if err != nil {
240 log.Println("error extracting patches from submission:", err)
241 return []patchutil.FormatPatch{}
242 }
243
244 return patches
245}
246
247func NewPull(tx *sql.Tx, pull *Pull) error {
248 _, err := tx.Exec(`
249 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
250 values (?, 1)
251 `, pull.RepoAt)
252 if err != nil {
253 return err
254 }
255
256 var nextId int
257 err = tx.QueryRow(`
258 update repo_pull_seqs
259 set next_pull_id = next_pull_id + 1
260 where repo_at = ?
261 returning next_pull_id - 1
262 `, pull.RepoAt).Scan(&nextId)
263 if err != nil {
264 return err
265 }
266
267 pull.PullId = nextId
268 pull.State = PullOpen
269
270 var sourceBranch, sourceRepoAt *string
271 if pull.PullSource != nil {
272 sourceBranch = &pull.PullSource.Branch
273 if pull.PullSource.RepoAt != nil {
274 x := pull.PullSource.RepoAt.String()
275 sourceRepoAt = &x
276 }
277 }
278
279 var stackId, changeId, parentChangeId *string
280 if pull.StackId != "" {
281 stackId = &pull.StackId
282 }
283 if pull.ChangeId != "" {
284 changeId = &pull.ChangeId
285 }
286 if pull.ParentChangeId != "" {
287 parentChangeId = &pull.ParentChangeId
288 }
289
290 _, err = tx.Exec(
291 `
292 insert into pulls (
293 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
294 )
295 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
296 pull.RepoAt,
297 pull.OwnerDid,
298 pull.PullId,
299 pull.Title,
300 pull.TargetBranch,
301 pull.Body,
302 pull.Rkey,
303 pull.State,
304 sourceBranch,
305 sourceRepoAt,
306 stackId,
307 changeId,
308 parentChangeId,
309 )
310 if err != nil {
311 return err
312 }
313
314 _, err = tx.Exec(`
315 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
316 values (?, ?, ?, ?, ?)
317 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
318 return err
319}
320
321func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
322 pull, err := GetPull(e, repoAt, pullId)
323 if err != nil {
324 return "", err
325 }
326 return pull.PullAt(), err
327}
328
329func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
330 var pullId int
331 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
332 return pullId - 1, err
333}
334
335func GetPulls(e Execer, filters ...filter) ([]*Pull, error) {
336 pulls := make(map[int]*Pull)
337
338 var conditions []string
339 var args []any
340 for _, filter := range filters {
341 conditions = append(conditions, filter.Condition())
342 args = append(args, filter.arg)
343 }
344
345 whereClause := ""
346 if conditions != nil {
347 whereClause = " where " + strings.Join(conditions, " and ")
348 }
349
350 query := fmt.Sprintf(`
351 select
352 owner_did,
353 repo_at,
354 pull_id,
355 created,
356 title,
357 state,
358 target_branch,
359 body,
360 rkey,
361 source_branch,
362 source_repo_at,
363 stack_id,
364 change_id,
365 parent_change_id
366 from
367 pulls
368 %s
369 `, whereClause)
370
371 rows, err := e.Query(query, args...)
372 if err != nil {
373 return nil, err
374 }
375 defer rows.Close()
376
377 for rows.Next() {
378 var pull Pull
379 var createdAt string
380 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
381 err := rows.Scan(
382 &pull.OwnerDid,
383 &pull.RepoAt,
384 &pull.PullId,
385 &createdAt,
386 &pull.Title,
387 &pull.State,
388 &pull.TargetBranch,
389 &pull.Body,
390 &pull.Rkey,
391 &sourceBranch,
392 &sourceRepoAt,
393 &stackId,
394 &changeId,
395 &parentChangeId,
396 )
397 if err != nil {
398 return nil, err
399 }
400
401 createdTime, err := time.Parse(time.RFC3339, createdAt)
402 if err != nil {
403 return nil, err
404 }
405 pull.Created = createdTime
406
407 if sourceBranch.Valid {
408 pull.PullSource = &PullSource{
409 Branch: sourceBranch.String,
410 }
411 if sourceRepoAt.Valid {
412 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
413 if err != nil {
414 return nil, err
415 }
416 pull.PullSource.RepoAt = &sourceRepoAtParsed
417 }
418 }
419
420 if stackId.Valid {
421 pull.StackId = stackId.String
422 }
423 if changeId.Valid {
424 pull.ChangeId = changeId.String
425 }
426 if parentChangeId.Valid {
427 pull.ParentChangeId = parentChangeId.String
428 }
429
430 pulls[pull.PullId] = &pull
431 }
432
433 // get latest round no. for each pull
434 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
435 submissionsQuery := fmt.Sprintf(`
436 select
437 id, pull_id, round_number, patch
438 from
439 pull_submissions
440 where
441 repo_at in (%s) and pull_id in (%s)
442 `, inClause, inClause)
443
444 args = make([]any, len(pulls)*2)
445 idx := 0
446 for _, p := range pulls {
447 args[idx] = p.RepoAt
448 idx += 1
449 }
450 for _, p := range pulls {
451 args[idx] = p.PullId
452 idx += 1
453 }
454 submissionsRows, err := e.Query(submissionsQuery, args...)
455 if err != nil {
456 return nil, err
457 }
458 defer submissionsRows.Close()
459
460 for submissionsRows.Next() {
461 var s PullSubmission
462 err := submissionsRows.Scan(
463 &s.ID,
464 &s.PullId,
465 &s.RoundNumber,
466 &s.Patch,
467 )
468 if err != nil {
469 return nil, err
470 }
471
472 if p, ok := pulls[s.PullId]; ok {
473 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
474 p.Submissions[s.RoundNumber] = &s
475 }
476 }
477 if err := rows.Err(); err != nil {
478 return nil, err
479 }
480
481 // get comment count on latest submission on each pull
482 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
483 commentsQuery := fmt.Sprintf(`
484 select
485 count(id), pull_id
486 from
487 pull_comments
488 where
489 submission_id in (%s)
490 group by
491 submission_id
492 `, inClause)
493
494 args = []any{}
495 for _, p := range pulls {
496 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
497 }
498 commentsRows, err := e.Query(commentsQuery, args...)
499 if err != nil {
500 return nil, err
501 }
502 defer commentsRows.Close()
503
504 for commentsRows.Next() {
505 var commentCount, pullId int
506 err := commentsRows.Scan(
507 &commentCount,
508 &pullId,
509 )
510 if err != nil {
511 return nil, err
512 }
513 if p, ok := pulls[pullId]; ok {
514 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
515 }
516 }
517 if err := rows.Err(); err != nil {
518 return nil, err
519 }
520
521 orderedByPullId := []*Pull{}
522 for _, p := range pulls {
523 orderedByPullId = append(orderedByPullId, p)
524 }
525 sort.Slice(orderedByPullId, func(i, j int) bool {
526 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
527 })
528
529 return orderedByPullId, nil
530}
531
532func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
533 query := `
534 select
535 owner_did,
536 pull_id,
537 created,
538 title,
539 state,
540 target_branch,
541 repo_at,
542 body,
543 rkey,
544 source_branch,
545 source_repo_at,
546 stack_id,
547 change_id,
548 parent_change_id
549 from
550 pulls
551 where
552 repo_at = ? and pull_id = ?
553 `
554 row := e.QueryRow(query, repoAt, pullId)
555
556 var pull Pull
557 var createdAt string
558 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
559 err := row.Scan(
560 &pull.OwnerDid,
561 &pull.PullId,
562 &createdAt,
563 &pull.Title,
564 &pull.State,
565 &pull.TargetBranch,
566 &pull.RepoAt,
567 &pull.Body,
568 &pull.Rkey,
569 &sourceBranch,
570 &sourceRepoAt,
571 &stackId,
572 &changeId,
573 &parentChangeId,
574 )
575 if err != nil {
576 return nil, err
577 }
578
579 createdTime, err := time.Parse(time.RFC3339, createdAt)
580 if err != nil {
581 return nil, err
582 }
583 pull.Created = createdTime
584
585 // populate source
586 if sourceBranch.Valid {
587 pull.PullSource = &PullSource{
588 Branch: sourceBranch.String,
589 }
590 if sourceRepoAt.Valid {
591 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
592 if err != nil {
593 return nil, err
594 }
595 pull.PullSource.RepoAt = &sourceRepoAtParsed
596 }
597 }
598
599 if stackId.Valid {
600 pull.StackId = stackId.String
601 }
602 if changeId.Valid {
603 pull.ChangeId = changeId.String
604 }
605 if parentChangeId.Valid {
606 pull.ParentChangeId = parentChangeId.String
607 }
608
609 submissionsQuery := `
610 select
611 id, pull_id, repo_at, round_number, patch, created, source_rev
612 from
613 pull_submissions
614 where
615 repo_at = ? and pull_id = ?
616 `
617 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
618 if err != nil {
619 return nil, err
620 }
621 defer submissionsRows.Close()
622
623 submissionsMap := make(map[int]*PullSubmission)
624
625 for submissionsRows.Next() {
626 var submission PullSubmission
627 var submissionCreatedStr string
628 var submissionSourceRev sql.NullString
629 err := submissionsRows.Scan(
630 &submission.ID,
631 &submission.PullId,
632 &submission.RepoAt,
633 &submission.RoundNumber,
634 &submission.Patch,
635 &submissionCreatedStr,
636 &submissionSourceRev,
637 )
638 if err != nil {
639 return nil, err
640 }
641
642 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
643 if err != nil {
644 return nil, err
645 }
646 submission.Created = submissionCreatedTime
647
648 if submissionSourceRev.Valid {
649 submission.SourceRev = submissionSourceRev.String
650 }
651
652 submissionsMap[submission.ID] = &submission
653 }
654 if err = submissionsRows.Close(); err != nil {
655 return nil, err
656 }
657 if len(submissionsMap) == 0 {
658 return &pull, nil
659 }
660
661 var args []any
662 for k := range submissionsMap {
663 args = append(args, k)
664 }
665 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
666 commentsQuery := fmt.Sprintf(`
667 select
668 id,
669 pull_id,
670 submission_id,
671 repo_at,
672 owner_did,
673 comment_at,
674 body,
675 created
676 from
677 pull_comments
678 where
679 submission_id IN (%s)
680 order by
681 created asc
682 `, inClause)
683 commentsRows, err := e.Query(commentsQuery, args...)
684 if err != nil {
685 return nil, err
686 }
687 defer commentsRows.Close()
688
689 for commentsRows.Next() {
690 var comment PullComment
691 var commentCreatedStr string
692 err := commentsRows.Scan(
693 &comment.ID,
694 &comment.PullId,
695 &comment.SubmissionId,
696 &comment.RepoAt,
697 &comment.OwnerDid,
698 &comment.CommentAt,
699 &comment.Body,
700 &commentCreatedStr,
701 )
702 if err != nil {
703 return nil, err
704 }
705
706 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
707 if err != nil {
708 return nil, err
709 }
710 comment.Created = commentCreatedTime
711
712 // Add the comment to its submission
713 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
714 submission.Comments = append(submission.Comments, comment)
715 }
716
717 }
718 if err = commentsRows.Err(); err != nil {
719 return nil, err
720 }
721
722 var pullSourceRepo *Repo
723 if pull.PullSource != nil {
724 if pull.PullSource.RepoAt != nil {
725 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
726 if err != nil {
727 log.Printf("failed to get repo by at uri: %v", err)
728 } else {
729 pull.PullSource.Repo = pullSourceRepo
730 }
731 }
732 }
733
734 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
735 for _, submission := range submissionsMap {
736 pull.Submissions[submission.RoundNumber] = submission
737 }
738
739 return &pull, nil
740}
741
742// timeframe here is directly passed into the sql query filter, and any
743// timeframe in the past should be negative; e.g.: "-3 months"
744func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
745 var pulls []Pull
746
747 rows, err := e.Query(`
748 select
749 p.owner_did,
750 p.repo_at,
751 p.pull_id,
752 p.created,
753 p.title,
754 p.state,
755 r.did,
756 r.name,
757 r.knot,
758 r.rkey,
759 r.created
760 from
761 pulls p
762 join
763 repos r on p.repo_at = r.at_uri
764 where
765 p.owner_did = ? and p.created >= date ('now', ?)
766 order by
767 p.created desc`, did, timeframe)
768 if err != nil {
769 return nil, err
770 }
771 defer rows.Close()
772
773 for rows.Next() {
774 var pull Pull
775 var repo Repo
776 var pullCreatedAt, repoCreatedAt string
777 err := rows.Scan(
778 &pull.OwnerDid,
779 &pull.RepoAt,
780 &pull.PullId,
781 &pullCreatedAt,
782 &pull.Title,
783 &pull.State,
784 &repo.Did,
785 &repo.Name,
786 &repo.Knot,
787 &repo.Rkey,
788 &repoCreatedAt,
789 )
790 if err != nil {
791 return nil, err
792 }
793
794 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
795 if err != nil {
796 return nil, err
797 }
798 pull.Created = pullCreatedTime
799
800 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
801 if err != nil {
802 return nil, err
803 }
804 repo.Created = repoCreatedTime
805
806 pull.Repo = &repo
807
808 pulls = append(pulls, pull)
809 }
810
811 if err := rows.Err(); err != nil {
812 return nil, err
813 }
814
815 return pulls, nil
816}
817
818func NewPullComment(e Execer, comment *PullComment) (int64, error) {
819 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
820 res, err := e.Exec(
821 query,
822 comment.OwnerDid,
823 comment.RepoAt,
824 comment.SubmissionId,
825 comment.CommentAt,
826 comment.PullId,
827 comment.Body,
828 )
829 if err != nil {
830 return 0, err
831 }
832
833 i, err := res.LastInsertId()
834 if err != nil {
835 return 0, err
836 }
837
838 return i, nil
839}
840
841func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
842 _, err := e.Exec(`update pulls set state = ? where repo_at = ? and pull_id = ?`, pullState, repoAt, pullId)
843 return err
844}
845
846func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
847 err := SetPullState(e, repoAt, pullId, PullClosed)
848 return err
849}
850
851func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
852 err := SetPullState(e, repoAt, pullId, PullOpen)
853 return err
854}
855
856func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
857 err := SetPullState(e, repoAt, pullId, PullMerged)
858 return err
859}
860
861func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
862 newRoundNumber := len(pull.Submissions)
863 _, err := e.Exec(`
864 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
865 values (?, ?, ?, ?, ?)
866 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
867
868 return err
869}
870
871type PullCount struct {
872 Open int
873 Merged int
874 Closed int
875}
876
877func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
878 row := e.QueryRow(`
879 select
880 count(case when state = ? then 1 end) as open_count,
881 count(case when state = ? then 1 end) as merged_count,
882 count(case when state = ? then 1 end) as closed_count
883 from pulls
884 where repo_at = ?`,
885 PullOpen,
886 PullMerged,
887 PullClosed,
888 repoAt,
889 )
890
891 var count PullCount
892 if err := row.Scan(&count.Open, &count.Merged, &count.Closed); err != nil {
893 return PullCount{0, 0, 0}, err
894 }
895
896 return count, nil
897}
898
899type Stack []*Pull
900
901// change-id parent-change-id
902//
903// 4 w ,-------- z (TOP)
904// 3 z <----',------- y
905// 2 y <-----',------ x
906// 1 x <------' nil (BOT)
907//
908// `w` is parent of none, so it is the top of the stack
909func GetStack(e Execer, stackId string) (Stack, error) {
910 unorderedPulls, err := GetPulls(e, Filter("stack_id", stackId))
911 if err != nil {
912 return nil, err
913 }
914 // map of parent-change-id to pull
915 changeIdMap := make(map[string]*Pull, len(unorderedPulls))
916 parentMap := make(map[string]*Pull, len(unorderedPulls))
917 for _, p := range unorderedPulls {
918 changeIdMap[p.ChangeId] = p
919 if p.ParentChangeId != "" {
920 parentMap[p.ParentChangeId] = p
921 }
922 }
923
924 // the top of the stack is the pull that is not a parent of any pull
925 var topPull *Pull
926 for _, maybeTop := range unorderedPulls {
927 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
928 topPull = maybeTop
929 break
930 }
931 }
932
933 pulls := []*Pull{}
934 for {
935 pulls = append(pulls, topPull)
936 if topPull.ParentChangeId != "" {
937 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
938 topPull = next
939 } else {
940 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
941 }
942 } else {
943 break
944 }
945 }
946
947 return pulls, nil
948}
949
950// position of this pull in the stack
951func (stack Stack) Position(pull *Pull) int {
952 return slices.IndexFunc(stack, func(p *Pull) bool {
953 return p.ChangeId == pull.ChangeId
954 })
955}
956
957// all pulls below this pull (including self) in this stack
958//
959// nil if this pull does not belong to this stack
960func (stack Stack) Below(pull *Pull) Stack {
961 position := stack.Position(pull)
962
963 if position < 0 {
964 return nil
965 }
966
967 return stack[position:]
968}
969
970// all pulls below this pull (excluding self) in this stack
971func (stack Stack) StrictlyBelow(pull *Pull) Stack {
972 below := stack.Below(pull)
973
974 if len(below) > 0 {
975 return below[1:]
976 }
977
978 return nil
979}
980
981// the combined format-patches of all the newest submissions in this stack
982func (stack Stack) CombinedPatch() string {
983 // go in reverse order because the bottom of the stack is the last element in the slice
984 var combined strings.Builder
985 for idx := range stack {
986 pull := stack[len(stack)-1-idx]
987 combined.WriteString(pull.LatestPatch())
988 combined.WriteString("\n")
989 }
990 return combined.String()
991}