import asyncio import json import re from abc import ABC from dataclasses import dataclass, field from typing import Any, cast, override import httpx import websockets import env from atproto.models import AtUri from atproto.store import get_store from bluesky.info import SERVICE, BlueskyService, validate_and_transform from bluesky.richtext import richtext_to_tokens from cross.attachments import ( Attachment, LabelsAttachment, LanguagesAttachment, MediaAttachment, QuoteAttachment, RemoteUrlAttachment, ) from cross.media import Blob, download_blob from cross.post import Post, PostRef from cross.service import InputService from database.connection import DatabasePool from util.util import Result @dataclass(kw_only=True) class BlueskyInputOptions: handle: str | None = None did: str | None = None pds: str | None = None filters: list[re.Pattern[str]] = field(default_factory=lambda: []) @classmethod def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions": validate_and_transform(data) if "filters" in data: data["filters"] = [re.compile(r) for r in data["filters"]] return BlueskyInputOptions(**data) class BlueskyBaseInputService(BlueskyService, InputService, ABC): def __init__(self, db: DatabasePool, http: httpx.Client) -> None: super().__init__(SERVICE, db) self.http = http def _on_post(self, record: dict[str, Any]): post_uri = cast(str, record["$xpost.strongRef"]["uri"]) post_cid = cast(str, record["$xpost.strongRef"]["cid"]) if self._is_post_crossposted(self.url, self.did, post_uri): self.log.info("Skipping '%s': already crossposted", post_uri) return parent_uri = cast( str, None if not record.get("reply") else record["reply"]["parent"]["uri"] ) parent = None if parent_uri: did, _, _ = AtUri.record_uri(parent_uri) if did != self.did: self.log.info("Skipping '%s': reply to other user..", post_uri) return parent = self._get_post(self.url, self.did, parent_uri) if not parent: self.log.info( "Skipping '%s': parent '%s' not found in db", post_uri, parent_uri ) return tokens = richtext_to_tokens(record["text"], record.get("facets", [])) post = Post( id=post_uri, author=self.did, service=self.url, parent_id=parent_uri, tokens=tokens, ) did, _, rid = AtUri.record_uri(post_uri) post.attachments.put( RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}") ) embed: dict[str, Any] = record.get("embed", {}) attachments: list[Attachment] = [] blob_urls: list[tuple[str, str, str | None]] = [] def handle_embeds( embed: dict[str, Any], ) -> Result[None, str]: if "$type" not in embed: return Result.ok(None) match cast(str, embed["$type"]): case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia": rcrd = ( embed["record"]["record"] if embed["record"].get("record") else embed["record"] ) did, collection, _ = AtUri.record_uri(rcrd["uri"]) if collection != "app.bsky.feed.post": return Result.err(f"unhandled record collection '{collection}'") if did != self.did: return Result.err(f"quote of other user '{did}'") rquote = self._get_post(self.url, did, rcrd["uri"]) if not rquote: return Result.err(f"quote '{rcrd['uri']}' not found in db") attachments.append( QuoteAttachment(quoted_id=rcrd["uri"], quoted_user=did) ) if embed.get("media"): return handle_embeds(embed["media"]) case "app.bsky.embed.images": for image in embed["images"]: blob_cid = image["image"]["ref"]["$link"] url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" blob_urls.append((url, blob_cid, image.get("alt"))) case "app.bsky.embed.video": blob_cid = embed["video"]["ref"]["$link"] url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" blob_urls.append((url, blob_cid, embed.get("alt"))) case _: self.log.warning(f"unhandled embed type '{embed['$type']}'") return Result.ok(None) embeds = handle_embeds(embed) if not embeds.is_ok(): self.log.info("Skipping '%s': %s", post_uri, embeds.error()) return for a in attachments: post.attachments.put(a) if blob_urls: blobs: list[Blob] = [] for url, cid, alt in blob_urls: self.log.info("Downloading '%s'...", cid) blob = download_blob(url, alt, client=self.http) if not blob.is_ok(): self.log.error( "Skipping '%s': failed to download blob. %s", post_uri, blob.error(), ) return blobs.append(blob.value()) post.attachments.put(MediaAttachment(blobs=blobs)) if "langs" in record: post.attachments.put(LanguagesAttachment(langs=record["langs"])) if "labels" in record: post.attachments.put( LabelsAttachment( labels=[ label["val"].replace("-", " ") for label in record["values"] ] ), ) if parent: self._insert_post( { "user": self.did, "service": self.url, "identifier": post_uri, "parent": parent["id"], "root": parent["root"] or parent["id"], "extra_data": json.dumps({"cid": post_cid}), } ) else: self._insert_post( { "user": self.did, "service": self.url, "identifier": post_uri, "extra_data": json.dumps({"cid": post_cid}), } ) self.log.info("Crossposting: '%s'", post_uri) for out in self.outputs: self.submitter(lambda: out.accept_post(post)) def _on_repost(self, record: dict[str, Any]): post_uri = cast(str, record["$xpost.strongRef"]["uri"]) post_cid = cast(str, record["$xpost.strongRef"]["cid"]) reposted_uri = cast(str, record["subject"]["uri"]) reposted = self._get_post(self.url, self.did, reposted_uri) if not reposted: self.log.info( "Skipping repost '%s': reposted post '%s' not found in db", post_uri, reposted_uri, ) return self._insert_post( { "user": self.did, "service": self.url, "identifier": post_uri, "reposted": reposted["id"], "extra_data": json.dumps({"cid": post_cid}), } ) repost_ref = PostRef(id=post_uri, author=self.did, service=self.url) reposted_ref = PostRef(id=reposted_uri, author=self.did, service=self.url) self.log.info("Crossposting: '%s'", post_uri) for out in self.outputs: self.submitter(lambda: out.accept_repost(repost_ref, reposted_ref)) def _on_delete_post(self, post_id: str, repost: bool): post = self._get_post(self.url, self.did, post_id) if not post: self.log.warning("Skipping delete '%s': post not found in db", post_id) return post_ref = PostRef(id=post_id, author=self.did, service=self.url) if repost: self.log.info("Deleting repost: '%s'", post_id) for output in self.outputs: self.submitter(lambda: output.delete_repost(post_ref)) else: self.log.info("Deleting post: '%s'", post_id) for output in self.outputs: self.submitter(lambda: output.delete_post(post_ref)) self.submitter(lambda: self._delete_post_by_id(post["id"])) class BlueskyJetstreamInputService(BlueskyBaseInputService): def __init__( self, db: DatabasePool, http: httpx.Client, options: BlueskyInputOptions, ) -> None: super().__init__(db, http) self.options: BlueskyInputOptions = options self._store = get_store(db) self._init_identity() @override def get_identity_options(self) -> tuple[str | None, str | None, str | None]: return (self.options.handle, self.options.did, self.options.pds) def _accept_msg(self, msg: websockets.Data) -> None: data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) if data.get("did") != self.did: return commit: dict[str, Any] | None = data.get("commit") if not commit: return commit_type: str = cast(str, commit["operation"]) match commit_type: case "create": record: dict[str, Any] = cast(dict[str, Any], commit["record"]) record["$xpost.strongRef"] = { "cid": commit["cid"], "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}", } match cast(str, commit["collection"]): case "app.bsky.feed.post": self._on_post(record) case "app.bsky.feed.repost": self._on_repost(record) case _: pass case "delete": post_id: str = ( f"at://{self.did}/{commit['collection']}/{commit['rkey']}" ) match cast(str, commit["collection"]): case "app.bsky.feed.post": self._on_delete_post(post_id, False) case "app.bsky.feed.repost": self._on_delete_post(post_id, True) case _: pass case _: pass @override async def listen(self): url = env.JETSTREAM_URL + "?" url += "wantedCollections=app.bsky.feed.post" url += "&wantedCollections=app.bsky.feed.repost" url += f"&wantedDids={self.did}" async for ws in websockets.connect( url, ping_interval=20, ping_timeout=10, close_timeout=5, ): try: self.log.info("Listening to '%s'...", env.JETSTREAM_URL) async def listen_for_messages(): async for msg in ws: self.submitter(lambda: self._accept_msg(msg)) listen = asyncio.create_task(listen_for_messages()) _ = await asyncio.gather(listen) except websockets.ConnectionClosedError as e: self.log.error(e, stack_info=True, exc_info=True) self.log.info("Reconnecting to '%s'...", env.JETSTREAM_URL) continue except TimeoutError as e: self.log.error("Connection timeout: '%s'", e) self.log.info("Reconnecting to '%s'...", env.JETSTREAM_URL) continue