···6767from getpass import getpass
68686969from docopt import docopt
7070-import aiohttp
7070+from .ssrf import get_ssrf_safe_client
7171+71727273import cbrrr
7374···234235 elif args["run"]:
235236236237 async def run_service_with_client():
237237- async with aiohttp.ClientSession() as client:
238238+ # TODO: option to use regular unsafe client for local dev testing
239239+ async with get_ssrf_safe_client() as client:
238240 await service.run(
239241 db=db,
240242 client=client,
+37
src/millipds/ssrf.py
···11+"""
22+This is a bit of a bodge, for now.
33+44+See https://github.com/aio-libs/aiohttp/discussions/10224 for the discussion
55+that led to this, and maybe a better solution in the future.
66+"""
77+88+import ipaddress
99+from aiohttp import TCPConnector, ClientSession
1010+import aiohttp.connector
1111+from aiohttp.resolver import DefaultResolver, AbstractResolver
1212+1313+# XXX: monkeypatch to force all hosts to go through the resolver
1414+# (without this, bare IPs in the URL will bypass the resolver, where our SSRF check is)
1515+aiohttp.connector.is_ip_address = lambda _: False
1616+1717+class SSRFException(ValueError):
1818+ pass
1919+2020+class SSRFSafeResolverWrapper(AbstractResolver):
2121+ def __init__(self, resolver: AbstractResolver):
2222+ self.resolver = resolver
2323+2424+ async def resolve(self, host: str, port: int, family: int):
2525+ result = await self.resolver.resolve(host, port, family)
2626+ for host in result:
2727+ if ipaddress.ip_address(host["host"]).is_private:
2828+ raise SSRFException("Can't connect to private IP: " + host["host"])
2929+ return result
3030+3131+ async def close(self) -> None:
3232+ await self.resolver.close()
3333+3434+def get_ssrf_safe_client() -> ClientSession:
3535+ resolver = SSRFSafeResolverWrapper(DefaultResolver())
3636+ connector = TCPConnector(resolver=resolver)
3737+ return ClientSession(connector=connector)