A better Rust ATProto crate
1//! State module generation for builders
2//!
3//! Generates the state trait, Empty state, and SetX<S> transition types
4//! that enable type-safe builder patterns.
5
6use std::collections::HashSet;
7
8use heck::{ToPascalCase, ToSnakeCase};
9use jacquard_common::smol_str::SmolStr;
10use proc_macro2::TokenStream;
11use quote::{format_ident, quote};
12
13use crate::codegen::utils::make_ident;
14
15/// Information about a required field for builder state generation
16#[derive(Debug, Clone, Hash, PartialEq, Eq)]
17pub struct RequiredField {
18 /// Field name (snake_case)
19 pub name_snake: SmolStr,
20 /// Field name (PascalCase)
21 pub name_pascal: SmolStr,
22}
23
24impl RequiredField {
25 pub fn new(field_name: &str) -> Self {
26 Self {
27 name_snake: SmolStr::new(field_name.to_snake_case()),
28 name_pascal: SmolStr::new(field_name.to_pascal_case()),
29 }
30 }
31}
32
33/// Collect required fields from a builder schema
34pub fn collect_required_fields(schema: &super::BuilderSchema<'_>) -> Vec<RequiredField> {
35 let required = schema.required().unwrap_or(&[]);
36 let nullable = schema.nullable().unwrap_or(&[]);
37
38 let set: HashSet<_> = required
39 .iter()
40 .filter_map(|field_name| {
41 if nullable.contains(&field_name) {
42 None
43 } else {
44 let field_name_str: &str = field_name.as_ref();
45 Some(RequiredField::new(field_name_str))
46 }
47 })
48 .collect();
49
50 set.into_iter().collect()
51}
52
53/// Generate the complete state module for a builder
54pub fn generate_state_module(type_name: &str, required_fields: &[RequiredField]) -> TokenStream {
55 let state_mod_name = format_ident!("{}_state", type_name.to_snake_case());
56
57 let state_trait = generate_state_trait(required_fields);
58 let empty_state = generate_empty_state(required_fields);
59 let transition_types = generate_transition_types(required_fields);
60 let members_mod = generate_members_module(required_fields);
61
62 quote! {
63 pub mod #state_mod_name {
64 pub use crate::builder_types::{Set, Unset, IsSet, IsUnset};
65 #[allow(unused)]
66 use ::core::marker::PhantomData;
67
68 mod sealed {
69 pub trait Sealed {}
70 }
71
72 #state_trait
73 #empty_state
74 #transition_types
75 #members_mod
76 }
77 }
78}
79
80/// Generate the State trait with associated types for each required field
81fn generate_state_trait(required_fields: &[RequiredField]) -> TokenStream {
82 let associated_types = required_fields.iter().map(|field| {
83 let field_pascal = format_ident!("{}", field.name_pascal.as_str());
84 quote! { type #field_pascal; }
85 });
86
87 quote! {
88 /// State trait tracking which required fields have been set
89 pub trait State: sealed::Sealed {
90 #( #associated_types )*
91 }
92 }
93}
94
95/// Generate the Empty state (all fields Unset)
96fn generate_empty_state(required_fields: &[RequiredField]) -> TokenStream {
97 let field_impls = required_fields.iter().map(|field| {
98 let field_pascal = format_ident!("{}", field.name_pascal.as_str());
99 quote! {
100 type #field_pascal = Unset;
101 }
102 });
103
104 quote! {
105 /// Empty state - all required fields are unset
106 pub struct Empty(());
107
108 impl sealed::Sealed for Empty {}
109
110 impl State for Empty {
111 #( #field_impls )*
112 }
113 }
114}
115
116/// Generate SetX<S> transition types for each required field
117fn generate_transition_types(required_fields: &[RequiredField]) -> TokenStream {
118 let transition_impls = required_fields.iter().map(|field| {
119 let struct_name = format_ident!("Set{}", field.name_pascal.as_str());
120
121 let doc = format!(
122 "State transition - sets the `{}` field to Set",
123 field.name_snake
124 );
125
126 // Generate associated type impls - this field is Set, others preserve S's state
127 let field_impls = required_fields.iter().map(|other_field| {
128 let other_pascal = format_ident!("{}", other_field.name_pascal.as_str());
129 let other_snake = make_ident(other_field.name_snake.as_str());
130
131 if field.name_snake == other_field.name_snake {
132 // This field becomes Set
133 quote! {
134 type #other_pascal = Set<members::#other_snake>;
135 }
136 } else {
137 // Other fields preserve their state from S
138 quote! {
139 type #other_pascal = S::#other_pascal;
140 }
141 }
142 });
143
144 quote! {
145 #[doc = #doc]
146 pub struct #struct_name<S: State = Empty>(PhantomData<fn() -> S>);
147
148 impl<S: State> sealed::Sealed for #struct_name<S> {}
149
150 impl<S: State> State for #struct_name<S> {
151 #( #field_impls )*
152 }
153 }
154 });
155
156 quote! {
157 #( #transition_impls )*
158 }
159}
160
161/// Generate the members module with marker types for each field
162fn generate_members_module(required_fields: &[RequiredField]) -> TokenStream {
163 let member_types = required_fields.iter().map(|field| {
164 let field_snake = make_ident(field.name_snake.as_str());
165 let doc = format!("Marker type for the `{}` field", field.name_snake);
166
167 quote! {
168 #[doc = #doc]
169 pub struct #field_snake(());
170 }
171 });
172
173 quote! {
174 /// Marker types for field names
175 #[allow(non_camel_case_types)]
176 pub mod members {
177 #( #member_types )*
178 }
179 }
180}