Auto-indexing service and GraphQL API for AT Protocol Records
quickslice.slices.network/
atproto
gleam
graphql
1/// GraphQL WebSocket Handler
2///
3/// Handles WebSocket connections for GraphQL subscriptions using the graphql-ws protocol
4import atproto_auth
5import database/executor.{type Executor}
6import database/repositories/actors
7import gleam/dict.{type Dict}
8import gleam/erlang/process.{type Subject}
9import gleam/http/request.{type Request}
10import gleam/http/response
11import gleam/list
12import gleam/option.{type Option, None, Some}
13import gleam/result
14import gleam/string
15import graphql/lexicon/schema as lexicon_schema
16import graphql/ws
17import lib/oauth/did_cache
18import logging
19import mist.{
20 type Connection, type ResponseData, type WebsocketConnection,
21 type WebsocketMessage,
22}
23import pubsub
24import swell/executor as swell_executor
25import swell/parser
26import swell/schema
27import swell/value
28import websocket_ffi
29
30/// Configuration constants
31const max_subscriptions_per_connection = 100
32
33const max_subscriptions_global = 10_000
34
35/// FFI bindings for global atomic subscription counter
36@external(erlang, "subscription_counter_ffi", "increment_global")
37fn increment_global_subscriptions() -> Int
38
39@external(erlang, "subscription_counter_ffi", "decrement_global")
40fn decrement_global_subscriptions() -> Int
41
42@external(erlang, "subscription_counter_ffi", "get_global_count")
43fn get_global_subscription_count() -> Int
44
45/// Convert a PubSub RecordEvent to a GraphQL value.Value
46/// Similar to record_to_graphql_value but for subscription events
47fn event_to_graphql_value(
48 event: pubsub.RecordEvent,
49 db: Executor,
50) -> value.Value {
51 // Parse the record JSON value
52 let value_object = case lexicon_schema.parse_json_to_value(event.value) {
53 Ok(val) -> val
54 Error(_) -> value.Object([])
55 }
56
57 // Look up actor handle from actor table
58 let actor_handle = case actors.get(db, event.did) {
59 Ok([actor, ..]) -> value.String(actor.handle)
60 _ -> value.Null
61 }
62
63 // Create the full record object with metadata and value
64 value.Object([
65 #("uri", value.String(event.uri)),
66 #("cid", value.String(event.cid)),
67 #("did", value.String(event.did)),
68 #("collection", value.String(event.collection)),
69 #("indexedAt", value.String(event.indexed_at)),
70 #("actorHandle", actor_handle),
71 #("value", value_object),
72 ])
73}
74
75/// Execute a subscription query with event data and variables
76/// Returns the formatted JSON response
77fn execute_subscription_query(
78 query: String,
79 variables: Dict(String, value.Value),
80 graphql_schema: schema.Schema,
81 event: pubsub.RecordEvent,
82 db: Executor,
83) -> Result(String, String) {
84 // Convert event to GraphQL value
85 let event_value = event_to_graphql_value(event, db)
86
87 // Create context with the event data and variables
88 let ctx = schema.context_with_variables(Some(event_value), variables)
89
90 // Execute the subscription query directly
91 // The executor now natively supports subscription operations
92 use response <- result.try(swell_executor.execute(query, graphql_schema, ctx))
93
94 // Format the response as JSON
95 Ok(lexicon_schema.format_response(response))
96}
97
98/// Convert collection name to GraphQL field name format
99/// Example: "xyz.statusphere.status" -> "xyzStatusphereStatus"
100fn collection_to_graphql_name(collection: String) -> String {
101 collection
102 |> string.split(".")
103 |> list.index_map(fn(part, index) {
104 case index {
105 0 -> part
106 // Keep first segment lowercase
107 _ -> {
108 // Capitalize first letter of subsequent segments
109 case string.pop_grapheme(part) {
110 Ok(#(first, rest)) -> string.uppercase(first) <> rest
111 Error(_) -> part
112 }
113 }
114 }
115 })
116 |> string.join("")
117}
118
119/// Extract the subscription field name from a parsed GraphQL query
120/// Returns the first field name found in the subscription operation
121fn extract_subscription_field(query: String) -> Result(String, String) {
122 use document <- result.try(
123 parser.parse(query)
124 |> result.map_error(fn(_) { "Failed to parse subscription query" }),
125 )
126
127 // Find the first subscription operation
128 let subscription_op =
129 document.operations
130 |> list.find(fn(op) {
131 case op {
132 parser.Subscription(_) | parser.NamedSubscription(_, _, _) -> True
133 _ -> False
134 }
135 })
136
137 use op <- result.try(
138 subscription_op
139 |> result.replace_error("No subscription operation found in query"),
140 )
141
142 // Extract the selection set
143 let selections = case op {
144 parser.Subscription(parser.SelectionSet(sels)) -> Ok(sels)
145 parser.NamedSubscription(_, _, parser.SelectionSet(sels)) -> Ok(sels)
146 _ -> Error("Invalid subscription operation")
147 }
148
149 use sels <- result.try(selections)
150
151 // Get the first field from the selections
152 case sels {
153 [parser.Field(name, _, _, _), ..] -> Ok(name)
154 _ -> Error("No fields found in subscription")
155 }
156}
157
158/// Subscription metadata
159pub type SubscriptionInfo {
160 SubscriptionInfo(
161 pid: process.Pid,
162 field_name: String,
163 variables: Dict(String, value.Value),
164 )
165}
166
167/// State for WebSocket connection
168pub type State {
169 State(
170 db: Executor,
171 // Map of subscription ID -> subscription info
172 subscriptions: Dict(String, SubscriptionInfo),
173 // The WebSocket connection for sending frames
174 conn: WebsocketConnection,
175 // Subject for receiving subscription data
176 subscription_subject: process.Subject(websocket_ffi.SubscriptionMessage),
177 // GraphQL schema for executing subscription queries
178 schema: schema.Schema,
179 // Authenticated viewer DID (extracted from auth token)
180 viewer_did: Option(String),
181 )
182}
183
184/// Handle WebSocket connection request (called from Mist handler)
185pub fn handle_websocket(
186 req: Request(Connection),
187 db: Executor,
188 did_cache: Subject(did_cache.Message),
189 signing_key: Option(String),
190 atp_client_id: String,
191 plc_url: String,
192 domain_authority: String,
193) -> response.Response(ResponseData) {
194 // Extract auth token from request headers before WebSocket upgrade
195 let auth_token = case request.get_header(req, "authorization") {
196 Ok(auth_header) -> {
197 case string.starts_with(auth_header, "Bearer ") {
198 True -> Some(string.drop_start(auth_header, 7))
199 False -> None
200 }
201 }
202 Error(_) -> None
203 }
204
205 // Verify auth token and extract viewer DID
206 let viewer_did = case auth_token {
207 Some(token) -> {
208 case atproto_auth.verify_token(db, token) {
209 Ok(user_info) -> Some(user_info.did)
210 Error(_) -> None
211 }
212 }
213 None -> None
214 }
215
216 mist.websocket(
217 request: req,
218 on_init: fn(conn) {
219 logging.log(logging.Info, "[websocket] Client connected")
220
221 // Build GraphQL schema for subscriptions
222 let graphql_schema = case
223 lexicon_schema.build_schema_from_db(
224 db,
225 did_cache,
226 signing_key,
227 atp_client_id,
228 plc_url,
229 domain_authority,
230 )
231 {
232 Ok(schema) -> schema
233 Error(err) -> {
234 // Schema build failed - this is a critical error for subscriptions
235 logging.log(
236 logging.Error,
237 "[websocket] FATAL: Failed to build GraphQL schema: " <> err,
238 )
239 // Panic because we can't continue without a schema
240 panic as "Cannot initialize WebSocket subscriptions without GraphQL schema"
241 }
242 }
243
244 // Create a Subject for receiving subscription data
245 let subscription_subject = process.new_subject()
246
247 // Create a selector that listens to the subject
248 let selector =
249 process.new_selector()
250 |> process.select(subscription_subject)
251
252 let state =
253 State(
254 db: db,
255 subscriptions: dict.new(),
256 conn: conn,
257 subscription_subject: subscription_subject,
258 schema: graphql_schema,
259 viewer_did: viewer_did,
260 )
261
262 #(state, Some(selector))
263 },
264 on_close: fn(state) {
265 logging.log(logging.Info, "[websocket] Client disconnected")
266
267 // Clean up all active subscriptions
268 let subscription_count = dict.size(state.subscriptions)
269 dict.each(state.subscriptions, fn(_id, info) {
270 // Kill the listener process
271 process.kill(info.pid)
272 // Decrement global counter
273 let _ = decrement_global_subscriptions()
274 Nil
275 })
276
277 logging.log(
278 logging.Info,
279 "[websocket] Cleaned up "
280 <> string.inspect(subscription_count)
281 <> " subscriptions",
282 )
283 Nil
284 },
285 handler: handle_ws_message,
286 )
287}
288
289/// Handle incoming WebSocket messages
290fn handle_ws_message(
291 state: State,
292 message: WebsocketMessage(websocket_ffi.SubscriptionMessage),
293 conn: WebsocketConnection,
294) {
295 case message {
296 mist.Text(text) -> {
297 handle_text_message(state, conn, text)
298 }
299 mist.Binary(_) -> {
300 logging.log(
301 logging.Warning,
302 "[websocket] Received binary message, ignoring",
303 )
304 mist.continue(state)
305 }
306 mist.Closed | mist.Shutdown -> {
307 mist.stop()
308 }
309 mist.Custom(websocket_ffi.SubscriptionData(id, data)) -> {
310 // Handle subscription data from listener processes
311 let next_msg = ws.format_message(ws.Next(id, data))
312 let _ = mist.send_text_frame(conn, next_msg)
313 mist.continue(state)
314 }
315 }
316}
317
318/// Handle text messages (GraphQL-WS protocol)
319fn handle_text_message(state: State, conn: WebsocketConnection, text: String) {
320 case ws.parse_message(text) {
321 Ok(ws.ConnectionInit(_payload)) -> {
322 // Send connection_ack
323 let ack_msg = ws.format_message(ws.ConnectionAck)
324 let _ = mist.send_text_frame(conn, ack_msg)
325 logging.log(logging.Info, "[websocket] Connection initialized")
326 mist.continue(state)
327 }
328
329 Ok(ws.Subscribe(id, query, variables_opt)) -> {
330 // Check per-connection subscription limit
331 let connection_count = dict.size(state.subscriptions)
332 case connection_count >= max_subscriptions_per_connection {
333 True -> {
334 logging.log(
335 logging.Warning,
336 "[websocket] Subscription limit reached for connection: "
337 <> string.inspect(connection_count),
338 )
339 let error_msg =
340 ws.format_message(ws.ErrorMessage(
341 id,
342 "Maximum subscriptions per connection exceeded ("
343 <> string.inspect(max_subscriptions_per_connection)
344 <> ")",
345 ))
346 let _ = mist.send_text_frame(conn, error_msg)
347 mist.continue(state)
348 }
349 False -> {
350 // Check global subscription limit
351 let global_count = get_global_subscription_count()
352 case global_count >= max_subscriptions_global {
353 True -> {
354 logging.log(
355 logging.Warning,
356 "[websocket] Global subscription limit reached: "
357 <> string.inspect(global_count),
358 )
359 let error_msg =
360 ws.format_message(ws.ErrorMessage(
361 id,
362 "Global subscription limit exceeded",
363 ))
364 let _ = mist.send_text_frame(conn, error_msg)
365 mist.continue(state)
366 }
367 False -> {
368 // Parse and validate the subscription query to extract field name
369 case extract_subscription_field(query) {
370 Error(err) -> {
371 logging.log(
372 logging.Warning,
373 "[websocket] Invalid subscription query: " <> err,
374 )
375 let error_msg =
376 ws.format_message(ws.ErrorMessage(
377 id,
378 "Invalid subscription query: " <> err,
379 ))
380 let _ = mist.send_text_frame(conn, error_msg)
381 mist.continue(state)
382 }
383 Ok(field_name) -> {
384 // Parse variables from JSON
385 // SECURITY: Strip any client-provided viewer_did - it must come from auth token only
386 let base_variables = case variables_opt {
387 Some(vars_json) ->
388 lexicon_schema.json_string_to_variables_dict(vars_json)
389 |> dict.delete("viewer_did")
390 None -> dict.new()
391 }
392
393 // Inject viewer_did from auth token into variables
394 // This is stored in variables (not ctx.data) because ctx.data
395 // gets overwritten with parent values during field resolution
396 let variables = case state.viewer_did {
397 Some(did) ->
398 dict.insert(
399 base_variables,
400 "viewer_did",
401 value.String(did),
402 )
403 None -> base_variables
404 }
405
406 logging.log(
407 logging.Info,
408 "[websocket] Subscription started: "
409 <> id
410 <> " (field: "
411 <> field_name
412 <> ")",
413 )
414
415 // Spawn an unlinked process to listen for PubSub events
416 let listener_pid =
417 process.spawn_unlinked(fn() {
418 subscription_listener(
419 state.subscription_subject,
420 id,
421 query,
422 field_name,
423 variables,
424 state.db,
425 state.schema,
426 )
427 })
428
429 // Increment global counter
430 let _ = increment_global_subscriptions()
431
432 // Store subscription info
433 let subscription_info =
434 SubscriptionInfo(listener_pid, field_name, variables)
435 let new_subscriptions =
436 dict.insert(state.subscriptions, id, subscription_info)
437 let new_state =
438 State(..state, subscriptions: new_subscriptions)
439
440 mist.continue(new_state)
441 }
442 }
443 }
444 }
445 }
446 }
447 }
448
449 Ok(ws.Complete(id)) -> {
450 // Client wants to stop subscription
451 case dict.get(state.subscriptions, id) {
452 Ok(info) -> {
453 // Kill the listener process explicitly
454 process.kill(info.pid)
455
456 // Decrement global counter
457 let _ = decrement_global_subscriptions()
458
459 let new_subscriptions = dict.delete(state.subscriptions, id)
460 let new_state = State(..state, subscriptions: new_subscriptions)
461
462 logging.log(
463 logging.Info,
464 "[websocket] Subscription completed: " <> id,
465 )
466
467 mist.continue(new_state)
468 }
469 Error(_) -> {
470 logging.log(
471 logging.Warning,
472 "[websocket] Complete for unknown subscription: " <> id,
473 )
474 mist.continue(state)
475 }
476 }
477 }
478
479 Ok(ws.Ping) -> {
480 // Respond with pong
481 let pong_msg = ws.format_message(ws.Pong)
482 let _ = mist.send_text_frame(conn, pong_msg)
483 mist.continue(state)
484 }
485
486 Ok(ws.Pong) -> {
487 // Client responded to our ping, just continue
488 mist.continue(state)
489 }
490
491 Ok(_) -> {
492 // Other message types (server messages we shouldn't receive)
493 logging.log(
494 logging.Warning,
495 "[websocket] Received unexpected message type",
496 )
497 mist.continue(state)
498 }
499
500 Error(err) -> {
501 logging.log(
502 logging.Warning,
503 "[websocket] Failed to parse message: " <> err,
504 )
505 mist.continue(state)
506 }
507 }
508}
509
510/// Check if an event matches the subscription field
511/// Uses exact field name matching instead of string.contains
512fn event_matches_subscription(
513 event: pubsub.RecordEvent,
514 subscription_field: String,
515) -> Bool {
516 // Convert collection name to GraphQL field name format
517 let graphql_name = collection_to_graphql_name(event.collection)
518 let event_field = case event.operation {
519 pubsub.Create -> graphql_name <> "Created"
520 pubsub.Update -> graphql_name <> "Updated"
521 pubsub.Delete -> graphql_name <> "Deleted"
522 }
523
524 // Exact field name match
525 event_field == subscription_field
526}
527
528/// Check if a record event matches a notification subscription
529///
530/// A notification event matches when:
531/// 1. The event value contains the subscribed DID (is a mention)
532/// 2. The event is NOT authored by the subscribed DID (excludes self)
533/// 3. The operation is Create (notifications are for new records only)
534/// 4. The collection matches the filter (if provided)
535pub fn event_matches_notification_subscription(
536 event: pubsub.RecordEvent,
537 subscribed_did: String,
538 collections: Option(List(String)),
539) -> Bool {
540 // Event value must contain the subscribed DID (mentioning them)
541 let contains_did = string.contains(event.value, subscribed_did)
542
543 // Event must NOT be authored by the subscribed DID (exclude self)
544 let not_self_authored = event.did != subscribed_did
545
546 // Event must be a Create operation (notifications for new records only)
547 let is_create = event.operation == pubsub.Create
548
549 // Event collection must match filter (if provided)
550 let matches_collection = case collections {
551 None -> True
552 Some([]) -> True
553 Some(cols) -> list.contains(cols, event.collection)
554 }
555
556 contains_did && not_self_authored && is_create && matches_collection
557}
558
559/// Extract a string value from variables dict
560fn get_variable_string(
561 variables: Dict(String, value.Value),
562 key: String,
563) -> Option(String) {
564 case dict.get(variables, key) {
565 Ok(value.String(s)) -> Some(s)
566 _ -> None
567 }
568}
569
570/// Extract a list of strings from variables dict (for enum list values)
571fn get_variable_string_list(
572 variables: Dict(String, value.Value),
573 key: String,
574) -> Option(List(String)) {
575 case dict.get(variables, key) {
576 Ok(value.List(items)) -> {
577 let strings =
578 list.filter_map(items, fn(item) {
579 case item {
580 value.String(s) -> Ok(s)
581 value.Enum(e) -> Ok(e)
582 _ -> Error(Nil)
583 }
584 })
585 Some(strings)
586 }
587 _ -> None
588 }
589}
590
591/// Process an event and send it to the WebSocket client if it matches
592fn process_event(
593 event: pubsub.RecordEvent,
594 subscription_subject: process.Subject(websocket_ffi.SubscriptionMessage),
595 subscription_id: String,
596 subscription_field: String,
597 query: String,
598 variables: Dict(String, value.Value),
599 db: Executor,
600 graphql_schema: schema.Schema,
601) -> Nil {
602 // Check if this is a notification subscription
603 let matches = case subscription_field {
604 "notificationCreated" -> {
605 // For notifications, extract did and collections from variables
606 let subscribed_did = case get_variable_string(variables, "did") {
607 Some(did) -> did
608 None -> ""
609 }
610 // Convert enum values to NSIDs (APP_BSKY_FEED_LIKE -> app.bsky.feed.like)
611 let collections = case
612 get_variable_string_list(variables, "collections")
613 {
614 Some(enum_values) -> {
615 let nsids =
616 list.map(enum_values, fn(enum_val) {
617 enum_val
618 |> string.lowercase()
619 |> string.replace("_", ".")
620 })
621 Some(nsids)
622 }
623 None -> None
624 }
625 event_matches_notification_subscription(
626 event,
627 subscribed_did,
628 collections,
629 )
630 }
631 _ -> event_matches_subscription(event, subscription_field)
632 }
633
634 case matches {
635 True -> {
636 // Execute the GraphQL subscription query with the event data and variables
637 case
638 execute_subscription_query(query, variables, graphql_schema, event, db)
639 {
640 Ok(result_json) -> {
641 // Send message to handler via Subject
642 process.send(
643 subscription_subject,
644 websocket_ffi.SubscriptionData(subscription_id, result_json),
645 )
646 Nil
647 }
648 Error(err) -> {
649 logging.log(
650 logging.Error,
651 "[websocket] Failed to execute subscription query: " <> err,
652 )
653 Nil
654 }
655 }
656 }
657 False -> Nil
658 }
659}
660
661/// Event loop - processes events indefinitely
662fn event_loop(
663 selector: process.Selector(pubsub.RecordEvent),
664 subscription_subject: process.Subject(websocket_ffi.SubscriptionMessage),
665 subscription_id: String,
666 subscription_field: String,
667 query: String,
668 variables: Dict(String, value.Value),
669 db: Executor,
670 graphql_schema: schema.Schema,
671) -> Nil {
672 // Wait for next event
673 let event = process.selector_receive_forever(selector)
674
675 // Process the event
676 process_event(
677 event,
678 subscription_subject,
679 subscription_id,
680 subscription_field,
681 query,
682 variables,
683 db,
684 graphql_schema,
685 )
686
687 // Continue listening
688 event_loop(
689 selector,
690 subscription_subject,
691 subscription_id,
692 subscription_field,
693 query,
694 variables,
695 db,
696 graphql_schema,
697 )
698}
699
700/// Listen for PubSub events and forward them to the WebSocket client
701fn subscription_listener(
702 subscription_subject: process.Subject(websocket_ffi.SubscriptionMessage),
703 subscription_id: String,
704 query: String,
705 subscription_field: String,
706 variables: Dict(String, value.Value),
707 db: Executor,
708 graphql_schema: schema.Schema,
709) -> Nil {
710 // Subscribe to PubSub from this process
711 let my_subject = pubsub.subscribe()
712
713 // Create a selector to receive RecordEvent messages
714 let selector =
715 process.new_selector()
716 |> process.select_map(my_subject, fn(msg) { msg })
717
718 // Start the event loop
719 event_loop(
720 selector,
721 subscription_subject,
722 subscription_id,
723 subscription_field,
724 query,
725 variables,
726 db,
727 graphql_schema,
728 )
729}