social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky
at next 329 lines 12 kB view raw
1import asyncio 2import json 3import re 4from abc import ABC 5from dataclasses import dataclass, field 6from typing import Any, cast, override 7 8import httpx 9import websockets 10 11import env 12from atproto.models import AtUri 13from atproto.store import get_store 14from bluesky.info import SERVICE, BlueskyService, validate_and_transform 15from bluesky.richtext import richtext_to_tokens 16from cross.attachments import ( 17 Attachment, 18 LabelsAttachment, 19 LanguagesAttachment, 20 MediaAttachment, 21 QuoteAttachment, 22 RemoteUrlAttachment, 23) 24from cross.media import Blob, download_blob 25from cross.post import Post, PostRef 26from cross.service import InputService 27from database.connection import DatabasePool 28from util.util import Result 29 30 31@dataclass(kw_only=True) 32class BlueskyInputOptions: 33 handle: str | None = None 34 did: str | None = None 35 pds: str | None = None 36 filters: list[re.Pattern[str]] = field(default_factory=lambda: []) 37 38 @classmethod 39 def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions": 40 validate_and_transform(data) 41 42 if "filters" in data: 43 data["filters"] = [re.compile(r) for r in data["filters"]] 44 45 return BlueskyInputOptions(**data) 46 47 48class BlueskyBaseInputService(BlueskyService, InputService, ABC): 49 def __init__(self, db: DatabasePool, http: httpx.Client) -> None: 50 super().__init__(SERVICE, db) 51 self.http = http 52 53 def _on_post(self, record: dict[str, Any]): 54 post_uri = cast(str, record["$xpost.strongRef"]["uri"]) 55 post_cid = cast(str, record["$xpost.strongRef"]["cid"]) 56 57 if self._is_post_crossposted(self.url, self.did, post_uri): 58 self.log.info("Skipping '%s': already crossposted", post_uri) 59 return 60 61 parent_uri = cast( 62 str, None if not record.get("reply") else record["reply"]["parent"]["uri"] 63 ) 64 parent = None 65 if parent_uri: 66 did, _, _ = AtUri.record_uri(parent_uri) 67 if did != self.did: 68 self.log.info("Skipping '%s': reply to other user..", post_uri) 69 return 70 71 parent = self._get_post(self.url, self.did, parent_uri) 72 if not parent: 73 self.log.info( 74 "Skipping '%s': parent '%s' not found in db", post_uri, parent_uri 75 ) 76 return 77 78 tokens = richtext_to_tokens(record["text"], record.get("facets", [])) 79 post = Post( 80 id=post_uri, 81 author=self.did, 82 service=self.url, 83 parent_id=parent_uri, 84 tokens=tokens, 85 ) 86 87 did, _, rid = AtUri.record_uri(post_uri) 88 post.attachments.put( 89 RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}") 90 ) 91 92 embed: dict[str, Any] = record.get("embed", {}) 93 attachments: list[Attachment] = [] 94 blob_urls: list[tuple[str, str, str | None]] = [] 95 96 def handle_embeds( 97 embed: dict[str, Any], 98 ) -> Result[None, str]: 99 if "$type" not in embed: 100 return Result.ok(None) 101 102 match cast(str, embed["$type"]): 103 case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia": 104 rcrd = ( 105 embed["record"]["record"] 106 if embed["record"].get("record") 107 else embed["record"] 108 ) 109 did, collection, _ = AtUri.record_uri(rcrd["uri"]) 110 if collection != "app.bsky.feed.post": 111 return Result.err(f"unhandled record collection '{collection}'") 112 if did != self.did: 113 return Result.err(f"quote of other user '{did}'") 114 115 rquote = self._get_post(self.url, did, rcrd["uri"]) 116 if not rquote: 117 return Result.err(f"quote '{rcrd['uri']}' not found in db") 118 119 attachments.append( 120 QuoteAttachment(quoted_id=rcrd["uri"], quoted_user=did) 121 ) 122 if embed.get("media"): 123 return handle_embeds(embed["media"]) 124 case "app.bsky.embed.images": 125 for image in embed["images"]: 126 blob_cid = image["image"]["ref"]["$link"] 127 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" 128 blob_urls.append((url, blob_cid, image.get("alt"))) 129 case "app.bsky.embed.video": 130 blob_cid = embed["video"]["ref"]["$link"] 131 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}" 132 blob_urls.append((url, blob_cid, embed.get("alt"))) 133 case _: 134 self.log.warning(f"unhandled embed type '{embed['$type']}'") 135 return Result.ok(None) 136 137 embeds = handle_embeds(embed) 138 if not embeds.is_ok(): 139 self.log.info("Skipping '%s': %s", post_uri, embeds.error()) 140 return 141 142 for a in attachments: 143 post.attachments.put(a) 144 145 if blob_urls: 146 blobs: list[Blob] = [] 147 for url, cid, alt in blob_urls: 148 self.log.info("Downloading '%s'...", cid) 149 blob = download_blob(url, alt, client=self.http) 150 if not blob.is_ok(): 151 self.log.error( 152 "Skipping '%s': failed to download blob. %s", 153 post_uri, 154 blob.error(), 155 ) 156 return 157 blobs.append(blob.value()) 158 post.attachments.put(MediaAttachment(blobs=blobs)) 159 160 if "langs" in record: 161 post.attachments.put(LanguagesAttachment(langs=record["langs"])) 162 if "labels" in record: 163 post.attachments.put( 164 LabelsAttachment( 165 labels=[ 166 label["val"].replace("-", " ") for label in record["values"] 167 ] 168 ), 169 ) 170 171 if parent: 172 self._insert_post( 173 { 174 "user": self.did, 175 "service": self.url, 176 "identifier": post_uri, 177 "parent": parent["id"], 178 "root": parent["root"] or parent["id"], 179 "extra_data": json.dumps({"cid": post_cid}), 180 } 181 ) 182 else: 183 self._insert_post( 184 { 185 "user": self.did, 186 "service": self.url, 187 "identifier": post_uri, 188 "extra_data": json.dumps({"cid": post_cid}), 189 } 190 ) 191 192 self.log.info("Crossposting: '%s'", post_uri) 193 for out in self.outputs: 194 self.submitter(lambda: out.accept_post(post)) 195 196 def _on_repost(self, record: dict[str, Any]): 197 post_uri = cast(str, record["$xpost.strongRef"]["uri"]) 198 post_cid = cast(str, record["$xpost.strongRef"]["cid"]) 199 200 reposted_uri = cast(str, record["subject"]["uri"]) 201 reposted = self._get_post(self.url, self.did, reposted_uri) 202 if not reposted: 203 self.log.info( 204 "Skipping repost '%s': reposted post '%s' not found in db", 205 post_uri, 206 reposted_uri, 207 ) 208 return 209 210 self._insert_post( 211 { 212 "user": self.did, 213 "service": self.url, 214 "identifier": post_uri, 215 "reposted": reposted["id"], 216 "extra_data": json.dumps({"cid": post_cid}), 217 } 218 ) 219 220 repost_ref = PostRef(id=post_uri, author=self.did, service=self.url) 221 reposted_ref = PostRef(id=reposted_uri, author=self.did, service=self.url) 222 223 self.log.info("Crossposting: '%s'", post_uri) 224 for out in self.outputs: 225 self.submitter(lambda: out.accept_repost(repost_ref, reposted_ref)) 226 227 def _on_delete_post(self, post_id: str, repost: bool): 228 post = self._get_post(self.url, self.did, post_id) 229 if not post: 230 self.log.warning("Skipping delete '%s': post not found in db", post_id) 231 return 232 233 post_ref = PostRef(id=post_id, author=self.did, service=self.url) 234 if repost: 235 self.log.info("Deleting repost: '%s'", post_id) 236 for output in self.outputs: 237 self.submitter(lambda: output.delete_repost(post_ref)) 238 else: 239 self.log.info("Deleting post: '%s'", post_id) 240 for output in self.outputs: 241 self.submitter(lambda: output.delete_post(post_ref)) 242 self.submitter(lambda: self._delete_post_by_id(post["id"])) 243 244 245class BlueskyJetstreamInputService(BlueskyBaseInputService): 246 def __init__( 247 self, 248 db: DatabasePool, 249 http: httpx.Client, 250 options: BlueskyInputOptions, 251 ) -> None: 252 super().__init__(db, http) 253 self.options: BlueskyInputOptions = options 254 self._store = get_store(db) 255 self._init_identity() 256 257 @override 258 def get_identity_options(self) -> tuple[str | None, str | None, str | None]: 259 return (self.options.handle, self.options.did, self.options.pds) 260 261 def _accept_msg(self, msg: websockets.Data) -> None: 262 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg)) 263 if data.get("did") != self.did: 264 return 265 commit: dict[str, Any] | None = data.get("commit") 266 if not commit: 267 return 268 269 commit_type: str = cast(str, commit["operation"]) 270 match commit_type: 271 case "create": 272 record: dict[str, Any] = cast(dict[str, Any], commit["record"]) 273 record["$xpost.strongRef"] = { 274 "cid": commit["cid"], 275 "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}", 276 } 277 278 match cast(str, commit["collection"]): 279 case "app.bsky.feed.post": 280 self._on_post(record) 281 case "app.bsky.feed.repost": 282 self._on_repost(record) 283 case _: 284 pass 285 case "delete": 286 post_id: str = ( 287 f"at://{self.did}/{commit['collection']}/{commit['rkey']}" 288 ) 289 match cast(str, commit["collection"]): 290 case "app.bsky.feed.post": 291 self._on_delete_post(post_id, False) 292 case "app.bsky.feed.repost": 293 self._on_delete_post(post_id, True) 294 case _: 295 pass 296 case _: 297 pass 298 299 @override 300 async def listen(self): 301 url = env.JETSTREAM_URL + "?" 302 url += "wantedCollections=app.bsky.feed.post" 303 url += "&wantedCollections=app.bsky.feed.repost" 304 url += f"&wantedDids={self.did}" 305 306 async for ws in websockets.connect( 307 url, 308 ping_interval=20, 309 ping_timeout=10, 310 close_timeout=5, 311 ): 312 try: 313 self.log.info("Listening to '%s'...", env.JETSTREAM_URL) 314 315 async def listen_for_messages(): 316 async for msg in ws: 317 self.submitter(lambda: self._accept_msg(msg)) 318 319 listen = asyncio.create_task(listen_for_messages()) 320 321 _ = await asyncio.gather(listen) 322 except websockets.ConnectionClosedError as e: 323 self.log.error(e, stack_info=True, exc_info=True) 324 self.log.info("Reconnecting to '%s'...", env.JETSTREAM_URL) 325 continue 326 except TimeoutError as e: 327 self.log.error("Connection timeout: '%s'", e) 328 self.log.info("Reconnecting to '%s'...", env.JETSTREAM_URL) 329 continue