social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import asyncio
2import json
3import re
4from abc import ABC
5from dataclasses import dataclass, field
6from typing import Any, cast, override
7
8import httpx
9import websockets
10
11import env
12from atproto.models import AtUri
13from atproto.store import get_store
14from bluesky.info import SERVICE, BlueskyService, validate_and_transform
15from bluesky.richtext import richtext_to_tokens
16from cross.attachments import (
17 Attachment,
18 LabelsAttachment,
19 LanguagesAttachment,
20 MediaAttachment,
21 QuoteAttachment,
22 RemoteUrlAttachment,
23)
24from cross.media import Blob, download_blob
25from cross.post import Post, PostRef
26from cross.service import InputService
27from database.connection import DatabasePool
28from util.util import Result
29
30
31@dataclass(kw_only=True)
32class BlueskyInputOptions:
33 handle: str | None = None
34 did: str | None = None
35 pds: str | None = None
36 filters: list[re.Pattern[str]] = field(default_factory=lambda: [])
37
38 @classmethod
39 def from_dict(cls, data: dict[str, Any]) -> "BlueskyInputOptions":
40 validate_and_transform(data)
41
42 if "filters" in data:
43 data["filters"] = [re.compile(r) for r in data["filters"]]
44
45 return BlueskyInputOptions(**data)
46
47
48class BlueskyBaseInputService(BlueskyService, InputService, ABC):
49 def __init__(self, db: DatabasePool, http: httpx.Client) -> None:
50 super().__init__(SERVICE, db)
51 self.http = http
52
53 def _on_post(self, record: dict[str, Any]):
54 post_uri = cast(str, record["$xpost.strongRef"]["uri"])
55 post_cid = cast(str, record["$xpost.strongRef"]["cid"])
56
57 if self._is_post_crossposted(self.url, self.did, post_uri):
58 self.log.info("Skipping '%s': already crossposted", post_uri)
59 return
60
61 parent_uri = cast(
62 str, None if not record.get("reply") else record["reply"]["parent"]["uri"]
63 )
64 parent = None
65 if parent_uri:
66 did, _, _ = AtUri.record_uri(parent_uri)
67 if did != self.did:
68 self.log.info("Skipping '%s': reply to other user..", post_uri)
69 return
70
71 parent = self._get_post(self.url, self.did, parent_uri)
72 if not parent:
73 self.log.info(
74 "Skipping '%s': parent '%s' not found in db", post_uri, parent_uri
75 )
76 return
77
78 tokens = richtext_to_tokens(record["text"], record.get("facets", []))
79 post = Post(
80 id=post_uri,
81 author=self.did,
82 service=self.url,
83 parent_id=parent_uri,
84 tokens=tokens,
85 )
86
87 did, _, rid = AtUri.record_uri(post_uri)
88 post.attachments.put(
89 RemoteUrlAttachment(url=f"https://bsky.app/profile/{did}/post/{rid}")
90 )
91
92 embed: dict[str, Any] = record.get("embed", {})
93 attachments: list[Attachment] = []
94 blob_urls: list[tuple[str, str, str | None]] = []
95
96 def handle_embeds(
97 embed: dict[str, Any],
98 ) -> Result[None, str]:
99 if "$type" not in embed:
100 return Result.ok(None)
101
102 match cast(str, embed["$type"]):
103 case "app.bsky.embed.record" | "app.bsky.embed.recordWithMedia":
104 rcrd = (
105 embed["record"]["record"]
106 if embed["record"].get("record")
107 else embed["record"]
108 )
109 did, collection, _ = AtUri.record_uri(rcrd["uri"])
110 if collection != "app.bsky.feed.post":
111 return Result.err(f"unhandled record collection '{collection}'")
112 if did != self.did:
113 return Result.err(f"quote of other user '{did}'")
114
115 rquote = self._get_post(self.url, did, rcrd["uri"])
116 if not rquote:
117 return Result.err(f"quote '{rcrd['uri']}' not found in db")
118
119 attachments.append(
120 QuoteAttachment(quoted_id=rcrd["uri"], quoted_user=did)
121 )
122 if embed.get("media"):
123 return handle_embeds(embed["media"])
124 case "app.bsky.embed.images":
125 for image in embed["images"]:
126 blob_cid = image["image"]["ref"]["$link"]
127 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}"
128 blob_urls.append((url, blob_cid, image.get("alt")))
129 case "app.bsky.embed.video":
130 blob_cid = embed["video"]["ref"]["$link"]
131 url = f"{self.pds}/xrpc/com.atproto.sync.getBlob?did={self.did}&cid={blob_cid}"
132 blob_urls.append((url, blob_cid, embed.get("alt")))
133 case _:
134 self.log.warning(f"unhandled embed type '{embed['$type']}'")
135 return Result.ok(None)
136
137 embeds = handle_embeds(embed)
138 if not embeds.is_ok():
139 self.log.info("Skipping '%s': %s", post_uri, embeds.error())
140 return
141
142 for a in attachments:
143 post.attachments.put(a)
144
145 if blob_urls:
146 blobs: list[Blob] = []
147 for url, cid, alt in blob_urls:
148 self.log.info("Downloading '%s'...", cid)
149 blob = download_blob(url, alt, client=self.http)
150 if not blob.is_ok():
151 self.log.error(
152 "Skipping '%s': failed to download blob. %s",
153 post_uri,
154 blob.error(),
155 )
156 return
157 blobs.append(blob.value())
158 post.attachments.put(MediaAttachment(blobs=blobs))
159
160 if "langs" in record:
161 post.attachments.put(LanguagesAttachment(langs=record["langs"]))
162 if "labels" in record:
163 post.attachments.put(
164 LabelsAttachment(
165 labels=[
166 label["val"].replace("-", " ") for label in record["values"]
167 ]
168 ),
169 )
170
171 if parent:
172 self._insert_post(
173 {
174 "user": self.did,
175 "service": self.url,
176 "identifier": post_uri,
177 "parent": parent["id"],
178 "root": parent["root"] or parent["id"],
179 "extra_data": json.dumps({"cid": post_cid}),
180 }
181 )
182 else:
183 self._insert_post(
184 {
185 "user": self.did,
186 "service": self.url,
187 "identifier": post_uri,
188 "extra_data": json.dumps({"cid": post_cid}),
189 }
190 )
191
192 self.log.info("Crossposting: '%s'", post_uri)
193 for out in self.outputs:
194 self.submitter(lambda: out.accept_post(post))
195
196 def _on_repost(self, record: dict[str, Any]):
197 post_uri = cast(str, record["$xpost.strongRef"]["uri"])
198 post_cid = cast(str, record["$xpost.strongRef"]["cid"])
199
200 reposted_uri = cast(str, record["subject"]["uri"])
201 reposted = self._get_post(self.url, self.did, reposted_uri)
202 if not reposted:
203 self.log.info(
204 "Skipping repost '%s': reposted post '%s' not found in db",
205 post_uri,
206 reposted_uri,
207 )
208 return
209
210 self._insert_post(
211 {
212 "user": self.did,
213 "service": self.url,
214 "identifier": post_uri,
215 "reposted": reposted["id"],
216 "extra_data": json.dumps({"cid": post_cid}),
217 }
218 )
219
220 repost_ref = PostRef(id=post_uri, author=self.did, service=self.url)
221 reposted_ref = PostRef(id=reposted_uri, author=self.did, service=self.url)
222
223 self.log.info("Crossposting: '%s'", post_uri)
224 for out in self.outputs:
225 self.submitter(lambda: out.accept_repost(repost_ref, reposted_ref))
226
227 def _on_delete_post(self, post_id: str, repost: bool):
228 post = self._get_post(self.url, self.did, post_id)
229 if not post:
230 self.log.warning("Skipping delete '%s': post not found in db", post_id)
231 return
232
233 post_ref = PostRef(id=post_id, author=self.did, service=self.url)
234 if repost:
235 self.log.info("Deleting repost: '%s'", post_id)
236 for output in self.outputs:
237 self.submitter(lambda: output.delete_repost(post_ref))
238 else:
239 self.log.info("Deleting post: '%s'", post_id)
240 for output in self.outputs:
241 self.submitter(lambda: output.delete_post(post_ref))
242 self.submitter(lambda: self._delete_post_by_id(post["id"]))
243
244
245class BlueskyJetstreamInputService(BlueskyBaseInputService):
246 def __init__(
247 self,
248 db: DatabasePool,
249 http: httpx.Client,
250 options: BlueskyInputOptions,
251 ) -> None:
252 super().__init__(db, http)
253 self.options: BlueskyInputOptions = options
254 self._store = get_store(db)
255 self._init_identity()
256
257 @override
258 def get_identity_options(self) -> tuple[str | None, str | None, str | None]:
259 return (self.options.handle, self.options.did, self.options.pds)
260
261 def _accept_msg(self, msg: websockets.Data) -> None:
262 data: dict[str, Any] = cast(dict[str, Any], json.loads(msg))
263 if data.get("did") != self.did:
264 return
265 commit: dict[str, Any] | None = data.get("commit")
266 if not commit:
267 return
268
269 commit_type: str = cast(str, commit["operation"])
270 match commit_type:
271 case "create":
272 record: dict[str, Any] = cast(dict[str, Any], commit["record"])
273 record["$xpost.strongRef"] = {
274 "cid": commit["cid"],
275 "uri": f"at://{self.did}/{commit['collection']}/{commit['rkey']}",
276 }
277
278 match cast(str, commit["collection"]):
279 case "app.bsky.feed.post":
280 self._on_post(record)
281 case "app.bsky.feed.repost":
282 self._on_repost(record)
283 case _:
284 pass
285 case "delete":
286 post_id: str = (
287 f"at://{self.did}/{commit['collection']}/{commit['rkey']}"
288 )
289 match cast(str, commit["collection"]):
290 case "app.bsky.feed.post":
291 self._on_delete_post(post_id, False)
292 case "app.bsky.feed.repost":
293 self._on_delete_post(post_id, True)
294 case _:
295 pass
296 case _:
297 pass
298
299 @override
300 async def listen(self):
301 url = env.JETSTREAM_URL + "?"
302 url += "wantedCollections=app.bsky.feed.post"
303 url += "&wantedCollections=app.bsky.feed.repost"
304 url += f"&wantedDids={self.did}"
305
306 async for ws in websockets.connect(
307 url,
308 ping_interval=20,
309 ping_timeout=10,
310 close_timeout=5,
311 ):
312 try:
313 self.log.info("Listening to '%s'...", env.JETSTREAM_URL)
314
315 async def listen_for_messages():
316 async for msg in ws:
317 self.submitter(lambda: self._accept_msg(msg))
318
319 listen = asyncio.create_task(listen_for_messages())
320
321 _ = await asyncio.gather(listen)
322 except websockets.ConnectionClosedError as e:
323 self.log.error(e, stack_info=True, exc_info=True)
324 self.log.info("Reconnecting to '%s'...", env.JETSTREAM_URL)
325 continue
326 except TimeoutError as e:
327 self.log.error("Connection timeout: '%s'", e)
328 self.log.info("Reconnecting to '%s'...", env.JETSTREAM_URL)
329 continue