A better Rust ATProto crate
at main 184 lines 6.4 kB view raw
1use crate::codegen::nsid_utils::{NsidPath, RefPath}; 2use crate::corpus::LexiconCorpus; 3use crate::error::Result; 4use heck::ToPascalCase; 5use jacquard_common::CowStr; 6use proc_macro2::TokenStream; 7use quote::quote; 8use std::cell::RefCell; 9use std::collections::{HashMap, HashSet}; 10 11/// Information about a union variant 12#[derive(Debug, Clone)] 13pub struct UnionVariant { 14 /// The original ref string (normalized) 15 pub ref_str: String, 16 /// The variant name (may be disambiguated) 17 pub variant_name: String, 18 /// The Rust type for this variant 19 pub rust_type: TokenStream, 20} 21 22/// Context for tracking namespace dependencies during union generation 23pub struct UnionGenContext<'a> { 24 pub corpus: &'a LexiconCorpus, 25 pub namespace_deps: &'a RefCell<HashMap<String, HashSet<String>>>, 26 pub current_nsid: &'a str, 27} 28 29impl<'a> UnionGenContext<'a> { 30 /// Build variants for a union with collision detection and disambiguation 31 pub fn build_union_variants( 32 &self, 33 refs: &[CowStr<'static>], 34 ref_to_rust_type: impl Fn(&str) -> Result<TokenStream>, 35 ) -> Result<Vec<UnionVariant>> { 36 let current_nsid_path = NsidPath::parse(self.current_nsid); 37 let current_namespace = current_nsid_path.namespace(); 38 39 // First pass: collect all variant names and detect collisions 40 #[derive(Debug)] 41 struct VariantInfo { 42 ref_str: String, 43 ref_nsid: String, 44 simple_name: String, 45 is_current_namespace: bool, 46 } 47 48 let mut variant_infos = Vec::new(); 49 for ref_str in refs { 50 let normalized_ref = RefPath::normalize(ref_str, self.current_nsid); 51 let ref_path = RefPath::parse(&normalized_ref, None); 52 let ref_nsid_str = ref_path.nsid(); 53 let ref_def = ref_path.def(); 54 55 // Skip unknown refs 56 if !self.corpus.ref_exists(&normalized_ref) { 57 continue; 58 } 59 60 let is_current_namespace = ref_nsid_str.starts_with(&current_namespace); 61 let is_same_module = ref_nsid_str == self.current_nsid; 62 63 // Generate simple variant name 64 let last_segment = ref_nsid_str.split('.').last().unwrap(); 65 let simple_name = if ref_def == "main" { 66 last_segment.to_pascal_case() 67 } else if last_segment == "defs" { 68 ref_def.to_pascal_case() 69 } else if is_same_module { 70 ref_def.to_pascal_case() 71 } else { 72 format!( 73 "{}{}", 74 last_segment.to_pascal_case(), 75 ref_def.to_pascal_case() 76 ) 77 }; 78 79 variant_infos.push(VariantInfo { 80 ref_str: normalized_ref.clone(), 81 ref_nsid: ref_nsid_str.to_string(), 82 simple_name, 83 is_current_namespace, 84 }); 85 } 86 87 // Second pass: detect collisions and disambiguate 88 let mut name_counts: HashMap<String, usize> = HashMap::new(); 89 for info in &variant_infos { 90 *name_counts.entry(info.simple_name.clone()).or_insert(0) += 1; 91 } 92 93 let mut variants = Vec::new(); 94 for info in variant_infos { 95 let has_collision = name_counts.get(&info.simple_name).copied().unwrap_or(0) > 1; 96 97 // Track namespace dependency for foreign refs 98 if !info.is_current_namespace { 99 let ref_nsid_path = NsidPath::parse(&info.ref_nsid); 100 let foreign_namespace = ref_nsid_path.namespace(); 101 self.namespace_deps 102 .borrow_mut() 103 .entry(current_namespace.clone()) 104 .or_default() 105 .insert(foreign_namespace); 106 } 107 108 // Disambiguate: add second NSID segment prefix only to foreign refs when there's a collision 109 let variant_name = if has_collision && !info.is_current_namespace { 110 let ref_nsid_path = NsidPath::parse(&info.ref_nsid); 111 let segments = ref_nsid_path.segments(); 112 let prefix = if segments.len() >= 2 { 113 segments[1].to_pascal_case() 114 } else { 115 segments[0].to_pascal_case() 116 }; 117 format!("{}{}", prefix, info.simple_name) 118 } else { 119 info.simple_name.clone() 120 }; 121 122 let rust_type = ref_to_rust_type(&info.ref_str)?; 123 124 variants.push(UnionVariant { 125 ref_str: info.ref_str, 126 variant_name, 127 rust_type, 128 }); 129 } 130 131 Ok(variants) 132 } 133 134 /// Build variants for a union without collision detection (simple mode) 135 pub fn build_simple_union_variants( 136 &self, 137 refs: &[CowStr<'static>], 138 ref_to_rust_type: impl Fn(&str) -> Result<TokenStream>, 139 ) -> Result<Vec<UnionVariant>> { 140 let mut variants = Vec::new(); 141 142 for ref_str in refs { 143 let ref_str_s = ref_str.as_ref(); 144 let normalized_ref = RefPath::normalize(ref_str, self.current_nsid); 145 let ref_path = RefPath::parse(&normalized_ref, None); 146 let ref_nsid = ref_path.nsid(); 147 let ref_def = ref_path.def(); 148 149 let variant_name = if ref_def == "main" { 150 let ref_nsid_path = NsidPath::parse(ref_nsid); 151 ref_nsid_path.last_segment().to_pascal_case() 152 } else { 153 ref_def.to_pascal_case() 154 }; 155 156 let rust_type = ref_to_rust_type(&normalized_ref)?; 157 158 variants.push(UnionVariant { 159 ref_str: ref_str_s.to_string(), 160 variant_name, 161 rust_type, 162 }); 163 } 164 165 Ok(variants) 166 } 167} 168 169/// Generate variant tokens for a union enum 170pub fn generate_variant_tokens(variants: &[UnionVariant]) -> Vec<TokenStream> { 171 variants 172 .iter() 173 .map(|variant| { 174 let variant_ident = syn::Ident::new(&variant.variant_name, proc_macro2::Span::call_site()); 175 let ref_str_literal = &variant.ref_str; 176 let rust_type = &variant.rust_type; 177 178 quote! { 179 #[serde(rename = #ref_str_literal)] 180 #variant_ident(Box<#rust_type>) 181 } 182 }) 183 .collect() 184}