this repo has no description

appview/db: transactions all the things

Changed files
+182 -47
appview
+2 -2
appview/db/db.go
··· 7 7 ) 8 8 9 9 type DB struct { 10 - db *sql.DB 10 + Db *sql.DB 11 11 } 12 12 13 13 func Make(dbPath string) (*DB, error) { ··· 43 43 if err != nil { 44 44 return nil, err 45 45 } 46 - return &DB{db: db}, nil 46 + return &DB{Db: db}, nil 47 47 }
+11 -4
appview/db/pubkeys.go
··· 1 1 package db 2 2 3 3 import ( 4 + "database/sql" 4 5 "encoding/json" 5 6 "time" 6 7 ) 7 8 8 9 func (d *DB) AddPublicKey(did, name, key string) error { 9 10 query := `insert into public_keys (did, name, key) values (?, ?, ?)` 10 - _, err := d.db.Exec(query, did, name, key) 11 + _, err := d.Db.Exec(query, did, name, key) 12 + return err 13 + } 14 + 15 + func (d *DB) AddPublicKeyTx(tx *sql.Tx, did, name, key string) error { 16 + query := `insert into public_keys (did, name, key) values (?, ?, ?)` 17 + _, err := tx.Exec(query, did, name, key) 11 18 return err 12 19 } 13 20 14 21 func (d *DB) RemovePublicKey(did string) error { 15 22 query := `delete from public_keys where did = ?` 16 - _, err := d.db.Exec(query, did) 23 + _, err := d.Db.Exec(query, did) 17 24 return err 18 25 } 19 26 ··· 38 45 func (d *DB) GetAllPublicKeys() ([]PublicKey, error) { 39 46 var keys []PublicKey 40 47 41 - rows, err := d.db.Query(`select key, name, did, created from public_keys`) 48 + rows, err := d.Db.Query(`select key, name, did, created from public_keys`) 42 49 if err != nil { 43 50 return nil, err 44 51 } ··· 64 71 func (d *DB) GetPublicKeys(did string) ([]PublicKey, error) { 65 72 var keys []PublicKey 66 73 67 - rows, err := d.db.Query(`select did, key, name, created from public_keys where did = ?`, did) 74 + rows, err := d.Db.Query(`select did, key, name, created from public_keys where did = ?`, did) 68 75 if err != nil { 69 76 return nil, err 70 77 }
+58 -5
appview/db/registration.go
··· 35 35 func (d *DB) RegistrationsByDid(did string) ([]Registration, error) { 36 36 var registrations []Registration 37 37 38 - rows, err := d.db.Query(` 38 + rows, err := d.Db.Query(` 39 39 select domain, did, created, registered from registrations 40 40 where did = ? 41 41 `, did) ··· 75 75 var registeredAt *int64 76 76 var registration Registration 77 77 78 - err := d.db.QueryRow(` 78 + err := d.Db.QueryRow(` 79 + select domain, did, created, registered from registrations 80 + where domain = ? 81 + `, domain).Scan(&registration.Domain, &registration.ByDid, &createdAt, &registeredAt) 82 + 83 + if err != nil { 84 + if err == sql.ErrNoRows { 85 + return nil, nil 86 + } else { 87 + return nil, err 88 + } 89 + } 90 + 91 + createdAtTime := time.Unix(*createdAt, 0) 92 + var registeredAtTime *time.Time 93 + if registeredAt != nil { 94 + x := time.Unix(*registeredAt, 0) 95 + registeredAtTime = &x 96 + } 97 + 98 + registration.Created = &createdAtTime 99 + registration.Registered = registeredAtTime 100 + 101 + return &registration, nil 102 + } 103 + 104 + func (d *DB) RegistrationByDomainTx(tx *sql.Tx, domain string) (*Registration, error) { 105 + var createdAt *int64 106 + var registeredAt *int64 107 + var registration Registration 108 + 109 + err := tx.QueryRow(` 79 110 select domain, did, created, registered from registrations 80 111 where domain = ? 81 112 `, domain).Scan(&registration.Domain, &registration.ByDid, &createdAt, &registeredAt) ··· 122 153 123 154 secret := uuid.New().String() 124 155 125 - _, err = d.db.Exec(` 156 + _, err = d.Db.Exec(` 126 157 insert into registrations (domain, did, secret) 127 158 values (?, ?, ?) 128 159 on conflict(domain) do update set did = excluded.did, secret = excluded.secret ··· 136 167 } 137 168 138 169 func (d *DB) GetRegistrationKey(domain string) (string, error) { 139 - res := d.db.QueryRow(`select secret from registrations where domain = ?`, domain) 170 + res := d.Db.QueryRow(`select secret from registrations where domain = ?`, domain) 171 + 172 + var secret string 173 + err := res.Scan(&secret) 174 + if err != nil || secret == "" { 175 + return "", err 176 + } 177 + 178 + return secret, nil 179 + } 180 + 181 + func (d *DB) GetRegistrationKeyTx(tx *sql.Tx, domain string) (string, error) { 182 + res := tx.QueryRow(`select secret from registrations where domain = ?`, domain) 140 183 141 184 var secret string 142 185 err := res.Scan(&secret) ··· 148 191 } 149 192 150 193 func (d *DB) Register(domain string) error { 151 - _, err := d.db.Exec(` 194 + _, err := d.Db.Exec(` 195 + update registrations 196 + set registered = strftime('%s', 'now') 197 + where domain = ?; 198 + `, domain) 199 + 200 + return err 201 + } 202 + 203 + func (d *DB) RegisterTx(tx *sql.Tx, domain string) error { 204 + _, err := tx.Exec(` 152 205 update registrations 153 206 set registered = strftime('%s', 'now') 154 207 where domain = ?;
+11 -4
appview/db/repos.go
··· 1 1 package db 2 2 3 + import "database/sql" 4 + 3 5 type Repo struct { 4 6 Did string 5 7 Name string ··· 10 12 func (d *DB) GetAllReposByDid(did string) ([]Repo, error) { 11 13 var repos []Repo 12 14 13 - rows, err := d.db.Query(`select did, name, knot, created from repos where did = ?`, did) 15 + rows, err := d.Db.Query(`select did, name, knot, created from repos where did = ?`, did) 14 16 if err != nil { 15 17 return nil, err 16 18 } ··· 36 38 func (d *DB) GetRepo(did, name string) (*Repo, error) { 37 39 var repo Repo 38 40 39 - row := d.db.QueryRow(`select did, name, knot, created from repos where did = ? and name = ?`, did, name) 41 + row := d.Db.QueryRow(`select did, name, knot, created from repos where did = ? and name = ?`, did, name) 40 42 var createdAt *int64 41 43 if err := row.Scan(&repo.Did, &repo.Name, &repo.Knot, &createdAt); err != nil { 42 44 return nil, err ··· 47 49 } 48 50 49 51 func (d *DB) AddRepo(repo *Repo) error { 50 - _, err := d.db.Exec(`insert into repos (did, name, knot) values (?, ?, ?)`, repo.Did, repo.Name, repo.Knot) 52 + _, err := d.Db.Exec(`insert into repos (did, name, knot) values (?, ?, ?)`, repo.Did, repo.Name, repo.Knot) 53 + return err 54 + } 55 + 56 + func (d *DB) AddRepoTx(tx *sql.Tx, repo *Repo) error { 57 + _, err := tx.Exec(`insert into repos (did, name, knot) values (?, ?, ?)`, repo.Did, repo.Name, repo.Knot) 51 58 return err 52 59 } 53 60 54 61 func (d *DB) RemoveRepo(did, name, knot string) error { 55 - _, err := d.db.Exec(`delete from repos where did = ? and name = ? and knot = ?`, did, name, knot) 62 + _, err := d.Db.Exec(`delete from repos where did = ? and name = ? and knot = ?`, did, name, knot) 56 63 return err 57 64 }
+100 -32
appview/state/state.go
··· 200 200 return 201 201 } 202 202 203 - if err := s.db.AddPublicKey(did, name, key); err != nil { 203 + // Start transaction 204 + tx, err := s.db.Db.Begin() 205 + if err != nil { 206 + log.Printf("failed to start transaction: %s", err) 207 + http.Error(w, "Internal server error", http.StatusInternalServerError) 208 + return 209 + } 210 + defer tx.Rollback() // Will rollback if not committed 211 + 212 + if err := s.db.AddPublicKeyTx(tx, did, name, key); err != nil { 204 213 log.Printf("adding public key: %s", err) 205 214 return 206 215 } ··· 223 232 return 224 233 } 225 234 235 + // If everything succeeded, commit the transaction 236 + if err := tx.Commit(); err != nil { 237 + log.Printf("failed to commit transaction: %s", err) 238 + http.Error(w, "Internal server error", http.StatusInternalServerError) 239 + return 240 + } 241 + 226 242 log.Println("created atproto record: ", resp.Uri) 227 243 228 244 return ··· 240 256 } 241 257 log.Println("checking ", domain) 242 258 243 - secret, err := s.db.GetRegistrationKey(domain) 259 + // Start transaction 260 + tx, err := s.db.Db.Begin() 261 + if err != nil { 262 + log.Printf("failed to start transaction: %s", err) 263 + http.Error(w, "Internal server error", http.StatusInternalServerError) 264 + return 265 + } 266 + defer tx.Rollback() // Will rollback if not committed 267 + 268 + secret, err := s.db.GetRegistrationKeyTx(tx, domain) 244 269 if err != nil { 245 270 log.Printf("no key found for domain %s: %s\n", domain, err) 246 271 return ··· 285 310 return 286 311 } 287 312 288 - // mark as registered 289 - err = s.db.Register(domain) 313 + // mark as registered within transaction 314 + err = s.db.RegisterTx(tx, domain) 290 315 if err != nil { 291 316 log.Println("failed to register domain", err) 292 317 http.Error(w, err.Error(), http.StatusInternalServerError) 293 318 return 294 319 } 295 320 296 - // set permissions for this did as owner 297 - reg, err := s.db.RegistrationByDomain(domain) 321 + // set permissions for this did as owner within transaction 322 + reg, err := s.db.RegistrationByDomainTx(tx, domain) 298 323 if err != nil { 299 324 log.Println("failed to register domain", err) 300 325 http.Error(w, err.Error(), http.StatusInternalServerError) 301 326 return 302 327 } 303 328 304 - // add basic acls for this domain 329 + // add basic acls for this domain within transaction 305 330 err = s.enforcer.AddDomain(domain) 306 331 if err != nil { 307 332 log.Println("failed to setup owner of domain", err) ··· 309 334 return 310 335 } 311 336 312 - // add this did as owner of this domain 337 + // add this did as owner of this domain within transaction 313 338 err = s.enforcer.AddOwner(domain, reg.ByDid) 314 339 if err != nil { 315 340 log.Println("failed to setup owner of domain", err) 316 341 http.Error(w, err.Error(), http.StatusInternalServerError) 342 + return 343 + } 344 + 345 + // Commit transaction 346 + if err := tx.Commit(); err != nil { 347 + log.Printf("failed to commit transaction: %s", err) 348 + http.Error(w, "Internal server error", http.StatusInternalServerError) 317 349 return 318 350 } 319 351 ··· 411 443 } 412 444 log.Printf("adding %s to %s\n", memberIdent.Handle.String(), domain) 413 445 414 - // announce this relation into the firehose, store into owners' pds 415 - client, _ := s.auth.AuthorizedClient(r) 416 - currentUser := s.auth.GetUser(r) 417 - addedAt := time.Now().Format(time.RFC3339) 418 - resp, err := comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 419 - Collection: tangled.KnotMemberNSID, 420 - Repo: currentUser.Did, 421 - Rkey: s.TID(), 422 - Record: &lexutil.LexiconTypeDecoder{ 423 - Val: &tangled.KnotMember{ 424 - Member: memberIdent.DID.String(), 425 - Domain: domain, 426 - AddedAt: &addedAt, 427 - }}, 428 - }) 429 - // invalid record 446 + // Start transaction 447 + tx, err := s.db.Db.Begin() 430 448 if err != nil { 431 - log.Printf("failed to create record: %s", err) 449 + log.Printf("failed to start transaction: %s", err) 450 + http.Error(w, "Internal server error", http.StatusInternalServerError) 432 451 return 433 452 } 434 - log.Println("created atproto record: ", resp.Uri) 453 + defer tx.Rollback() // Will rollback if not committed 435 454 436 - secret, err := s.db.GetRegistrationKey(domain) 455 + // Get registration key within transaction 456 + secret, err := s.db.GetRegistrationKeyTx(tx, domain) 437 457 if err != nil { 438 458 log.Printf("no key found for domain %s: %s\n", domain, err) 439 459 return 440 460 } 441 461 462 + // Make the external call to the knot server 442 463 ksClient, err := NewSignedClient(domain, secret) 443 464 if err != nil { 444 465 log.Println("failed to create client to ", domain) ··· 447 468 448 469 ksResp, err := ksClient.AddMember(memberIdent.DID.String(), []string{}) 449 470 if err != nil { 450 - log.Printf("failet to make request to %s: %s", domain, err) 471 + log.Printf("failed to make request to %s: %s", domain, err) 472 + return 451 473 } 452 474 453 475 if ksResp.StatusCode != http.StatusNoContent { ··· 455 477 return 456 478 } 457 479 480 + // Create ATProto record within transaction 481 + client, _ := s.auth.AuthorizedClient(r) 482 + currentUser := s.auth.GetUser(r) 483 + addedAt := time.Now().Format(time.RFC3339) 484 + resp, err := comatproto.RepoPutRecord(r.Context(), client, &comatproto.RepoPutRecord_Input{ 485 + Collection: tangled.KnotMemberNSID, 486 + Repo: currentUser.Did, 487 + Rkey: s.TID(), 488 + Record: &lexutil.LexiconTypeDecoder{ 489 + Val: &tangled.KnotMember{ 490 + Member: memberIdent.DID.String(), 491 + Domain: domain, 492 + AddedAt: &addedAt, 493 + }}, 494 + }) 495 + if err != nil { 496 + log.Printf("failed to create record: %s", err) 497 + return 498 + } 499 + 500 + // Update RBAC within transaction 458 501 err = s.enforcer.AddMember(domain, memberIdent.DID.String()) 459 502 if err != nil { 460 503 w.Write([]byte(fmt.Sprint("failed to add member: ", err))) 461 504 return 462 505 } 463 506 507 + // If everything succeeded, commit the transaction 508 + if err := tx.Commit(); err != nil { 509 + log.Printf("failed to commit transaction: %s", err) 510 + http.Error(w, "Internal server error", http.StatusInternalServerError) 511 + return 512 + } 513 + 514 + log.Println("created atproto record: ", resp.Uri) 464 515 w.Write([]byte(fmt.Sprint("added member: ", memberIdent.Handle.String()))) 465 516 } 466 517 ··· 494 545 return 495 546 } 496 547 497 - secret, err := s.db.GetRegistrationKey(domain) 548 + // Start transaction 549 + tx, err := s.db.Db.Begin() 550 + if err != nil { 551 + log.Printf("failed to start transaction: %s", err) 552 + http.Error(w, "Internal server error", http.StatusInternalServerError) 553 + return 554 + } 555 + defer tx.Rollback() // Will rollback if not committed 556 + 557 + secret, err := s.db.GetRegistrationKeyTx(tx, domain) 498 558 if err != nil { 499 559 log.Printf("no key found for domain %s: %s\n", domain, err) 500 560 return ··· 503 563 client, err := NewSignedClient(domain, secret) 504 564 if err != nil { 505 565 log.Println("failed to create client to ", domain) 566 + return 506 567 } 507 568 508 569 resp, err := client.NewRepo(user.Did, repoName) ··· 515 576 return 516 577 } 517 578 518 - // add to local db 579 + // add to local db within transaction 519 580 repo := &db.Repo{ 520 581 Did: user.Did, 521 582 Name: repoName, 522 583 Knot: domain, 523 584 } 524 - err = s.db.AddRepo(repo) 585 + err = s.db.AddRepoTx(tx, repo) 525 586 if err != nil { 526 587 log.Println("failed to add repo to db", err) 527 588 return 528 589 } 529 590 530 - // acls 591 + // acls within transaction 531 592 err = s.enforcer.AddRepo(user.Did, domain, filepath.Join(user.Did, repoName)) 532 593 if err != nil { 533 594 log.Println("failed to set up acls", err) 595 + return 596 + } 597 + 598 + // Commit transaction 599 + if err := tx.Commit(); err != nil { 600 + log.Printf("failed to commit transaction: %s", err) 601 + http.Error(w, "Internal server error", http.StatusInternalServerError) 534 602 return 535 603 } 536 604