this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "slices" 8 "sort" 9 "strings" 10 "time" 11 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.sh/tangled.sh/core/api/tangled" 14 "tangled.sh/tangled.sh/core/patchutil" 15 "tangled.sh/tangled.sh/core/types" 16) 17 18type PullState int 19 20const ( 21 PullClosed PullState = iota 22 PullOpen 23 PullMerged 24 PullDeleted 25) 26 27func (p PullState) String() string { 28 switch p { 29 case PullOpen: 30 return "open" 31 case PullMerged: 32 return "merged" 33 case PullClosed: 34 return "closed" 35 case PullDeleted: 36 return "deleted" 37 default: 38 return "closed" 39 } 40} 41 42func (p PullState) IsOpen() bool { 43 return p == PullOpen 44} 45func (p PullState) IsMerged() bool { 46 return p == PullMerged 47} 48func (p PullState) IsClosed() bool { 49 return p == PullClosed 50} 51func (p PullState) IsDeleted() bool { 52 return p == PullDeleted 53} 54 55type Pull struct { 56 // ids 57 ID int 58 PullId int 59 60 // at ids 61 RepoAt syntax.ATURI 62 OwnerDid string 63 Rkey string 64 65 // content 66 Title string 67 Body string 68 TargetBranch string 69 State PullState 70 Submissions []*PullSubmission 71 72 // stacking 73 StackId string // nullable string 74 ChangeId string // nullable string 75 ParentChangeId string // nullable string 76 77 // meta 78 Created time.Time 79 PullSource *PullSource 80 81 // optionally, populate this when querying for reverse mappings 82 Repo *Repo 83} 84 85func (p Pull) AsRecord() tangled.RepoPull { 86 var source *tangled.RepoPull_Source 87 if p.PullSource != nil { 88 s := p.PullSource.AsRecord() 89 source = &s 90 source.Sha = p.LatestSha() 91 } 92 93 record := tangled.RepoPull{ 94 Title: p.Title, 95 Body: &p.Body, 96 CreatedAt: p.Created.Format(time.RFC3339), 97 TargetRepo: p.RepoAt.String(), 98 TargetBranch: p.TargetBranch, 99 Patch: p.LatestPatch(), 100 Source: source, 101 } 102 return record 103} 104 105type PullSource struct { 106 Branch string 107 RepoAt *syntax.ATURI 108 109 // optionally populate this for reverse mappings 110 Repo *Repo 111} 112 113func (p PullSource) AsRecord() tangled.RepoPull_Source { 114 var repoAt *string 115 if p.RepoAt != nil { 116 s := p.RepoAt.String() 117 repoAt = &s 118 } 119 record := tangled.RepoPull_Source{ 120 Branch: p.Branch, 121 Repo: repoAt, 122 } 123 return record 124} 125 126type PullSubmission struct { 127 // ids 128 ID int 129 PullId int 130 131 // at ids 132 RepoAt syntax.ATURI 133 134 // content 135 RoundNumber int 136 Patch string 137 Comments []PullComment 138 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs 139 140 // meta 141 Created time.Time 142} 143 144type PullComment struct { 145 // ids 146 ID int 147 PullId int 148 SubmissionId int 149 150 // at ids 151 RepoAt string 152 OwnerDid string 153 CommentAt string 154 155 // content 156 Body string 157 158 // meta 159 Created time.Time 160} 161 162func (p *Pull) LatestPatch() string { 163 latestSubmission := p.Submissions[p.LastRoundNumber()] 164 return latestSubmission.Patch 165} 166 167func (p *Pull) LatestSha() string { 168 latestSubmission := p.Submissions[p.LastRoundNumber()] 169 return latestSubmission.SourceRev 170} 171 172func (p *Pull) PullAt() syntax.ATURI { 173 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey)) 174} 175 176func (p *Pull) LastRoundNumber() int { 177 return len(p.Submissions) - 1 178} 179 180func (p *Pull) IsPatchBased() bool { 181 return p.PullSource == nil 182} 183 184func (p *Pull) IsBranchBased() bool { 185 if p.PullSource != nil { 186 if p.PullSource.RepoAt != nil { 187 return p.PullSource.RepoAt == &p.RepoAt 188 } else { 189 // no repo specified 190 return true 191 } 192 } 193 return false 194} 195 196func (p *Pull) IsForkBased() bool { 197 if p.PullSource != nil { 198 if p.PullSource.RepoAt != nil { 199 // make sure repos are different 200 return p.PullSource.RepoAt != &p.RepoAt 201 } 202 } 203 return false 204} 205 206func (p *Pull) IsStacked() bool { 207 return p.StackId != "" 208} 209 210func (s PullSubmission) IsFormatPatch() bool { 211 return patchutil.IsFormatPatch(s.Patch) 212} 213 214func (s PullSubmission) AsFormatPatch() []types.FormatPatch { 215 patches, err := patchutil.ExtractPatches(s.Patch) 216 if err != nil { 217 log.Println("error extracting patches from submission:", err) 218 return []types.FormatPatch{} 219 } 220 221 return patches 222} 223 224func NewPull(tx *sql.Tx, pull *Pull) error { 225 _, err := tx.Exec(` 226 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 227 values (?, 1) 228 `, pull.RepoAt) 229 if err != nil { 230 return err 231 } 232 233 var nextId int 234 err = tx.QueryRow(` 235 update repo_pull_seqs 236 set next_pull_id = next_pull_id + 1 237 where repo_at = ? 238 returning next_pull_id - 1 239 `, pull.RepoAt).Scan(&nextId) 240 if err != nil { 241 return err 242 } 243 244 pull.PullId = nextId 245 pull.State = PullOpen 246 247 var sourceBranch, sourceRepoAt *string 248 if pull.PullSource != nil { 249 sourceBranch = &pull.PullSource.Branch 250 if pull.PullSource.RepoAt != nil { 251 x := pull.PullSource.RepoAt.String() 252 sourceRepoAt = &x 253 } 254 } 255 256 var stackId, changeId, parentChangeId *string 257 if pull.StackId != "" { 258 stackId = &pull.StackId 259 } 260 if pull.ChangeId != "" { 261 changeId = &pull.ChangeId 262 } 263 if pull.ParentChangeId != "" { 264 parentChangeId = &pull.ParentChangeId 265 } 266 267 _, err = tx.Exec( 268 ` 269 insert into pulls ( 270 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 271 ) 272 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 273 pull.RepoAt, 274 pull.OwnerDid, 275 pull.PullId, 276 pull.Title, 277 pull.TargetBranch, 278 pull.Body, 279 pull.Rkey, 280 pull.State, 281 sourceBranch, 282 sourceRepoAt, 283 stackId, 284 changeId, 285 parentChangeId, 286 ) 287 if err != nil { 288 return err 289 } 290 291 _, err = tx.Exec(` 292 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 293 values (?, ?, ?, ?, ?) 294 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 295 return err 296} 297 298func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 299 pull, err := GetPull(e, repoAt, pullId) 300 if err != nil { 301 return "", err 302 } 303 return pull.PullAt(), err 304} 305 306func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 307 var pullId int 308 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 309 return pullId - 1, err 310} 311 312func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*Pull, error) { 313 pulls := make(map[int]*Pull) 314 315 var conditions []string 316 var args []any 317 for _, filter := range filters { 318 conditions = append(conditions, filter.Condition()) 319 args = append(args, filter.Arg()...) 320 } 321 322 whereClause := "" 323 if conditions != nil { 324 whereClause = " where " + strings.Join(conditions, " and ") 325 } 326 limitClause := "" 327 if limit != 0 { 328 limitClause = fmt.Sprintf(" limit %d ", limit) 329 } 330 331 query := fmt.Sprintf(` 332 select 333 owner_did, 334 repo_at, 335 pull_id, 336 created, 337 title, 338 state, 339 target_branch, 340 body, 341 rkey, 342 source_branch, 343 source_repo_at, 344 stack_id, 345 change_id, 346 parent_change_id 347 from 348 pulls 349 %s 350 order by 351 created desc 352 %s 353 `, whereClause, limitClause) 354 355 rows, err := e.Query(query, args...) 356 if err != nil { 357 return nil, err 358 } 359 defer rows.Close() 360 361 for rows.Next() { 362 var pull Pull 363 var createdAt string 364 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 365 err := rows.Scan( 366 &pull.OwnerDid, 367 &pull.RepoAt, 368 &pull.PullId, 369 &createdAt, 370 &pull.Title, 371 &pull.State, 372 &pull.TargetBranch, 373 &pull.Body, 374 &pull.Rkey, 375 &sourceBranch, 376 &sourceRepoAt, 377 &stackId, 378 &changeId, 379 &parentChangeId, 380 ) 381 if err != nil { 382 return nil, err 383 } 384 385 createdTime, err := time.Parse(time.RFC3339, createdAt) 386 if err != nil { 387 return nil, err 388 } 389 pull.Created = createdTime 390 391 if sourceBranch.Valid { 392 pull.PullSource = &PullSource{ 393 Branch: sourceBranch.String, 394 } 395 if sourceRepoAt.Valid { 396 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 397 if err != nil { 398 return nil, err 399 } 400 pull.PullSource.RepoAt = &sourceRepoAtParsed 401 } 402 } 403 404 if stackId.Valid { 405 pull.StackId = stackId.String 406 } 407 if changeId.Valid { 408 pull.ChangeId = changeId.String 409 } 410 if parentChangeId.Valid { 411 pull.ParentChangeId = parentChangeId.String 412 } 413 414 pulls[pull.PullId] = &pull 415 } 416 417 // get latest round no. for each pull 418 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 419 submissionsQuery := fmt.Sprintf(` 420 select 421 id, pull_id, round_number, patch, created, source_rev 422 from 423 pull_submissions 424 where 425 repo_at in (%s) and pull_id in (%s) 426 `, inClause, inClause) 427 428 args = make([]any, len(pulls)*2) 429 idx := 0 430 for _, p := range pulls { 431 args[idx] = p.RepoAt 432 idx += 1 433 } 434 for _, p := range pulls { 435 args[idx] = p.PullId 436 idx += 1 437 } 438 submissionsRows, err := e.Query(submissionsQuery, args...) 439 if err != nil { 440 return nil, err 441 } 442 defer submissionsRows.Close() 443 444 for submissionsRows.Next() { 445 var s PullSubmission 446 var sourceRev sql.NullString 447 var createdAt string 448 err := submissionsRows.Scan( 449 &s.ID, 450 &s.PullId, 451 &s.RoundNumber, 452 &s.Patch, 453 &createdAt, 454 &sourceRev, 455 ) 456 if err != nil { 457 return nil, err 458 } 459 460 createdTime, err := time.Parse(time.RFC3339, createdAt) 461 if err != nil { 462 return nil, err 463 } 464 s.Created = createdTime 465 466 if sourceRev.Valid { 467 s.SourceRev = sourceRev.String 468 } 469 470 if p, ok := pulls[s.PullId]; ok { 471 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 472 p.Submissions[s.RoundNumber] = &s 473 } 474 } 475 if err := rows.Err(); err != nil { 476 return nil, err 477 } 478 479 // get comment count on latest submission on each pull 480 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 481 commentsQuery := fmt.Sprintf(` 482 select 483 count(id), pull_id 484 from 485 pull_comments 486 where 487 submission_id in (%s) 488 group by 489 submission_id 490 `, inClause) 491 492 args = []any{} 493 for _, p := range pulls { 494 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 495 } 496 commentsRows, err := e.Query(commentsQuery, args...) 497 if err != nil { 498 return nil, err 499 } 500 defer commentsRows.Close() 501 502 for commentsRows.Next() { 503 var commentCount, pullId int 504 err := commentsRows.Scan( 505 &commentCount, 506 &pullId, 507 ) 508 if err != nil { 509 return nil, err 510 } 511 if p, ok := pulls[pullId]; ok { 512 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 513 } 514 } 515 if err := rows.Err(); err != nil { 516 return nil, err 517 } 518 519 orderedByPullId := []*Pull{} 520 for _, p := range pulls { 521 orderedByPullId = append(orderedByPullId, p) 522 } 523 sort.Slice(orderedByPullId, func(i, j int) bool { 524 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 525 }) 526 527 return orderedByPullId, nil 528} 529 530func GetPulls(e Execer, filters ...filter) ([]*Pull, error) { 531 return GetPullsWithLimit(e, 0, filters...) 532} 533 534func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 535 query := ` 536 select 537 owner_did, 538 pull_id, 539 created, 540 title, 541 state, 542 target_branch, 543 repo_at, 544 body, 545 rkey, 546 source_branch, 547 source_repo_at, 548 stack_id, 549 change_id, 550 parent_change_id 551 from 552 pulls 553 where 554 repo_at = ? and pull_id = ? 555 ` 556 row := e.QueryRow(query, repoAt, pullId) 557 558 var pull Pull 559 var createdAt string 560 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 561 err := row.Scan( 562 &pull.OwnerDid, 563 &pull.PullId, 564 &createdAt, 565 &pull.Title, 566 &pull.State, 567 &pull.TargetBranch, 568 &pull.RepoAt, 569 &pull.Body, 570 &pull.Rkey, 571 &sourceBranch, 572 &sourceRepoAt, 573 &stackId, 574 &changeId, 575 &parentChangeId, 576 ) 577 if err != nil { 578 return nil, err 579 } 580 581 createdTime, err := time.Parse(time.RFC3339, createdAt) 582 if err != nil { 583 return nil, err 584 } 585 pull.Created = createdTime 586 587 // populate source 588 if sourceBranch.Valid { 589 pull.PullSource = &PullSource{ 590 Branch: sourceBranch.String, 591 } 592 if sourceRepoAt.Valid { 593 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 594 if err != nil { 595 return nil, err 596 } 597 pull.PullSource.RepoAt = &sourceRepoAtParsed 598 } 599 } 600 601 if stackId.Valid { 602 pull.StackId = stackId.String 603 } 604 if changeId.Valid { 605 pull.ChangeId = changeId.String 606 } 607 if parentChangeId.Valid { 608 pull.ParentChangeId = parentChangeId.String 609 } 610 611 submissionsQuery := ` 612 select 613 id, pull_id, repo_at, round_number, patch, created, source_rev 614 from 615 pull_submissions 616 where 617 repo_at = ? and pull_id = ? 618 ` 619 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 620 if err != nil { 621 return nil, err 622 } 623 defer submissionsRows.Close() 624 625 submissionsMap := make(map[int]*PullSubmission) 626 627 for submissionsRows.Next() { 628 var submission PullSubmission 629 var submissionCreatedStr string 630 var submissionSourceRev sql.NullString 631 err := submissionsRows.Scan( 632 &submission.ID, 633 &submission.PullId, 634 &submission.RepoAt, 635 &submission.RoundNumber, 636 &submission.Patch, 637 &submissionCreatedStr, 638 &submissionSourceRev, 639 ) 640 if err != nil { 641 return nil, err 642 } 643 644 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 645 if err != nil { 646 return nil, err 647 } 648 submission.Created = submissionCreatedTime 649 650 if submissionSourceRev.Valid { 651 submission.SourceRev = submissionSourceRev.String 652 } 653 654 submissionsMap[submission.ID] = &submission 655 } 656 if err = submissionsRows.Close(); err != nil { 657 return nil, err 658 } 659 if len(submissionsMap) == 0 { 660 return &pull, nil 661 } 662 663 var args []any 664 for k := range submissionsMap { 665 args = append(args, k) 666 } 667 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 668 commentsQuery := fmt.Sprintf(` 669 select 670 id, 671 pull_id, 672 submission_id, 673 repo_at, 674 owner_did, 675 comment_at, 676 body, 677 created 678 from 679 pull_comments 680 where 681 submission_id IN (%s) 682 order by 683 created asc 684 `, inClause) 685 commentsRows, err := e.Query(commentsQuery, args...) 686 if err != nil { 687 return nil, err 688 } 689 defer commentsRows.Close() 690 691 for commentsRows.Next() { 692 var comment PullComment 693 var commentCreatedStr string 694 err := commentsRows.Scan( 695 &comment.ID, 696 &comment.PullId, 697 &comment.SubmissionId, 698 &comment.RepoAt, 699 &comment.OwnerDid, 700 &comment.CommentAt, 701 &comment.Body, 702 &commentCreatedStr, 703 ) 704 if err != nil { 705 return nil, err 706 } 707 708 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 709 if err != nil { 710 return nil, err 711 } 712 comment.Created = commentCreatedTime 713 714 // Add the comment to its submission 715 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 716 submission.Comments = append(submission.Comments, comment) 717 } 718 719 } 720 if err = commentsRows.Err(); err != nil { 721 return nil, err 722 } 723 724 var pullSourceRepo *Repo 725 if pull.PullSource != nil { 726 if pull.PullSource.RepoAt != nil { 727 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 728 if err != nil { 729 log.Printf("failed to get repo by at uri: %v", err) 730 } else { 731 pull.PullSource.Repo = pullSourceRepo 732 } 733 } 734 } 735 736 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 737 for _, submission := range submissionsMap { 738 pull.Submissions[submission.RoundNumber] = submission 739 } 740 741 return &pull, nil 742} 743 744// timeframe here is directly passed into the sql query filter, and any 745// timeframe in the past should be negative; e.g.: "-3 months" 746func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 747 var pulls []Pull 748 749 rows, err := e.Query(` 750 select 751 p.owner_did, 752 p.repo_at, 753 p.pull_id, 754 p.created, 755 p.title, 756 p.state, 757 r.did, 758 r.name, 759 r.knot, 760 r.rkey, 761 r.created 762 from 763 pulls p 764 join 765 repos r on p.repo_at = r.at_uri 766 where 767 p.owner_did = ? and p.created >= date ('now', ?) 768 order by 769 p.created desc`, did, timeframe) 770 if err != nil { 771 return nil, err 772 } 773 defer rows.Close() 774 775 for rows.Next() { 776 var pull Pull 777 var repo Repo 778 var pullCreatedAt, repoCreatedAt string 779 err := rows.Scan( 780 &pull.OwnerDid, 781 &pull.RepoAt, 782 &pull.PullId, 783 &pullCreatedAt, 784 &pull.Title, 785 &pull.State, 786 &repo.Did, 787 &repo.Name, 788 &repo.Knot, 789 &repo.Rkey, 790 &repoCreatedAt, 791 ) 792 if err != nil { 793 return nil, err 794 } 795 796 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 797 if err != nil { 798 return nil, err 799 } 800 pull.Created = pullCreatedTime 801 802 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 803 if err != nil { 804 return nil, err 805 } 806 repo.Created = repoCreatedTime 807 808 pull.Repo = &repo 809 810 pulls = append(pulls, pull) 811 } 812 813 if err := rows.Err(); err != nil { 814 return nil, err 815 } 816 817 return pulls, nil 818} 819 820func NewPullComment(e Execer, comment *PullComment) (int64, error) { 821 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 822 res, err := e.Exec( 823 query, 824 comment.OwnerDid, 825 comment.RepoAt, 826 comment.SubmissionId, 827 comment.CommentAt, 828 comment.PullId, 829 comment.Body, 830 ) 831 if err != nil { 832 return 0, err 833 } 834 835 i, err := res.LastInsertId() 836 if err != nil { 837 return 0, err 838 } 839 840 return i, nil 841} 842 843func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 844 _, err := e.Exec( 845 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 846 pullState, 847 repoAt, 848 pullId, 849 PullDeleted, // only update state of non-deleted pulls 850 PullMerged, // only update state of non-merged pulls 851 ) 852 return err 853} 854 855func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 856 err := SetPullState(e, repoAt, pullId, PullClosed) 857 return err 858} 859 860func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 861 err := SetPullState(e, repoAt, pullId, PullOpen) 862 return err 863} 864 865func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 866 err := SetPullState(e, repoAt, pullId, PullMerged) 867 return err 868} 869 870func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 871 err := SetPullState(e, repoAt, pullId, PullDeleted) 872 return err 873} 874 875func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 876 newRoundNumber := len(pull.Submissions) 877 _, err := e.Exec(` 878 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 879 values (?, ?, ?, ?, ?) 880 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 881 882 return err 883} 884 885func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 886 var conditions []string 887 var args []any 888 889 args = append(args, parentChangeId) 890 891 for _, filter := range filters { 892 conditions = append(conditions, filter.Condition()) 893 args = append(args, filter.Arg()...) 894 } 895 896 whereClause := "" 897 if conditions != nil { 898 whereClause = " where " + strings.Join(conditions, " and ") 899 } 900 901 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 902 _, err := e.Exec(query, args...) 903 904 return err 905} 906 907// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 908// otherwise submissions are immutable 909func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 910 var conditions []string 911 var args []any 912 913 args = append(args, sourceRev) 914 args = append(args, newPatch) 915 916 for _, filter := range filters { 917 conditions = append(conditions, filter.Condition()) 918 args = append(args, filter.Arg()...) 919 } 920 921 whereClause := "" 922 if conditions != nil { 923 whereClause = " where " + strings.Join(conditions, " and ") 924 } 925 926 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 927 _, err := e.Exec(query, args...) 928 929 return err 930} 931 932type PullCount struct { 933 Open int 934 Merged int 935 Closed int 936 Deleted int 937} 938 939func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) { 940 row := e.QueryRow(` 941 select 942 count(case when state = ? then 1 end) as open_count, 943 count(case when state = ? then 1 end) as merged_count, 944 count(case when state = ? then 1 end) as closed_count, 945 count(case when state = ? then 1 end) as deleted_count 946 from pulls 947 where repo_at = ?`, 948 PullOpen, 949 PullMerged, 950 PullClosed, 951 PullDeleted, 952 repoAt, 953 ) 954 955 var count PullCount 956 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 957 return PullCount{0, 0, 0, 0}, err 958 } 959 960 return count, nil 961} 962 963type Stack []*Pull 964 965// change-id parent-change-id 966// 967// 4 w ,-------- z (TOP) 968// 3 z <----',------- y 969// 2 y <-----',------ x 970// 1 x <------' nil (BOT) 971// 972// `w` is parent of none, so it is the top of the stack 973func GetStack(e Execer, stackId string) (Stack, error) { 974 unorderedPulls, err := GetPulls( 975 e, 976 FilterEq("stack_id", stackId), 977 FilterNotEq("state", PullDeleted), 978 ) 979 if err != nil { 980 return nil, err 981 } 982 // map of parent-change-id to pull 983 changeIdMap := make(map[string]*Pull, len(unorderedPulls)) 984 parentMap := make(map[string]*Pull, len(unorderedPulls)) 985 for _, p := range unorderedPulls { 986 changeIdMap[p.ChangeId] = p 987 if p.ParentChangeId != "" { 988 parentMap[p.ParentChangeId] = p 989 } 990 } 991 992 // the top of the stack is the pull that is not a parent of any pull 993 var topPull *Pull 994 for _, maybeTop := range unorderedPulls { 995 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 996 topPull = maybeTop 997 break 998 } 999 } 1000 1001 pulls := []*Pull{} 1002 for { 1003 pulls = append(pulls, topPull) 1004 if topPull.ParentChangeId != "" { 1005 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 1006 topPull = next 1007 } else { 1008 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 1009 } 1010 } else { 1011 break 1012 } 1013 } 1014 1015 return pulls, nil 1016} 1017 1018func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) { 1019 pulls, err := GetPulls( 1020 e, 1021 FilterEq("stack_id", stackId), 1022 FilterEq("state", PullDeleted), 1023 ) 1024 if err != nil { 1025 return nil, err 1026 } 1027 1028 return pulls, nil 1029} 1030 1031// position of this pull in the stack 1032func (stack Stack) Position(pull *Pull) int { 1033 return slices.IndexFunc(stack, func(p *Pull) bool { 1034 return p.ChangeId == pull.ChangeId 1035 }) 1036} 1037 1038// all pulls below this pull (including self) in this stack 1039// 1040// nil if this pull does not belong to this stack 1041func (stack Stack) Below(pull *Pull) Stack { 1042 position := stack.Position(pull) 1043 1044 if position < 0 { 1045 return nil 1046 } 1047 1048 return stack[position:] 1049} 1050 1051// all pulls below this pull (excluding self) in this stack 1052func (stack Stack) StrictlyBelow(pull *Pull) Stack { 1053 below := stack.Below(pull) 1054 1055 if len(below) > 0 { 1056 return below[1:] 1057 } 1058 1059 return nil 1060} 1061 1062// all pulls above this pull (including self) in this stack 1063func (stack Stack) Above(pull *Pull) Stack { 1064 position := stack.Position(pull) 1065 1066 if position < 0 { 1067 return nil 1068 } 1069 1070 return stack[:position+1] 1071} 1072 1073// all pulls below this pull (excluding self) in this stack 1074func (stack Stack) StrictlyAbove(pull *Pull) Stack { 1075 above := stack.Above(pull) 1076 1077 if len(above) > 0 { 1078 return above[:len(above)-1] 1079 } 1080 1081 return nil 1082} 1083 1084// the combined format-patches of all the newest submissions in this stack 1085func (stack Stack) CombinedPatch() string { 1086 // go in reverse order because the bottom of the stack is the last element in the slice 1087 var combined strings.Builder 1088 for idx := range stack { 1089 pull := stack[len(stack)-1-idx] 1090 combined.WriteString(pull.LatestPatch()) 1091 combined.WriteString("\n") 1092 } 1093 return combined.String() 1094} 1095 1096// filter out PRs that are "active" 1097// 1098// PRs that are still open are active 1099func (stack Stack) Mergeable() Stack { 1100 var mergeable Stack 1101 1102 for _, p := range stack { 1103 // stop at the first merged PR 1104 if p.State == PullMerged || p.State == PullClosed { 1105 break 1106 } 1107 1108 // skip over deleted PRs 1109 if p.State != PullDeleted { 1110 mergeable = append(mergeable, p) 1111 } 1112 } 1113 1114 return mergeable 1115}