A better Rust ATProto crate
at main 344 lines 14 kB view raw
1use crate::error::{CodegenError, Result}; 2use proc_macro2::TokenStream; 3use quote::quote; 4use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; 5 6use super::nsid_utils::NsidPath; 7use super::utils::{make_ident, sanitize_name}; 8use super::CodeGenerator; 9 10impl<'c> CodeGenerator<'c> { 11 /// Generate all code for the corpus, organized by file 12 /// Returns a map of file paths to (tokens, optional NSID) 13 pub fn generate_all( 14 &self, 15 ) -> Result<BTreeMap<std::path::PathBuf, (TokenStream, Option<String>)>> { 16 let mut file_contents: BTreeMap<std::path::PathBuf, Vec<TokenStream>> = BTreeMap::new(); 17 let mut file_nsids: BTreeMap<std::path::PathBuf, String> = BTreeMap::new(); 18 19 // Generate code for all lexicons 20 for (nsid, doc) in self.corpus.iter() { 21 let file_path = self.nsid_to_file_path(nsid.as_ref()); 22 23 // Track which NSID this file is for 24 file_nsids.insert(file_path.clone(), nsid.to_string()); 25 26 for (def_name, def) in &doc.defs { 27 let tokens = self.generate_def(nsid.as_ref(), def_name.as_ref(), def)?; 28 file_contents 29 .entry(file_path.clone()) 30 .or_default() 31 .push(tokens); 32 } 33 } 34 35 // Combine all tokens for each file 36 let mut result = BTreeMap::new(); 37 for (path, tokens_vec) in file_contents { 38 let nsid = file_nsids.get(&path).cloned(); 39 result.insert(path, (quote! { #(#tokens_vec)* }, nsid)); 40 } 41 42 Ok(result) 43 } 44 45 /// Generate parent module files with pub mod declarations 46 pub fn generate_module_tree( 47 &self, 48 file_map: &BTreeMap<std::path::PathBuf, (TokenStream, Option<String>)>, 49 defs_only: &BTreeMap<std::path::PathBuf, (TokenStream, Option<String>)>, 50 subscription_files: &HashSet<std::path::PathBuf>, 51 ) -> BTreeMap<std::path::PathBuf, (TokenStream, Option<String>)> { 52 // Track what modules each directory needs to declare 53 // Key: directory path, Value: set of module names (file stems) 54 let mut dir_modules: BTreeMap<std::path::PathBuf, BTreeSet<String>> = BTreeMap::new(); 55 56 // Collect all parent directories that have files 57 let mut all_dirs: BTreeSet<std::path::PathBuf> = BTreeSet::new(); 58 for path in file_map.keys() { 59 if let Some(parent_dir) = path.parent() { 60 all_dirs.insert(parent_dir.to_path_buf()); 61 } 62 } 63 64 for path in file_map.keys() { 65 if let Some(parent_dir) = path.parent() { 66 if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) { 67 // Skip mod.rs and lib.rs - they're module files, not modules to declare 68 if file_stem == "mod" || file_stem == "lib" { 69 continue; 70 } 71 72 // Always add the module declaration to parent 73 dir_modules 74 .entry(parent_dir.to_path_buf()) 75 .or_default() 76 .insert(file_stem.to_string()); 77 } 78 } 79 } 80 81 // Generate module files 82 let mut result = BTreeMap::new(); 83 84 for (dir, module_names) in dir_modules { 85 let mod_file_path = if dir.components().count() == 0 { 86 // Root directory -> lib.rs for library crates 87 std::path::PathBuf::from("lib.rs") 88 } else { 89 // Subdirectory: app_bsky/feed -> app_bsky/feed.rs (Rust 2018 style) 90 let dir_name = dir.file_name().and_then(|s| s.to_str()).unwrap_or("mod"); 91 let sanitized_dir_name = sanitize_name(dir_name); 92 let mut path = dir 93 .parent() 94 .unwrap_or_else(|| std::path::Path::new("")) 95 .to_path_buf(); 96 path.push(format!("{}.rs", sanitized_dir_name)); 97 path 98 }; 99 100 let is_root = dir.components().count() == 0; 101 let mods: Vec<_> = module_names 102 .iter() 103 .map(|name| { 104 let ident = make_ident(name); 105 106 // Check if this module is a subscription endpoint 107 let mut module_path = dir.clone(); 108 module_path.push(format!("{}.rs", name)); 109 let is_subscription = subscription_files.contains(&module_path); 110 111 if is_root && name != "builder_types" { 112 // Top-level modules get feature gates (except builder_types which is always needed) 113 quote! { 114 #[cfg(feature = #name)] 115 pub mod #ident; 116 } 117 } else if is_subscription { 118 // Subscription modules get streaming feature gate 119 quote! { 120 #[cfg(feature = "streaming")] 121 pub mod #ident; 122 } 123 } else { 124 quote! { pub mod #ident; } 125 } 126 }) 127 .collect(); 128 129 // If this file already exists in defs_only (e.g., from defs), merge the content 130 let module_tokens = if is_root { 131 // lib.rs needs extern crate alloc for no_std compatibility 132 quote! { extern crate alloc; #(#mods)* } 133 } else { 134 quote! { #(#mods)* } 135 }; 136 if let Some((existing_tokens, nsid)) = defs_only.get(&mod_file_path) { 137 // Put module declarations FIRST, then existing defs content 138 result.insert( 139 mod_file_path, 140 (quote! { #module_tokens #existing_tokens }, nsid.clone()), 141 ); 142 } else { 143 result.insert(mod_file_path, (module_tokens, None)); 144 } 145 } 146 147 result 148 } 149 150 /// Write all generated code to disk 151 pub fn write_to_disk(&self, output_dir: &std::path::Path) -> Result<()> { 152 // Generate all code (defs only) 153 let defs_files = self.generate_all()?; 154 let mut all_files = defs_files.clone(); 155 156 // Generate common builder types (Set, Unset, IsSet, IsUnset) 157 let common_types_path = std::path::PathBuf::from("builder_types.rs"); 158 let common_types_tokens = super::builder_gen::common::generate_common_types(); 159 all_files.insert(common_types_path, (common_types_tokens, None)); 160 161 // Get subscription files for feature gating 162 let subscription_files = self.subscription_files.borrow(); 163 164 // Generate module tree iteratively until no new files appear 165 loop { 166 let module_map = self.generate_module_tree(&all_files, &defs_files, &subscription_files); 167 let old_count = all_files.len(); 168 169 // Merge new module files 170 for (path, tokens) in module_map { 171 all_files.insert(path, tokens); 172 } 173 174 if all_files.len() == old_count { 175 // No new files added 176 break; 177 } 178 } 179 180 // Write to disk 181 for (path, (tokens, nsid)) in all_files { 182 let full_path = output_dir.join(&path); 183 184 // Create parent directories 185 if let Some(parent) = full_path.parent() { 186 std::fs::create_dir_all(parent)?; 187 } 188 189 // Format code 190 let file: syn::File = syn::parse2(tokens.clone()).map_err(|e| CodegenError::TokenParseError { 191 path: path.clone(), 192 source: e, 193 tokens: tokens.to_string(), 194 })?; 195 let mut formatted = prettyplease::unparse(&file); 196 197 // Add blank lines between top-level items for better readability 198 let lines: Vec<&str> = formatted.lines().collect(); 199 let mut result_lines = Vec::new(); 200 201 for (i, line) in lines.iter().enumerate() { 202 result_lines.push(*line); 203 204 // Add blank line after closing braces that are at column 0 (top-level items) 205 if *line == "}" && i + 1 < lines.len() && !lines[i + 1].is_empty() { 206 result_lines.push(""); 207 } 208 209 // Add blank line after last pub mod declaration before structs/enums 210 if line.starts_with("pub mod ") && i + 1 < lines.len() { 211 let next_line = lines[i + 1]; 212 if !next_line.starts_with("pub mod ") && !next_line.is_empty() { 213 result_lines.push(""); 214 } 215 } 216 } 217 218 formatted = result_lines.join("\n"); 219 220 // Add header comment 221 let header = if let Some(nsid) = nsid { 222 format!( 223 "// @generated by jacquard-lexicon. DO NOT EDIT.\n//\n// Lexicon: {}\n//\n// This file was automatically generated from Lexicon schemas.\n// Any manual changes will be overwritten on the next regeneration.\n\n", 224 nsid 225 ) 226 } else { 227 "// @generated by jacquard-lexicon. DO NOT EDIT.\n//\n// This file was automatically generated from Lexicon schemas.\n// Any manual changes will be overwritten on the next regeneration.\n\n".to_string() 228 }; 229 formatted = format!("{}{}", header, formatted); 230 231 // Write file 232 std::fs::write(&full_path, formatted)?; 233 } 234 235 Ok(()) 236 } 237 238 /// Get namespace dependencies collected during code generation 239 pub fn get_namespace_dependencies( 240 &self, 241 ) -> HashMap<String, HashSet<String>> { 242 self.namespace_deps.borrow().clone() 243 } 244 245 /// Generate Cargo.toml features section from namespace dependencies 246 pub fn generate_cargo_features(&self, lib_rs_path: Option<&std::path::Path>) -> String { 247 use std::fmt::Write; 248 249 let deps = self.namespace_deps.borrow(); 250 let mut all_namespaces: HashSet<String> = 251 HashSet::new(); 252 253 // Collect all namespaces from the corpus (first two segments of each NSID) 254 for (nsid, _doc) in self.corpus.iter() { 255 let nsid_path = NsidPath::parse(nsid.as_str()); 256 let namespace = nsid_path.namespace(); 257 all_namespaces.insert(namespace); 258 } 259 260 // Also collect existing feature names from lib.rs 261 let mut existing_features = HashSet::new(); 262 if let Some(lib_rs) = lib_rs_path { 263 if let Ok(content) = std::fs::read_to_string(lib_rs) { 264 for line in content.lines() { 265 if let Some(feature) = line 266 .trim() 267 .strip_prefix("#[cfg(feature = \"") 268 .and_then(|s| s.strip_suffix("\")]")) 269 { 270 existing_features.insert(feature.to_string()); 271 } 272 } 273 } 274 } 275 276 let mut output = String::new(); 277 writeln!(&mut output, "# Generated namespace features").unwrap(); 278 279 // Convert namespace to feature name (matching module path sanitization) 280 let to_feature_name = |ns: &str| { 281 ns.split('.') 282 .map(|segment| { 283 // Apply same sanitization as module names 284 let mut result = segment.replace('-', "_"); 285 // Prefix with underscore if starts with digit 286 if result.chars().next().map_or(false, |c| c.is_ascii_digit()) { 287 result.insert(0, '_'); 288 } 289 result 290 }) 291 .collect::<Vec<_>>() 292 .join("_") 293 }; 294 295 // Collect all feature names (from corpus + existing lib.rs) 296 let mut all_feature_names = HashSet::new(); 297 for ns in &all_namespaces { 298 all_feature_names.insert(to_feature_name(ns)); 299 } 300 all_feature_names.extend(existing_features); 301 302 // Sort for consistent output 303 let mut feature_names: Vec<_> = all_feature_names.iter().collect(); 304 feature_names.sort(); 305 306 // Map namespace to feature name for dependency lookup 307 let mut ns_to_feature: HashMap<&str, String> = 308 HashMap::new(); 309 for ns in &all_namespaces { 310 ns_to_feature.insert(ns.as_str(), to_feature_name(ns)); 311 } 312 313 for feature_name in feature_names { 314 // Find corresponding namespace for this feature (if any) to look up deps 315 let feature_deps: Vec<String> = all_namespaces 316 .iter() 317 .find(|ns| to_feature_name(ns) == *feature_name) 318 .and_then(|ns| deps.get(ns.as_str())) 319 .map(|ns_deps| { 320 let mut dep_features: Vec<_> = ns_deps 321 .iter() 322 .map(|d| format!("\"{}\"", to_feature_name(d))) 323 .collect(); 324 dep_features.sort(); 325 dep_features 326 }) 327 .unwrap_or_default(); 328 329 if !feature_deps.is_empty() { 330 writeln!( 331 &mut output, 332 "{} = [{}]", 333 feature_name, 334 feature_deps.join(", ") 335 ) 336 .unwrap(); 337 } else { 338 writeln!(&mut output, "{} = []", feature_name).unwrap(); 339 } 340 } 341 342 output 343 } 344}