A from-scratch atproto PDS implementation in Python (mirrors https://github.com/DavidBuchanan314/millipds)
1import os
2import asyncio
3import tempfile
4import urllib.parse
5import unittest.mock
6import pytest
7import dataclasses
8import aiohttp
9import aiohttp.web
10
11from millipds import service
12from millipds import database
13from millipds import crypto
14
15
16@dataclasses.dataclass
17class PDSInfo:
18 endpoint: str
19 db: database.Database
20
21
22old_web_tcpsite_start = aiohttp.web.TCPSite.start
23
24
25def make_capture_random_bound_port_web_tcpsite_startstart(queue: asyncio.Queue):
26 async def mock_start(site: aiohttp.web.TCPSite, *args, **kwargs):
27 nonlocal queue
28 await old_web_tcpsite_start(site, *args, **kwargs)
29 await queue.put(site._server.sockets[0].getsockname()[1])
30
31 return mock_start
32
33
34async def service_run_and_capture_port(queue: asyncio.Queue, **kwargs):
35 mock_start = make_capture_random_bound_port_web_tcpsite_startstart(queue)
36 with unittest.mock.patch.object(
37 aiohttp.web.TCPSite, "start", new=mock_start
38 ):
39 await service.run(**kwargs)
40
41
42if 0:
43 TEST_DID = "did:web:alice.test"
44 TEST_HANDLE = "alice.test"
45 TEST_PASSWORD = "alice_pw"
46else:
47 TEST_DID = "did:plc:bwxddkvw5c6pkkntbtp2j4lx"
48 TEST_HANDLE = "local.dev.retr0.id"
49 TEST_PASSWORD = "lol"
50TEST_PRIVKEY = crypto.keygen_p256()
51
52
53@pytest.fixture
54async def test_pds(aiolib):
55 queue = asyncio.Queue()
56 with tempfile.TemporaryDirectory() as tempdir:
57 async with aiohttp.ClientSession() as client:
58 db_path = f"{tempdir}/millipds-0000.db"
59 db = database.Database(path=db_path)
60
61 hostname = "localhost:0"
62 db.update_config(
63 pds_pfx=f"http://{hostname}",
64 pds_did=f"did:web:{urllib.parse.quote(hostname)}",
65 bsky_appview_pfx="https://api.bsky.app",
66 bsky_appview_did="did:web:api.bsky.app",
67 )
68
69 service_run_task = asyncio.create_task(
70 service_run_and_capture_port(
71 queue,
72 db=db,
73 client=client,
74 sock_path=None,
75 host="localhost",
76 port=0,
77 )
78 )
79 queue_get_task = asyncio.create_task(queue.get())
80 done, pending = await asyncio.wait(
81 (queue_get_task, service_run_task),
82 return_when=asyncio.FIRST_COMPLETED,
83 )
84 if done == service_run_task:
85 raise service_run_task.execption()
86 else:
87 port = queue_get_task.result()
88
89 hostname = f"localhost:{port}"
90 db.update_config(
91 pds_pfx=f"http://{hostname}",
92 pds_did=f"did:web:{urllib.parse.quote(hostname)}",
93 bsky_appview_pfx="https://api.bsky.app",
94 bsky_appview_did="did:web:api.bsky.app",
95 )
96 db.create_account(
97 did=TEST_DID,
98 handle=TEST_HANDLE,
99 password=TEST_PASSWORD,
100 privkey=TEST_PRIVKEY,
101 )
102
103 try:
104 yield PDSInfo(
105 endpoint=f"http://{hostname}",
106 db=db,
107 )
108 finally:
109 db.con.close()
110 service_run_task.cancel()
111 try:
112 await service_run_task
113 except asyncio.CancelledError:
114 pass
115
116
117@pytest.fixture
118async def s(aiolib):
119 async with aiohttp.ClientSession() as s:
120 yield s
121
122
123@pytest.fixture
124def pds_host(test_pds) -> str:
125 return test_pds.endpoint
126
127
128async def test_hello_world(s, pds_host):
129 async with s.get(pds_host + "/") as r:
130 r = await r.text()
131 print(r)
132 assert "Hello" in r
133
134
135async def test_describeServer(s, pds_host):
136 async with s.get(pds_host + "/xrpc/com.atproto.server.describeServer") as r:
137 print(await r.json())
138
139
140async def test_createSession_no_args(s, pds_host):
141 # no args
142 async with s.post(pds_host + "/xrpc/com.atproto.server.createSession") as r:
143 assert r.status != 200
144
145
146invalid_logins = [
147 {"identifier": [], "password": TEST_PASSWORD},
148 {"identifier": "example.invalid", "password": "wrongPassword123"},
149 {"identifier": TEST_HANDLE, "password": "wrongPassword123"},
150]
151
152
153@pytest.mark.parametrize("login_data", invalid_logins)
154async def test_invalid_logins(s, pds_host, login_data):
155 async with s.post(
156 pds_host + "/xrpc/com.atproto.server.createSession",
157 json=login_data,
158 ) as r:
159 assert r.status != 200
160
161
162valid_logins = [
163 {"identifier": TEST_HANDLE, "password": TEST_PASSWORD},
164 {"identifier": TEST_DID, "password": TEST_PASSWORD},
165]
166
167
168@pytest.mark.parametrize("login_data", valid_logins)
169async def test_valid_logins(s, pds_host, login_data):
170 async with s.post(
171 pds_host + "/xrpc/com.atproto.server.createSession",
172 json=login_data,
173 ) as r:
174 r = await r.json()
175 assert r["did"] == TEST_DID
176 assert r["handle"] == TEST_HANDLE
177 assert "accessJwt" in r
178 assert "refreshJwt" in r
179
180 token = r["accessJwt"]
181 auth_headers = {"Authorization": "Bearer " + token}
182
183 # good auth
184 async with s.get(
185 pds_host + "/xrpc/com.atproto.server.getSession",
186 headers=auth_headers,
187 ) as r:
188 print(await r.json())
189 assert r.status == 200
190
191 # bad auth
192 async with s.get(
193 pds_host + "/xrpc/com.atproto.server.getSession",
194 headers={"Authorization": "Bearer " + token[:-1]},
195 ) as r:
196 print(await r.text())
197 assert r.status != 200
198
199 # bad auth
200 async with s.get(
201 pds_host + "/xrpc/com.atproto.server.getSession",
202 headers={"Authorization": "Bearest"},
203 ) as r:
204 print(await r.text())
205 assert r.status != 200
206
207
208async def test_sync_getRepo(s, pds_host):
209 async with s.get(
210 pds_host + "/xrpc/com.atproto.sync.getRepo",
211 params={"did": TEST_DID},
212 ) as r:
213 assert r.status == 200
214
215
216@pytest.fixture
217async def auth_headers(s, pds_host):
218 async with s.post(
219 pds_host + "/xrpc/com.atproto.server.createSession",
220 json=valid_logins[0],
221 ) as r:
222 r = await r.json()
223 token = r["accessJwt"]
224 return {"Authorization": "Bearer " + token}
225
226
227@pytest.fixture
228async def populated_pds_host(s, pds_host, auth_headers):
229 # same thing as test_repo_applyWrites, for now
230 for i in range(10):
231 async with s.post(
232 pds_host + "/xrpc/com.atproto.repo.applyWrites",
233 headers=auth_headers,
234 json={
235 "repo": TEST_DID,
236 "writes": [
237 {
238 "$type": "com.atproto.repo.applyWrites#create",
239 "action": "create",
240 "collection": "app.bsky.feed.like",
241 "rkey": f"{i}-{j}",
242 "value": {"blah": "test record"},
243 }
244 for j in range(30)
245 ],
246 },
247 ) as r:
248 print(await r.json())
249 assert r.status == 200
250 return pds_host
251
252
253async def test_repo_applyWrites(s, pds_host, auth_headers):
254 # TODO: test more than just "create"!
255 for i in range(10):
256 async with s.post(
257 pds_host + "/xrpc/com.atproto.repo.applyWrites",
258 headers=auth_headers,
259 json={
260 "repo": TEST_DID,
261 "writes": [
262 {
263 "$type": "com.atproto.repo.applyWrites#create",
264 "action": "create",
265 "collection": "app.bsky.feed.like",
266 "rkey": f"{i}-{j}",
267 "value": {"blah": "test record"},
268 }
269 for j in range(30)
270 ],
271 },
272 ) as r:
273 print(await r.json())
274 assert r.status == 200
275
276
277async def test_repo_uploadBlob(s, pds_host, auth_headers):
278 blob = os.urandom(0x100000)
279
280 for _ in range(2): # test reupload is nop
281 async with s.post(
282 pds_host + "/xrpc/com.atproto.repo.uploadBlob",
283 headers=auth_headers | {"content-type": "blah"},
284 data=blob,
285 ) as r:
286 res = await r.json()
287 print(res)
288 assert r.status == 200
289
290 # getBlob should still 404 because refcount==0
291 async with s.get(
292 pds_host + "/xrpc/com.atproto.sync.getBlob",
293 params={"did": TEST_DID, "cid": res["blob"]["ref"]["$link"]},
294 ) as r:
295 assert r.status == 404
296
297 # get the blob refcount >0
298 async with s.post(
299 pds_host + "/xrpc/com.atproto.repo.createRecord",
300 headers=auth_headers,
301 json={
302 "repo": TEST_DID,
303 "collection": "app.bsky.feed.post",
304 "record": {"myblob": res},
305 },
306 ) as r:
307 print(await r.json())
308 assert r.status == 200
309
310 async with s.get(
311 pds_host + "/xrpc/com.atproto.sync.getBlob",
312 params={"did": TEST_DID, "cid": res["blob"]["ref"]["$link"]},
313 ) as r:
314 downloaded_blob = await r.read()
315 assert downloaded_blob == blob
316
317 async with s.get(
318 pds_host + "/xrpc/com.atproto.sync.getRepo",
319 params={"did": TEST_DID},
320 ) as r:
321 assert r.status == 200
322 open("repo.car", "wb").write(await r.read())
323
324
325async def test_sync_getRepo_not_found(s, pds_host):
326 async with s.get(
327 pds_host + "/xrpc/com.atproto.sync.getRepo",
328 params={"did": "did:web:nonexistent.invalid"},
329 ) as r:
330 assert r.status == 404
331
332
333async def test_sync_getRecord_nonexistent(s, populated_pds_host):
334 # nonexistent DID should still 404
335 async with s.get(
336 populated_pds_host + "/xrpc/com.atproto.sync.getRecord",
337 params={
338 "did": "did:web:nonexistent.invalid",
339 "collection": "app.bsky.feed.post",
340 "rkey": "nonexistent",
341 },
342 ) as r:
343 assert r.status == 404
344
345 # but extant DID with nonexistent record should 200, with exclusion proof CAR
346 async with s.get(
347 populated_pds_host + "/xrpc/com.atproto.sync.getRecord",
348 params={
349 "did": TEST_DID,
350 "collection": "app.bsky.feed.post",
351 "rkey": "nonexistent",
352 },
353 ) as r:
354 assert r.status == 200
355 assert r.content_type == "application/vnd.ipld.car"
356 proof_car = await r.read()
357 assert proof_car # nonempty
358 # TODO: make sure the proof is valid
359
360
361async def test_sync_getRecord_existent(s, populated_pds_host):
362 async with s.get(
363 populated_pds_host + "/xrpc/com.atproto.sync.getRecord",
364 params={
365 "did": TEST_DID,
366 "collection": "app.bsky.feed.like",
367 "rkey": "1-1",
368 },
369 ) as r:
370 assert r.status == 200
371 assert r.content_type == "application/vnd.ipld.car"
372 proof_car = await r.read()
373 assert proof_car # nonempty
374 # TODO: make sure the proof is valid, and contains the record
375 assert b"test record" in proof_car
376
377
378async def test_seviceauth(s, test_pds, auth_headers):
379 async with s.get(
380 test_pds.endpoint + "/xrpc/com.atproto.server.getServiceAuth",
381 headers=auth_headers,
382 params={
383 "aud": test_pds.db.config["pds_did"],
384 "lxm": "com.atproto.server.getSession",
385 },
386 ) as r:
387 assert r.status == 200
388 token = (await r.json())["token"]
389
390 # test if the token works
391 async with s.get(
392 test_pds.endpoint + "/xrpc/com.atproto.server.getSession",
393 headers={"Authorization": "Bearer " + token},
394 ) as r:
395 assert r.status == 200
396 await r.json()
397
398
399async def test_refreshSession(s, pds_host):
400 async with s.post(
401 pds_host + "/xrpc/com.atproto.server.createSession",
402 json=valid_logins[0],
403 ) as r:
404 assert r.status == 200
405 r = await r.json()
406 orig_session_token = r["accessJwt"]
407 orig_refresh_token = r["refreshJwt"]
408
409 # can't refresh using the session token
410 async with s.post(
411 pds_host + "/xrpc/com.atproto.server.refreshSession",
412 headers={"Authorization": "Bearer " + orig_session_token},
413 ) as r:
414 assert r.status != 200
415
416 # correctly refresh using the refresh token
417 async with s.post(
418 pds_host + "/xrpc/com.atproto.server.refreshSession",
419 headers={"Authorization": "Bearer " + orig_refresh_token},
420 ) as r:
421 assert r.status == 200
422 r = await r.json()
423 new_session_token = r["accessJwt"]
424 new_refresh_token = r["refreshJwt"]
425
426 # test if the new session token works
427 async with s.get(
428 pds_host + "/xrpc/com.atproto.server.getSession",
429 headers={"Authorization": "Bearer " + new_session_token},
430 ) as r:
431 assert r.status == 200
432 await r.json()
433
434 # test that the old session token is invalid
435 # XXX: in the future we might relax this behaviour
436 async with s.get(
437 pds_host + "/xrpc/com.atproto.server.getSession",
438 headers={"Authorization": "Bearer " + orig_session_token},
439 ) as r:
440 assert r.status != 200
441
442 # test that the old refresh token is invalid
443 async with s.post(
444 pds_host + "/xrpc/com.atproto.server.refreshSession",
445 headers={"Authorization": "Bearer " + orig_refresh_token},
446 ) as r:
447 assert r.status != 200
448
449
450async def test_deleteSession(s, pds_host):
451 async with s.post(
452 pds_host + "/xrpc/com.atproto.server.createSession",
453 json=valid_logins[0],
454 ) as r:
455 assert r.status == 200
456 r = await r.json()
457 session_token = r["accessJwt"]
458 refresh_token = r["refreshJwt"]
459
460 # sanity-check that the session token currently works
461 async with s.get(
462 pds_host + "/xrpc/com.atproto.server.getSession",
463 headers={"Authorization": "Bearer " + session_token},
464 ) as r:
465 assert r.status == 200
466 await r.json()
467
468 # can't delete using the session token
469 async with s.post(
470 pds_host + "/xrpc/com.atproto.server.deleteSession",
471 headers={"Authorization": "Bearer " + session_token},
472 ) as r:
473 assert r.status != 200
474
475 # can delete using the refresh token
476 async with s.post(
477 pds_host + "/xrpc/com.atproto.server.deleteSession",
478 headers={"Authorization": "Bearer " + refresh_token},
479 ) as r:
480 assert r.status == 200
481
482 # test that the session token is invalid now
483 # XXX: in the future we might relax this behaviour
484 async with s.get(
485 pds_host + "/xrpc/com.atproto.server.getSession",
486 headers={"Authorization": "Bearer " + session_token},
487 ) as r:
488 assert r.status != 200
489
490 # test that the refresh token is invalid too
491 async with s.post(
492 pds_host + "/xrpc/com.atproto.server.refreshSession",
493 headers={"Authorization": "Bearer " + refresh_token},
494 ) as r:
495 assert r.status != 200
496
497
498async def test_updateHandle(s, pds_host, auth_headers):
499 async with s.post(
500 pds_host + "/xrpc/com.atproto.identity.updateHandle",
501 headers=auth_headers,
502 json={"handle": "juliet.test"},
503 ) as r:
504 assert r.status == 200
505
506 async with s.get(
507 pds_host + "/xrpc/com.atproto.repo.describeRepo",
508 params={"repo": TEST_DID},
509 ) as r:
510 assert r.status == 200
511 r = await r.json()
512 assert r["handle"] == "juliet.test"