A better Rust ATProto crate
1use crate::codegen::nsid_utils::{NsidPath, RefPath};
2use crate::corpus::LexiconCorpus;
3use crate::error::Result;
4use heck::ToPascalCase;
5use jacquard_common::CowStr;
6use proc_macro2::TokenStream;
7use quote::quote;
8use std::cell::RefCell;
9use std::collections::{HashMap, HashSet};
10
11/// Information about a union variant
12#[derive(Debug, Clone)]
13pub struct UnionVariant {
14 /// The original ref string (normalized)
15 pub ref_str: String,
16 /// The variant name (may be disambiguated)
17 pub variant_name: String,
18 /// The Rust type for this variant
19 pub rust_type: TokenStream,
20}
21
22/// Context for tracking namespace dependencies during union generation
23pub struct UnionGenContext<'a> {
24 pub corpus: &'a LexiconCorpus,
25 pub namespace_deps: &'a RefCell<HashMap<String, HashSet<String>>>,
26 pub current_nsid: &'a str,
27}
28
29impl<'a> UnionGenContext<'a> {
30 /// Build variants for a union with collision detection and disambiguation
31 pub fn build_union_variants(
32 &self,
33 refs: &[CowStr<'static>],
34 ref_to_rust_type: impl Fn(&str) -> Result<TokenStream>,
35 ) -> Result<Vec<UnionVariant>> {
36 let current_nsid_path = NsidPath::parse(self.current_nsid);
37 let current_namespace = current_nsid_path.namespace();
38
39 // First pass: collect all variant names and detect collisions
40 #[derive(Debug)]
41 struct VariantInfo {
42 ref_str: String,
43 ref_nsid: String,
44 simple_name: String,
45 is_current_namespace: bool,
46 }
47
48 let mut variant_infos = Vec::new();
49 for ref_str in refs {
50 let normalized_ref = RefPath::normalize(ref_str, self.current_nsid);
51 let ref_path = RefPath::parse(&normalized_ref, None);
52 let ref_nsid_str = ref_path.nsid();
53 let ref_def = ref_path.def();
54
55 // Skip unknown refs
56 if !self.corpus.ref_exists(&normalized_ref) {
57 continue;
58 }
59
60 let is_current_namespace = ref_nsid_str.starts_with(¤t_namespace);
61 let is_same_module = ref_nsid_str == self.current_nsid;
62
63 // Generate simple variant name
64 let last_segment = ref_nsid_str.split('.').last().unwrap();
65 let simple_name = if ref_def == "main" {
66 last_segment.to_pascal_case()
67 } else if last_segment == "defs" {
68 ref_def.to_pascal_case()
69 } else if is_same_module {
70 ref_def.to_pascal_case()
71 } else {
72 format!(
73 "{}{}",
74 last_segment.to_pascal_case(),
75 ref_def.to_pascal_case()
76 )
77 };
78
79 variant_infos.push(VariantInfo {
80 ref_str: normalized_ref.clone(),
81 ref_nsid: ref_nsid_str.to_string(),
82 simple_name,
83 is_current_namespace,
84 });
85 }
86
87 // Second pass: detect collisions and disambiguate
88 let mut name_counts: HashMap<String, usize> = HashMap::new();
89 for info in &variant_infos {
90 *name_counts.entry(info.simple_name.clone()).or_insert(0) += 1;
91 }
92
93 let mut variants = Vec::new();
94 for info in variant_infos {
95 let has_collision = name_counts.get(&info.simple_name).copied().unwrap_or(0) > 1;
96
97 // Track namespace dependency for foreign refs
98 if !info.is_current_namespace {
99 let ref_nsid_path = NsidPath::parse(&info.ref_nsid);
100 let foreign_namespace = ref_nsid_path.namespace();
101 self.namespace_deps
102 .borrow_mut()
103 .entry(current_namespace.clone())
104 .or_default()
105 .insert(foreign_namespace);
106 }
107
108 // Disambiguate: add second NSID segment prefix only to foreign refs when there's a collision
109 let variant_name = if has_collision && !info.is_current_namespace {
110 let ref_nsid_path = NsidPath::parse(&info.ref_nsid);
111 let segments = ref_nsid_path.segments();
112 let prefix = if segments.len() >= 2 {
113 segments[1].to_pascal_case()
114 } else {
115 segments[0].to_pascal_case()
116 };
117 format!("{}{}", prefix, info.simple_name)
118 } else {
119 info.simple_name.clone()
120 };
121
122 let rust_type = ref_to_rust_type(&info.ref_str)?;
123
124 variants.push(UnionVariant {
125 ref_str: info.ref_str,
126 variant_name,
127 rust_type,
128 });
129 }
130
131 Ok(variants)
132 }
133
134 /// Build variants for a union without collision detection (simple mode)
135 pub fn build_simple_union_variants(
136 &self,
137 refs: &[CowStr<'static>],
138 ref_to_rust_type: impl Fn(&str) -> Result<TokenStream>,
139 ) -> Result<Vec<UnionVariant>> {
140 let mut variants = Vec::new();
141
142 for ref_str in refs {
143 let ref_str_s = ref_str.as_ref();
144 let normalized_ref = RefPath::normalize(ref_str, self.current_nsid);
145 let ref_path = RefPath::parse(&normalized_ref, None);
146 let ref_nsid = ref_path.nsid();
147 let ref_def = ref_path.def();
148
149 let variant_name = if ref_def == "main" {
150 let ref_nsid_path = NsidPath::parse(ref_nsid);
151 ref_nsid_path.last_segment().to_pascal_case()
152 } else {
153 ref_def.to_pascal_case()
154 };
155
156 let rust_type = ref_to_rust_type(&normalized_ref)?;
157
158 variants.push(UnionVariant {
159 ref_str: ref_str_s.to_string(),
160 variant_name,
161 rust_type,
162 });
163 }
164
165 Ok(variants)
166 }
167}
168
169/// Generate variant tokens for a union enum
170pub fn generate_variant_tokens(variants: &[UnionVariant]) -> Vec<TokenStream> {
171 variants
172 .iter()
173 .map(|variant| {
174 let variant_ident = syn::Ident::new(&variant.variant_name, proc_macro2::Span::call_site());
175 let ref_str_literal = &variant.ref_str;
176 let rust_type = &variant.rust_type;
177
178 quote! {
179 #[serde(rename = #ref_str_literal)]
180 #variant_ident(Box<#rust_type>)
181 }
182 })
183 .collect()
184}