this repo has no description
1package db 2 3import ( 4 "cmp" 5 "database/sql" 6 "errors" 7 "fmt" 8 "maps" 9 "slices" 10 "sort" 11 "strings" 12 "time" 13 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "tangled.org/core/appview/models" 16) 17 18func NewPull(tx *sql.Tx, pull *models.Pull) error { 19 _, err := tx.Exec(` 20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 21 values (?, 1) 22 `, pull.RepoAt) 23 if err != nil { 24 return err 25 } 26 27 var nextId int 28 err = tx.QueryRow(` 29 update repo_pull_seqs 30 set next_pull_id = next_pull_id + 1 31 where repo_at = ? 32 returning next_pull_id - 1 33 `, pull.RepoAt).Scan(&nextId) 34 if err != nil { 35 return err 36 } 37 38 pull.PullId = nextId 39 pull.State = models.PullOpen 40 41 var sourceBranch, sourceRepoAt *string 42 if pull.PullSource != nil { 43 sourceBranch = &pull.PullSource.Branch 44 if pull.PullSource.RepoAt != nil { 45 x := pull.PullSource.RepoAt.String() 46 sourceRepoAt = &x 47 } 48 } 49 50 var stackId, changeId, parentChangeId *string 51 if pull.StackId != "" { 52 stackId = &pull.StackId 53 } 54 if pull.ChangeId != "" { 55 changeId = &pull.ChangeId 56 } 57 if pull.ParentChangeId != "" { 58 parentChangeId = &pull.ParentChangeId 59 } 60 61 result, err := tx.Exec( 62 ` 63 insert into pulls ( 64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 65 ) 66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 67 pull.RepoAt, 68 pull.OwnerDid, 69 pull.PullId, 70 pull.Title, 71 pull.TargetBranch, 72 pull.Body, 73 pull.Rkey, 74 pull.State, 75 sourceBranch, 76 sourceRepoAt, 77 stackId, 78 changeId, 79 parentChangeId, 80 ) 81 if err != nil { 82 return err 83 } 84 85 // Set the database primary key ID 86 id, err := result.LastInsertId() 87 if err != nil { 88 return err 89 } 90 pull.ID = int(id) 91 92 _, err = tx.Exec(` 93 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 94 values (?, ?, ?, ?, ?) 95 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev) 96 return err 97} 98 99func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 100 pull, err := GetPull(e, repoAt, pullId) 101 if err != nil { 102 return "", err 103 } 104 return pull.AtUri(), err 105} 106 107func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 108 var pullId int 109 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 110 return pullId - 1, err 111} 112 113func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) { 114 pulls := make(map[syntax.ATURI]*models.Pull) 115 116 var conditions []string 117 var args []any 118 for _, filter := range filters { 119 conditions = append(conditions, filter.Condition()) 120 args = append(args, filter.Arg()...) 121 } 122 123 whereClause := "" 124 if conditions != nil { 125 whereClause = " where " + strings.Join(conditions, " and ") 126 } 127 limitClause := "" 128 if limit != 0 { 129 limitClause = fmt.Sprintf(" limit %d ", limit) 130 } 131 132 query := fmt.Sprintf(` 133 select 134 id, 135 owner_did, 136 repo_at, 137 pull_id, 138 created, 139 title, 140 state, 141 target_branch, 142 body, 143 rkey, 144 source_branch, 145 source_repo_at, 146 stack_id, 147 change_id, 148 parent_change_id 149 from 150 pulls 151 %s 152 order by 153 created desc 154 %s 155 `, whereClause, limitClause) 156 157 rows, err := e.Query(query, args...) 158 if err != nil { 159 return nil, err 160 } 161 defer rows.Close() 162 163 for rows.Next() { 164 var pull models.Pull 165 var createdAt string 166 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 167 err := rows.Scan( 168 &pull.ID, 169 &pull.OwnerDid, 170 &pull.RepoAt, 171 &pull.PullId, 172 &createdAt, 173 &pull.Title, 174 &pull.State, 175 &pull.TargetBranch, 176 &pull.Body, 177 &pull.Rkey, 178 &sourceBranch, 179 &sourceRepoAt, 180 &stackId, 181 &changeId, 182 &parentChangeId, 183 ) 184 if err != nil { 185 return nil, err 186 } 187 188 createdTime, err := time.Parse(time.RFC3339, createdAt) 189 if err != nil { 190 return nil, err 191 } 192 pull.Created = createdTime 193 194 if sourceBranch.Valid { 195 pull.PullSource = &models.PullSource{ 196 Branch: sourceBranch.String, 197 } 198 if sourceRepoAt.Valid { 199 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 200 if err != nil { 201 return nil, err 202 } 203 pull.PullSource.RepoAt = &sourceRepoAtParsed 204 } 205 } 206 207 if stackId.Valid { 208 pull.StackId = stackId.String 209 } 210 if changeId.Valid { 211 pull.ChangeId = changeId.String 212 } 213 if parentChangeId.Valid { 214 pull.ParentChangeId = parentChangeId.String 215 } 216 217 pulls[pull.AtUri()] = &pull 218 } 219 220 var pullAts []syntax.ATURI 221 for _, p := range pulls { 222 pullAts = append(pullAts, p.AtUri()) 223 } 224 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts)) 225 if err != nil { 226 return nil, fmt.Errorf("failed to get submissions: %w", err) 227 } 228 229 for pullAt, submissions := range submissionsMap { 230 if p, ok := pulls[pullAt]; ok { 231 p.Submissions = submissions 232 } 233 } 234 235 // collect allLabels for each issue 236 allLabels, err := GetLabels(e, FilterIn("subject", pullAts)) 237 if err != nil { 238 return nil, fmt.Errorf("failed to query labels: %w", err) 239 } 240 for pullAt, labels := range allLabels { 241 if p, ok := pulls[pullAt]; ok { 242 p.Labels = labels 243 } 244 } 245 246 // collect pull source for all pulls that need it 247 var sourceAts []syntax.ATURI 248 for _, p := range pulls { 249 if p.PullSource != nil && p.PullSource.RepoAt != nil { 250 sourceAts = append(sourceAts, *p.PullSource.RepoAt) 251 } 252 } 253 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts)) 254 if err != nil && !errors.Is(err, sql.ErrNoRows) { 255 return nil, fmt.Errorf("failed to get source repos: %w", err) 256 } 257 sourceRepoMap := make(map[syntax.ATURI]*models.Repo) 258 for _, r := range sourceRepos { 259 sourceRepoMap[r.RepoAt()] = &r 260 } 261 for _, p := range pulls { 262 if p.PullSource != nil && p.PullSource.RepoAt != nil { 263 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok { 264 p.PullSource.Repo = sourceRepo 265 } 266 } 267 } 268 269 orderedByPullId := []*models.Pull{} 270 for _, p := range pulls { 271 orderedByPullId = append(orderedByPullId, p) 272 } 273 sort.Slice(orderedByPullId, func(i, j int) bool { 274 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 275 }) 276 277 return orderedByPullId, nil 278} 279 280func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) { 281 return GetPullsWithLimit(e, 0, filters...) 282} 283 284func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) { 285 var ids []int64 286 287 var filters []filter 288 filters = append(filters, FilterEq("state", opts.State)) 289 if opts.RepoAt != "" { 290 filters = append(filters, FilterEq("repo_at", opts.RepoAt)) 291 } 292 293 var conditions []string 294 var args []any 295 296 for _, filter := range filters { 297 conditions = append(conditions, filter.Condition()) 298 args = append(args, filter.Arg()...) 299 } 300 301 whereClause := "" 302 if conditions != nil { 303 whereClause = " where " + strings.Join(conditions, " and ") 304 } 305 pageClause := "" 306 if opts.Page.Limit != 0 { 307 pageClause = fmt.Sprintf( 308 " limit %d offset %d ", 309 opts.Page.Limit, 310 opts.Page.Offset, 311 ) 312 } 313 314 query := fmt.Sprintf( 315 ` 316 select 317 id 318 from 319 pulls 320 %s 321 %s`, 322 whereClause, 323 pageClause, 324 ) 325 args = append(args, opts.Page.Limit, opts.Page.Offset) 326 rows, err := e.Query(query, args...) 327 if err != nil { 328 return nil, err 329 } 330 defer rows.Close() 331 332 for rows.Next() { 333 var id int64 334 err := rows.Scan(&id) 335 if err != nil { 336 return nil, err 337 } 338 339 ids = append(ids, id) 340 } 341 342 return ids, nil 343} 344 345func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 346 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId)) 347 if err != nil { 348 return nil, err 349 } 350 if len(pulls) == 0 { 351 return nil, sql.ErrNoRows 352 } 353 354 return pulls[0], nil 355} 356 357// mapping from pull -> pull submissions 358func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 359 var conditions []string 360 var args []any 361 for _, filter := range filters { 362 conditions = append(conditions, filter.Condition()) 363 args = append(args, filter.Arg()...) 364 } 365 366 whereClause := "" 367 if conditions != nil { 368 whereClause = " where " + strings.Join(conditions, " and ") 369 } 370 371 query := fmt.Sprintf(` 372 select 373 id, 374 pull_at, 375 round_number, 376 patch, 377 combined, 378 created, 379 source_rev 380 from 381 pull_submissions 382 %s 383 order by 384 round_number asc 385 `, whereClause) 386 387 rows, err := e.Query(query, args...) 388 if err != nil { 389 return nil, err 390 } 391 defer rows.Close() 392 393 submissionMap := make(map[int]*models.PullSubmission) 394 395 for rows.Next() { 396 var submission models.PullSubmission 397 var submissionCreatedStr string 398 var submissionSourceRev, submissionCombined sql.NullString 399 err := rows.Scan( 400 &submission.ID, 401 &submission.PullAt, 402 &submission.RoundNumber, 403 &submission.Patch, 404 &submissionCombined, 405 &submissionCreatedStr, 406 &submissionSourceRev, 407 ) 408 if err != nil { 409 return nil, err 410 } 411 412 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil { 413 submission.Created = t 414 } 415 416 if submissionSourceRev.Valid { 417 submission.SourceRev = submissionSourceRev.String 418 } 419 420 if submissionCombined.Valid { 421 submission.Combined = submissionCombined.String 422 } 423 424 submissionMap[submission.ID] = &submission 425 } 426 427 if err := rows.Err(); err != nil { 428 return nil, err 429 } 430 431 // Get comments for all submissions using GetPullComments 432 submissionIds := slices.Collect(maps.Keys(submissionMap)) 433 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds)) 434 if err != nil { 435 return nil, fmt.Errorf("failed to get pull comments: %w", err) 436 } 437 for _, comment := range comments { 438 if submission, ok := submissionMap[comment.SubmissionId]; ok { 439 submission.Comments = append(submission.Comments, comment) 440 } 441 } 442 443 // group the submissions by pull_at 444 m := make(map[syntax.ATURI][]*models.PullSubmission) 445 for _, s := range submissionMap { 446 m[s.PullAt] = append(m[s.PullAt], s) 447 } 448 449 // sort each one by round number 450 for _, s := range m { 451 slices.SortFunc(s, func(a, b *models.PullSubmission) int { 452 return cmp.Compare(a.RoundNumber, b.RoundNumber) 453 }) 454 } 455 456 return m, nil 457} 458 459func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) { 460 var conditions []string 461 var args []any 462 for _, filter := range filters { 463 conditions = append(conditions, filter.Condition()) 464 args = append(args, filter.Arg()...) 465 } 466 467 whereClause := "" 468 if conditions != nil { 469 whereClause = " where " + strings.Join(conditions, " and ") 470 } 471 472 query := fmt.Sprintf(` 473 select 474 id, 475 pull_id, 476 submission_id, 477 repo_at, 478 owner_did, 479 comment_at, 480 body, 481 created 482 from 483 pull_comments 484 %s 485 order by 486 created asc 487 `, whereClause) 488 489 rows, err := e.Query(query, args...) 490 if err != nil { 491 return nil, err 492 } 493 defer rows.Close() 494 495 commentMap := make(map[string]*models.PullComment) 496 for rows.Next() { 497 var comment models.PullComment 498 var createdAt string 499 err := rows.Scan( 500 &comment.ID, 501 &comment.PullId, 502 &comment.SubmissionId, 503 &comment.RepoAt, 504 &comment.OwnerDid, 505 &comment.CommentAt, 506 &comment.Body, 507 &createdAt, 508 ) 509 if err != nil { 510 return nil, err 511 } 512 513 if t, err := time.Parse(time.RFC3339, createdAt); err == nil { 514 comment.Created = t 515 } 516 517 atUri := comment.AtUri().String() 518 commentMap[atUri] = &comment 519 } 520 521 if err := rows.Err(); err != nil { 522 return nil, err 523 } 524 525 // collect references for each comments 526 commentAts := slices.Collect(maps.Keys(commentMap)) 527 allReferencs, err := GetReferencesAll(e, FilterIn("from_at", commentAts)) 528 if err != nil { 529 return nil, fmt.Errorf("failed to query reference_links: %w", err) 530 } 531 for commentAt, references := range allReferencs { 532 if comment, ok := commentMap[commentAt.String()]; ok { 533 comment.References = references 534 } 535 } 536 537 var comments []models.PullComment 538 for _, c := range commentMap { 539 comments = append(comments, *c) 540 } 541 542 sort.Slice(comments, func(i, j int) bool { 543 return comments[i].Created.Before(comments[j].Created) 544 }) 545 546 return comments, nil 547} 548 549// timeframe here is directly passed into the sql query filter, and any 550// timeframe in the past should be negative; e.g.: "-3 months" 551func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 552 var pulls []models.Pull 553 554 rows, err := e.Query(` 555 select 556 p.owner_did, 557 p.repo_at, 558 p.pull_id, 559 p.created, 560 p.title, 561 p.state, 562 r.did, 563 r.name, 564 r.knot, 565 r.rkey, 566 r.created 567 from 568 pulls p 569 join 570 repos r on p.repo_at = r.at_uri 571 where 572 p.owner_did = ? and p.created >= date ('now', ?) 573 order by 574 p.created desc`, did, timeframe) 575 if err != nil { 576 return nil, err 577 } 578 defer rows.Close() 579 580 for rows.Next() { 581 var pull models.Pull 582 var repo models.Repo 583 var pullCreatedAt, repoCreatedAt string 584 err := rows.Scan( 585 &pull.OwnerDid, 586 &pull.RepoAt, 587 &pull.PullId, 588 &pullCreatedAt, 589 &pull.Title, 590 &pull.State, 591 &repo.Did, 592 &repo.Name, 593 &repo.Knot, 594 &repo.Rkey, 595 &repoCreatedAt, 596 ) 597 if err != nil { 598 return nil, err 599 } 600 601 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 602 if err != nil { 603 return nil, err 604 } 605 pull.Created = pullCreatedTime 606 607 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 608 if err != nil { 609 return nil, err 610 } 611 repo.Created = repoCreatedTime 612 613 pull.Repo = &repo 614 615 pulls = append(pulls, pull) 616 } 617 618 if err := rows.Err(); err != nil { 619 return nil, err 620 } 621 622 return pulls, nil 623} 624 625func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) { 626 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 627 res, err := tx.Exec( 628 query, 629 comment.OwnerDid, 630 comment.RepoAt, 631 comment.SubmissionId, 632 comment.CommentAt, 633 comment.PullId, 634 comment.Body, 635 ) 636 if err != nil { 637 return 0, err 638 } 639 640 i, err := res.LastInsertId() 641 if err != nil { 642 return 0, err 643 } 644 645 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil { 646 return 0, fmt.Errorf("put reference_links: %w", err) 647 } 648 649 return i, nil 650} 651 652func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 653 _, err := e.Exec( 654 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 655 pullState, 656 repoAt, 657 pullId, 658 models.PullDeleted, // only update state of non-deleted pulls 659 models.PullMerged, // only update state of non-merged pulls 660 ) 661 return err 662} 663 664func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 665 err := SetPullState(e, repoAt, pullId, models.PullClosed) 666 return err 667} 668 669func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 670 err := SetPullState(e, repoAt, pullId, models.PullOpen) 671 return err 672} 673 674func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 675 err := SetPullState(e, repoAt, pullId, models.PullMerged) 676 return err 677} 678 679func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 680 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 681 return err 682} 683 684func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error { 685 _, err := e.Exec(` 686 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 687 values (?, ?, ?, ?, ?) 688 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev) 689 690 return err 691} 692 693func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error { 694 var conditions []string 695 var args []any 696 697 args = append(args, parentChangeId) 698 699 for _, filter := range filters { 700 conditions = append(conditions, filter.Condition()) 701 args = append(args, filter.Arg()...) 702 } 703 704 whereClause := "" 705 if conditions != nil { 706 whereClause = " where " + strings.Join(conditions, " and ") 707 } 708 709 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 710 _, err := e.Exec(query, args...) 711 712 return err 713} 714 715// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 716// otherwise submissions are immutable 717func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error { 718 var conditions []string 719 var args []any 720 721 args = append(args, sourceRev) 722 args = append(args, newPatch) 723 724 for _, filter := range filters { 725 conditions = append(conditions, filter.Condition()) 726 args = append(args, filter.Arg()...) 727 } 728 729 whereClause := "" 730 if conditions != nil { 731 whereClause = " where " + strings.Join(conditions, " and ") 732 } 733 734 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 735 _, err := e.Exec(query, args...) 736 737 return err 738} 739 740func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 741 row := e.QueryRow(` 742 select 743 count(case when state = ? then 1 end) as open_count, 744 count(case when state = ? then 1 end) as merged_count, 745 count(case when state = ? then 1 end) as closed_count, 746 count(case when state = ? then 1 end) as deleted_count 747 from pulls 748 where repo_at = ?`, 749 models.PullOpen, 750 models.PullMerged, 751 models.PullClosed, 752 models.PullDeleted, 753 repoAt, 754 ) 755 756 var count models.PullCount 757 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 758 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 759 } 760 761 return count, nil 762} 763 764// change-id parent-change-id 765// 766// 4 w ,-------- z (TOP) 767// 3 z <----',------- y 768// 2 y <-----',------ x 769// 1 x <------' nil (BOT) 770// 771// `w` is parent of none, so it is the top of the stack 772func GetStack(e Execer, stackId string) (models.Stack, error) { 773 unorderedPulls, err := GetPulls( 774 e, 775 FilterEq("stack_id", stackId), 776 FilterNotEq("state", models.PullDeleted), 777 ) 778 if err != nil { 779 return nil, err 780 } 781 // map of parent-change-id to pull 782 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 783 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 784 for _, p := range unorderedPulls { 785 changeIdMap[p.ChangeId] = p 786 if p.ParentChangeId != "" { 787 parentMap[p.ParentChangeId] = p 788 } 789 } 790 791 // the top of the stack is the pull that is not a parent of any pull 792 var topPull *models.Pull 793 for _, maybeTop := range unorderedPulls { 794 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 795 topPull = maybeTop 796 break 797 } 798 } 799 800 pulls := []*models.Pull{} 801 for { 802 pulls = append(pulls, topPull) 803 if topPull.ParentChangeId != "" { 804 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 805 topPull = next 806 } else { 807 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 808 } 809 } else { 810 break 811 } 812 } 813 814 return pulls, nil 815} 816 817func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 818 pulls, err := GetPulls( 819 e, 820 FilterEq("stack_id", stackId), 821 FilterEq("state", models.PullDeleted), 822 ) 823 if err != nil { 824 return nil, err 825 } 826 827 return pulls, nil 828}