Write on the margins of the internet. Powered by the AT Protocol. margin.at
extension web atproto comments
at ui-refactor 306 lines 7.0 kB view raw
1package xrpc 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "crypto/rand" 8 "crypto/sha256" 9 "encoding/base64" 10 "encoding/json" 11 "fmt" 12 "io" 13 "net/http" 14 "time" 15 16 "github.com/go-jose/go-jose/v4" 17) 18 19type Client struct { 20 PDS string 21 AccessToken string 22 DPoPKey *ecdsa.PrivateKey 23 DPoPNonce string 24} 25 26func NewClient(pds, accessToken string, dpopKey *ecdsa.PrivateKey) *Client { 27 return &Client{ 28 PDS: pds, 29 AccessToken: accessToken, 30 DPoPKey: dpopKey, 31 } 32} 33 34func (c *Client) createDPoPProof(method, uri string) (string, error) { 35 now := time.Now() 36 jti := make([]byte, 16) 37 if _, err := io.ReadFull(rand.Reader, jti); err != nil { 38 39 for i := range jti { 40 jti[i] = byte(now.UnixNano() >> (i * 8)) 41 } 42 } 43 44 publicJWK := jose.JSONWebKey{ 45 Key: &c.DPoPKey.PublicKey, 46 Algorithm: string(jose.ES256), 47 } 48 49 ath := "" 50 if c.AccessToken != "" { 51 hash := sha256.Sum256([]byte(c.AccessToken)) 52 ath = base64.RawURLEncoding.EncodeToString(hash[:]) 53 } 54 55 claims := map[string]interface{}{ 56 "jti": base64.RawURLEncoding.EncodeToString(jti), 57 "htm": method, 58 "htu": uri, 59 "iat": now.Unix(), 60 "exp": now.Add(5 * time.Minute).Unix(), 61 } 62 if c.DPoPNonce != "" { 63 claims["nonce"] = c.DPoPNonce 64 } 65 if ath != "" { 66 claims["ath"] = ath 67 } 68 69 signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: c.DPoPKey}, &jose.SignerOptions{ 70 ExtraHeaders: map[jose.HeaderKey]interface{}{ 71 "typ": "dpop+jwt", 72 "jwk": publicJWK, 73 }, 74 }) 75 if err != nil { 76 return "", err 77 } 78 79 claimsBytes, _ := json.Marshal(claims) 80 sig, err := signer.Sign(claimsBytes) 81 if err != nil { 82 return "", err 83 } 84 85 return sig.CompactSerialize() 86} 87 88func (c *Client) Call(ctx context.Context, method, nsid string, input, output interface{}) error { 89 url := fmt.Sprintf("%s/xrpc/%s", c.PDS, nsid) 90 91 maxRetries := 2 92 for i := 0; i < maxRetries; i++ { 93 var reqBody io.Reader 94 if input != nil { 95 96 data, err := json.Marshal(input) 97 if err != nil { 98 return err 99 } 100 reqBody = bytes.NewReader(data) 101 } 102 103 req, err := http.NewRequestWithContext(ctx, method, url, reqBody) 104 if err != nil { 105 return err 106 } 107 108 if input != nil { 109 req.Header.Set("Content-Type", "application/json") 110 } 111 112 dpopProof, err := c.createDPoPProof(method, url) 113 if err != nil { 114 return fmt.Errorf("failed to create DPoP proof: %w", err) 115 } 116 117 req.Header.Set("Authorization", "DPoP "+c.AccessToken) 118 req.Header.Set("DPoP", dpopProof) 119 120 resp, err := http.DefaultClient.Do(req) 121 if err != nil { 122 return err 123 } 124 defer resp.Body.Close() 125 126 if nonce := resp.Header.Get("DPoP-Nonce"); nonce != "" { 127 c.DPoPNonce = nonce 128 } 129 130 if resp.StatusCode < 400 { 131 if output != nil { 132 return json.NewDecoder(resp.Body).Decode(output) 133 } 134 return nil 135 } 136 137 bodyBytes, _ := io.ReadAll(resp.Body) 138 bodyStr := string(bodyBytes) 139 140 if resp.StatusCode == 401 && (bytes.Contains(bodyBytes, []byte("use_dpop_nonce")) || bytes.Contains(bodyBytes, []byte("UseDpopNonce"))) { 141 continue 142 } 143 144 return fmt.Errorf("XRPC error %d: %s", resp.StatusCode, bodyStr) 145 } 146 147 return fmt.Errorf("XRPC failed after retries") 148} 149 150type CreateRecordInput struct { 151 Repo string `json:"repo"` 152 Collection string `json:"collection"` 153 RKey string `json:"rkey,omitempty"` 154 Record interface{} `json:"record"` 155} 156 157type CreateRecordOutput struct { 158 URI string `json:"uri"` 159 CID string `json:"cid"` 160} 161 162func (c *Client) CreateRecord(ctx context.Context, repo, collection string, record interface{}) (*CreateRecordOutput, error) { 163 input := CreateRecordInput{ 164 Repo: repo, 165 Collection: collection, 166 Record: record, 167 } 168 169 var output CreateRecordOutput 170 err := c.Call(ctx, "POST", "com.atproto.repo.createRecord", input, &output) 171 if err != nil { 172 return nil, err 173 } 174 175 return &output, nil 176} 177 178type DeleteRecordInput struct { 179 Repo string `json:"repo"` 180 Collection string `json:"collection"` 181 RKey string `json:"rkey"` 182} 183 184func (c *Client) DeleteRecord(ctx context.Context, repo, collection, rkey string) error { 185 input := DeleteRecordInput{ 186 Repo: repo, 187 Collection: collection, 188 RKey: rkey, 189 } 190 191 return c.Call(ctx, "POST", "com.atproto.repo.deleteRecord", input, nil) 192} 193 194func (c *Client) DeleteRecordByURI(ctx context.Context, uri string) error { 195 parsed, err := ParseATURI(uri) 196 if err != nil { 197 return err 198 } 199 200 if parsed.Collection == "" || parsed.RKey == "" { 201 return fmt.Errorf("invalid AT-URI: must include collection and rkey") 202 } 203 204 return c.DeleteRecord(ctx, parsed.DID, parsed.Collection, parsed.RKey) 205} 206 207type PutRecordInput struct { 208 Repo string `json:"repo"` 209 Collection string `json:"collection"` 210 RKey string `json:"rkey"` 211 Record interface{} `json:"record"` 212} 213 214type PutRecordOutput struct { 215 URI string `json:"uri"` 216 CID string `json:"cid"` 217} 218 219func (c *Client) PutRecord(ctx context.Context, repo, collection, rkey string, record interface{}) (*PutRecordOutput, error) { 220 input := PutRecordInput{ 221 Repo: repo, 222 Collection: collection, 223 RKey: rkey, 224 Record: record, 225 } 226 227 var output PutRecordOutput 228 err := c.Call(ctx, "POST", "com.atproto.repo.putRecord", input, &output) 229 if err != nil { 230 return nil, err 231 } 232 233 return &output, nil 234} 235 236type GetRecordOutput struct { 237 URI string `json:"uri"` 238 CID string `json:"cid"` 239 Value json.RawMessage `json:"value"` 240} 241 242func (c *Client) GetRecord(ctx context.Context, repo, collection, rkey string) (*GetRecordOutput, error) { 243 url := fmt.Sprintf("%s/xrpc/com.atproto.repo.getRecord?repo=%s&collection=%s&rkey=%s", 244 c.PDS, repo, collection, rkey) 245 246 req, err := http.NewRequestWithContext(ctx, "GET", url, nil) 247 if err != nil { 248 return nil, err 249 } 250 251 dpopProof, err := c.createDPoPProof("GET", url) 252 if err != nil { 253 return nil, err 254 } 255 256 req.Header.Set("Authorization", "DPoP "+c.AccessToken) 257 req.Header.Set("DPoP", dpopProof) 258 259 resp, err := http.DefaultClient.Do(req) 260 if err != nil { 261 return nil, err 262 } 263 defer resp.Body.Close() 264 265 if resp.StatusCode >= 400 { 266 bodyBytes, _ := io.ReadAll(resp.Body) 267 return nil, fmt.Errorf("XRPC error %d: %s", resp.StatusCode, string(bodyBytes)) 268 } 269 270 var output GetRecordOutput 271 if err := json.NewDecoder(resp.Body).Decode(&output); err != nil { 272 return nil, err 273 } 274 275 return &output, nil 276} 277 278type ResolveHandleOutput struct { 279 Did string `json:"did"` 280} 281 282func (c *Client) ResolveHandle(ctx context.Context, handle string) (string, error) { 283 url := fmt.Sprintf("%s/xrpc/com.atproto.identity.resolveHandle?handle=%s", c.PDS, handle) 284 285 req, err := http.NewRequestWithContext(ctx, "GET", url, nil) 286 if err != nil { 287 return "", err 288 } 289 290 resp, err := http.DefaultClient.Do(req) 291 if err != nil { 292 return "", err 293 } 294 defer resp.Body.Close() 295 296 if resp.StatusCode >= 400 { 297 return "", fmt.Errorf("XRPC error %d", resp.StatusCode) 298 } 299 300 var output ResolveHandleOutput 301 if err := json.NewDecoder(resp.Body).Decode(&output); err != nil { 302 return "", err 303 } 304 305 return output.Did, nil 306}