social media crossposting tool. 3rd time's the charm
mastodon
misskey
crossposting
bluesky
1import argparse
2import asyncio
3import json
4import queue
5import threading
6from collections.abc import Callable
7from typing import Any
8
9import env
10from database.connection import DatabasePool
11from database.migrations import DatabaseMigrator
12from util.util import LOGGER, read_env, shutdown_hook
13
14
15EXAMPLE_CONFIG = {
16 "services": [
17 {
18 "input": {"type": "bluesky-jetstream", "handle": "bsky.app"},
19 "outputs": [
20 {
21 "type": "mastodon",
22 "instance": "https://mastodon.social",
23 "token": "env:MASTODON_TOKEN",
24 }
25 ],
26 }
27 ]
28}
29
30
31def dump_example_config() -> None:
32 env.SETTINGS_DIR.parent.mkdir(parents=True, exist_ok=True)
33 with open(env.SETTINGS_DIR, "w") as f:
34 json.dump(EXAMPLE_CONFIG, f, indent=2)
35
36
37def flush_caches() -> None:
38 from atproto.store import flush_caches as flush_atproto_caches
39 from atproto.store import get_store
40
41 db_pool = DatabasePool(env.DATABASE_DIR)
42 get_store(db_pool)
43
44 LOGGER.info("Flushing atproto caches...")
45 sessions, identities = flush_atproto_caches()
46 LOGGER.info("Flushed %d sessions and %d identities", sessions, identities)
47 LOGGER.info("Cache flush complete!")
48
49 db_pool.close()
50
51
52def main() -> None:
53 parser = argparse.ArgumentParser(
54 description="xpost: social media crossposting tool"
55 )
56 parser.add_argument(
57 "--flush-caches",
58 action="store_true",
59 help="Flush all caches (like sessions and identities)",
60 )
61 args = parser.parse_args()
62
63 if args.flush_caches:
64 flush_caches()
65 return
66
67 if not env.DATA_DIR.exists():
68 env.DATA_DIR.mkdir(parents=True)
69
70 if not env.SETTINGS_DIR.exists():
71 LOGGER.info("First launch detected! Creating %s and exiting!", env.SETTINGS_DIR)
72 dump_example_config()
73 LOGGER.info("Example config written to %s", env.SETTINGS_DIR)
74 LOGGER.info("Please edit the config file and run again!")
75 return
76
77 migrator = DatabaseMigrator(env.DATABASE_DIR, env.MIGRATIONS_DIR)
78 try:
79 migrator.migrate()
80 except Exception:
81 LOGGER.exception("Failed to migrate database!")
82 return
83 finally:
84 migrator.close()
85
86 db_pool = DatabasePool(env.DATABASE_DIR)
87 import httpx
88
89 http_client = httpx.Client(timeout=httpx.Timeout(30))
90
91 LOGGER.info("Bootstrapping registries...")
92 from registry import create_input_service, create_output_service
93 from registry_bootstrap import bootstrap
94
95 bootstrap()
96
97 LOGGER.info("Loading settings...")
98
99 with open(env.SETTINGS_DIR) as f:
100 settings = json.load(f)
101 read_env(settings)
102
103 if "services" not in settings:
104 raise KeyError("No 'services' specified in settings!")
105
106 service_pairs: list[tuple[Any, list[Any]]] = []
107 for svc in settings["services"]:
108 if "input" not in svc:
109 raise KeyError("Each service must have an 'input' field!")
110 if "outputs" not in svc:
111 raise KeyError("Each service must have an 'outputs' field!")
112
113 inp = create_input_service(db_pool, http_client, svc["input"])
114 outs = [
115 create_output_service(db_pool, http_client, data) for data in svc["outputs"]
116 ]
117 service_pairs.append((inp, outs))
118
119 LOGGER.info("Starting task worker...")
120
121 def worker(task_queue: queue.Queue[Callable[[], None] | None]):
122 while True:
123 task = task_queue.get()
124 if task is None:
125 break
126
127 try:
128 task()
129 except Exception:
130 LOGGER.exception("Exception in worker thread!")
131 finally:
132 task_queue.task_done()
133
134 task_queue: queue.Queue[Callable[[], None] | None] = queue.Queue()
135 thread = threading.Thread(target=worker, args=(task_queue,), daemon=True)
136 thread.start()
137
138 for inp, outs in service_pairs:
139 inp.outputs = outs
140 inp.submitter = lambda c: task_queue.put(c)
141
142 inputs = [inp for inp, _ in service_pairs]
143 LOGGER.info("Starting %d input service(s)...", len(inputs))
144 try:
145 asyncio.run(_run_all_inputs(inputs))
146 except KeyboardInterrupt:
147 LOGGER.info("Stopping...")
148
149 task_queue.join()
150 task_queue.put(None)
151 thread.join()
152
153 for shook in shutdown_hook:
154 shook()
155
156 db_pool.close()
157 http_client.close()
158
159
160async def _run_all_inputs(inputs: list[Any]) -> None:
161 tasks = [asyncio.create_task(inp.listen()) for inp in inputs]
162 await asyncio.gather(*tasks, return_exceptions=True)
163
164
165if __name__ == "__main__":
166 main()