A better Rust ATProto crate
at main 180 lines 5.6 kB view raw
1//! State module generation for builders 2//! 3//! Generates the state trait, Empty state, and SetX<S> transition types 4//! that enable type-safe builder patterns. 5 6use std::collections::HashSet; 7 8use heck::{ToPascalCase, ToSnakeCase}; 9use jacquard_common::smol_str::SmolStr; 10use proc_macro2::TokenStream; 11use quote::{format_ident, quote}; 12 13use crate::codegen::utils::make_ident; 14 15/// Information about a required field for builder state generation 16#[derive(Debug, Clone, Hash, PartialEq, Eq)] 17pub struct RequiredField { 18 /// Field name (snake_case) 19 pub name_snake: SmolStr, 20 /// Field name (PascalCase) 21 pub name_pascal: SmolStr, 22} 23 24impl RequiredField { 25 pub fn new(field_name: &str) -> Self { 26 Self { 27 name_snake: SmolStr::new(field_name.to_snake_case()), 28 name_pascal: SmolStr::new(field_name.to_pascal_case()), 29 } 30 } 31} 32 33/// Collect required fields from a builder schema 34pub fn collect_required_fields(schema: &super::BuilderSchema<'_>) -> Vec<RequiredField> { 35 let required = schema.required().unwrap_or(&[]); 36 let nullable = schema.nullable().unwrap_or(&[]); 37 38 let set: HashSet<_> = required 39 .iter() 40 .filter_map(|field_name| { 41 if nullable.contains(&field_name) { 42 None 43 } else { 44 let field_name_str: &str = field_name.as_ref(); 45 Some(RequiredField::new(field_name_str)) 46 } 47 }) 48 .collect(); 49 50 set.into_iter().collect() 51} 52 53/// Generate the complete state module for a builder 54pub fn generate_state_module(type_name: &str, required_fields: &[RequiredField]) -> TokenStream { 55 let state_mod_name = format_ident!("{}_state", type_name.to_snake_case()); 56 57 let state_trait = generate_state_trait(required_fields); 58 let empty_state = generate_empty_state(required_fields); 59 let transition_types = generate_transition_types(required_fields); 60 let members_mod = generate_members_module(required_fields); 61 62 quote! { 63 pub mod #state_mod_name { 64 pub use crate::builder_types::{Set, Unset, IsSet, IsUnset}; 65 #[allow(unused)] 66 use ::core::marker::PhantomData; 67 68 mod sealed { 69 pub trait Sealed {} 70 } 71 72 #state_trait 73 #empty_state 74 #transition_types 75 #members_mod 76 } 77 } 78} 79 80/// Generate the State trait with associated types for each required field 81fn generate_state_trait(required_fields: &[RequiredField]) -> TokenStream { 82 let associated_types = required_fields.iter().map(|field| { 83 let field_pascal = format_ident!("{}", field.name_pascal.as_str()); 84 quote! { type #field_pascal; } 85 }); 86 87 quote! { 88 /// State trait tracking which required fields have been set 89 pub trait State: sealed::Sealed { 90 #( #associated_types )* 91 } 92 } 93} 94 95/// Generate the Empty state (all fields Unset) 96fn generate_empty_state(required_fields: &[RequiredField]) -> TokenStream { 97 let field_impls = required_fields.iter().map(|field| { 98 let field_pascal = format_ident!("{}", field.name_pascal.as_str()); 99 quote! { 100 type #field_pascal = Unset; 101 } 102 }); 103 104 quote! { 105 /// Empty state - all required fields are unset 106 pub struct Empty(()); 107 108 impl sealed::Sealed for Empty {} 109 110 impl State for Empty { 111 #( #field_impls )* 112 } 113 } 114} 115 116/// Generate SetX<S> transition types for each required field 117fn generate_transition_types(required_fields: &[RequiredField]) -> TokenStream { 118 let transition_impls = required_fields.iter().map(|field| { 119 let struct_name = format_ident!("Set{}", field.name_pascal.as_str()); 120 121 let doc = format!( 122 "State transition - sets the `{}` field to Set", 123 field.name_snake 124 ); 125 126 // Generate associated type impls - this field is Set, others preserve S's state 127 let field_impls = required_fields.iter().map(|other_field| { 128 let other_pascal = format_ident!("{}", other_field.name_pascal.as_str()); 129 let other_snake = make_ident(other_field.name_snake.as_str()); 130 131 if field.name_snake == other_field.name_snake { 132 // This field becomes Set 133 quote! { 134 type #other_pascal = Set<members::#other_snake>; 135 } 136 } else { 137 // Other fields preserve their state from S 138 quote! { 139 type #other_pascal = S::#other_pascal; 140 } 141 } 142 }); 143 144 quote! { 145 #[doc = #doc] 146 pub struct #struct_name<S: State = Empty>(PhantomData<fn() -> S>); 147 148 impl<S: State> sealed::Sealed for #struct_name<S> {} 149 150 impl<S: State> State for #struct_name<S> { 151 #( #field_impls )* 152 } 153 } 154 }); 155 156 quote! { 157 #( #transition_impls )* 158 } 159} 160 161/// Generate the members module with marker types for each field 162fn generate_members_module(required_fields: &[RequiredField]) -> TokenStream { 163 let member_types = required_fields.iter().map(|field| { 164 let field_snake = make_ident(field.name_snake.as_str()); 165 let doc = format!("Marker type for the `{}` field", field.name_snake); 166 167 quote! { 168 #[doc = #doc] 169 pub struct #field_snake(()); 170 } 171 }); 172 173 quote! { 174 /// Marker types for field names 175 #[allow(non_camel_case_types)] 176 pub mod members { 177 #( #member_types )* 178 } 179 } 180}