Write on the margins of the internet. Powered by the AT Protocol.
margin.at
extension
web
atproto
comments
1package xrpc
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "crypto/rand"
8 "crypto/sha256"
9 "encoding/base64"
10 "encoding/json"
11 "fmt"
12 "io"
13 "net/http"
14 "time"
15
16 "github.com/go-jose/go-jose/v4"
17)
18
19type Client struct {
20 PDS string
21 AccessToken string
22 DPoPKey *ecdsa.PrivateKey
23 DPoPNonce string
24}
25
26func NewClient(pds, accessToken string, dpopKey *ecdsa.PrivateKey) *Client {
27 return &Client{
28 PDS: pds,
29 AccessToken: accessToken,
30 DPoPKey: dpopKey,
31 }
32}
33
34func (c *Client) createDPoPProof(method, uri string) (string, error) {
35 now := time.Now()
36 jti := make([]byte, 16)
37 if _, err := io.ReadFull(rand.Reader, jti); err != nil {
38
39 for i := range jti {
40 jti[i] = byte(now.UnixNano() >> (i * 8))
41 }
42 }
43
44 publicJWK := jose.JSONWebKey{
45 Key: &c.DPoPKey.PublicKey,
46 Algorithm: string(jose.ES256),
47 }
48
49 ath := ""
50 if c.AccessToken != "" {
51 hash := sha256.Sum256([]byte(c.AccessToken))
52 ath = base64.RawURLEncoding.EncodeToString(hash[:])
53 }
54
55 claims := map[string]interface{}{
56 "jti": base64.RawURLEncoding.EncodeToString(jti),
57 "htm": method,
58 "htu": uri,
59 "iat": now.Unix(),
60 "exp": now.Add(5 * time.Minute).Unix(),
61 }
62 if c.DPoPNonce != "" {
63 claims["nonce"] = c.DPoPNonce
64 }
65 if ath != "" {
66 claims["ath"] = ath
67 }
68
69 signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: c.DPoPKey}, &jose.SignerOptions{
70 ExtraHeaders: map[jose.HeaderKey]interface{}{
71 "typ": "dpop+jwt",
72 "jwk": publicJWK,
73 },
74 })
75 if err != nil {
76 return "", err
77 }
78
79 claimsBytes, _ := json.Marshal(claims)
80 sig, err := signer.Sign(claimsBytes)
81 if err != nil {
82 return "", err
83 }
84
85 return sig.CompactSerialize()
86}
87
88func (c *Client) Call(ctx context.Context, method, nsid string, input, output interface{}) error {
89 url := fmt.Sprintf("%s/xrpc/%s", c.PDS, nsid)
90
91 maxRetries := 2
92 for i := 0; i < maxRetries; i++ {
93 var reqBody io.Reader
94 if input != nil {
95
96 data, err := json.Marshal(input)
97 if err != nil {
98 return err
99 }
100 reqBody = bytes.NewReader(data)
101 }
102
103 req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
104 if err != nil {
105 return err
106 }
107
108 if input != nil {
109 req.Header.Set("Content-Type", "application/json")
110 }
111
112 dpopProof, err := c.createDPoPProof(method, url)
113 if err != nil {
114 return fmt.Errorf("failed to create DPoP proof: %w", err)
115 }
116
117 req.Header.Set("Authorization", "DPoP "+c.AccessToken)
118 req.Header.Set("DPoP", dpopProof)
119
120 resp, err := http.DefaultClient.Do(req)
121 if err != nil {
122 return err
123 }
124 defer resp.Body.Close()
125
126 if nonce := resp.Header.Get("DPoP-Nonce"); nonce != "" {
127 c.DPoPNonce = nonce
128 }
129
130 if resp.StatusCode < 400 {
131 if output != nil {
132 return json.NewDecoder(resp.Body).Decode(output)
133 }
134 return nil
135 }
136
137 bodyBytes, _ := io.ReadAll(resp.Body)
138 bodyStr := string(bodyBytes)
139
140 if resp.StatusCode == 401 && (bytes.Contains(bodyBytes, []byte("use_dpop_nonce")) || bytes.Contains(bodyBytes, []byte("UseDpopNonce"))) {
141 continue
142 }
143
144 return fmt.Errorf("XRPC error %d: %s", resp.StatusCode, bodyStr)
145 }
146
147 return fmt.Errorf("XRPC failed after retries")
148}
149
150type CreateRecordInput struct {
151 Repo string `json:"repo"`
152 Collection string `json:"collection"`
153 RKey string `json:"rkey,omitempty"`
154 Record interface{} `json:"record"`
155}
156
157type CreateRecordOutput struct {
158 URI string `json:"uri"`
159 CID string `json:"cid"`
160}
161
162func (c *Client) CreateRecord(ctx context.Context, repo, collection string, record interface{}) (*CreateRecordOutput, error) {
163 input := CreateRecordInput{
164 Repo: repo,
165 Collection: collection,
166 Record: record,
167 }
168
169 var output CreateRecordOutput
170 err := c.Call(ctx, "POST", "com.atproto.repo.createRecord", input, &output)
171 if err != nil {
172 return nil, err
173 }
174
175 return &output, nil
176}
177
178type DeleteRecordInput struct {
179 Repo string `json:"repo"`
180 Collection string `json:"collection"`
181 RKey string `json:"rkey"`
182}
183
184func (c *Client) DeleteRecord(ctx context.Context, repo, collection, rkey string) error {
185 input := DeleteRecordInput{
186 Repo: repo,
187 Collection: collection,
188 RKey: rkey,
189 }
190
191 return c.Call(ctx, "POST", "com.atproto.repo.deleteRecord", input, nil)
192}
193
194func (c *Client) DeleteRecordByURI(ctx context.Context, uri string) error {
195 parsed, err := ParseATURI(uri)
196 if err != nil {
197 return err
198 }
199
200 if parsed.Collection == "" || parsed.RKey == "" {
201 return fmt.Errorf("invalid AT-URI: must include collection and rkey")
202 }
203
204 return c.DeleteRecord(ctx, parsed.DID, parsed.Collection, parsed.RKey)
205}
206
207type PutRecordInput struct {
208 Repo string `json:"repo"`
209 Collection string `json:"collection"`
210 RKey string `json:"rkey"`
211 Record interface{} `json:"record"`
212}
213
214type PutRecordOutput struct {
215 URI string `json:"uri"`
216 CID string `json:"cid"`
217}
218
219func (c *Client) PutRecord(ctx context.Context, repo, collection, rkey string, record interface{}) (*PutRecordOutput, error) {
220 input := PutRecordInput{
221 Repo: repo,
222 Collection: collection,
223 RKey: rkey,
224 Record: record,
225 }
226
227 var output PutRecordOutput
228 err := c.Call(ctx, "POST", "com.atproto.repo.putRecord", input, &output)
229 if err != nil {
230 return nil, err
231 }
232
233 return &output, nil
234}
235
236type GetRecordOutput struct {
237 URI string `json:"uri"`
238 CID string `json:"cid"`
239 Value json.RawMessage `json:"value"`
240}
241
242func (c *Client) GetRecord(ctx context.Context, repo, collection, rkey string) (*GetRecordOutput, error) {
243 url := fmt.Sprintf("%s/xrpc/com.atproto.repo.getRecord?repo=%s&collection=%s&rkey=%s",
244 c.PDS, repo, collection, rkey)
245
246 req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
247 if err != nil {
248 return nil, err
249 }
250
251 dpopProof, err := c.createDPoPProof("GET", url)
252 if err != nil {
253 return nil, err
254 }
255
256 req.Header.Set("Authorization", "DPoP "+c.AccessToken)
257 req.Header.Set("DPoP", dpopProof)
258
259 resp, err := http.DefaultClient.Do(req)
260 if err != nil {
261 return nil, err
262 }
263 defer resp.Body.Close()
264
265 if resp.StatusCode >= 400 {
266 bodyBytes, _ := io.ReadAll(resp.Body)
267 return nil, fmt.Errorf("XRPC error %d: %s", resp.StatusCode, string(bodyBytes))
268 }
269
270 var output GetRecordOutput
271 if err := json.NewDecoder(resp.Body).Decode(&output); err != nil {
272 return nil, err
273 }
274
275 return &output, nil
276}
277
278type ResolveHandleOutput struct {
279 Did string `json:"did"`
280}
281
282func (c *Client) ResolveHandle(ctx context.Context, handle string) (string, error) {
283 url := fmt.Sprintf("%s/xrpc/com.atproto.identity.resolveHandle?handle=%s", c.PDS, handle)
284
285 req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
286 if err != nil {
287 return "", err
288 }
289
290 resp, err := http.DefaultClient.Do(req)
291 if err != nil {
292 return "", err
293 }
294 defer resp.Body.Close()
295
296 if resp.StatusCode >= 400 {
297 return "", fmt.Errorf("XRPC error %d", resp.StatusCode)
298 }
299
300 var output ResolveHandleOutput
301 if err := json.NewDecoder(resp.Body).Decode(&output); err != nil {
302 return "", err
303 }
304
305 return output.Did, nil
306}