A better Rust ATProto crate
1use crate::error::{CodegenError, Result};
2use proc_macro2::TokenStream;
3use quote::quote;
4use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
5
6use super::nsid_utils::NsidPath;
7use super::utils::{make_ident, sanitize_name};
8use super::CodeGenerator;
9
10impl<'c> CodeGenerator<'c> {
11 /// Generate all code for the corpus, organized by file
12 /// Returns a map of file paths to (tokens, optional NSID)
13 pub fn generate_all(
14 &self,
15 ) -> Result<BTreeMap<std::path::PathBuf, (TokenStream, Option<String>)>> {
16 let mut file_contents: BTreeMap<std::path::PathBuf, Vec<TokenStream>> = BTreeMap::new();
17 let mut file_nsids: BTreeMap<std::path::PathBuf, String> = BTreeMap::new();
18
19 // Generate code for all lexicons
20 for (nsid, doc) in self.corpus.iter() {
21 let file_path = self.nsid_to_file_path(nsid.as_ref());
22
23 // Track which NSID this file is for
24 file_nsids.insert(file_path.clone(), nsid.to_string());
25
26 for (def_name, def) in &doc.defs {
27 let tokens = self.generate_def(nsid.as_ref(), def_name.as_ref(), def)?;
28 file_contents
29 .entry(file_path.clone())
30 .or_default()
31 .push(tokens);
32 }
33 }
34
35 // Combine all tokens for each file
36 let mut result = BTreeMap::new();
37 for (path, tokens_vec) in file_contents {
38 let nsid = file_nsids.get(&path).cloned();
39 result.insert(path, (quote! { #(#tokens_vec)* }, nsid));
40 }
41
42 Ok(result)
43 }
44
45 /// Generate parent module files with pub mod declarations
46 pub fn generate_module_tree(
47 &self,
48 file_map: &BTreeMap<std::path::PathBuf, (TokenStream, Option<String>)>,
49 defs_only: &BTreeMap<std::path::PathBuf, (TokenStream, Option<String>)>,
50 subscription_files: &HashSet<std::path::PathBuf>,
51 ) -> BTreeMap<std::path::PathBuf, (TokenStream, Option<String>)> {
52 // Track what modules each directory needs to declare
53 // Key: directory path, Value: set of module names (file stems)
54 let mut dir_modules: BTreeMap<std::path::PathBuf, BTreeSet<String>> = BTreeMap::new();
55
56 // Collect all parent directories that have files
57 let mut all_dirs: BTreeSet<std::path::PathBuf> = BTreeSet::new();
58 for path in file_map.keys() {
59 if let Some(parent_dir) = path.parent() {
60 all_dirs.insert(parent_dir.to_path_buf());
61 }
62 }
63
64 for path in file_map.keys() {
65 if let Some(parent_dir) = path.parent() {
66 if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) {
67 // Skip mod.rs and lib.rs - they're module files, not modules to declare
68 if file_stem == "mod" || file_stem == "lib" {
69 continue;
70 }
71
72 // Always add the module declaration to parent
73 dir_modules
74 .entry(parent_dir.to_path_buf())
75 .or_default()
76 .insert(file_stem.to_string());
77 }
78 }
79 }
80
81 // Generate module files
82 let mut result = BTreeMap::new();
83
84 for (dir, module_names) in dir_modules {
85 let mod_file_path = if dir.components().count() == 0 {
86 // Root directory -> lib.rs for library crates
87 std::path::PathBuf::from("lib.rs")
88 } else {
89 // Subdirectory: app_bsky/feed -> app_bsky/feed.rs (Rust 2018 style)
90 let dir_name = dir.file_name().and_then(|s| s.to_str()).unwrap_or("mod");
91 let sanitized_dir_name = sanitize_name(dir_name);
92 let mut path = dir
93 .parent()
94 .unwrap_or_else(|| std::path::Path::new(""))
95 .to_path_buf();
96 path.push(format!("{}.rs", sanitized_dir_name));
97 path
98 };
99
100 let is_root = dir.components().count() == 0;
101 let mods: Vec<_> = module_names
102 .iter()
103 .map(|name| {
104 let ident = make_ident(name);
105
106 // Check if this module is a subscription endpoint
107 let mut module_path = dir.clone();
108 module_path.push(format!("{}.rs", name));
109 let is_subscription = subscription_files.contains(&module_path);
110
111 if is_root && name != "builder_types" {
112 // Top-level modules get feature gates (except builder_types which is always needed)
113 quote! {
114 #[cfg(feature = #name)]
115 pub mod #ident;
116 }
117 } else if is_subscription {
118 // Subscription modules get streaming feature gate
119 quote! {
120 #[cfg(feature = "streaming")]
121 pub mod #ident;
122 }
123 } else {
124 quote! { pub mod #ident; }
125 }
126 })
127 .collect();
128
129 // If this file already exists in defs_only (e.g., from defs), merge the content
130 let module_tokens = if is_root {
131 // lib.rs needs extern crate alloc for no_std compatibility
132 quote! { extern crate alloc; #(#mods)* }
133 } else {
134 quote! { #(#mods)* }
135 };
136 if let Some((existing_tokens, nsid)) = defs_only.get(&mod_file_path) {
137 // Put module declarations FIRST, then existing defs content
138 result.insert(
139 mod_file_path,
140 (quote! { #module_tokens #existing_tokens }, nsid.clone()),
141 );
142 } else {
143 result.insert(mod_file_path, (module_tokens, None));
144 }
145 }
146
147 result
148 }
149
150 /// Write all generated code to disk
151 pub fn write_to_disk(&self, output_dir: &std::path::Path) -> Result<()> {
152 // Generate all code (defs only)
153 let defs_files = self.generate_all()?;
154 let mut all_files = defs_files.clone();
155
156 // Generate common builder types (Set, Unset, IsSet, IsUnset)
157 let common_types_path = std::path::PathBuf::from("builder_types.rs");
158 let common_types_tokens = super::builder_gen::common::generate_common_types();
159 all_files.insert(common_types_path, (common_types_tokens, None));
160
161 // Get subscription files for feature gating
162 let subscription_files = self.subscription_files.borrow();
163
164 // Generate module tree iteratively until no new files appear
165 loop {
166 let module_map = self.generate_module_tree(&all_files, &defs_files, &subscription_files);
167 let old_count = all_files.len();
168
169 // Merge new module files
170 for (path, tokens) in module_map {
171 all_files.insert(path, tokens);
172 }
173
174 if all_files.len() == old_count {
175 // No new files added
176 break;
177 }
178 }
179
180 // Write to disk
181 for (path, (tokens, nsid)) in all_files {
182 let full_path = output_dir.join(&path);
183
184 // Create parent directories
185 if let Some(parent) = full_path.parent() {
186 std::fs::create_dir_all(parent)?;
187 }
188
189 // Format code
190 let file: syn::File = syn::parse2(tokens.clone()).map_err(|e| CodegenError::TokenParseError {
191 path: path.clone(),
192 source: e,
193 tokens: tokens.to_string(),
194 })?;
195 let mut formatted = prettyplease::unparse(&file);
196
197 // Add blank lines between top-level items for better readability
198 let lines: Vec<&str> = formatted.lines().collect();
199 let mut result_lines = Vec::new();
200
201 for (i, line) in lines.iter().enumerate() {
202 result_lines.push(*line);
203
204 // Add blank line after closing braces that are at column 0 (top-level items)
205 if *line == "}" && i + 1 < lines.len() && !lines[i + 1].is_empty() {
206 result_lines.push("");
207 }
208
209 // Add blank line after last pub mod declaration before structs/enums
210 if line.starts_with("pub mod ") && i + 1 < lines.len() {
211 let next_line = lines[i + 1];
212 if !next_line.starts_with("pub mod ") && !next_line.is_empty() {
213 result_lines.push("");
214 }
215 }
216 }
217
218 formatted = result_lines.join("\n");
219
220 // Add header comment
221 let header = if let Some(nsid) = nsid {
222 format!(
223 "// @generated by jacquard-lexicon. DO NOT EDIT.\n//\n// Lexicon: {}\n//\n// This file was automatically generated from Lexicon schemas.\n// Any manual changes will be overwritten on the next regeneration.\n\n",
224 nsid
225 )
226 } else {
227 "// @generated by jacquard-lexicon. DO NOT EDIT.\n//\n// This file was automatically generated from Lexicon schemas.\n// Any manual changes will be overwritten on the next regeneration.\n\n".to_string()
228 };
229 formatted = format!("{}{}", header, formatted);
230
231 // Write file
232 std::fs::write(&full_path, formatted)?;
233 }
234
235 Ok(())
236 }
237
238 /// Get namespace dependencies collected during code generation
239 pub fn get_namespace_dependencies(
240 &self,
241 ) -> HashMap<String, HashSet<String>> {
242 self.namespace_deps.borrow().clone()
243 }
244
245 /// Generate Cargo.toml features section from namespace dependencies
246 pub fn generate_cargo_features(&self, lib_rs_path: Option<&std::path::Path>) -> String {
247 use std::fmt::Write;
248
249 let deps = self.namespace_deps.borrow();
250 let mut all_namespaces: HashSet<String> =
251 HashSet::new();
252
253 // Collect all namespaces from the corpus (first two segments of each NSID)
254 for (nsid, _doc) in self.corpus.iter() {
255 let nsid_path = NsidPath::parse(nsid.as_str());
256 let namespace = nsid_path.namespace();
257 all_namespaces.insert(namespace);
258 }
259
260 // Also collect existing feature names from lib.rs
261 let mut existing_features = HashSet::new();
262 if let Some(lib_rs) = lib_rs_path {
263 if let Ok(content) = std::fs::read_to_string(lib_rs) {
264 for line in content.lines() {
265 if let Some(feature) = line
266 .trim()
267 .strip_prefix("#[cfg(feature = \"")
268 .and_then(|s| s.strip_suffix("\")]"))
269 {
270 existing_features.insert(feature.to_string());
271 }
272 }
273 }
274 }
275
276 let mut output = String::new();
277 writeln!(&mut output, "# Generated namespace features").unwrap();
278
279 // Convert namespace to feature name (matching module path sanitization)
280 let to_feature_name = |ns: &str| {
281 ns.split('.')
282 .map(|segment| {
283 // Apply same sanitization as module names
284 let mut result = segment.replace('-', "_");
285 // Prefix with underscore if starts with digit
286 if result.chars().next().map_or(false, |c| c.is_ascii_digit()) {
287 result.insert(0, '_');
288 }
289 result
290 })
291 .collect::<Vec<_>>()
292 .join("_")
293 };
294
295 // Collect all feature names (from corpus + existing lib.rs)
296 let mut all_feature_names = HashSet::new();
297 for ns in &all_namespaces {
298 all_feature_names.insert(to_feature_name(ns));
299 }
300 all_feature_names.extend(existing_features);
301
302 // Sort for consistent output
303 let mut feature_names: Vec<_> = all_feature_names.iter().collect();
304 feature_names.sort();
305
306 // Map namespace to feature name for dependency lookup
307 let mut ns_to_feature: HashMap<&str, String> =
308 HashMap::new();
309 for ns in &all_namespaces {
310 ns_to_feature.insert(ns.as_str(), to_feature_name(ns));
311 }
312
313 for feature_name in feature_names {
314 // Find corresponding namespace for this feature (if any) to look up deps
315 let feature_deps: Vec<String> = all_namespaces
316 .iter()
317 .find(|ns| to_feature_name(ns) == *feature_name)
318 .and_then(|ns| deps.get(ns.as_str()))
319 .map(|ns_deps| {
320 let mut dep_features: Vec<_> = ns_deps
321 .iter()
322 .map(|d| format!("\"{}\"", to_feature_name(d)))
323 .collect();
324 dep_features.sort();
325 dep_features
326 })
327 .unwrap_or_default();
328
329 if !feature_deps.is_empty() {
330 writeln!(
331 &mut output,
332 "{} = [{}]",
333 feature_name,
334 feature_deps.join(", ")
335 )
336 .unwrap();
337 } else {
338 writeln!(&mut output, "{} = []", feature_name).unwrap();
339 }
340 }
341
342 output
343 }
344}