this repo has no description
at main 619 lines 22 kB view raw
1import json 2import yaml 3import dotenv 4import os 5import logging 6from typing import Optional, Dict, Any, List 7from atproto_client import Client, Session, SessionEvent, models 8 9# Configure logging 10logging.basicConfig( 11 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 12) 13logger = logging.getLogger("bluesky_session_handler") 14 15# Load the environment variables 16dotenv.load_dotenv(override=True) 17 18 19# Strip fields. A list of fields to remove from a JSON object 20STRIP_FIELDS = [ 21 "cid", 22 "rev", 23 "did", 24 "uri", 25 "langs", 26 "threadgate", 27 "py_type", 28 "labels", 29 "facets", 30 "avatar", 31 "viewer", 32 "indexed_at", 33 "tags", 34 "associated", 35 "thread_context", 36 "aspect_ratio", 37 "thumb", 38 "fullsize", 39 "root", 40 "created_at", 41 "verification", 42 "like_count", 43 "quote_count", 44 "reply_count", 45 "repost_count", 46 "embedding_disabled", 47 "thread_muted", 48 "reply_disabled", 49 "pinned", 50 "like", 51 "repost", 52 "blocked_by", 53 "blocking", 54 "blocking_by_list", 55 "followed_by", 56 "following", 57 "known_followers", 58 "muted", 59 "muted_by_list", 60 "root_author_like", 61 "entities", 62 "ref", 63 "mime_type", 64 "size", 65] 66 67 68def convert_to_basic_types(obj): 69 """Convert complex Python objects to basic types for JSON/YAML serialization.""" 70 if hasattr(obj, '__dict__'): 71 # Convert objects with __dict__ to their dictionary representation 72 return convert_to_basic_types(obj.__dict__) 73 elif isinstance(obj, dict): 74 return {key: convert_to_basic_types(value) for key, value in obj.items()} 75 elif isinstance(obj, list): 76 return [convert_to_basic_types(item) for item in obj] 77 elif isinstance(obj, (str, int, float, bool)) or obj is None: 78 return obj 79 else: 80 # For other types, try to convert to string 81 return str(obj) 82 83 84def strip_fields(obj, strip_field_list): 85 """Recursively strip fields from a JSON object.""" 86 if isinstance(obj, dict): 87 keys_flagged_for_removal = [] 88 89 # Remove fields from strip list and pydantic metadata 90 for field in list(obj.keys()): 91 if field in strip_field_list or field.startswith("__"): 92 keys_flagged_for_removal.append(field) 93 94 # Remove flagged keys 95 for key in keys_flagged_for_removal: 96 obj.pop(key, None) 97 98 # Recursively process remaining values 99 for key, value in list(obj.items()): 100 obj[key] = strip_fields(value, strip_field_list) 101 # Remove empty/null values after processing 102 if ( 103 obj[key] is None 104 or (isinstance(obj[key], dict) and len(obj[key]) == 0) 105 or (isinstance(obj[key], list) and len(obj[key]) == 0) 106 or (isinstance(obj[key], str) and obj[key].strip() == "") 107 ): 108 obj.pop(key, None) 109 110 elif isinstance(obj, list): 111 for i, value in enumerate(obj): 112 obj[i] = strip_fields(value, strip_field_list) 113 # Remove None values from list 114 obj[:] = [item for item in obj if item is not None] 115 116 return obj 117 118 119def flatten_thread_structure(thread_data): 120 """ 121 Flatten a nested thread structure into a list while preserving all data. 122 123 Args: 124 thread_data: The thread data from get_post_thread 125 126 Returns: 127 Dict with 'posts' key containing a list of posts in chronological order 128 """ 129 posts = [] 130 131 def traverse_thread(node): 132 """Recursively traverse the thread structure to collect posts.""" 133 if not node: 134 return 135 136 # If this node has a parent, traverse it first (to maintain chronological order) 137 if hasattr(node, 'parent') and node.parent: 138 traverse_thread(node.parent) 139 140 # Then add this node's post 141 if hasattr(node, 'post') and node.post: 142 # Convert to dict if needed to ensure we can process it 143 if hasattr(node.post, '__dict__'): 144 post_dict = node.post.__dict__.copy() 145 elif isinstance(node.post, dict): 146 post_dict = node.post.copy() 147 else: 148 post_dict = {} 149 150 posts.append(post_dict) 151 152 # Handle the thread structure 153 if hasattr(thread_data, 'thread'): 154 # Start from the main thread node 155 traverse_thread(thread_data.thread) 156 elif hasattr(thread_data, '__dict__') and 'thread' in thread_data.__dict__: 157 traverse_thread(thread_data.__dict__['thread']) 158 159 # Return a simple structure with posts list 160 return {'posts': posts} 161 162 163def thread_to_yaml_string(thread, strip_metadata=True): 164 """ 165 Convert thread data to a YAML-formatted string for LLM parsing. 166 167 Args: 168 thread: The thread data from get_post_thread 169 strip_metadata: Whether to strip metadata fields for cleaner output 170 171 Returns: 172 YAML-formatted string representation of the thread 173 """ 174 # First flatten the thread structure to avoid deep nesting 175 flattened = flatten_thread_structure(thread) 176 177 # Convert complex objects to basic types 178 basic_thread = convert_to_basic_types(flattened) 179 180 if strip_metadata: 181 # Create a copy and strip unwanted fields 182 cleaned_thread = strip_fields(basic_thread, STRIP_FIELDS) 183 else: 184 cleaned_thread = basic_thread 185 186 return yaml.dump(cleaned_thread, indent=2, allow_unicode=True, default_flow_style=False) 187 188 189def get_session(username: str) -> Optional[str]: 190 try: 191 with open(f"session_{username}.txt", encoding="UTF-8") as f: 192 return f.read() 193 except FileNotFoundError: 194 logger.debug(f"No existing session found for {username}") 195 return None 196 197 198def save_session(username: str, session_string: str) -> None: 199 with open(f"session_{username}.txt", "w", encoding="UTF-8") as f: 200 f.write(session_string) 201 logger.debug(f"Session saved for {username}") 202 203 204def on_session_change(username: str, event: SessionEvent, session: Session) -> None: 205 logger.debug(f"Session changed: {event} {repr(session)}") 206 if event in (SessionEvent.CREATE, SessionEvent.REFRESH): 207 logger.debug(f"Saving changed session for {username}") 208 save_session(username, session.export()) 209 210 211def init_client(username: str, password: str, pds_uri: str = "https://bsky.social") -> Client: 212 if pds_uri is None: 213 logger.warning( 214 "No PDS URI provided. Falling back to bsky.social. Note! If you are on a non-Bluesky PDS, this can cause logins to fail. Please provide a PDS URI using the PDS_URI environment variable." 215 ) 216 pds_uri = "https://bsky.social" 217 218 # Print the PDS URI 219 logger.debug(f"Using PDS URI: {pds_uri}") 220 221 client = Client(pds_uri) 222 client.on_session_change( 223 lambda event, session: on_session_change(username, event, session) 224 ) 225 226 session_string = get_session(username) 227 if session_string: 228 logger.debug(f"Reusing existing session for {username}") 229 client.login(session_string=session_string) 230 else: 231 logger.debug(f"Creating new session for {username}") 232 client.login(username, password) 233 234 return client 235 236 237def default_login() -> Client: 238 # Try to load from config first, fall back to environment variables 239 try: 240 from config_loader import get_bluesky_config 241 config = get_bluesky_config() 242 username = config['username'] 243 password = config['password'] 244 pds_uri = config['pds_uri'] 245 except (ImportError, FileNotFoundError, KeyError) as e: 246 logger.warning( 247 f"Could not load from config file ({e}), falling back to environment variables") 248 username = os.getenv("BSKY_USERNAME") 249 password = os.getenv("BSKY_PASSWORD") 250 pds_uri = os.getenv("PDS_URI", "https://bsky.social") 251 252 if username is None: 253 logger.error( 254 "No username provided. Please provide a username using the BSKY_USERNAME environment variable or config.yaml." 255 ) 256 exit() 257 258 if password is None: 259 logger.error( 260 "No password provided. Please provide a password using the BSKY_PASSWORD environment variable or config.yaml." 261 ) 262 exit() 263 264 return init_client(username, password, pds_uri) 265 266 267def remove_outside_quotes(text: str) -> str: 268 """ 269 Remove outside double quotes from response text. 270 271 Only handles double quotes to avoid interfering with contractions: 272 - Double quotes: "text" → text 273 - Preserves single quotes and internal quotes 274 275 Args: 276 text: The text to process 277 278 Returns: 279 Text with outside double quotes removed 280 """ 281 if not text or len(text) < 2: 282 return text 283 284 text = text.strip() 285 286 # Only remove double quotes from start and end 287 if text.startswith('"') and text.endswith('"'): 288 return text[1:-1] 289 290 return text 291 292 293def reply_to_post(client: Client, text: str, reply_to_uri: str, reply_to_cid: str, root_uri: Optional[str] = None, root_cid: Optional[str] = None, lang: Optional[str] = None) -> Dict[str, Any]: 294 """ 295 Reply to a post on Bluesky with rich text support. 296 297 Args: 298 client: Authenticated Bluesky client 299 text: The reply text 300 reply_to_uri: The URI of the post being replied to (parent) 301 reply_to_cid: The CID of the post being replied to (parent) 302 root_uri: The URI of the root post (if replying to a reply). If None, uses reply_to_uri 303 root_cid: The CID of the root post (if replying to a reply). If None, uses reply_to_cid 304 lang: Language code for the post (e.g., 'en-US', 'es', 'ja') 305 306 Returns: 307 The response from sending the post 308 """ 309 import re 310 311 # If root is not provided, this is a reply to the root post 312 if root_uri is None: 313 root_uri = reply_to_uri 314 root_cid = reply_to_cid 315 316 # Create references for the reply 317 parent_ref = models.create_strong_ref( 318 models.ComAtprotoRepoStrongRef.Main(uri=reply_to_uri, cid=reply_to_cid)) 319 root_ref = models.create_strong_ref( 320 models.ComAtprotoRepoStrongRef.Main(uri=root_uri, cid=root_cid)) 321 322 # Parse rich text facets (mentions and URLs) 323 facets = [] 324 text_bytes = text.encode("UTF-8") 325 326 # Parse mentions - fixed to handle @ at start of text 327 mention_regex = rb"(?:^|[$|\W])(@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)" 328 329 for m in re.finditer(mention_regex, text_bytes): 330 handle = m.group(1)[1:].decode("UTF-8") # Remove @ prefix 331 # Adjust byte positions to account for the optional prefix 332 mention_start = m.start(1) 333 mention_end = m.end(1) 334 try: 335 # Resolve handle to DID using the API 336 resolve_resp = client.app.bsky.actor.get_profile({'actor': handle}) 337 if resolve_resp and hasattr(resolve_resp, 'did'): 338 facets.append( 339 models.AppBskyRichtextFacet.Main( 340 index=models.AppBskyRichtextFacet.ByteSlice( 341 byteStart=mention_start, 342 byteEnd=mention_end 343 ), 344 features=[models.AppBskyRichtextFacet.Mention( 345 did=resolve_resp.did)] 346 ) 347 ) 348 except Exception as e: 349 # Handle specific error cases 350 error_str = str(e) 351 if 'Could not find user info' in error_str or 'InvalidRequest' in error_str: 352 logger.warning( 353 f"User @{handle} not found (account may be deleted/suspended), skipping mention facet") 354 elif 'BadRequestError' in error_str: 355 logger.warning( 356 f"Bad request when resolving @{handle}, skipping mention facet: {e}") 357 else: 358 logger.debug(f"Failed to resolve handle @{handle}: {e}") 359 continue 360 361 # Parse URLs - fixed to handle URLs at start of text 362 url_regex = rb"(?:^|[$|\W])(https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*[-a-zA-Z0-9@%_\+~#//=])?)" 363 364 for m in re.finditer(url_regex, text_bytes): 365 url = m.group(1).decode("UTF-8") 366 # Adjust byte positions to account for the optional prefix 367 url_start = m.start(1) 368 url_end = m.end(1) 369 facets.append( 370 models.AppBskyRichtextFacet.Main( 371 index=models.AppBskyRichtextFacet.ByteSlice( 372 byteStart=url_start, 373 byteEnd=url_end 374 ), 375 features=[models.AppBskyRichtextFacet.Link(uri=url)] 376 ) 377 ) 378 379 # Send the reply with facets if any were found 380 if facets: 381 response = client.send_post( 382 text=text, 383 reply_to=models.AppBskyFeedPost.ReplyRef( 384 parent=parent_ref, root=root_ref), 385 facets=facets, 386 langs=[lang] if lang else None 387 ) 388 else: 389 response = client.send_post( 390 text=text, 391 reply_to=models.AppBskyFeedPost.ReplyRef( 392 parent=parent_ref, root=root_ref), 393 langs=[lang] if lang else None 394 ) 395 396 logger.info(f"Reply sent successfully: {response.uri}") 397 return response 398 399 400def get_post_thread(client: Client, uri: str) -> Optional[Dict[str, Any]]: 401 """ 402 Get the thread containing a post to find root post information. 403 404 Args: 405 client: Authenticated Bluesky client 406 uri: The URI of the post 407 408 Returns: 409 The thread data or None if not found 410 """ 411 try: 412 thread = client.app.bsky.feed.get_post_thread( 413 {'uri': uri, 'parent_height': 60, 'depth': 10}) 414 return thread 415 except Exception as e: 416 error_str = str(e) 417 # Handle specific error cases more gracefully 418 if 'Could not find user info' in error_str or 'InvalidRequest' in error_str: 419 logger.warning( 420 f"User account not found for post URI {uri} (account may be deleted/suspended)") 421 elif 'NotFound' in error_str or 'Post not found' in error_str: 422 logger.warning(f"Post not found for URI {uri}") 423 elif 'BadRequestError' in error_str: 424 logger.warning(f"Bad request error for URI {uri}: {e}") 425 else: 426 logger.error(f"Error fetching post thread: {e}") 427 return None 428 429 430def reply_to_notification(client: Client, notification: Any, reply_text: str, lang: str = "en-US") -> Optional[Dict[str, Any]]: 431 """ 432 Reply to a notification (mention or reply). 433 434 Args: 435 client: Authenticated Bluesky client 436 notification: The notification object from list_notifications 437 reply_text: The text to reply with 438 lang: Language code for the post (defaults to "en-US") 439 440 Returns: 441 The response from sending the reply or None if failed 442 """ 443 try: 444 # Get the post URI and CID from the notification (handle both dict and object) 445 if isinstance(notification, dict): 446 post_uri = notification.get('uri') 447 post_cid = notification.get('cid') 448 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'): 449 post_uri = notification.uri 450 post_cid = notification.cid 451 else: 452 post_uri = None 453 post_cid = None 454 455 if not post_uri or not post_cid: 456 logger.error("Notification doesn't have required uri/cid fields") 457 return None 458 459 # Get the thread to find the root post 460 thread_data = get_post_thread(client, post_uri) 461 462 if thread_data and hasattr(thread_data, 'thread'): 463 thread = thread_data.thread 464 465 # Find root post 466 root_uri = post_uri 467 root_cid = post_cid 468 469 # If this has a parent, find the root 470 if hasattr(thread, 'parent') and thread.parent: 471 # Keep going up until we find the root 472 current = thread 473 while hasattr(current, 'parent') and current.parent: 474 current = current.parent 475 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'): 476 root_uri = current.post.uri 477 root_cid = current.post.cid 478 479 # Reply to the notification 480 return reply_to_post( 481 client=client, 482 text=reply_text, 483 reply_to_uri=post_uri, 484 reply_to_cid=post_cid, 485 root_uri=root_uri, 486 root_cid=root_cid, 487 lang=lang 488 ) 489 else: 490 # If we can't get thread data, just reply directly 491 return reply_to_post( 492 client=client, 493 text=reply_text, 494 reply_to_uri=post_uri, 495 reply_to_cid=post_cid, 496 lang=lang 497 ) 498 499 except Exception as e: 500 logger.error(f"Error replying to notification: {e}") 501 return None 502 503 504def reply_with_thread_to_notification(client: Client, notification: Any, reply_messages: List[str], lang: str = "en-US") -> Optional[List[Dict[str, Any]]]: 505 """ 506 Reply to a notification with a threaded chain of messages (max 15). 507 508 Args: 509 client: Authenticated Bluesky client 510 notification: The notification object from list_notifications 511 reply_messages: List of reply texts (max 15 messages, each max 300 chars) 512 lang: Language code for the posts (defaults to "en-US") 513 514 Returns: 515 List of responses from sending the replies or None if failed 516 """ 517 try: 518 # Validate input 519 if not reply_messages or len(reply_messages) == 0: 520 logger.error("Reply messages list cannot be empty") 521 return None 522 if len(reply_messages) > 15: 523 logger.error( 524 f"Cannot send more than 15 reply messages (got {len(reply_messages)})") 525 return None 526 527 # Get the post URI and CID from the notification (handle both dict and object) 528 if isinstance(notification, dict): 529 post_uri = notification.get('uri') 530 post_cid = notification.get('cid') 531 elif hasattr(notification, 'uri') and hasattr(notification, 'cid'): 532 post_uri = notification.uri 533 post_cid = notification.cid 534 else: 535 post_uri = None 536 post_cid = None 537 538 if not post_uri or not post_cid: 539 logger.error("Notification doesn't have required uri/cid fields") 540 return None 541 542 # Get the thread to find the root post 543 thread_data = get_post_thread(client, post_uri) 544 545 root_uri = post_uri 546 root_cid = post_cid 547 548 if thread_data and hasattr(thread_data, 'thread'): 549 thread = thread_data.thread 550 # If this has a parent, find the root 551 if hasattr(thread, 'parent') and thread.parent: 552 # Keep going up until we find the root 553 current = thread 554 while hasattr(current, 'parent') and current.parent: 555 current = current.parent 556 if hasattr(current, 'post') and hasattr(current.post, 'uri') and hasattr(current.post, 'cid'): 557 root_uri = current.post.uri 558 root_cid = current.post.cid 559 560 # Send replies in sequence, creating a thread 561 responses = [] 562 current_parent_uri = post_uri 563 current_parent_cid = post_cid 564 565 for i, message in enumerate(reply_messages): 566 logger.info( 567 f"Sending reply {i+1}/{len(reply_messages)}: {message[:50]}...") 568 569 # Send this reply 570 response = reply_to_post( 571 client=client, 572 text=message, 573 reply_to_uri=current_parent_uri, 574 reply_to_cid=current_parent_cid, 575 root_uri=root_uri, 576 root_cid=root_cid, 577 lang=lang 578 ) 579 580 if not response: 581 logger.error( 582 f"Failed to send reply {i+1}, posting system failure message") 583 # Try to post a system failure message 584 failure_response = reply_to_post( 585 client=client, 586 text="[SYSTEM FAILURE: COULD NOT POST MESSAGE, PLEASE TRY AGAIN]", 587 reply_to_uri=current_parent_uri, 588 reply_to_cid=current_parent_cid, 589 root_uri=root_uri, 590 root_cid=root_cid, 591 lang=lang 592 ) 593 if failure_response: 594 responses.append(failure_response) 595 current_parent_uri = failure_response.uri 596 current_parent_cid = failure_response.cid 597 else: 598 logger.error( 599 "Could not even send system failure message, stopping thread") 600 return responses if responses else None 601 else: 602 responses.append(response) 603 # Update parent references for next reply (if any) 604 if i < len(reply_messages) - 1: # Not the last message 605 current_parent_uri = response.uri 606 current_parent_cid = response.cid 607 608 logger.info(f"Successfully sent {len(responses)} threaded replies") 609 return responses 610 611 except Exception as e: 612 logger.error(f"Error sending threaded reply to notification: {e}") 613 return None 614 615 616if __name__ == "__main__": 617 client = default_login() 618 # do something with the client 619 logger.info("Client is ready to use!")