this repo has no description
at 20bc691d73a47fecbfa97590f9b88b3842d9df07 553 lines 20 kB view raw
1import os 2import logging 3from typing import Optional, Dict, Any, List 4from atproto_client import Client, Session, SessionEvent, models 5 6# Configure logging 7logging.basicConfig( 8 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 9) 10logger = logging.getLogger("bluesky_session_handler") 11 12# Load the environment variables 13import dotenv 14dotenv.load_dotenv(override=True) 15 16import yaml 17import json 18 19# Strip fields. A list of fields to remove from a JSON object 20STRIP_FIELDS = [ 21 "cid", 22 "rev", 23 "did", 24 "uri", 25 "langs", 26 "threadgate", 27 "py_type", 28 "labels", 29 "facets", 30 "avatar", 31 "viewer", 32 "indexed_at", 33 "tags", 34 "associated", 35 "thread_context", 36 "aspect_ratio", 37 "thumb", 38 "fullsize", 39 "root", 40 "created_at", 41 "verification", 42 "like_count", 43 "quote_count", 44 "reply_count", 45 "repost_count", 46 "embedding_disabled", 47 "thread_muted", 48 "reply_disabled", 49 "pinned", 50 "like", 51 "repost", 52 "blocked_by", 53 "blocking", 54 "blocking_by_list", 55 "followed_by", 56 "following", 57 "known_followers", 58 "muted", 59 "muted_by_list", 60 "root_author_like", 61 "entities", 62 "ref", 63 "mime_type", 64 "size", 65] 66def convert_to_basic_types(obj): 67 """Convert complex Python objects to basic types for JSON/YAML serialization.""" 68 if hasattr(obj, '__dict__'): 69 # Convert objects with __dict__ to their dictionary representation 70 return convert_to_basic_types(obj.__dict__) 71 elif isinstance(obj, dict): 72 return {key: convert_to_basic_types(value) for key, value in obj.items()} 73 elif isinstance(obj, list): 74 return [convert_to_basic_types(item) for item in obj] 75 elif isinstance(obj, (str, int, float, bool)) or obj is None: 76 return obj 77 else: 78 # For other types, try to convert to string 79 return str(obj) 80 81 82def strip_fields(obj, strip_field_list): 83 """Recursively strip fields from a JSON object.""" 84 if isinstance(obj, dict): 85 keys_flagged_for_removal = [] 86 87 # Remove fields from strip list and pydantic metadata 88 for field in list(obj.keys()): 89 if field in strip_field_list or field.startswith("__"): 90 keys_flagged_for_removal.append(field) 91 92 # Remove flagged keys 93 for key in keys_flagged_for_removal: 94 obj.pop(key, None) 95 96 # Recursively process remaining values 97 for key, value in list(obj.items()): 98 obj[key] = strip_fields(value, strip_field_list) 99 # Remove empty/null values after processing 100 if ( 101 obj[key] is None 102 or (isinstance(obj[key], dict) and len(obj[key]) == 0) 103 or (isinstance(obj[key], list) and len(obj[key]) == 0) 104 or (isinstance(obj[key], str) and obj[key].strip() == "") 105 ): 106 obj.pop(key, None) 107 108 elif isinstance(obj, list): 109 for i, value in enumerate(obj): 110 obj[i] = strip_fields(value, strip_field_list) 111 # Remove None values from list 112 obj[:] = [item for item in obj if item is not None] 113 114 return obj 115 116 117def flatten_thread_structure(thread_data): 118 """ 119 Flatten a nested thread structure into a list while preserving all data. 120 121 Args: 122 thread_data: The thread data from get_post_thread 123 124 Returns: 125 Dict with 'posts' key containing a list of posts in chronological order 126 """ 127 posts = [] 128 129 def traverse_thread(node): 130 """Recursively traverse the thread structure to collect posts.""" 131 if not node: 132 return 133 134 # If this node has a parent, traverse it first (to maintain chronological order) 135 if hasattr(node, 'parent') and node.parent: 136 traverse_thread(node.parent) 137 138 # Then add this node's post 139 if hasattr(node, 'post') and node.post: 140 # Convert to dict if needed to ensure we can process it 141 if hasattr(node.post, '__dict__'): 142 post_dict = node.post.__dict__.copy() 143 elif isinstance(node.post, dict): 144 post_dict = node.post.copy() 145 else: 146 post_dict = {} 147 148 posts.append(post_dict) 149 150 # Handle the thread structure 151 if hasattr(thread_data, 'thread'): 152 # Start from the main thread node 153 traverse_thread(thread_data.thread) 154 elif hasattr(thread_data, '__dict__') and 'thread' in thread_data.__dict__: 155 traverse_thread(thread_data.__dict__['thread']) 156 157 # Return a simple structure with posts list 158 return {'posts': posts} 159 160 161def thread_to_yaml_string(thread, strip_metadata=True): 162 """ 163 Convert thread data to a YAML-formatted string for LLM parsing. 164 165 Args: 166 thread: The thread data from get_post_thread 167 strip_metadata: Whether to strip metadata fields for cleaner output 168 169 Returns: 170 YAML-formatted string representation of the thread 171 """ 172 # First flatten the thread structure to avoid deep nesting 173 flattened = flatten_thread_structure(thread) 174 175 # Convert complex objects to basic types 176 basic_thread = convert_to_basic_types(flattened) 177 178 if strip_metadata: 179 # Create a copy and strip unwanted fields 180 cleaned_thread = strip_fields(basic_thread, STRIP_FIELDS) 181 else: 182 cleaned_thread = basic_thread 183 184 return yaml.dump(cleaned_thread, indent=2, allow_unicode=True, default_flow_style=False) 185 186 187 188 189 190 191 192def get_session(username: str) -> Optional[str]: 193 try: 194 with open(f"session_{username}.txt", encoding="UTF-8") as f: 195 return f.read() 196 except FileNotFoundError: 197 logger.debug(f"No existing session found for {username}") 198 return None 199 200def save_session(username: str, session_string: str) -> None: 201 with open(f"session_{username}.txt", "w", encoding="UTF-8") as f: 202 f.write(session_string) 203 logger.debug(f"Session saved for {username}") 204 205def on_session_change(username: str, event: SessionEvent, session: Session) -> None: 206 logger.debug(f"Session changed: {event} {repr(session)}") 207 if event in (SessionEvent.CREATE, SessionEvent.REFRESH): 208 logger.debug(f"Saving changed session for {username}") 209 save_session(username, session.export()) 210 211def init_client(username: str, password: str) -> Client: 212 pds_uri = os.getenv("PDS_URI") 213 if pds_uri is None: 214 logger.warning( 215 "No PDS URI provided. Falling back to bsky.social. Note! If you are on a non-Bluesky PDS, this can cause logins to fail. Please provide a PDS URI using the PDS_URI environment variable." 216 ) 217 pds_uri = "https://bsky.social" 218 219 # Print the PDS URI 220 logger.debug(f"Using PDS URI: {pds_uri}") 221 222 client = Client(pds_uri) 223 client.on_session_change( 224 lambda event, session: on_session_change(username, event, session) 225 ) 226 227 session_string = get_session(username) 228 if session_string: 229 logger.debug(f"Reusing existing session for {username}") 230 client.login(session_string=session_string) 231 else: 232 logger.debug(f"Creating new session for {username}") 233 client.login(username, password) 234 235 return client 236 237 238def default_login() -> Client: 239 username = os.getenv("BSKY_USERNAME") 240 password = os.getenv("BSKY_PASSWORD") 241 242 if username is None: 243 logger.error( 244 "No username provided. Please provide a username using the BSKY_USERNAME environment variable." 245 ) 246 exit() 247 248 if password is None: 249 logger.error( 250 "No password provided. Please provide a password using the BSKY_PASSWORD environment variable." 251 ) 252 exit() 253 254 return init_client(username, password) 255 256def reply_to_post(client: Client, text: str, reply_to_uri: str, reply_to_cid: str, root_uri: Optional[str] = None, root_cid: Optional[str] = None, lang: Optional[str] = None) -> Dict[str, Any]: 257 """ 258 Reply to a post on Bluesky with rich text support. 259 260 Args: 261 client: Authenticated Bluesky client 262 text: The reply text 263 reply_to_uri: The URI of the post being replied to (parent) 264 reply_to_cid: The CID of the post being replied to (parent) 265 root_uri: The URI of the root post (if replying to a reply). If None, uses reply_to_uri 266 root_cid: The CID of the root post (if replying to a reply). If None, uses reply_to_cid 267 lang: Language code for the post (e.g., 'en-US', 'es', 'ja') 268 269 Returns: 270 The response from sending the post 271 """ 272 import re 273 274 # If root is not provided, this is a reply to the root post 275 if root_uri is None: 276 root_uri = reply_to_uri 277 root_cid = reply_to_cid 278 279 # Create references for the reply 280 parent_ref = models.create_strong_ref(models.ComAtprotoRepoStrongRef.Main(uri=reply_to_uri, cid=reply_to_cid)) 281 root_ref = models.create_strong_ref(models.ComAtprotoRepoStrongRef.Main(uri=root_uri, cid=root_cid)) 282 283 # Parse rich text facets (mentions and URLs) 284 facets = [] 285 text_bytes = text.encode("UTF-8") 286 287 # Parse mentions - fixed to handle @ at start of text 288 mention_regex = rb"(?:^|[$|\W])(@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)" 289 290 for m in re.finditer(mention_regex, text_bytes): 291 handle = m.group(1)[1:].decode("UTF-8") # Remove @ prefix 292 # Adjust byte positions to account for the optional prefix 293 mention_start = m.start(1) 294 mention_end = m.end(1) 295 try: 296 # Resolve handle to DID using the API 297 resolve_resp = client.app.bsky.actor.get_profile({'actor': handle}) 298 if resolve_resp and hasattr(resolve_resp, 'did'): 299 facets.append( 300 models.AppBskyRichtextFacet.Main( 301 index=models.AppBskyRichtextFacet.ByteSlice( 302 byteStart=mention_start, 303 byteEnd=mention_end 304 ), 305 features=[models.AppBskyRichtextFacet.Mention(did=resolve_resp.did)] 306 ) 307 ) 308 except Exception as e: 309 logger.debug(f"Failed to resolve handle {handle}: {e}") 310 continue 311 312 # Parse URLs - fixed to handle URLs at start of text 313 url_regex = rb"(?:^|[$|\W])(https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*[-a-zA-Z0-9@%_\+~#//=])?)" 314 315 for m in re.finditer(url_regex, text_bytes): 316 url = m.group(1).decode("UTF-8") 317 # Adjust byte positions to account for the optional prefix 318 url_start = m.start(1) 319 url_end = m.end(1) 320 facets.append( 321 models.AppBskyRichtextFacet.Main( 322 index=models.AppBskyRichtextFacet.ByteSlice( 323 byteStart=url_start, 324 byteEnd=url_end 325 ), 326 features=[models.AppBskyRichtextFacet.Link(uri=url)] 327 ) 328 ) 329 330 # Send the reply with facets if any were found 331 if facets: 332 response = client.send_post( 333 text=text, 334 reply_to=models.AppBskyFeedPost.ReplyRef(parent=parent_ref, root=root_ref), 335 facets=facets, 336 langs=[lang] if lang else None 337 ) 338 else: 339 response = client.send_post( 340 text=text, 341 reply_to=models.AppBskyFeedPost.ReplyRef(parent=parent_ref, root=root_ref), 342 langs=[lang] if lang else None 343 ) 344 345 logger.info(f"Reply sent successfully: {response.uri}") 346 return response 347 348 349def get_post_thread(client: Client, uri: str) -> Optional[Dict[str, Any]]: 350 """ 351 Get the thread containing a post to find root post information. 352 353 Args: 354 client: Authenticated Bluesky client 355 uri: The URI of the post 356 357 Returns: 358 The thread data or None if not found 359 """ 360 try: 361 thread = client.app.bsky.feed.get_post_thread({'uri': uri, 'parent_height': 60, 'depth': 10}) 362 return thread 363 except Exception as e: 364 logger.error(f"Error fetching post thread: {e}") 365 return None 366 367 368def reply_to_notification(client: Client, notification: Any, reply_text: str, lang: str = "en-US") -> Optional[Dict[str, Any]]: 369 """ 370 Reply to a notification (mention or reply). 371 372 Args: 373 client: Authenticated Bluesky client 374 notification: The notification object from list_notifications 375 reply_text: The text to reply with 376 lang: Language code for the post (defaults to "en-US") 377 378 Returns: 379 The response from sending the reply or None if failed 380 """ 381 try: 382 # Get the post URI and CID from the notification (handle both dict and object) 383 if isinstance(notification, dict): 384 post_uri = notification.get('uri') 385 post_cid = notification.get('cid') 386 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'): 387 post_uri = notification.uri 388 post_cid = notification.cid 389 else: 390 post_uri = None 391 post_cid = None 392 393 if not post_uri or not post_cid: 394 logger.error("Notification doesn't have required uri/cid fields") 395 return None 396 397 # Get the thread to find the root post 398 thread_data = get_post_thread(client, post_uri) 399 400 if thread_data and hasattr(thread_data, 'thread'): 401 thread = thread_data.thread 402 403 # Find root post 404 root_uri = post_uri 405 root_cid = post_cid 406 407 # If this has a parent, find the root 408 if hasattr(thread, 'parent') and thread.parent: 409 # Keep going up until we find the root 410 current = thread 411 while hasattr(current, 'parent') and current.parent: 412 current = current.parent 413 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'): 414 root_uri = current.post.uri 415 root_cid = current.post.cid 416 417 # Reply to the notification 418 return reply_to_post( 419 client=client, 420 text=reply_text, 421 reply_to_uri=post_uri, 422 reply_to_cid=post_cid, 423 root_uri=root_uri, 424 root_cid=root_cid, 425 lang=lang 426 ) 427 else: 428 # If we can't get thread data, just reply directly 429 return reply_to_post( 430 client=client, 431 text=reply_text, 432 reply_to_uri=post_uri, 433 reply_to_cid=post_cid, 434 lang=lang 435 ) 436 437 except Exception as e: 438 logger.error(f"Error replying to notification: {e}") 439 return None 440 441 442def reply_with_thread_to_notification(client: Client, notification: Any, reply_messages: List[str], lang: str = "en-US") -> Optional[List[Dict[str, Any]]]: 443 """ 444 Reply to a notification with a threaded chain of messages (max 4). 445 446 Args: 447 client: Authenticated Bluesky client 448 notification: The notification object from list_notifications 449 reply_messages: List of reply texts (max 4 messages, each max 300 chars) 450 lang: Language code for the posts (defaults to "en-US") 451 452 Returns: 453 List of responses from sending the replies or None if failed 454 """ 455 try: 456 # Validate input 457 if not reply_messages or len(reply_messages) == 0: 458 logger.error("Reply messages list cannot be empty") 459 return None 460 if len(reply_messages) > 4: 461 logger.error(f"Cannot send more than 4 reply messages (got {len(reply_messages)})") 462 return None 463 464 # Get the post URI and CID from the notification (handle both dict and object) 465 if isinstance(notification, dict): 466 post_uri = notification.get('uri') 467 post_cid = notification.get('cid') 468 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'): 469 post_uri = notification.uri 470 post_cid = notification.cid 471 else: 472 post_uri = None 473 post_cid = None 474 475 if not post_uri or not post_cid: 476 logger.error("Notification doesn't have required uri/cid fields") 477 return None 478 479 # Get the thread to find the root post 480 thread_data = get_post_thread(client, post_uri) 481 482 root_uri = post_uri 483 root_cid = post_cid 484 485 if thread_data and hasattr(thread_data, 'thread'): 486 thread = thread_data.thread 487 # If this has a parent, find the root 488 if hasattr(thread, 'parent') and thread.parent: 489 # Keep going up until we find the root 490 current = thread 491 while hasattr(current, 'parent') and current.parent: 492 current = current.parent 493 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'): 494 root_uri = current.post.uri 495 root_cid = current.post.cid 496 497 # Send replies in sequence, creating a thread 498 responses = [] 499 current_parent_uri = post_uri 500 current_parent_cid = post_cid 501 502 for i, message in enumerate(reply_messages): 503 logger.info(f"Sending reply {i+1}/{len(reply_messages)}: {message[:50]}...") 504 505 # Send this reply 506 response = reply_to_post( 507 client=client, 508 text=message, 509 reply_to_uri=current_parent_uri, 510 reply_to_cid=current_parent_cid, 511 root_uri=root_uri, 512 root_cid=root_cid, 513 lang=lang 514 ) 515 516 if not response: 517 logger.error(f"Failed to send reply {i+1}, posting system failure message") 518 # Try to post a system failure message 519 failure_response = reply_to_post( 520 client=client, 521 text="[SYSTEM FAILURE: COULD NOT POST MESSAGE, PLEASE TRY AGAIN]", 522 reply_to_uri=current_parent_uri, 523 reply_to_cid=current_parent_cid, 524 root_uri=root_uri, 525 root_cid=root_cid, 526 lang=lang 527 ) 528 if failure_response: 529 responses.append(failure_response) 530 current_parent_uri = failure_response.uri 531 current_parent_cid = failure_response.cid 532 else: 533 logger.error("Could not even send system failure message, stopping thread") 534 return responses if responses else None 535 else: 536 responses.append(response) 537 # Update parent references for next reply (if any) 538 if i < len(reply_messages) - 1: # Not the last message 539 current_parent_uri = response.uri 540 current_parent_cid = response.cid 541 542 logger.info(f"Successfully sent {len(responses)} threaded replies") 543 return responses 544 545 except Exception as e: 546 logger.error(f"Error sending threaded reply to notification: {e}") 547 return None 548 549 550if __name__ == "__main__": 551 client = default_login() 552 # do something with the client 553 logger.info("Client is ready to use!")