this repo has no description
1package db 2 3import ( 4 "database/sql" 5 "fmt" 6 "log" 7 "slices" 8 "sort" 9 "strings" 10 "time" 11 12 "github.com/bluesky-social/indigo/atproto/syntax" 13 "tangled.org/core/api/tangled" 14 "tangled.org/core/appview/models" 15 "tangled.org/core/patchutil" 16 "tangled.org/core/types" 17) 18 19type PullState int 20 21const ( 22 PullClosed PullState = iota 23 PullOpen 24 PullMerged 25 PullDeleted 26) 27 28func (p PullState) String() string { 29 switch p { 30 case PullOpen: 31 return "open" 32 case PullMerged: 33 return "merged" 34 case PullClosed: 35 return "closed" 36 case PullDeleted: 37 return "deleted" 38 default: 39 return "closed" 40 } 41} 42 43func (p PullState) IsOpen() bool { 44 return p == PullOpen 45} 46func (p PullState) IsMerged() bool { 47 return p == PullMerged 48} 49func (p PullState) IsClosed() bool { 50 return p == PullClosed 51} 52func (p PullState) IsDeleted() bool { 53 return p == PullDeleted 54} 55 56type Pull struct { 57 // ids 58 ID int 59 PullId int 60 61 // at ids 62 RepoAt syntax.ATURI 63 OwnerDid string 64 Rkey string 65 66 // content 67 Title string 68 Body string 69 TargetBranch string 70 State PullState 71 Submissions []*PullSubmission 72 73 // stacking 74 StackId string // nullable string 75 ChangeId string // nullable string 76 ParentChangeId string // nullable string 77 78 // meta 79 Created time.Time 80 PullSource *PullSource 81 82 // optionally, populate this when querying for reverse mappings 83 Repo *models.Repo 84} 85 86func (p Pull) AsRecord() tangled.RepoPull { 87 var source *tangled.RepoPull_Source 88 if p.PullSource != nil { 89 s := p.PullSource.AsRecord() 90 source = &s 91 source.Sha = p.LatestSha() 92 } 93 94 record := tangled.RepoPull{ 95 Title: p.Title, 96 Body: &p.Body, 97 CreatedAt: p.Created.Format(time.RFC3339), 98 Target: &tangled.RepoPull_Target{ 99 Repo: p.RepoAt.String(), 100 Branch: p.TargetBranch, 101 }, 102 Patch: p.LatestPatch(), 103 Source: source, 104 } 105 return record 106} 107 108type PullSource struct { 109 Branch string 110 RepoAt *syntax.ATURI 111 112 // optionally populate this for reverse mappings 113 Repo *models.Repo 114} 115 116func (p PullSource) AsRecord() tangled.RepoPull_Source { 117 var repoAt *string 118 if p.RepoAt != nil { 119 s := p.RepoAt.String() 120 repoAt = &s 121 } 122 record := tangled.RepoPull_Source{ 123 Branch: p.Branch, 124 Repo: repoAt, 125 } 126 return record 127} 128 129type PullSubmission struct { 130 // ids 131 ID int 132 PullId int 133 134 // at ids 135 RepoAt syntax.ATURI 136 137 // content 138 RoundNumber int 139 Patch string 140 Comments []PullComment 141 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs 142 143 // meta 144 Created time.Time 145} 146 147type PullComment struct { 148 // ids 149 ID int 150 PullId int 151 SubmissionId int 152 153 // at ids 154 RepoAt string 155 OwnerDid string 156 CommentAt string 157 158 // content 159 Body string 160 161 // meta 162 Created time.Time 163} 164 165func (p *Pull) LatestPatch() string { 166 latestSubmission := p.Submissions[p.LastRoundNumber()] 167 return latestSubmission.Patch 168} 169 170func (p *Pull) LatestSha() string { 171 latestSubmission := p.Submissions[p.LastRoundNumber()] 172 return latestSubmission.SourceRev 173} 174 175func (p *Pull) PullAt() syntax.ATURI { 176 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey)) 177} 178 179func (p *Pull) LastRoundNumber() int { 180 return len(p.Submissions) - 1 181} 182 183func (p *Pull) IsPatchBased() bool { 184 return p.PullSource == nil 185} 186 187func (p *Pull) IsBranchBased() bool { 188 if p.PullSource != nil { 189 if p.PullSource.RepoAt != nil { 190 return p.PullSource.RepoAt == &p.RepoAt 191 } else { 192 // no repo specified 193 return true 194 } 195 } 196 return false 197} 198 199func (p *Pull) IsForkBased() bool { 200 if p.PullSource != nil { 201 if p.PullSource.RepoAt != nil { 202 // make sure repos are different 203 return p.PullSource.RepoAt != &p.RepoAt 204 } 205 } 206 return false 207} 208 209func (p *Pull) IsStacked() bool { 210 return p.StackId != "" 211} 212 213func (s PullSubmission) IsFormatPatch() bool { 214 return patchutil.IsFormatPatch(s.Patch) 215} 216 217func (s PullSubmission) AsFormatPatch() []types.FormatPatch { 218 patches, err := patchutil.ExtractPatches(s.Patch) 219 if err != nil { 220 log.Println("error extracting patches from submission:", err) 221 return []types.FormatPatch{} 222 } 223 224 return patches 225} 226 227func NewPull(tx *sql.Tx, pull *Pull) error { 228 _, err := tx.Exec(` 229 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 230 values (?, 1) 231 `, pull.RepoAt) 232 if err != nil { 233 return err 234 } 235 236 var nextId int 237 err = tx.QueryRow(` 238 update repo_pull_seqs 239 set next_pull_id = next_pull_id + 1 240 where repo_at = ? 241 returning next_pull_id - 1 242 `, pull.RepoAt).Scan(&nextId) 243 if err != nil { 244 return err 245 } 246 247 pull.PullId = nextId 248 pull.State = PullOpen 249 250 var sourceBranch, sourceRepoAt *string 251 if pull.PullSource != nil { 252 sourceBranch = &pull.PullSource.Branch 253 if pull.PullSource.RepoAt != nil { 254 x := pull.PullSource.RepoAt.String() 255 sourceRepoAt = &x 256 } 257 } 258 259 var stackId, changeId, parentChangeId *string 260 if pull.StackId != "" { 261 stackId = &pull.StackId 262 } 263 if pull.ChangeId != "" { 264 changeId = &pull.ChangeId 265 } 266 if pull.ParentChangeId != "" { 267 parentChangeId = &pull.ParentChangeId 268 } 269 270 _, err = tx.Exec( 271 ` 272 insert into pulls ( 273 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 274 ) 275 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 276 pull.RepoAt, 277 pull.OwnerDid, 278 pull.PullId, 279 pull.Title, 280 pull.TargetBranch, 281 pull.Body, 282 pull.Rkey, 283 pull.State, 284 sourceBranch, 285 sourceRepoAt, 286 stackId, 287 changeId, 288 parentChangeId, 289 ) 290 if err != nil { 291 return err 292 } 293 294 _, err = tx.Exec(` 295 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 296 values (?, ?, ?, ?, ?) 297 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev) 298 return err 299} 300 301func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 302 pull, err := GetPull(e, repoAt, pullId) 303 if err != nil { 304 return "", err 305 } 306 return pull.PullAt(), err 307} 308 309func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 310 var pullId int 311 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 312 return pullId - 1, err 313} 314 315func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*Pull, error) { 316 pulls := make(map[int]*Pull) 317 318 var conditions []string 319 var args []any 320 for _, filter := range filters { 321 conditions = append(conditions, filter.Condition()) 322 args = append(args, filter.Arg()...) 323 } 324 325 whereClause := "" 326 if conditions != nil { 327 whereClause = " where " + strings.Join(conditions, " and ") 328 } 329 limitClause := "" 330 if limit != 0 { 331 limitClause = fmt.Sprintf(" limit %d ", limit) 332 } 333 334 query := fmt.Sprintf(` 335 select 336 owner_did, 337 repo_at, 338 pull_id, 339 created, 340 title, 341 state, 342 target_branch, 343 body, 344 rkey, 345 source_branch, 346 source_repo_at, 347 stack_id, 348 change_id, 349 parent_change_id 350 from 351 pulls 352 %s 353 order by 354 created desc 355 %s 356 `, whereClause, limitClause) 357 358 rows, err := e.Query(query, args...) 359 if err != nil { 360 return nil, err 361 } 362 defer rows.Close() 363 364 for rows.Next() { 365 var pull Pull 366 var createdAt string 367 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 368 err := rows.Scan( 369 &pull.OwnerDid, 370 &pull.RepoAt, 371 &pull.PullId, 372 &createdAt, 373 &pull.Title, 374 &pull.State, 375 &pull.TargetBranch, 376 &pull.Body, 377 &pull.Rkey, 378 &sourceBranch, 379 &sourceRepoAt, 380 &stackId, 381 &changeId, 382 &parentChangeId, 383 ) 384 if err != nil { 385 return nil, err 386 } 387 388 createdTime, err := time.Parse(time.RFC3339, createdAt) 389 if err != nil { 390 return nil, err 391 } 392 pull.Created = createdTime 393 394 if sourceBranch.Valid { 395 pull.PullSource = &PullSource{ 396 Branch: sourceBranch.String, 397 } 398 if sourceRepoAt.Valid { 399 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 400 if err != nil { 401 return nil, err 402 } 403 pull.PullSource.RepoAt = &sourceRepoAtParsed 404 } 405 } 406 407 if stackId.Valid { 408 pull.StackId = stackId.String 409 } 410 if changeId.Valid { 411 pull.ChangeId = changeId.String 412 } 413 if parentChangeId.Valid { 414 pull.ParentChangeId = parentChangeId.String 415 } 416 417 pulls[pull.PullId] = &pull 418 } 419 420 // get latest round no. for each pull 421 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 422 submissionsQuery := fmt.Sprintf(` 423 select 424 id, pull_id, round_number, patch, created, source_rev 425 from 426 pull_submissions 427 where 428 repo_at in (%s) and pull_id in (%s) 429 `, inClause, inClause) 430 431 args = make([]any, len(pulls)*2) 432 idx := 0 433 for _, p := range pulls { 434 args[idx] = p.RepoAt 435 idx += 1 436 } 437 for _, p := range pulls { 438 args[idx] = p.PullId 439 idx += 1 440 } 441 submissionsRows, err := e.Query(submissionsQuery, args...) 442 if err != nil { 443 return nil, err 444 } 445 defer submissionsRows.Close() 446 447 for submissionsRows.Next() { 448 var s PullSubmission 449 var sourceRev sql.NullString 450 var createdAt string 451 err := submissionsRows.Scan( 452 &s.ID, 453 &s.PullId, 454 &s.RoundNumber, 455 &s.Patch, 456 &createdAt, 457 &sourceRev, 458 ) 459 if err != nil { 460 return nil, err 461 } 462 463 createdTime, err := time.Parse(time.RFC3339, createdAt) 464 if err != nil { 465 return nil, err 466 } 467 s.Created = createdTime 468 469 if sourceRev.Valid { 470 s.SourceRev = sourceRev.String 471 } 472 473 if p, ok := pulls[s.PullId]; ok { 474 p.Submissions = make([]*PullSubmission, s.RoundNumber+1) 475 p.Submissions[s.RoundNumber] = &s 476 } 477 } 478 if err := rows.Err(); err != nil { 479 return nil, err 480 } 481 482 // get comment count on latest submission on each pull 483 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ") 484 commentsQuery := fmt.Sprintf(` 485 select 486 count(id), pull_id 487 from 488 pull_comments 489 where 490 submission_id in (%s) 491 group by 492 submission_id 493 `, inClause) 494 495 args = []any{} 496 for _, p := range pulls { 497 args = append(args, p.Submissions[p.LastRoundNumber()].ID) 498 } 499 commentsRows, err := e.Query(commentsQuery, args...) 500 if err != nil { 501 return nil, err 502 } 503 defer commentsRows.Close() 504 505 for commentsRows.Next() { 506 var commentCount, pullId int 507 err := commentsRows.Scan( 508 &commentCount, 509 &pullId, 510 ) 511 if err != nil { 512 return nil, err 513 } 514 if p, ok := pulls[pullId]; ok { 515 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount) 516 } 517 } 518 if err := rows.Err(); err != nil { 519 return nil, err 520 } 521 522 orderedByPullId := []*Pull{} 523 for _, p := range pulls { 524 orderedByPullId = append(orderedByPullId, p) 525 } 526 sort.Slice(orderedByPullId, func(i, j int) bool { 527 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 528 }) 529 530 return orderedByPullId, nil 531} 532 533func GetPulls(e Execer, filters ...filter) ([]*Pull, error) { 534 return GetPullsWithLimit(e, 0, filters...) 535} 536 537func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) { 538 query := ` 539 select 540 owner_did, 541 pull_id, 542 created, 543 title, 544 state, 545 target_branch, 546 repo_at, 547 body, 548 rkey, 549 source_branch, 550 source_repo_at, 551 stack_id, 552 change_id, 553 parent_change_id 554 from 555 pulls 556 where 557 repo_at = ? and pull_id = ? 558 ` 559 row := e.QueryRow(query, repoAt, pullId) 560 561 var pull Pull 562 var createdAt string 563 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 564 err := row.Scan( 565 &pull.OwnerDid, 566 &pull.PullId, 567 &createdAt, 568 &pull.Title, 569 &pull.State, 570 &pull.TargetBranch, 571 &pull.RepoAt, 572 &pull.Body, 573 &pull.Rkey, 574 &sourceBranch, 575 &sourceRepoAt, 576 &stackId, 577 &changeId, 578 &parentChangeId, 579 ) 580 if err != nil { 581 return nil, err 582 } 583 584 createdTime, err := time.Parse(time.RFC3339, createdAt) 585 if err != nil { 586 return nil, err 587 } 588 pull.Created = createdTime 589 590 // populate source 591 if sourceBranch.Valid { 592 pull.PullSource = &PullSource{ 593 Branch: sourceBranch.String, 594 } 595 if sourceRepoAt.Valid { 596 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 597 if err != nil { 598 return nil, err 599 } 600 pull.PullSource.RepoAt = &sourceRepoAtParsed 601 } 602 } 603 604 if stackId.Valid { 605 pull.StackId = stackId.String 606 } 607 if changeId.Valid { 608 pull.ChangeId = changeId.String 609 } 610 if parentChangeId.Valid { 611 pull.ParentChangeId = parentChangeId.String 612 } 613 614 submissionsQuery := ` 615 select 616 id, pull_id, repo_at, round_number, patch, created, source_rev 617 from 618 pull_submissions 619 where 620 repo_at = ? and pull_id = ? 621 ` 622 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId) 623 if err != nil { 624 return nil, err 625 } 626 defer submissionsRows.Close() 627 628 submissionsMap := make(map[int]*PullSubmission) 629 630 for submissionsRows.Next() { 631 var submission PullSubmission 632 var submissionCreatedStr string 633 var submissionSourceRev sql.NullString 634 err := submissionsRows.Scan( 635 &submission.ID, 636 &submission.PullId, 637 &submission.RepoAt, 638 &submission.RoundNumber, 639 &submission.Patch, 640 &submissionCreatedStr, 641 &submissionSourceRev, 642 ) 643 if err != nil { 644 return nil, err 645 } 646 647 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr) 648 if err != nil { 649 return nil, err 650 } 651 submission.Created = submissionCreatedTime 652 653 if submissionSourceRev.Valid { 654 submission.SourceRev = submissionSourceRev.String 655 } 656 657 submissionsMap[submission.ID] = &submission 658 } 659 if err = submissionsRows.Close(); err != nil { 660 return nil, err 661 } 662 if len(submissionsMap) == 0 { 663 return &pull, nil 664 } 665 666 var args []any 667 for k := range submissionsMap { 668 args = append(args, k) 669 } 670 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ") 671 commentsQuery := fmt.Sprintf(` 672 select 673 id, 674 pull_id, 675 submission_id, 676 repo_at, 677 owner_did, 678 comment_at, 679 body, 680 created 681 from 682 pull_comments 683 where 684 submission_id IN (%s) 685 order by 686 created asc 687 `, inClause) 688 commentsRows, err := e.Query(commentsQuery, args...) 689 if err != nil { 690 return nil, err 691 } 692 defer commentsRows.Close() 693 694 for commentsRows.Next() { 695 var comment PullComment 696 var commentCreatedStr string 697 err := commentsRows.Scan( 698 &comment.ID, 699 &comment.PullId, 700 &comment.SubmissionId, 701 &comment.RepoAt, 702 &comment.OwnerDid, 703 &comment.CommentAt, 704 &comment.Body, 705 &commentCreatedStr, 706 ) 707 if err != nil { 708 return nil, err 709 } 710 711 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr) 712 if err != nil { 713 return nil, err 714 } 715 comment.Created = commentCreatedTime 716 717 // Add the comment to its submission 718 if submission, ok := submissionsMap[comment.SubmissionId]; ok { 719 submission.Comments = append(submission.Comments, comment) 720 } 721 722 } 723 if err = commentsRows.Err(); err != nil { 724 return nil, err 725 } 726 727 var pullSourceRepo *models.Repo 728 if pull.PullSource != nil { 729 if pull.PullSource.RepoAt != nil { 730 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String()) 731 if err != nil { 732 log.Printf("failed to get repo by at uri: %v", err) 733 } else { 734 pull.PullSource.Repo = pullSourceRepo 735 } 736 } 737 } 738 739 pull.Submissions = make([]*PullSubmission, len(submissionsMap)) 740 for _, submission := range submissionsMap { 741 pull.Submissions[submission.RoundNumber] = submission 742 } 743 744 return &pull, nil 745} 746 747// timeframe here is directly passed into the sql query filter, and any 748// timeframe in the past should be negative; e.g.: "-3 months" 749func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) { 750 var pulls []Pull 751 752 rows, err := e.Query(` 753 select 754 p.owner_did, 755 p.repo_at, 756 p.pull_id, 757 p.created, 758 p.title, 759 p.state, 760 r.did, 761 r.name, 762 r.knot, 763 r.rkey, 764 r.created 765 from 766 pulls p 767 join 768 repos r on p.repo_at = r.at_uri 769 where 770 p.owner_did = ? and p.created >= date ('now', ?) 771 order by 772 p.created desc`, did, timeframe) 773 if err != nil { 774 return nil, err 775 } 776 defer rows.Close() 777 778 for rows.Next() { 779 var pull Pull 780 var repo models.Repo 781 var pullCreatedAt, repoCreatedAt string 782 err := rows.Scan( 783 &pull.OwnerDid, 784 &pull.RepoAt, 785 &pull.PullId, 786 &pullCreatedAt, 787 &pull.Title, 788 &pull.State, 789 &repo.Did, 790 &repo.Name, 791 &repo.Knot, 792 &repo.Rkey, 793 &repoCreatedAt, 794 ) 795 if err != nil { 796 return nil, err 797 } 798 799 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 800 if err != nil { 801 return nil, err 802 } 803 pull.Created = pullCreatedTime 804 805 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 806 if err != nil { 807 return nil, err 808 } 809 repo.Created = repoCreatedTime 810 811 pull.Repo = &repo 812 813 pulls = append(pulls, pull) 814 } 815 816 if err := rows.Err(); err != nil { 817 return nil, err 818 } 819 820 return pulls, nil 821} 822 823func NewPullComment(e Execer, comment *PullComment) (int64, error) { 824 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 825 res, err := e.Exec( 826 query, 827 comment.OwnerDid, 828 comment.RepoAt, 829 comment.SubmissionId, 830 comment.CommentAt, 831 comment.PullId, 832 comment.Body, 833 ) 834 if err != nil { 835 return 0, err 836 } 837 838 i, err := res.LastInsertId() 839 if err != nil { 840 return 0, err 841 } 842 843 return i, nil 844} 845 846func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error { 847 _, err := e.Exec( 848 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 849 pullState, 850 repoAt, 851 pullId, 852 PullDeleted, // only update state of non-deleted pulls 853 PullMerged, // only update state of non-merged pulls 854 ) 855 return err 856} 857 858func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 859 err := SetPullState(e, repoAt, pullId, PullClosed) 860 return err 861} 862 863func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 864 err := SetPullState(e, repoAt, pullId, PullOpen) 865 return err 866} 867 868func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 869 err := SetPullState(e, repoAt, pullId, PullMerged) 870 return err 871} 872 873func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 874 err := SetPullState(e, repoAt, pullId, PullDeleted) 875 return err 876} 877 878func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error { 879 newRoundNumber := len(pull.Submissions) 880 _, err := e.Exec(` 881 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev) 882 values (?, ?, ?, ?, ?) 883 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev) 884 885 return err 886} 887 888func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 889 var conditions []string 890 var args []any 891 892 args = append(args, parentChangeId) 893 894 for _, filter := range filters { 895 conditions = append(conditions, filter.Condition()) 896 args = append(args, filter.Arg()...) 897 } 898 899 whereClause := "" 900 if conditions != nil { 901 whereClause = " where " + strings.Join(conditions, " and ") 902 } 903 904 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 905 _, err := e.Exec(query, args...) 906 907 return err 908} 909 910// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 911// otherwise submissions are immutable 912func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 913 var conditions []string 914 var args []any 915 916 args = append(args, sourceRev) 917 args = append(args, newPatch) 918 919 for _, filter := range filters { 920 conditions = append(conditions, filter.Condition()) 921 args = append(args, filter.Arg()...) 922 } 923 924 whereClause := "" 925 if conditions != nil { 926 whereClause = " where " + strings.Join(conditions, " and ") 927 } 928 929 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 930 _, err := e.Exec(query, args...) 931 932 return err 933} 934 935func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 936 row := e.QueryRow(` 937 select 938 count(case when state = ? then 1 end) as open_count, 939 count(case when state = ? then 1 end) as merged_count, 940 count(case when state = ? then 1 end) as closed_count, 941 count(case when state = ? then 1 end) as deleted_count 942 from pulls 943 where repo_at = ?`, 944 PullOpen, 945 PullMerged, 946 PullClosed, 947 PullDeleted, 948 repoAt, 949 ) 950 951 var count models.PullCount 952 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 953 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 954 } 955 956 return count, nil 957} 958 959type Stack []*Pull 960 961// change-id parent-change-id 962// 963// 4 w ,-------- z (TOP) 964// 3 z <----',------- y 965// 2 y <-----',------ x 966// 1 x <------' nil (BOT) 967// 968// `w` is parent of none, so it is the top of the stack 969func GetStack(e Execer, stackId string) (Stack, error) { 970 unorderedPulls, err := GetPulls( 971 e, 972 FilterEq("stack_id", stackId), 973 FilterNotEq("state", PullDeleted), 974 ) 975 if err != nil { 976 return nil, err 977 } 978 // map of parent-change-id to pull 979 changeIdMap := make(map[string]*Pull, len(unorderedPulls)) 980 parentMap := make(map[string]*Pull, len(unorderedPulls)) 981 for _, p := range unorderedPulls { 982 changeIdMap[p.ChangeId] = p 983 if p.ParentChangeId != "" { 984 parentMap[p.ParentChangeId] = p 985 } 986 } 987 988 // the top of the stack is the pull that is not a parent of any pull 989 var topPull *Pull 990 for _, maybeTop := range unorderedPulls { 991 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 992 topPull = maybeTop 993 break 994 } 995 } 996 997 pulls := []*Pull{} 998 for { 999 pulls = append(pulls, topPull) 1000 if topPull.ParentChangeId != "" { 1001 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 1002 topPull = next 1003 } else { 1004 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 1005 } 1006 } else { 1007 break 1008 } 1009 } 1010 1011 return pulls, nil 1012} 1013 1014func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) { 1015 pulls, err := GetPulls( 1016 e, 1017 FilterEq("stack_id", stackId), 1018 FilterEq("state", PullDeleted), 1019 ) 1020 if err != nil { 1021 return nil, err 1022 } 1023 1024 return pulls, nil 1025} 1026 1027// position of this pull in the stack 1028func (stack Stack) Position(pull *Pull) int { 1029 return slices.IndexFunc(stack, func(p *Pull) bool { 1030 return p.ChangeId == pull.ChangeId 1031 }) 1032} 1033 1034// all pulls below this pull (including self) in this stack 1035// 1036// nil if this pull does not belong to this stack 1037func (stack Stack) Below(pull *Pull) Stack { 1038 position := stack.Position(pull) 1039 1040 if position < 0 { 1041 return nil 1042 } 1043 1044 return stack[position:] 1045} 1046 1047// all pulls below this pull (excluding self) in this stack 1048func (stack Stack) StrictlyBelow(pull *Pull) Stack { 1049 below := stack.Below(pull) 1050 1051 if len(below) > 0 { 1052 return below[1:] 1053 } 1054 1055 return nil 1056} 1057 1058// all pulls above this pull (including self) in this stack 1059func (stack Stack) Above(pull *Pull) Stack { 1060 position := stack.Position(pull) 1061 1062 if position < 0 { 1063 return nil 1064 } 1065 1066 return stack[:position+1] 1067} 1068 1069// all pulls below this pull (excluding self) in this stack 1070func (stack Stack) StrictlyAbove(pull *Pull) Stack { 1071 above := stack.Above(pull) 1072 1073 if len(above) > 0 { 1074 return above[:len(above)-1] 1075 } 1076 1077 return nil 1078} 1079 1080// the combined format-patches of all the newest submissions in this stack 1081func (stack Stack) CombinedPatch() string { 1082 // go in reverse order because the bottom of the stack is the last element in the slice 1083 var combined strings.Builder 1084 for idx := range stack { 1085 pull := stack[len(stack)-1-idx] 1086 combined.WriteString(pull.LatestPatch()) 1087 combined.WriteString("\n") 1088 } 1089 return combined.String() 1090} 1091 1092// filter out PRs that are "active" 1093// 1094// PRs that are still open are active 1095func (stack Stack) Mergeable() Stack { 1096 var mergeable Stack 1097 1098 for _, p := range stack { 1099 // stop at the first merged PR 1100 if p.State == PullMerged || p.State == PullClosed { 1101 break 1102 } 1103 1104 // skip over deleted PRs 1105 if p.State != PullDeleted { 1106 mergeable = append(mergeable, p) 1107 } 1108 } 1109 1110 return mergeable 1111}