Auto-indexing service and GraphQL API for AT Protocol Records quickslice.slices.network/
atproto gleam graphql
at main 729 lines 22 kB view raw
1/// GraphQL WebSocket Handler 2/// 3/// Handles WebSocket connections for GraphQL subscriptions using the graphql-ws protocol 4import atproto_auth 5import database/executor.{type Executor} 6import database/repositories/actors 7import gleam/dict.{type Dict} 8import gleam/erlang/process.{type Subject} 9import gleam/http/request.{type Request} 10import gleam/http/response 11import gleam/list 12import gleam/option.{type Option, None, Some} 13import gleam/result 14import gleam/string 15import graphql/lexicon/schema as lexicon_schema 16import graphql/ws 17import lib/oauth/did_cache 18import logging 19import mist.{ 20 type Connection, type ResponseData, type WebsocketConnection, 21 type WebsocketMessage, 22} 23import pubsub 24import swell/executor as swell_executor 25import swell/parser 26import swell/schema 27import swell/value 28import websocket_ffi 29 30/// Configuration constants 31const max_subscriptions_per_connection = 100 32 33const max_subscriptions_global = 10_000 34 35/// FFI bindings for global atomic subscription counter 36@external(erlang, "subscription_counter_ffi", "increment_global") 37fn increment_global_subscriptions() -> Int 38 39@external(erlang, "subscription_counter_ffi", "decrement_global") 40fn decrement_global_subscriptions() -> Int 41 42@external(erlang, "subscription_counter_ffi", "get_global_count") 43fn get_global_subscription_count() -> Int 44 45/// Convert a PubSub RecordEvent to a GraphQL value.Value 46/// Similar to record_to_graphql_value but for subscription events 47fn event_to_graphql_value( 48 event: pubsub.RecordEvent, 49 db: Executor, 50) -> value.Value { 51 // Parse the record JSON value 52 let value_object = case lexicon_schema.parse_json_to_value(event.value) { 53 Ok(val) -> val 54 Error(_) -> value.Object([]) 55 } 56 57 // Look up actor handle from actor table 58 let actor_handle = case actors.get(db, event.did) { 59 Ok([actor, ..]) -> value.String(actor.handle) 60 _ -> value.Null 61 } 62 63 // Create the full record object with metadata and value 64 value.Object([ 65 #("uri", value.String(event.uri)), 66 #("cid", value.String(event.cid)), 67 #("did", value.String(event.did)), 68 #("collection", value.String(event.collection)), 69 #("indexedAt", value.String(event.indexed_at)), 70 #("actorHandle", actor_handle), 71 #("value", value_object), 72 ]) 73} 74 75/// Execute a subscription query with event data and variables 76/// Returns the formatted JSON response 77fn execute_subscription_query( 78 query: String, 79 variables: Dict(String, value.Value), 80 graphql_schema: schema.Schema, 81 event: pubsub.RecordEvent, 82 db: Executor, 83) -> Result(String, String) { 84 // Convert event to GraphQL value 85 let event_value = event_to_graphql_value(event, db) 86 87 // Create context with the event data and variables 88 let ctx = schema.context_with_variables(Some(event_value), variables) 89 90 // Execute the subscription query directly 91 // The executor now natively supports subscription operations 92 use response <- result.try(swell_executor.execute(query, graphql_schema, ctx)) 93 94 // Format the response as JSON 95 Ok(lexicon_schema.format_response(response)) 96} 97 98/// Convert collection name to GraphQL field name format 99/// Example: "xyz.statusphere.status" -> "xyzStatusphereStatus" 100fn collection_to_graphql_name(collection: String) -> String { 101 collection 102 |> string.split(".") 103 |> list.index_map(fn(part, index) { 104 case index { 105 0 -> part 106 // Keep first segment lowercase 107 _ -> { 108 // Capitalize first letter of subsequent segments 109 case string.pop_grapheme(part) { 110 Ok(#(first, rest)) -> string.uppercase(first) <> rest 111 Error(_) -> part 112 } 113 } 114 } 115 }) 116 |> string.join("") 117} 118 119/// Extract the subscription field name from a parsed GraphQL query 120/// Returns the first field name found in the subscription operation 121fn extract_subscription_field(query: String) -> Result(String, String) { 122 use document <- result.try( 123 parser.parse(query) 124 |> result.map_error(fn(_) { "Failed to parse subscription query" }), 125 ) 126 127 // Find the first subscription operation 128 let subscription_op = 129 document.operations 130 |> list.find(fn(op) { 131 case op { 132 parser.Subscription(_) | parser.NamedSubscription(_, _, _) -> True 133 _ -> False 134 } 135 }) 136 137 use op <- result.try( 138 subscription_op 139 |> result.replace_error("No subscription operation found in query"), 140 ) 141 142 // Extract the selection set 143 let selections = case op { 144 parser.Subscription(parser.SelectionSet(sels)) -> Ok(sels) 145 parser.NamedSubscription(_, _, parser.SelectionSet(sels)) -> Ok(sels) 146 _ -> Error("Invalid subscription operation") 147 } 148 149 use sels <- result.try(selections) 150 151 // Get the first field from the selections 152 case sels { 153 [parser.Field(name, _, _, _), ..] -> Ok(name) 154 _ -> Error("No fields found in subscription") 155 } 156} 157 158/// Subscription metadata 159pub type SubscriptionInfo { 160 SubscriptionInfo( 161 pid: process.Pid, 162 field_name: String, 163 variables: Dict(String, value.Value), 164 ) 165} 166 167/// State for WebSocket connection 168pub type State { 169 State( 170 db: Executor, 171 // Map of subscription ID -> subscription info 172 subscriptions: Dict(String, SubscriptionInfo), 173 // The WebSocket connection for sending frames 174 conn: WebsocketConnection, 175 // Subject for receiving subscription data 176 subscription_subject: process.Subject(websocket_ffi.SubscriptionMessage), 177 // GraphQL schema for executing subscription queries 178 schema: schema.Schema, 179 // Authenticated viewer DID (extracted from auth token) 180 viewer_did: Option(String), 181 ) 182} 183 184/// Handle WebSocket connection request (called from Mist handler) 185pub fn handle_websocket( 186 req: Request(Connection), 187 db: Executor, 188 did_cache: Subject(did_cache.Message), 189 signing_key: Option(String), 190 atp_client_id: String, 191 plc_url: String, 192 domain_authority: String, 193) -> response.Response(ResponseData) { 194 // Extract auth token from request headers before WebSocket upgrade 195 let auth_token = case request.get_header(req, "authorization") { 196 Ok(auth_header) -> { 197 case string.starts_with(auth_header, "Bearer ") { 198 True -> Some(string.drop_start(auth_header, 7)) 199 False -> None 200 } 201 } 202 Error(_) -> None 203 } 204 205 // Verify auth token and extract viewer DID 206 let viewer_did = case auth_token { 207 Some(token) -> { 208 case atproto_auth.verify_token(db, token) { 209 Ok(user_info) -> Some(user_info.did) 210 Error(_) -> None 211 } 212 } 213 None -> None 214 } 215 216 mist.websocket( 217 request: req, 218 on_init: fn(conn) { 219 logging.log(logging.Info, "[websocket] Client connected") 220 221 // Build GraphQL schema for subscriptions 222 let graphql_schema = case 223 lexicon_schema.build_schema_from_db( 224 db, 225 did_cache, 226 signing_key, 227 atp_client_id, 228 plc_url, 229 domain_authority, 230 ) 231 { 232 Ok(schema) -> schema 233 Error(err) -> { 234 // Schema build failed - this is a critical error for subscriptions 235 logging.log( 236 logging.Error, 237 "[websocket] FATAL: Failed to build GraphQL schema: " <> err, 238 ) 239 // Panic because we can't continue without a schema 240 panic as "Cannot initialize WebSocket subscriptions without GraphQL schema" 241 } 242 } 243 244 // Create a Subject for receiving subscription data 245 let subscription_subject = process.new_subject() 246 247 // Create a selector that listens to the subject 248 let selector = 249 process.new_selector() 250 |> process.select(subscription_subject) 251 252 let state = 253 State( 254 db: db, 255 subscriptions: dict.new(), 256 conn: conn, 257 subscription_subject: subscription_subject, 258 schema: graphql_schema, 259 viewer_did: viewer_did, 260 ) 261 262 #(state, Some(selector)) 263 }, 264 on_close: fn(state) { 265 logging.log(logging.Info, "[websocket] Client disconnected") 266 267 // Clean up all active subscriptions 268 let subscription_count = dict.size(state.subscriptions) 269 dict.each(state.subscriptions, fn(_id, info) { 270 // Kill the listener process 271 process.kill(info.pid) 272 // Decrement global counter 273 let _ = decrement_global_subscriptions() 274 Nil 275 }) 276 277 logging.log( 278 logging.Info, 279 "[websocket] Cleaned up " 280 <> string.inspect(subscription_count) 281 <> " subscriptions", 282 ) 283 Nil 284 }, 285 handler: handle_ws_message, 286 ) 287} 288 289/// Handle incoming WebSocket messages 290fn handle_ws_message( 291 state: State, 292 message: WebsocketMessage(websocket_ffi.SubscriptionMessage), 293 conn: WebsocketConnection, 294) { 295 case message { 296 mist.Text(text) -> { 297 handle_text_message(state, conn, text) 298 } 299 mist.Binary(_) -> { 300 logging.log( 301 logging.Warning, 302 "[websocket] Received binary message, ignoring", 303 ) 304 mist.continue(state) 305 } 306 mist.Closed | mist.Shutdown -> { 307 mist.stop() 308 } 309 mist.Custom(websocket_ffi.SubscriptionData(id, data)) -> { 310 // Handle subscription data from listener processes 311 let next_msg = ws.format_message(ws.Next(id, data)) 312 let _ = mist.send_text_frame(conn, next_msg) 313 mist.continue(state) 314 } 315 } 316} 317 318/// Handle text messages (GraphQL-WS protocol) 319fn handle_text_message(state: State, conn: WebsocketConnection, text: String) { 320 case ws.parse_message(text) { 321 Ok(ws.ConnectionInit(_payload)) -> { 322 // Send connection_ack 323 let ack_msg = ws.format_message(ws.ConnectionAck) 324 let _ = mist.send_text_frame(conn, ack_msg) 325 logging.log(logging.Info, "[websocket] Connection initialized") 326 mist.continue(state) 327 } 328 329 Ok(ws.Subscribe(id, query, variables_opt)) -> { 330 // Check per-connection subscription limit 331 let connection_count = dict.size(state.subscriptions) 332 case connection_count >= max_subscriptions_per_connection { 333 True -> { 334 logging.log( 335 logging.Warning, 336 "[websocket] Subscription limit reached for connection: " 337 <> string.inspect(connection_count), 338 ) 339 let error_msg = 340 ws.format_message(ws.ErrorMessage( 341 id, 342 "Maximum subscriptions per connection exceeded (" 343 <> string.inspect(max_subscriptions_per_connection) 344 <> ")", 345 )) 346 let _ = mist.send_text_frame(conn, error_msg) 347 mist.continue(state) 348 } 349 False -> { 350 // Check global subscription limit 351 let global_count = get_global_subscription_count() 352 case global_count >= max_subscriptions_global { 353 True -> { 354 logging.log( 355 logging.Warning, 356 "[websocket] Global subscription limit reached: " 357 <> string.inspect(global_count), 358 ) 359 let error_msg = 360 ws.format_message(ws.ErrorMessage( 361 id, 362 "Global subscription limit exceeded", 363 )) 364 let _ = mist.send_text_frame(conn, error_msg) 365 mist.continue(state) 366 } 367 False -> { 368 // Parse and validate the subscription query to extract field name 369 case extract_subscription_field(query) { 370 Error(err) -> { 371 logging.log( 372 logging.Warning, 373 "[websocket] Invalid subscription query: " <> err, 374 ) 375 let error_msg = 376 ws.format_message(ws.ErrorMessage( 377 id, 378 "Invalid subscription query: " <> err, 379 )) 380 let _ = mist.send_text_frame(conn, error_msg) 381 mist.continue(state) 382 } 383 Ok(field_name) -> { 384 // Parse variables from JSON 385 // SECURITY: Strip any client-provided viewer_did - it must come from auth token only 386 let base_variables = case variables_opt { 387 Some(vars_json) -> 388 lexicon_schema.json_string_to_variables_dict(vars_json) 389 |> dict.delete("viewer_did") 390 None -> dict.new() 391 } 392 393 // Inject viewer_did from auth token into variables 394 // This is stored in variables (not ctx.data) because ctx.data 395 // gets overwritten with parent values during field resolution 396 let variables = case state.viewer_did { 397 Some(did) -> 398 dict.insert( 399 base_variables, 400 "viewer_did", 401 value.String(did), 402 ) 403 None -> base_variables 404 } 405 406 logging.log( 407 logging.Info, 408 "[websocket] Subscription started: " 409 <> id 410 <> " (field: " 411 <> field_name 412 <> ")", 413 ) 414 415 // Spawn an unlinked process to listen for PubSub events 416 let listener_pid = 417 process.spawn_unlinked(fn() { 418 subscription_listener( 419 state.subscription_subject, 420 id, 421 query, 422 field_name, 423 variables, 424 state.db, 425 state.schema, 426 ) 427 }) 428 429 // Increment global counter 430 let _ = increment_global_subscriptions() 431 432 // Store subscription info 433 let subscription_info = 434 SubscriptionInfo(listener_pid, field_name, variables) 435 let new_subscriptions = 436 dict.insert(state.subscriptions, id, subscription_info) 437 let new_state = 438 State(..state, subscriptions: new_subscriptions) 439 440 mist.continue(new_state) 441 } 442 } 443 } 444 } 445 } 446 } 447 } 448 449 Ok(ws.Complete(id)) -> { 450 // Client wants to stop subscription 451 case dict.get(state.subscriptions, id) { 452 Ok(info) -> { 453 // Kill the listener process explicitly 454 process.kill(info.pid) 455 456 // Decrement global counter 457 let _ = decrement_global_subscriptions() 458 459 let new_subscriptions = dict.delete(state.subscriptions, id) 460 let new_state = State(..state, subscriptions: new_subscriptions) 461 462 logging.log( 463 logging.Info, 464 "[websocket] Subscription completed: " <> id, 465 ) 466 467 mist.continue(new_state) 468 } 469 Error(_) -> { 470 logging.log( 471 logging.Warning, 472 "[websocket] Complete for unknown subscription: " <> id, 473 ) 474 mist.continue(state) 475 } 476 } 477 } 478 479 Ok(ws.Ping) -> { 480 // Respond with pong 481 let pong_msg = ws.format_message(ws.Pong) 482 let _ = mist.send_text_frame(conn, pong_msg) 483 mist.continue(state) 484 } 485 486 Ok(ws.Pong) -> { 487 // Client responded to our ping, just continue 488 mist.continue(state) 489 } 490 491 Ok(_) -> { 492 // Other message types (server messages we shouldn't receive) 493 logging.log( 494 logging.Warning, 495 "[websocket] Received unexpected message type", 496 ) 497 mist.continue(state) 498 } 499 500 Error(err) -> { 501 logging.log( 502 logging.Warning, 503 "[websocket] Failed to parse message: " <> err, 504 ) 505 mist.continue(state) 506 } 507 } 508} 509 510/// Check if an event matches the subscription field 511/// Uses exact field name matching instead of string.contains 512fn event_matches_subscription( 513 event: pubsub.RecordEvent, 514 subscription_field: String, 515) -> Bool { 516 // Convert collection name to GraphQL field name format 517 let graphql_name = collection_to_graphql_name(event.collection) 518 let event_field = case event.operation { 519 pubsub.Create -> graphql_name <> "Created" 520 pubsub.Update -> graphql_name <> "Updated" 521 pubsub.Delete -> graphql_name <> "Deleted" 522 } 523 524 // Exact field name match 525 event_field == subscription_field 526} 527 528/// Check if a record event matches a notification subscription 529/// 530/// A notification event matches when: 531/// 1. The event value contains the subscribed DID (is a mention) 532/// 2. The event is NOT authored by the subscribed DID (excludes self) 533/// 3. The operation is Create (notifications are for new records only) 534/// 4. The collection matches the filter (if provided) 535pub fn event_matches_notification_subscription( 536 event: pubsub.RecordEvent, 537 subscribed_did: String, 538 collections: Option(List(String)), 539) -> Bool { 540 // Event value must contain the subscribed DID (mentioning them) 541 let contains_did = string.contains(event.value, subscribed_did) 542 543 // Event must NOT be authored by the subscribed DID (exclude self) 544 let not_self_authored = event.did != subscribed_did 545 546 // Event must be a Create operation (notifications for new records only) 547 let is_create = event.operation == pubsub.Create 548 549 // Event collection must match filter (if provided) 550 let matches_collection = case collections { 551 None -> True 552 Some([]) -> True 553 Some(cols) -> list.contains(cols, event.collection) 554 } 555 556 contains_did && not_self_authored && is_create && matches_collection 557} 558 559/// Extract a string value from variables dict 560fn get_variable_string( 561 variables: Dict(String, value.Value), 562 key: String, 563) -> Option(String) { 564 case dict.get(variables, key) { 565 Ok(value.String(s)) -> Some(s) 566 _ -> None 567 } 568} 569 570/// Extract a list of strings from variables dict (for enum list values) 571fn get_variable_string_list( 572 variables: Dict(String, value.Value), 573 key: String, 574) -> Option(List(String)) { 575 case dict.get(variables, key) { 576 Ok(value.List(items)) -> { 577 let strings = 578 list.filter_map(items, fn(item) { 579 case item { 580 value.String(s) -> Ok(s) 581 value.Enum(e) -> Ok(e) 582 _ -> Error(Nil) 583 } 584 }) 585 Some(strings) 586 } 587 _ -> None 588 } 589} 590 591/// Process an event and send it to the WebSocket client if it matches 592fn process_event( 593 event: pubsub.RecordEvent, 594 subscription_subject: process.Subject(websocket_ffi.SubscriptionMessage), 595 subscription_id: String, 596 subscription_field: String, 597 query: String, 598 variables: Dict(String, value.Value), 599 db: Executor, 600 graphql_schema: schema.Schema, 601) -> Nil { 602 // Check if this is a notification subscription 603 let matches = case subscription_field { 604 "notificationCreated" -> { 605 // For notifications, extract did and collections from variables 606 let subscribed_did = case get_variable_string(variables, "did") { 607 Some(did) -> did 608 None -> "" 609 } 610 // Convert enum values to NSIDs (APP_BSKY_FEED_LIKE -> app.bsky.feed.like) 611 let collections = case 612 get_variable_string_list(variables, "collections") 613 { 614 Some(enum_values) -> { 615 let nsids = 616 list.map(enum_values, fn(enum_val) { 617 enum_val 618 |> string.lowercase() 619 |> string.replace("_", ".") 620 }) 621 Some(nsids) 622 } 623 None -> None 624 } 625 event_matches_notification_subscription( 626 event, 627 subscribed_did, 628 collections, 629 ) 630 } 631 _ -> event_matches_subscription(event, subscription_field) 632 } 633 634 case matches { 635 True -> { 636 // Execute the GraphQL subscription query with the event data and variables 637 case 638 execute_subscription_query(query, variables, graphql_schema, event, db) 639 { 640 Ok(result_json) -> { 641 // Send message to handler via Subject 642 process.send( 643 subscription_subject, 644 websocket_ffi.SubscriptionData(subscription_id, result_json), 645 ) 646 Nil 647 } 648 Error(err) -> { 649 logging.log( 650 logging.Error, 651 "[websocket] Failed to execute subscription query: " <> err, 652 ) 653 Nil 654 } 655 } 656 } 657 False -> Nil 658 } 659} 660 661/// Event loop - processes events indefinitely 662fn event_loop( 663 selector: process.Selector(pubsub.RecordEvent), 664 subscription_subject: process.Subject(websocket_ffi.SubscriptionMessage), 665 subscription_id: String, 666 subscription_field: String, 667 query: String, 668 variables: Dict(String, value.Value), 669 db: Executor, 670 graphql_schema: schema.Schema, 671) -> Nil { 672 // Wait for next event 673 let event = process.selector_receive_forever(selector) 674 675 // Process the event 676 process_event( 677 event, 678 subscription_subject, 679 subscription_id, 680 subscription_field, 681 query, 682 variables, 683 db, 684 graphql_schema, 685 ) 686 687 // Continue listening 688 event_loop( 689 selector, 690 subscription_subject, 691 subscription_id, 692 subscription_field, 693 query, 694 variables, 695 db, 696 graphql_schema, 697 ) 698} 699 700/// Listen for PubSub events and forward them to the WebSocket client 701fn subscription_listener( 702 subscription_subject: process.Subject(websocket_ffi.SubscriptionMessage), 703 subscription_id: String, 704 query: String, 705 subscription_field: String, 706 variables: Dict(String, value.Value), 707 db: Executor, 708 graphql_schema: schema.Schema, 709) -> Nil { 710 // Subscribe to PubSub from this process 711 let my_subject = pubsub.subscribe() 712 713 // Create a selector to receive RecordEvent messages 714 let selector = 715 process.new_selector() 716 |> process.select_map(my_subject, fn(msg) { msg }) 717 718 // Start the event loop 719 event_loop( 720 selector, 721 subscription_subject, 722 subscription_id, 723 subscription_field, 724 query, 725 variables, 726 db, 727 graphql_schema, 728 ) 729}