this repo has no description
1package db 2 3import ( 4 "cmp" 5 "database/sql" 6 "errors" 7 "fmt" 8 "maps" 9 "slices" 10 "sort" 11 "strings" 12 "time" 13 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "tangled.org/core/appview/models" 16 "tangled.org/core/orm" 17) 18 19func NewPull(tx *sql.Tx, pull *models.Pull) error { 20 _, err := tx.Exec(` 21 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 22 values (?, 1) 23 `, pull.RepoAt) 24 if err != nil { 25 return err 26 } 27 28 var nextId int 29 err = tx.QueryRow(` 30 update repo_pull_seqs 31 set next_pull_id = next_pull_id + 1 32 where repo_at = ? 33 returning next_pull_id - 1 34 `, pull.RepoAt).Scan(&nextId) 35 if err != nil { 36 return err 37 } 38 39 pull.PullId = nextId 40 pull.State = models.PullOpen 41 42 var sourceBranch, sourceRepoAt *string 43 if pull.PullSource != nil { 44 sourceBranch = &pull.PullSource.Branch 45 if pull.PullSource.RepoAt != nil { 46 x := pull.PullSource.RepoAt.String() 47 sourceRepoAt = &x 48 } 49 } 50 51 var stackId, changeId, parentChangeId *string 52 if pull.StackId != "" { 53 stackId = &pull.StackId 54 } 55 if pull.ChangeId != "" { 56 changeId = &pull.ChangeId 57 } 58 if pull.ParentChangeId != "" { 59 parentChangeId = &pull.ParentChangeId 60 } 61 62 result, err := tx.Exec( 63 ` 64 insert into pulls ( 65 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 66 ) 67 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 68 pull.RepoAt, 69 pull.OwnerDid, 70 pull.PullId, 71 pull.Title, 72 pull.TargetBranch, 73 pull.Body, 74 pull.Rkey, 75 pull.State, 76 sourceBranch, 77 sourceRepoAt, 78 stackId, 79 changeId, 80 parentChangeId, 81 ) 82 if err != nil { 83 return err 84 } 85 86 // Set the database primary key ID 87 id, err := result.LastInsertId() 88 if err != nil { 89 return err 90 } 91 pull.ID = int(id) 92 93 _, err = tx.Exec(` 94 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 95 values (?, ?, ?, ?, ?) 96 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev) 97 if err != nil { 98 return err 99 } 100 101 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil { 102 return fmt.Errorf("put reference_links: %w", err) 103 } 104 105 return nil 106} 107 108func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 109 pull, err := GetPull(e, repoAt, pullId) 110 if err != nil { 111 return "", err 112 } 113 return pull.AtUri(), err 114} 115 116func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 117 var pullId int 118 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 119 return pullId - 1, err 120} 121 122func GetPullsWithLimit(e Execer, limit int, filters ...orm.Filter) ([]*models.Pull, error) { 123 pulls := make(map[syntax.ATURI]*models.Pull) 124 125 var conditions []string 126 var args []any 127 for _, filter := range filters { 128 conditions = append(conditions, filter.Condition()) 129 args = append(args, filter.Arg()...) 130 } 131 132 whereClause := "" 133 if conditions != nil { 134 whereClause = " where " + strings.Join(conditions, " and ") 135 } 136 limitClause := "" 137 if limit != 0 { 138 limitClause = fmt.Sprintf(" limit %d ", limit) 139 } 140 141 query := fmt.Sprintf(` 142 select 143 id, 144 owner_did, 145 repo_at, 146 pull_id, 147 created, 148 title, 149 state, 150 target_branch, 151 body, 152 rkey, 153 source_branch, 154 source_repo_at, 155 stack_id, 156 change_id, 157 parent_change_id 158 from 159 pulls 160 %s 161 order by 162 created desc 163 %s 164 `, whereClause, limitClause) 165 166 rows, err := e.Query(query, args...) 167 if err != nil { 168 return nil, err 169 } 170 defer rows.Close() 171 172 for rows.Next() { 173 var pull models.Pull 174 var createdAt string 175 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 176 err := rows.Scan( 177 &pull.ID, 178 &pull.OwnerDid, 179 &pull.RepoAt, 180 &pull.PullId, 181 &createdAt, 182 &pull.Title, 183 &pull.State, 184 &pull.TargetBranch, 185 &pull.Body, 186 &pull.Rkey, 187 &sourceBranch, 188 &sourceRepoAt, 189 &stackId, 190 &changeId, 191 &parentChangeId, 192 ) 193 if err != nil { 194 return nil, err 195 } 196 197 createdTime, err := time.Parse(time.RFC3339, createdAt) 198 if err != nil { 199 return nil, err 200 } 201 pull.Created = createdTime 202 203 if sourceBranch.Valid { 204 pull.PullSource = &models.PullSource{ 205 Branch: sourceBranch.String, 206 } 207 if sourceRepoAt.Valid { 208 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 209 if err != nil { 210 return nil, err 211 } 212 pull.PullSource.RepoAt = &sourceRepoAtParsed 213 } 214 } 215 216 if stackId.Valid { 217 pull.StackId = stackId.String 218 } 219 if changeId.Valid { 220 pull.ChangeId = changeId.String 221 } 222 if parentChangeId.Valid { 223 pull.ParentChangeId = parentChangeId.String 224 } 225 226 pulls[pull.AtUri()] = &pull 227 } 228 229 var pullAts []syntax.ATURI 230 for _, p := range pulls { 231 pullAts = append(pullAts, p.AtUri()) 232 } 233 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts)) 234 if err != nil { 235 return nil, fmt.Errorf("failed to get submissions: %w", err) 236 } 237 238 for pullAt, submissions := range submissionsMap { 239 if p, ok := pulls[pullAt]; ok { 240 p.Submissions = submissions 241 } 242 } 243 244 // collect allLabels for each issue 245 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts)) 246 if err != nil { 247 return nil, fmt.Errorf("failed to query labels: %w", err) 248 } 249 for pullAt, labels := range allLabels { 250 if p, ok := pulls[pullAt]; ok { 251 p.Labels = labels 252 } 253 } 254 255 // collect pull source for all pulls that need it 256 var sourceAts []syntax.ATURI 257 for _, p := range pulls { 258 if p.PullSource != nil && p.PullSource.RepoAt != nil { 259 sourceAts = append(sourceAts, *p.PullSource.RepoAt) 260 } 261 } 262 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts)) 263 if err != nil && !errors.Is(err, sql.ErrNoRows) { 264 return nil, fmt.Errorf("failed to get source repos: %w", err) 265 } 266 sourceRepoMap := make(map[syntax.ATURI]*models.Repo) 267 for _, r := range sourceRepos { 268 sourceRepoMap[r.RepoAt()] = &r 269 } 270 for _, p := range pulls { 271 if p.PullSource != nil && p.PullSource.RepoAt != nil { 272 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok { 273 p.PullSource.Repo = sourceRepo 274 } 275 } 276 } 277 278 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts)) 279 if err != nil { 280 return nil, fmt.Errorf("failed to query reference_links: %w", err) 281 } 282 for pullAt, references := range allReferences { 283 if pull, ok := pulls[pullAt]; ok { 284 pull.References = references 285 } 286 } 287 288 orderedByPullId := []*models.Pull{} 289 for _, p := range pulls { 290 orderedByPullId = append(orderedByPullId, p) 291 } 292 sort.Slice(orderedByPullId, func(i, j int) bool { 293 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 294 }) 295 296 return orderedByPullId, nil 297} 298 299func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) { 300 return GetPullsWithLimit(e, 0, filters...) 301} 302 303func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) { 304 var ids []int64 305 306 var filters []orm.Filter 307 filters = append(filters, orm.FilterEq("state", opts.State)) 308 if opts.RepoAt != "" { 309 filters = append(filters, orm.FilterEq("repo_at", opts.RepoAt)) 310 } 311 312 var conditions []string 313 var args []any 314 315 for _, filter := range filters { 316 conditions = append(conditions, filter.Condition()) 317 args = append(args, filter.Arg()...) 318 } 319 320 whereClause := "" 321 if conditions != nil { 322 whereClause = " where " + strings.Join(conditions, " and ") 323 } 324 pageClause := "" 325 if opts.Page.Limit != 0 { 326 pageClause = fmt.Sprintf( 327 " limit %d offset %d ", 328 opts.Page.Limit, 329 opts.Page.Offset, 330 ) 331 } 332 333 query := fmt.Sprintf( 334 ` 335 select 336 id 337 from 338 pulls 339 %s 340 %s`, 341 whereClause, 342 pageClause, 343 ) 344 args = append(args, opts.Page.Limit, opts.Page.Offset) 345 rows, err := e.Query(query, args...) 346 if err != nil { 347 return nil, err 348 } 349 defer rows.Close() 350 351 for rows.Next() { 352 var id int64 353 err := rows.Scan(&id) 354 if err != nil { 355 return nil, err 356 } 357 358 ids = append(ids, id) 359 } 360 361 return ids, nil 362} 363 364func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 365 pulls, err := GetPullsWithLimit(e, 1, orm.FilterEq("repo_at", repoAt), orm.FilterEq("pull_id", pullId)) 366 if err != nil { 367 return nil, err 368 } 369 if len(pulls) == 0 { 370 return nil, sql.ErrNoRows 371 } 372 373 return pulls[0], nil 374} 375 376// mapping from pull -> pull submissions 377func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 378 var conditions []string 379 var args []any 380 for _, filter := range filters { 381 conditions = append(conditions, filter.Condition()) 382 args = append(args, filter.Arg()...) 383 } 384 385 whereClause := "" 386 if conditions != nil { 387 whereClause = " where " + strings.Join(conditions, " and ") 388 } 389 390 query := fmt.Sprintf(` 391 select 392 id, 393 pull_at, 394 round_number, 395 patch, 396 combined, 397 created, 398 source_rev 399 from 400 pull_submissions 401 %s 402 order by 403 round_number asc 404 `, whereClause) 405 406 rows, err := e.Query(query, args...) 407 if err != nil { 408 return nil, err 409 } 410 defer rows.Close() 411 412 submissionMap := make(map[int]*models.PullSubmission) 413 414 for rows.Next() { 415 var submission models.PullSubmission 416 var submissionCreatedStr string 417 var submissionSourceRev, submissionCombined sql.NullString 418 err := rows.Scan( 419 &submission.ID, 420 &submission.PullAt, 421 &submission.RoundNumber, 422 &submission.Patch, 423 &submissionCombined, 424 &submissionCreatedStr, 425 &submissionSourceRev, 426 ) 427 if err != nil { 428 return nil, err 429 } 430 431 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil { 432 submission.Created = t 433 } 434 435 if submissionSourceRev.Valid { 436 submission.SourceRev = submissionSourceRev.String 437 } 438 439 if submissionCombined.Valid { 440 submission.Combined = submissionCombined.String 441 } 442 443 submissionMap[submission.ID] = &submission 444 } 445 446 if err := rows.Err(); err != nil { 447 return nil, err 448 } 449 450 // Get comments for all submissions using GetComments 451 submissionIds := slices.Collect(maps.Keys(submissionMap)) 452 comments, err := GetComments(e, orm.FilterIn("pull_submission_id", submissionIds)) 453 if err != nil { 454 return nil, fmt.Errorf("failed to get pull comments: %w", err) 455 } 456 for _, comment := range comments { 457 if comment.PullSubmissionId != nil { 458 if submission, ok := submissionMap[*comment.PullSubmissionId]; ok { 459 submission.Comments = append(submission.Comments, comment) 460 } 461 } 462 } 463 464 // group the submissions by pull_at 465 m := make(map[syntax.ATURI][]*models.PullSubmission) 466 for _, s := range submissionMap { 467 m[s.PullAt] = append(m[s.PullAt], s) 468 } 469 470 // sort each one by round number 471 for _, s := range m { 472 slices.SortFunc(s, func(a, b *models.PullSubmission) int { 473 return cmp.Compare(a.RoundNumber, b.RoundNumber) 474 }) 475 } 476 477 return m, nil 478} 479 480// timeframe here is directly passed into the sql query filter, and any 481// timeframe in the past should be negative; e.g.: "-3 months" 482func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 483 var pulls []models.Pull 484 485 rows, err := e.Query(` 486 select 487 p.owner_did, 488 p.repo_at, 489 p.pull_id, 490 p.created, 491 p.title, 492 p.state, 493 r.did, 494 r.name, 495 r.knot, 496 r.rkey, 497 r.created 498 from 499 pulls p 500 join 501 repos r on p.repo_at = r.at_uri 502 where 503 p.owner_did = ? and p.created >= date ('now', ?) 504 order by 505 p.created desc`, did, timeframe) 506 if err != nil { 507 return nil, err 508 } 509 defer rows.Close() 510 511 for rows.Next() { 512 var pull models.Pull 513 var repo models.Repo 514 var pullCreatedAt, repoCreatedAt string 515 err := rows.Scan( 516 &pull.OwnerDid, 517 &pull.RepoAt, 518 &pull.PullId, 519 &pullCreatedAt, 520 &pull.Title, 521 &pull.State, 522 &repo.Did, 523 &repo.Name, 524 &repo.Knot, 525 &repo.Rkey, 526 &repoCreatedAt, 527 ) 528 if err != nil { 529 return nil, err 530 } 531 532 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 533 if err != nil { 534 return nil, err 535 } 536 pull.Created = pullCreatedTime 537 538 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 539 if err != nil { 540 return nil, err 541 } 542 repo.Created = repoCreatedTime 543 544 pull.Repo = &repo 545 546 pulls = append(pulls, pull) 547 } 548 549 if err := rows.Err(); err != nil { 550 return nil, err 551 } 552 553 return pulls, nil 554} 555 556func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 557 _, err := e.Exec( 558 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 559 pullState, 560 repoAt, 561 pullId, 562 models.PullDeleted, // only update state of non-deleted pulls 563 models.PullMerged, // only update state of non-merged pulls 564 ) 565 return err 566} 567 568func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 569 err := SetPullState(e, repoAt, pullId, models.PullClosed) 570 return err 571} 572 573func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 574 err := SetPullState(e, repoAt, pullId, models.PullOpen) 575 return err 576} 577 578func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 579 err := SetPullState(e, repoAt, pullId, models.PullMerged) 580 return err 581} 582 583func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 584 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 585 return err 586} 587 588func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error { 589 _, err := e.Exec(` 590 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 591 values (?, ?, ?, ?, ?) 592 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev) 593 594 return err 595} 596 597func SetPullParentChangeId(e Execer, parentChangeId string, filters ...orm.Filter) error { 598 var conditions []string 599 var args []any 600 601 args = append(args, parentChangeId) 602 603 for _, filter := range filters { 604 conditions = append(conditions, filter.Condition()) 605 args = append(args, filter.Arg()...) 606 } 607 608 whereClause := "" 609 if conditions != nil { 610 whereClause = " where " + strings.Join(conditions, " and ") 611 } 612 613 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 614 _, err := e.Exec(query, args...) 615 616 return err 617} 618 619// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 620// otherwise submissions are immutable 621func UpdatePull(e Execer, newPatch, sourceRev string, filters ...orm.Filter) error { 622 var conditions []string 623 var args []any 624 625 args = append(args, sourceRev) 626 args = append(args, newPatch) 627 628 for _, filter := range filters { 629 conditions = append(conditions, filter.Condition()) 630 args = append(args, filter.Arg()...) 631 } 632 633 whereClause := "" 634 if conditions != nil { 635 whereClause = " where " + strings.Join(conditions, " and ") 636 } 637 638 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 639 _, err := e.Exec(query, args...) 640 641 return err 642} 643 644func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 645 row := e.QueryRow(` 646 select 647 count(case when state = ? then 1 end) as open_count, 648 count(case when state = ? then 1 end) as merged_count, 649 count(case when state = ? then 1 end) as closed_count, 650 count(case when state = ? then 1 end) as deleted_count 651 from pulls 652 where repo_at = ?`, 653 models.PullOpen, 654 models.PullMerged, 655 models.PullClosed, 656 models.PullDeleted, 657 repoAt, 658 ) 659 660 var count models.PullCount 661 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 662 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 663 } 664 665 return count, nil 666} 667 668// change-id parent-change-id 669// 670// 4 w ,-------- z (TOP) 671// 3 z <----',------- y 672// 2 y <-----',------ x 673// 1 x <------' nil (BOT) 674// 675// `w` is parent of none, so it is the top of the stack 676func GetStack(e Execer, stackId string) (models.Stack, error) { 677 unorderedPulls, err := GetPulls( 678 e, 679 orm.FilterEq("stack_id", stackId), 680 orm.FilterNotEq("state", models.PullDeleted), 681 ) 682 if err != nil { 683 return nil, err 684 } 685 // map of parent-change-id to pull 686 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 687 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 688 for _, p := range unorderedPulls { 689 changeIdMap[p.ChangeId] = p 690 if p.ParentChangeId != "" { 691 parentMap[p.ParentChangeId] = p 692 } 693 } 694 695 // the top of the stack is the pull that is not a parent of any pull 696 var topPull *models.Pull 697 for _, maybeTop := range unorderedPulls { 698 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 699 topPull = maybeTop 700 break 701 } 702 } 703 704 pulls := []*models.Pull{} 705 for { 706 pulls = append(pulls, topPull) 707 if topPull.ParentChangeId != "" { 708 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 709 topPull = next 710 } else { 711 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 712 } 713 } else { 714 break 715 } 716 } 717 718 return pulls, nil 719} 720 721func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 722 pulls, err := GetPulls( 723 e, 724 orm.FilterEq("stack_id", stackId), 725 orm.FilterEq("state", models.PullDeleted), 726 ) 727 if err != nil { 728 return nil, err 729 } 730 731 return pulls, nil 732}