social media crossposting tool. 3rd time's the charm
mastodon misskey crossposting bluesky
at next 166 lines 4.7 kB view raw
1import argparse 2import asyncio 3import json 4import queue 5import threading 6from collections.abc import Callable 7from typing import Any 8 9import env 10from database.connection import DatabasePool 11from database.migrations import DatabaseMigrator 12from util.util import LOGGER, read_env, shutdown_hook 13 14 15EXAMPLE_CONFIG = { 16 "services": [ 17 { 18 "input": {"type": "bluesky-jetstream", "handle": "bsky.app"}, 19 "outputs": [ 20 { 21 "type": "mastodon", 22 "instance": "https://mastodon.social", 23 "token": "env:MASTODON_TOKEN", 24 } 25 ], 26 } 27 ] 28} 29 30 31def dump_example_config() -> None: 32 env.SETTINGS_DIR.parent.mkdir(parents=True, exist_ok=True) 33 with open(env.SETTINGS_DIR, "w") as f: 34 json.dump(EXAMPLE_CONFIG, f, indent=2) 35 36 37def flush_caches() -> None: 38 from atproto.store import flush_caches as flush_atproto_caches 39 from atproto.store import get_store 40 41 db_pool = DatabasePool(env.DATABASE_DIR) 42 get_store(db_pool) 43 44 LOGGER.info("Flushing atproto caches...") 45 sessions, identities = flush_atproto_caches() 46 LOGGER.info("Flushed %d sessions and %d identities", sessions, identities) 47 LOGGER.info("Cache flush complete!") 48 49 db_pool.close() 50 51 52def main() -> None: 53 parser = argparse.ArgumentParser( 54 description="xpost: social media crossposting tool" 55 ) 56 parser.add_argument( 57 "--flush-caches", 58 action="store_true", 59 help="Flush all caches (like sessions and identities)", 60 ) 61 args = parser.parse_args() 62 63 if args.flush_caches: 64 flush_caches() 65 return 66 67 if not env.DATA_DIR.exists(): 68 env.DATA_DIR.mkdir(parents=True) 69 70 if not env.SETTINGS_DIR.exists(): 71 LOGGER.info("First launch detected! Creating %s and exiting!", env.SETTINGS_DIR) 72 dump_example_config() 73 LOGGER.info("Example config written to %s", env.SETTINGS_DIR) 74 LOGGER.info("Please edit the config file and run again!") 75 return 76 77 migrator = DatabaseMigrator(env.DATABASE_DIR, env.MIGRATIONS_DIR) 78 try: 79 migrator.migrate() 80 except Exception: 81 LOGGER.exception("Failed to migrate database!") 82 return 83 finally: 84 migrator.close() 85 86 db_pool = DatabasePool(env.DATABASE_DIR) 87 import httpx 88 89 http_client = httpx.Client(timeout=httpx.Timeout(30)) 90 91 LOGGER.info("Bootstrapping registries...") 92 from registry import create_input_service, create_output_service 93 from registry_bootstrap import bootstrap 94 95 bootstrap() 96 97 LOGGER.info("Loading settings...") 98 99 with open(env.SETTINGS_DIR) as f: 100 settings = json.load(f) 101 read_env(settings) 102 103 if "services" not in settings: 104 raise KeyError("No 'services' specified in settings!") 105 106 service_pairs: list[tuple[Any, list[Any]]] = [] 107 for svc in settings["services"]: 108 if "input" not in svc: 109 raise KeyError("Each service must have an 'input' field!") 110 if "outputs" not in svc: 111 raise KeyError("Each service must have an 'outputs' field!") 112 113 inp = create_input_service(db_pool, http_client, svc["input"]) 114 outs = [ 115 create_output_service(db_pool, http_client, data) for data in svc["outputs"] 116 ] 117 service_pairs.append((inp, outs)) 118 119 LOGGER.info("Starting task worker...") 120 121 def worker(task_queue: queue.Queue[Callable[[], None] | None]): 122 while True: 123 task = task_queue.get() 124 if task is None: 125 break 126 127 try: 128 task() 129 except Exception: 130 LOGGER.exception("Exception in worker thread!") 131 finally: 132 task_queue.task_done() 133 134 task_queue: queue.Queue[Callable[[], None] | None] = queue.Queue() 135 thread = threading.Thread(target=worker, args=(task_queue,), daemon=True) 136 thread.start() 137 138 for inp, outs in service_pairs: 139 inp.outputs = outs 140 inp.submitter = lambda c: task_queue.put(c) 141 142 inputs = [inp for inp, _ in service_pairs] 143 LOGGER.info("Starting %d input service(s)...", len(inputs)) 144 try: 145 asyncio.run(_run_all_inputs(inputs)) 146 except KeyboardInterrupt: 147 LOGGER.info("Stopping...") 148 149 task_queue.join() 150 task_queue.put(None) 151 thread.join() 152 153 for shook in shutdown_hook: 154 shook() 155 156 db_pool.close() 157 http_client.close() 158 159 160async def _run_all_inputs(inputs: list[Any]) -> None: 161 tasks = [asyncio.create_task(inp.listen()) for inp in inputs] 162 await asyncio.gather(*tasks, return_exceptions=True) 163 164 165if __name__ == "__main__": 166 main()