A from-scratch atproto PDS implementation in Python (mirrors https://github.com/DavidBuchanan314/millipds)
at main 512 lines 13 kB view raw
1import os 2import asyncio 3import tempfile 4import urllib.parse 5import unittest.mock 6import pytest 7import dataclasses 8import aiohttp 9import aiohttp.web 10 11from millipds import service 12from millipds import database 13from millipds import crypto 14 15 16@dataclasses.dataclass 17class PDSInfo: 18 endpoint: str 19 db: database.Database 20 21 22old_web_tcpsite_start = aiohttp.web.TCPSite.start 23 24 25def make_capture_random_bound_port_web_tcpsite_startstart(queue: asyncio.Queue): 26 async def mock_start(site: aiohttp.web.TCPSite, *args, **kwargs): 27 nonlocal queue 28 await old_web_tcpsite_start(site, *args, **kwargs) 29 await queue.put(site._server.sockets[0].getsockname()[1]) 30 31 return mock_start 32 33 34async def service_run_and_capture_port(queue: asyncio.Queue, **kwargs): 35 mock_start = make_capture_random_bound_port_web_tcpsite_startstart(queue) 36 with unittest.mock.patch.object( 37 aiohttp.web.TCPSite, "start", new=mock_start 38 ): 39 await service.run(**kwargs) 40 41 42if 0: 43 TEST_DID = "did:web:alice.test" 44 TEST_HANDLE = "alice.test" 45 TEST_PASSWORD = "alice_pw" 46else: 47 TEST_DID = "did:plc:bwxddkvw5c6pkkntbtp2j4lx" 48 TEST_HANDLE = "local.dev.retr0.id" 49 TEST_PASSWORD = "lol" 50TEST_PRIVKEY = crypto.keygen_p256() 51 52 53@pytest.fixture 54async def test_pds(aiolib): 55 queue = asyncio.Queue() 56 with tempfile.TemporaryDirectory() as tempdir: 57 async with aiohttp.ClientSession() as client: 58 db_path = f"{tempdir}/millipds-0000.db" 59 db = database.Database(path=db_path) 60 61 hostname = "localhost:0" 62 db.update_config( 63 pds_pfx=f"http://{hostname}", 64 pds_did=f"did:web:{urllib.parse.quote(hostname)}", 65 bsky_appview_pfx="https://api.bsky.app", 66 bsky_appview_did="did:web:api.bsky.app", 67 ) 68 69 service_run_task = asyncio.create_task( 70 service_run_and_capture_port( 71 queue, 72 db=db, 73 client=client, 74 sock_path=None, 75 host="localhost", 76 port=0, 77 ) 78 ) 79 queue_get_task = asyncio.create_task(queue.get()) 80 done, pending = await asyncio.wait( 81 (queue_get_task, service_run_task), 82 return_when=asyncio.FIRST_COMPLETED, 83 ) 84 if done == service_run_task: 85 raise service_run_task.execption() 86 else: 87 port = queue_get_task.result() 88 89 hostname = f"localhost:{port}" 90 db.update_config( 91 pds_pfx=f"http://{hostname}", 92 pds_did=f"did:web:{urllib.parse.quote(hostname)}", 93 bsky_appview_pfx="https://api.bsky.app", 94 bsky_appview_did="did:web:api.bsky.app", 95 ) 96 db.create_account( 97 did=TEST_DID, 98 handle=TEST_HANDLE, 99 password=TEST_PASSWORD, 100 privkey=TEST_PRIVKEY, 101 ) 102 103 try: 104 yield PDSInfo( 105 endpoint=f"http://{hostname}", 106 db=db, 107 ) 108 finally: 109 db.con.close() 110 service_run_task.cancel() 111 try: 112 await service_run_task 113 except asyncio.CancelledError: 114 pass 115 116 117@pytest.fixture 118async def s(aiolib): 119 async with aiohttp.ClientSession() as s: 120 yield s 121 122 123@pytest.fixture 124def pds_host(test_pds) -> str: 125 return test_pds.endpoint 126 127 128async def test_hello_world(s, pds_host): 129 async with s.get(pds_host + "/") as r: 130 r = await r.text() 131 print(r) 132 assert "Hello" in r 133 134 135async def test_describeServer(s, pds_host): 136 async with s.get(pds_host + "/xrpc/com.atproto.server.describeServer") as r: 137 print(await r.json()) 138 139 140async def test_createSession_no_args(s, pds_host): 141 # no args 142 async with s.post(pds_host + "/xrpc/com.atproto.server.createSession") as r: 143 assert r.status != 200 144 145 146invalid_logins = [ 147 {"identifier": [], "password": TEST_PASSWORD}, 148 {"identifier": "example.invalid", "password": "wrongPassword123"}, 149 {"identifier": TEST_HANDLE, "password": "wrongPassword123"}, 150] 151 152 153@pytest.mark.parametrize("login_data", invalid_logins) 154async def test_invalid_logins(s, pds_host, login_data): 155 async with s.post( 156 pds_host + "/xrpc/com.atproto.server.createSession", 157 json=login_data, 158 ) as r: 159 assert r.status != 200 160 161 162valid_logins = [ 163 {"identifier": TEST_HANDLE, "password": TEST_PASSWORD}, 164 {"identifier": TEST_DID, "password": TEST_PASSWORD}, 165] 166 167 168@pytest.mark.parametrize("login_data", valid_logins) 169async def test_valid_logins(s, pds_host, login_data): 170 async with s.post( 171 pds_host + "/xrpc/com.atproto.server.createSession", 172 json=login_data, 173 ) as r: 174 r = await r.json() 175 assert r["did"] == TEST_DID 176 assert r["handle"] == TEST_HANDLE 177 assert "accessJwt" in r 178 assert "refreshJwt" in r 179 180 token = r["accessJwt"] 181 auth_headers = {"Authorization": "Bearer " + token} 182 183 # good auth 184 async with s.get( 185 pds_host + "/xrpc/com.atproto.server.getSession", 186 headers=auth_headers, 187 ) as r: 188 print(await r.json()) 189 assert r.status == 200 190 191 # bad auth 192 async with s.get( 193 pds_host + "/xrpc/com.atproto.server.getSession", 194 headers={"Authorization": "Bearer " + token[:-1]}, 195 ) as r: 196 print(await r.text()) 197 assert r.status != 200 198 199 # bad auth 200 async with s.get( 201 pds_host + "/xrpc/com.atproto.server.getSession", 202 headers={"Authorization": "Bearest"}, 203 ) as r: 204 print(await r.text()) 205 assert r.status != 200 206 207 208async def test_sync_getRepo(s, pds_host): 209 async with s.get( 210 pds_host + "/xrpc/com.atproto.sync.getRepo", 211 params={"did": TEST_DID}, 212 ) as r: 213 assert r.status == 200 214 215 216@pytest.fixture 217async def auth_headers(s, pds_host): 218 async with s.post( 219 pds_host + "/xrpc/com.atproto.server.createSession", 220 json=valid_logins[0], 221 ) as r: 222 r = await r.json() 223 token = r["accessJwt"] 224 return {"Authorization": "Bearer " + token} 225 226 227@pytest.fixture 228async def populated_pds_host(s, pds_host, auth_headers): 229 # same thing as test_repo_applyWrites, for now 230 for i in range(10): 231 async with s.post( 232 pds_host + "/xrpc/com.atproto.repo.applyWrites", 233 headers=auth_headers, 234 json={ 235 "repo": TEST_DID, 236 "writes": [ 237 { 238 "$type": "com.atproto.repo.applyWrites#create", 239 "action": "create", 240 "collection": "app.bsky.feed.like", 241 "rkey": f"{i}-{j}", 242 "value": {"blah": "test record"}, 243 } 244 for j in range(30) 245 ], 246 }, 247 ) as r: 248 print(await r.json()) 249 assert r.status == 200 250 return pds_host 251 252 253async def test_repo_applyWrites(s, pds_host, auth_headers): 254 # TODO: test more than just "create"! 255 for i in range(10): 256 async with s.post( 257 pds_host + "/xrpc/com.atproto.repo.applyWrites", 258 headers=auth_headers, 259 json={ 260 "repo": TEST_DID, 261 "writes": [ 262 { 263 "$type": "com.atproto.repo.applyWrites#create", 264 "action": "create", 265 "collection": "app.bsky.feed.like", 266 "rkey": f"{i}-{j}", 267 "value": {"blah": "test record"}, 268 } 269 for j in range(30) 270 ], 271 }, 272 ) as r: 273 print(await r.json()) 274 assert r.status == 200 275 276 277async def test_repo_uploadBlob(s, pds_host, auth_headers): 278 blob = os.urandom(0x100000) 279 280 for _ in range(2): # test reupload is nop 281 async with s.post( 282 pds_host + "/xrpc/com.atproto.repo.uploadBlob", 283 headers=auth_headers | {"content-type": "blah"}, 284 data=blob, 285 ) as r: 286 res = await r.json() 287 print(res) 288 assert r.status == 200 289 290 # getBlob should still 404 because refcount==0 291 async with s.get( 292 pds_host + "/xrpc/com.atproto.sync.getBlob", 293 params={"did": TEST_DID, "cid": res["blob"]["ref"]["$link"]}, 294 ) as r: 295 assert r.status == 404 296 297 # get the blob refcount >0 298 async with s.post( 299 pds_host + "/xrpc/com.atproto.repo.createRecord", 300 headers=auth_headers, 301 json={ 302 "repo": TEST_DID, 303 "collection": "app.bsky.feed.post", 304 "record": {"myblob": res}, 305 }, 306 ) as r: 307 print(await r.json()) 308 assert r.status == 200 309 310 async with s.get( 311 pds_host + "/xrpc/com.atproto.sync.getBlob", 312 params={"did": TEST_DID, "cid": res["blob"]["ref"]["$link"]}, 313 ) as r: 314 downloaded_blob = await r.read() 315 assert downloaded_blob == blob 316 317 async with s.get( 318 pds_host + "/xrpc/com.atproto.sync.getRepo", 319 params={"did": TEST_DID}, 320 ) as r: 321 assert r.status == 200 322 open("repo.car", "wb").write(await r.read()) 323 324 325async def test_sync_getRepo_not_found(s, pds_host): 326 async with s.get( 327 pds_host + "/xrpc/com.atproto.sync.getRepo", 328 params={"did": "did:web:nonexistent.invalid"}, 329 ) as r: 330 assert r.status == 404 331 332 333async def test_sync_getRecord_nonexistent(s, populated_pds_host): 334 # nonexistent DID should still 404 335 async with s.get( 336 populated_pds_host + "/xrpc/com.atproto.sync.getRecord", 337 params={ 338 "did": "did:web:nonexistent.invalid", 339 "collection": "app.bsky.feed.post", 340 "rkey": "nonexistent", 341 }, 342 ) as r: 343 assert r.status == 404 344 345 # but extant DID with nonexistent record should 200, with exclusion proof CAR 346 async with s.get( 347 populated_pds_host + "/xrpc/com.atproto.sync.getRecord", 348 params={ 349 "did": TEST_DID, 350 "collection": "app.bsky.feed.post", 351 "rkey": "nonexistent", 352 }, 353 ) as r: 354 assert r.status == 200 355 assert r.content_type == "application/vnd.ipld.car" 356 proof_car = await r.read() 357 assert proof_car # nonempty 358 # TODO: make sure the proof is valid 359 360 361async def test_sync_getRecord_existent(s, populated_pds_host): 362 async with s.get( 363 populated_pds_host + "/xrpc/com.atproto.sync.getRecord", 364 params={ 365 "did": TEST_DID, 366 "collection": "app.bsky.feed.like", 367 "rkey": "1-1", 368 }, 369 ) as r: 370 assert r.status == 200 371 assert r.content_type == "application/vnd.ipld.car" 372 proof_car = await r.read() 373 assert proof_car # nonempty 374 # TODO: make sure the proof is valid, and contains the record 375 assert b"test record" in proof_car 376 377 378async def test_seviceauth(s, test_pds, auth_headers): 379 async with s.get( 380 test_pds.endpoint + "/xrpc/com.atproto.server.getServiceAuth", 381 headers=auth_headers, 382 params={ 383 "aud": test_pds.db.config["pds_did"], 384 "lxm": "com.atproto.server.getSession", 385 }, 386 ) as r: 387 assert r.status == 200 388 token = (await r.json())["token"] 389 390 # test if the token works 391 async with s.get( 392 test_pds.endpoint + "/xrpc/com.atproto.server.getSession", 393 headers={"Authorization": "Bearer " + token}, 394 ) as r: 395 assert r.status == 200 396 await r.json() 397 398 399async def test_refreshSession(s, pds_host): 400 async with s.post( 401 pds_host + "/xrpc/com.atproto.server.createSession", 402 json=valid_logins[0], 403 ) as r: 404 assert r.status == 200 405 r = await r.json() 406 orig_session_token = r["accessJwt"] 407 orig_refresh_token = r["refreshJwt"] 408 409 # can't refresh using the session token 410 async with s.post( 411 pds_host + "/xrpc/com.atproto.server.refreshSession", 412 headers={"Authorization": "Bearer " + orig_session_token}, 413 ) as r: 414 assert r.status != 200 415 416 # correctly refresh using the refresh token 417 async with s.post( 418 pds_host + "/xrpc/com.atproto.server.refreshSession", 419 headers={"Authorization": "Bearer " + orig_refresh_token}, 420 ) as r: 421 assert r.status == 200 422 r = await r.json() 423 new_session_token = r["accessJwt"] 424 new_refresh_token = r["refreshJwt"] 425 426 # test if the new session token works 427 async with s.get( 428 pds_host + "/xrpc/com.atproto.server.getSession", 429 headers={"Authorization": "Bearer " + new_session_token}, 430 ) as r: 431 assert r.status == 200 432 await r.json() 433 434 # test that the old session token is invalid 435 # XXX: in the future we might relax this behaviour 436 async with s.get( 437 pds_host + "/xrpc/com.atproto.server.getSession", 438 headers={"Authorization": "Bearer " + orig_session_token}, 439 ) as r: 440 assert r.status != 200 441 442 # test that the old refresh token is invalid 443 async with s.post( 444 pds_host + "/xrpc/com.atproto.server.refreshSession", 445 headers={"Authorization": "Bearer " + orig_refresh_token}, 446 ) as r: 447 assert r.status != 200 448 449 450async def test_deleteSession(s, pds_host): 451 async with s.post( 452 pds_host + "/xrpc/com.atproto.server.createSession", 453 json=valid_logins[0], 454 ) as r: 455 assert r.status == 200 456 r = await r.json() 457 session_token = r["accessJwt"] 458 refresh_token = r["refreshJwt"] 459 460 # sanity-check that the session token currently works 461 async with s.get( 462 pds_host + "/xrpc/com.atproto.server.getSession", 463 headers={"Authorization": "Bearer " + session_token}, 464 ) as r: 465 assert r.status == 200 466 await r.json() 467 468 # can't delete using the session token 469 async with s.post( 470 pds_host + "/xrpc/com.atproto.server.deleteSession", 471 headers={"Authorization": "Bearer " + session_token}, 472 ) as r: 473 assert r.status != 200 474 475 # can delete using the refresh token 476 async with s.post( 477 pds_host + "/xrpc/com.atproto.server.deleteSession", 478 headers={"Authorization": "Bearer " + refresh_token}, 479 ) as r: 480 assert r.status == 200 481 482 # test that the session token is invalid now 483 # XXX: in the future we might relax this behaviour 484 async with s.get( 485 pds_host + "/xrpc/com.atproto.server.getSession", 486 headers={"Authorization": "Bearer " + session_token}, 487 ) as r: 488 assert r.status != 200 489 490 # test that the refresh token is invalid too 491 async with s.post( 492 pds_host + "/xrpc/com.atproto.server.refreshSession", 493 headers={"Authorization": "Bearer " + refresh_token}, 494 ) as r: 495 assert r.status != 200 496 497 498async def test_updateHandle(s, pds_host, auth_headers): 499 async with s.post( 500 pds_host + "/xrpc/com.atproto.identity.updateHandle", 501 headers=auth_headers, 502 json={"handle": "juliet.test"}, 503 ) as r: 504 assert r.status == 200 505 506 async with s.get( 507 pds_host + "/xrpc/com.atproto.repo.describeRepo", 508 params={"repo": TEST_DID}, 509 ) as r: 510 assert r.status == 200 511 r = await r.json() 512 assert r["handle"] == "juliet.test"