just playing with tangled
1// Copyright 2020-2024 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Domain-specific language helpers.
16
17use std::ascii;
18use std::collections::HashMap;
19use std::fmt;
20use std::slice;
21
22use itertools::Itertools as _;
23use pest::iterators::Pair;
24use pest::iterators::Pairs;
25use pest::RuleType;
26
27/// Manages diagnostic messages emitted during parsing.
28///
29/// `T` is usually a parse error type of the language, which contains a message
30/// and source span of 'static lifetime.
31#[derive(Debug)]
32pub struct Diagnostics<T> {
33 // This might be extended to [{ kind: Warning|Error, message: T }, ..].
34 diagnostics: Vec<T>,
35}
36
37impl<T> Diagnostics<T> {
38 /// Creates new empty diagnostics collector.
39 pub fn new() -> Self {
40 Diagnostics {
41 diagnostics: Vec::new(),
42 }
43 }
44
45 /// Returns `true` if there are no diagnostic messages.
46 pub fn is_empty(&self) -> bool {
47 self.diagnostics.is_empty()
48 }
49
50 /// Returns the number of diagnostic messages.
51 pub fn len(&self) -> usize {
52 self.diagnostics.len()
53 }
54
55 /// Returns iterator over diagnostic messages.
56 pub fn iter(&self) -> slice::Iter<'_, T> {
57 self.diagnostics.iter()
58 }
59
60 /// Adds a diagnostic message of warning level.
61 pub fn add_warning(&mut self, diag: T) {
62 self.diagnostics.push(diag);
63 }
64
65 /// Moves diagnostic messages of different type (such as fileset warnings
66 /// emitted within `file()` revset.)
67 pub fn extend_with<U>(&mut self, diagnostics: Diagnostics<U>, mut f: impl FnMut(U) -> T) {
68 self.diagnostics
69 .extend(diagnostics.diagnostics.into_iter().map(&mut f));
70 }
71}
72
73impl<T> Default for Diagnostics<T> {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl<'a, T> IntoIterator for &'a Diagnostics<T> {
80 type Item = &'a T;
81 type IntoIter = slice::Iter<'a, T>;
82
83 fn into_iter(self) -> Self::IntoIter {
84 self.iter()
85 }
86}
87
88/// AST node without type or name checking.
89#[derive(Clone, Debug, Eq, PartialEq)]
90pub struct ExpressionNode<'i, T> {
91 /// Expression item such as identifier, literal, function call, etc.
92 pub kind: T,
93 /// Span of the node.
94 pub span: pest::Span<'i>,
95}
96
97impl<'i, T> ExpressionNode<'i, T> {
98 /// Wraps the given expression and span.
99 pub fn new(kind: T, span: pest::Span<'i>) -> Self {
100 ExpressionNode { kind, span }
101 }
102}
103
104/// Function call in AST.
105#[derive(Clone, Debug, Eq, PartialEq)]
106pub struct FunctionCallNode<'i, T> {
107 /// Function name.
108 pub name: &'i str,
109 /// Span of the function name.
110 pub name_span: pest::Span<'i>,
111 /// List of positional arguments.
112 pub args: Vec<ExpressionNode<'i, T>>,
113 /// List of keyword arguments.
114 pub keyword_args: Vec<KeywordArgument<'i, T>>,
115 /// Span of the arguments list.
116 pub args_span: pest::Span<'i>,
117}
118
119/// Keyword argument pair in AST.
120#[derive(Clone, Debug, Eq, PartialEq)]
121pub struct KeywordArgument<'i, T> {
122 /// Parameter name.
123 pub name: &'i str,
124 /// Span of the parameter name.
125 pub name_span: pest::Span<'i>,
126 /// Value expression.
127 pub value: ExpressionNode<'i, T>,
128}
129
130impl<'i, T> FunctionCallNode<'i, T> {
131 /// Number of arguments assuming named arguments are all unique.
132 pub fn arity(&self) -> usize {
133 self.args.len() + self.keyword_args.len()
134 }
135
136 /// Ensures that no arguments passed.
137 pub fn expect_no_arguments(&self) -> Result<(), InvalidArguments<'i>> {
138 let ([], []) = self.expect_arguments()?;
139 Ok(())
140 }
141
142 /// Extracts exactly N required arguments.
143 pub fn expect_exact_arguments<const N: usize>(
144 &self,
145 ) -> Result<&[ExpressionNode<'i, T>; N], InvalidArguments<'i>> {
146 let (args, []) = self.expect_arguments()?;
147 Ok(args)
148 }
149
150 /// Extracts N required arguments and remainders.
151 #[expect(clippy::type_complexity)]
152 pub fn expect_some_arguments<const N: usize>(
153 &self,
154 ) -> Result<(&[ExpressionNode<'i, T>; N], &[ExpressionNode<'i, T>]), InvalidArguments<'i>> {
155 self.ensure_no_keyword_arguments()?;
156 if self.args.len() >= N {
157 let (required, rest) = self.args.split_at(N);
158 Ok((required.try_into().unwrap(), rest))
159 } else {
160 Err(self.invalid_arguments_count(N, None))
161 }
162 }
163
164 /// Extracts N required arguments and M optional arguments.
165 #[expect(clippy::type_complexity)]
166 pub fn expect_arguments<const N: usize, const M: usize>(
167 &self,
168 ) -> Result<
169 (
170 &[ExpressionNode<'i, T>; N],
171 [Option<&ExpressionNode<'i, T>>; M],
172 ),
173 InvalidArguments<'i>,
174 > {
175 self.ensure_no_keyword_arguments()?;
176 let count_range = N..=(N + M);
177 if count_range.contains(&self.args.len()) {
178 let (required, rest) = self.args.split_at(N);
179 let mut optional = rest.iter().map(Some).collect_vec();
180 optional.resize(M, None);
181 Ok((
182 required.try_into().unwrap(),
183 optional.try_into().ok().unwrap(),
184 ))
185 } else {
186 let (min, max) = count_range.into_inner();
187 Err(self.invalid_arguments_count(min, Some(max)))
188 }
189 }
190
191 /// Extracts N required arguments and M optional arguments. Some of them can
192 /// be specified as keyword arguments.
193 ///
194 /// `names` is a list of parameter names. Unnamed positional arguments
195 /// should be padded with `""`.
196 #[expect(clippy::type_complexity)]
197 pub fn expect_named_arguments<const N: usize, const M: usize>(
198 &self,
199 names: &[&str],
200 ) -> Result<
201 (
202 [&ExpressionNode<'i, T>; N],
203 [Option<&ExpressionNode<'i, T>>; M],
204 ),
205 InvalidArguments<'i>,
206 > {
207 if self.keyword_args.is_empty() {
208 let (required, optional) = self.expect_arguments::<N, M>()?;
209 Ok((required.each_ref(), optional))
210 } else {
211 let (required, optional) = self.expect_named_arguments_vec(names, N, N + M)?;
212 Ok((
213 required.try_into().ok().unwrap(),
214 optional.try_into().ok().unwrap(),
215 ))
216 }
217 }
218
219 #[expect(clippy::type_complexity)]
220 fn expect_named_arguments_vec(
221 &self,
222 names: &[&str],
223 min: usize,
224 max: usize,
225 ) -> Result<
226 (
227 Vec<&ExpressionNode<'i, T>>,
228 Vec<Option<&ExpressionNode<'i, T>>>,
229 ),
230 InvalidArguments<'i>,
231 > {
232 assert!(names.len() <= max);
233
234 if self.args.len() > max {
235 return Err(self.invalid_arguments_count(min, Some(max)));
236 }
237 let mut extracted = Vec::with_capacity(max);
238 extracted.extend(self.args.iter().map(Some));
239 extracted.resize(max, None);
240
241 for arg in &self.keyword_args {
242 let name = arg.name;
243 let span = arg.name_span.start_pos().span(&arg.value.span.end_pos());
244 let pos = names.iter().position(|&n| n == name).ok_or_else(|| {
245 self.invalid_arguments(format!(r#"Unexpected keyword argument "{name}""#), span)
246 })?;
247 if extracted[pos].is_some() {
248 return Err(self.invalid_arguments(
249 format!(r#"Got multiple values for keyword "{name}""#),
250 span,
251 ));
252 }
253 extracted[pos] = Some(&arg.value);
254 }
255
256 let optional = extracted.split_off(min);
257 let required = extracted.into_iter().flatten().collect_vec();
258 if required.len() != min {
259 return Err(self.invalid_arguments_count(min, Some(max)));
260 }
261 Ok((required, optional))
262 }
263
264 fn ensure_no_keyword_arguments(&self) -> Result<(), InvalidArguments<'i>> {
265 if let (Some(first), Some(last)) = (self.keyword_args.first(), self.keyword_args.last()) {
266 let span = first.name_span.start_pos().span(&last.value.span.end_pos());
267 Err(self.invalid_arguments("Unexpected keyword arguments".to_owned(), span))
268 } else {
269 Ok(())
270 }
271 }
272
273 fn invalid_arguments(&self, message: String, span: pest::Span<'i>) -> InvalidArguments<'i> {
274 InvalidArguments {
275 name: self.name,
276 message,
277 span,
278 }
279 }
280
281 fn invalid_arguments_count(&self, min: usize, max: Option<usize>) -> InvalidArguments<'i> {
282 let message = match (min, max) {
283 (min, Some(max)) if min == max => format!("Expected {min} arguments"),
284 (min, Some(max)) => format!("Expected {min} to {max} arguments"),
285 (min, None) => format!("Expected at least {min} arguments"),
286 };
287 self.invalid_arguments(message, self.args_span)
288 }
289
290 fn invalid_arguments_count_with_arities(
291 &self,
292 arities: impl IntoIterator<Item = usize>,
293 ) -> InvalidArguments<'i> {
294 let message = format!("Expected {} arguments", arities.into_iter().join(", "));
295 self.invalid_arguments(message, self.args_span)
296 }
297}
298
299/// Unexpected number of arguments, or invalid combination of arguments.
300///
301/// This error is supposed to be converted to language-specific parse error
302/// type, where lifetime `'i` will be eliminated.
303#[derive(Clone, Debug)]
304pub struct InvalidArguments<'i> {
305 /// Function name.
306 pub name: &'i str,
307 /// Error message.
308 pub message: String,
309 /// Span of the bad arguments.
310 pub span: pest::Span<'i>,
311}
312
313/// Expression item that can be transformed recursively by using `folder: F`.
314pub trait FoldableExpression<'i>: Sized {
315 /// Transforms `self` by applying the `folder` to inner items.
316 fn fold<F>(self, folder: &mut F, span: pest::Span<'i>) -> Result<Self, F::Error>
317 where
318 F: ExpressionFolder<'i, Self> + ?Sized;
319}
320
321/// Visitor-like interface to transform AST nodes recursively.
322pub trait ExpressionFolder<'i, T: FoldableExpression<'i>> {
323 /// Transform error.
324 type Error;
325
326 /// Transforms the expression `node`. By default, inner items are
327 /// transformed recursively.
328 fn fold_expression(
329 &mut self,
330 node: ExpressionNode<'i, T>,
331 ) -> Result<ExpressionNode<'i, T>, Self::Error> {
332 let ExpressionNode { kind, span } = node;
333 let kind = kind.fold(self, span)?;
334 Ok(ExpressionNode { kind, span })
335 }
336
337 /// Transforms identifier.
338 fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error>;
339
340 /// Transforms function call.
341 fn fold_function_call(
342 &mut self,
343 function: Box<FunctionCallNode<'i, T>>,
344 span: pest::Span<'i>,
345 ) -> Result<T, Self::Error>;
346}
347
348/// Transforms list of `nodes` by using `folder`.
349pub fn fold_expression_nodes<'i, F, T>(
350 folder: &mut F,
351 nodes: Vec<ExpressionNode<'i, T>>,
352) -> Result<Vec<ExpressionNode<'i, T>>, F::Error>
353where
354 F: ExpressionFolder<'i, T> + ?Sized,
355 T: FoldableExpression<'i>,
356{
357 nodes
358 .into_iter()
359 .map(|node| folder.fold_expression(node))
360 .try_collect()
361}
362
363/// Transforms function call arguments by using `folder`.
364pub fn fold_function_call_args<'i, F, T>(
365 folder: &mut F,
366 function: FunctionCallNode<'i, T>,
367) -> Result<FunctionCallNode<'i, T>, F::Error>
368where
369 F: ExpressionFolder<'i, T> + ?Sized,
370 T: FoldableExpression<'i>,
371{
372 Ok(FunctionCallNode {
373 name: function.name,
374 name_span: function.name_span,
375 args: fold_expression_nodes(folder, function.args)?,
376 keyword_args: function
377 .keyword_args
378 .into_iter()
379 .map(|arg| {
380 Ok(KeywordArgument {
381 name: arg.name,
382 name_span: arg.name_span,
383 value: folder.fold_expression(arg.value)?,
384 })
385 })
386 .try_collect()?,
387 args_span: function.args_span,
388 })
389}
390
391/// Helper to parse string literal.
392#[derive(Debug)]
393pub struct StringLiteralParser<R> {
394 /// String content part.
395 pub content_rule: R,
396 /// Escape sequence part including backslash character.
397 pub escape_rule: R,
398}
399
400impl<R: RuleType> StringLiteralParser<R> {
401 /// Parses the given string literal `pairs` into string.
402 pub fn parse(&self, pairs: Pairs<R>) -> String {
403 let mut result = String::new();
404 for part in pairs {
405 if part.as_rule() == self.content_rule {
406 result.push_str(part.as_str());
407 } else if part.as_rule() == self.escape_rule {
408 match &part.as_str()[1..] {
409 "\"" => result.push('"'),
410 "\\" => result.push('\\'),
411 "t" => result.push('\t'),
412 "r" => result.push('\r'),
413 "n" => result.push('\n'),
414 "0" => result.push('\0'),
415 "e" => result.push('\x1b'),
416 hex if hex.starts_with('x') => {
417 result.push(char::from(
418 u8::from_str_radix(&hex[1..], 16).expect("hex characters"),
419 ));
420 }
421 char => panic!("invalid escape: \\{char:?}"),
422 }
423 } else {
424 panic!("unexpected part of string: {part:?}");
425 }
426 }
427 result
428 }
429}
430
431/// Escape special characters in the input
432pub fn escape_string(unescaped: &str) -> String {
433 let mut escaped = String::with_capacity(unescaped.len());
434 for c in unescaped.chars() {
435 match c {
436 '"' => escaped.push_str(r#"\""#),
437 '\\' => escaped.push_str(r#"\\"#),
438 '\t' => escaped.push_str(r#"\t"#),
439 '\r' => escaped.push_str(r#"\r"#),
440 '\n' => escaped.push_str(r#"\n"#),
441 '\0' => escaped.push_str(r#"\0"#),
442 c if c.is_ascii_control() => {
443 for b in ascii::escape_default(c as u8) {
444 escaped.push(b as char);
445 }
446 }
447 c => escaped.push(c),
448 }
449 }
450 escaped
451}
452
453/// Helper to parse function call.
454#[derive(Debug)]
455pub struct FunctionCallParser<R> {
456 /// Function name.
457 pub function_name_rule: R,
458 /// List of positional and keyword arguments.
459 pub function_arguments_rule: R,
460 /// Pair of parameter name and value.
461 pub keyword_argument_rule: R,
462 /// Parameter name.
463 pub argument_name_rule: R,
464 /// Value expression.
465 pub argument_value_rule: R,
466}
467
468impl<R: RuleType> FunctionCallParser<R> {
469 /// Parses the given `pair` as function call.
470 pub fn parse<'i, T, E: From<InvalidArguments<'i>>>(
471 &self,
472 pair: Pair<'i, R>,
473 // parse_name can be defined for any Pair<'_, R>, but parse_value should
474 // be allowed to construct T by capturing Pair<'i, R>.
475 parse_name: impl Fn(Pair<'i, R>) -> Result<&'i str, E>,
476 parse_value: impl Fn(Pair<'i, R>) -> Result<ExpressionNode<'i, T>, E>,
477 ) -> Result<FunctionCallNode<'i, T>, E> {
478 let (name_pair, args_pair) = pair.into_inner().collect_tuple().unwrap();
479 assert_eq!(name_pair.as_rule(), self.function_name_rule);
480 assert_eq!(args_pair.as_rule(), self.function_arguments_rule);
481 let name_span = name_pair.as_span();
482 let args_span = args_pair.as_span();
483 let function_name = parse_name(name_pair)?;
484 let mut args = Vec::new();
485 let mut keyword_args = Vec::new();
486 for pair in args_pair.into_inner() {
487 let span = pair.as_span();
488 if pair.as_rule() == self.argument_value_rule {
489 if !keyword_args.is_empty() {
490 return Err(InvalidArguments {
491 name: function_name,
492 message: "Positional argument follows keyword argument".to_owned(),
493 span,
494 }
495 .into());
496 }
497 args.push(parse_value(pair)?);
498 } else if pair.as_rule() == self.keyword_argument_rule {
499 let (name_pair, value_pair) = pair.into_inner().collect_tuple().unwrap();
500 assert_eq!(name_pair.as_rule(), self.argument_name_rule);
501 assert_eq!(value_pair.as_rule(), self.argument_value_rule);
502 let name_span = name_pair.as_span();
503 let arg = KeywordArgument {
504 name: parse_name(name_pair)?,
505 name_span,
506 value: parse_value(value_pair)?,
507 };
508 keyword_args.push(arg);
509 } else {
510 panic!("unexpected argument rule {pair:?}");
511 }
512 }
513 Ok(FunctionCallNode {
514 name: function_name,
515 name_span,
516 args,
517 keyword_args,
518 args_span,
519 })
520 }
521}
522
523/// Map of symbol and function aliases.
524#[derive(Clone, Debug, Default)]
525pub struct AliasesMap<P, V> {
526 symbol_aliases: HashMap<String, V>,
527 // name: [(params, defn)] (sorted by arity)
528 function_aliases: HashMap<String, Vec<(Vec<String>, V)>>,
529 // Parser type P helps prevent misuse of AliasesMap of different language.
530 parser: P,
531}
532
533impl<P, V> AliasesMap<P, V> {
534 /// Creates an empty aliases map with default-constructed parser.
535 pub fn new() -> Self
536 where
537 P: Default,
538 {
539 Self {
540 symbol_aliases: Default::default(),
541 function_aliases: Default::default(),
542 parser: Default::default(),
543 }
544 }
545
546 /// Adds new substitution rule `decl = defn`.
547 ///
548 /// Returns error if `decl` is invalid. The `defn` part isn't checked. A bad
549 /// `defn` will be reported when the alias is substituted.
550 pub fn insert(&mut self, decl: impl AsRef<str>, defn: impl Into<V>) -> Result<(), P::Error>
551 where
552 P: AliasDeclarationParser,
553 {
554 match self.parser.parse_declaration(decl.as_ref())? {
555 AliasDeclaration::Symbol(name) => {
556 self.symbol_aliases.insert(name, defn.into());
557 }
558 AliasDeclaration::Function(name, params) => {
559 let overloads = self.function_aliases.entry(name).or_default();
560 match overloads.binary_search_by_key(¶ms.len(), |(params, _)| params.len()) {
561 Ok(i) => overloads[i] = (params, defn.into()),
562 Err(i) => overloads.insert(i, (params, defn.into())),
563 }
564 }
565 }
566 Ok(())
567 }
568
569 /// Iterates symbol names in arbitrary order.
570 pub fn symbol_names(&self) -> impl Iterator<Item = &str> {
571 self.symbol_aliases.keys().map(|n| n.as_ref())
572 }
573
574 /// Iterates function names in arbitrary order.
575 pub fn function_names(&self) -> impl Iterator<Item = &str> {
576 self.function_aliases.keys().map(|n| n.as_ref())
577 }
578
579 /// Looks up symbol alias by name. Returns identifier and definition text.
580 pub fn get_symbol(&self, name: &str) -> Option<(AliasId<'_>, &V)> {
581 self.symbol_aliases
582 .get_key_value(name)
583 .map(|(name, defn)| (AliasId::Symbol(name), defn))
584 }
585
586 /// Looks up function alias by name and arity. Returns identifier, list of
587 /// parameter names, and definition text.
588 pub fn get_function(&self, name: &str, arity: usize) -> Option<(AliasId<'_>, &[String], &V)> {
589 let overloads = self.get_function_overloads(name)?;
590 overloads.find_by_arity(arity)
591 }
592
593 /// Looks up function aliases by name.
594 fn get_function_overloads(&self, name: &str) -> Option<AliasFunctionOverloads<'_, V>> {
595 let (name, overloads) = self.function_aliases.get_key_value(name)?;
596 Some(AliasFunctionOverloads { name, overloads })
597 }
598}
599
600#[derive(Clone, Debug)]
601struct AliasFunctionOverloads<'a, V> {
602 name: &'a String,
603 overloads: &'a Vec<(Vec<String>, V)>,
604}
605
606impl<'a, V> AliasFunctionOverloads<'a, V> {
607 // TODO: Perhaps, V doesn't have to be captured, but "currently, all type
608 // parameters are required to be mentioned in the precise captures list" as
609 // of rustc 1.85.0.
610 fn arities(&self) -> impl DoubleEndedIterator<Item = usize> + ExactSizeIterator + use<'a, V> {
611 self.overloads.iter().map(|(params, _)| params.len())
612 }
613
614 fn min_arity(&self) -> usize {
615 self.arities().next().unwrap()
616 }
617
618 fn max_arity(&self) -> usize {
619 self.arities().next_back().unwrap()
620 }
621
622 fn find_by_arity(&self, arity: usize) -> Option<(AliasId<'a>, &'a [String], &'a V)> {
623 let index = self
624 .overloads
625 .binary_search_by_key(&arity, |(params, _)| params.len())
626 .ok()?;
627 let (params, defn) = &self.overloads[index];
628 // Exact parameter names aren't needed to identify a function, but they
629 // provide a better error indication. (e.g. "foo(x, y)" is easier to
630 // follow than "foo/2".)
631 Some((AliasId::Function(self.name, params), params, defn))
632 }
633}
634
635/// Borrowed reference to identify alias expression.
636#[derive(Clone, Copy, Debug, Eq, PartialEq)]
637pub enum AliasId<'a> {
638 /// Symbol name.
639 Symbol(&'a str),
640 /// Function name and parameter names.
641 Function(&'a str, &'a [String]),
642 /// Function parameter name.
643 Parameter(&'a str),
644}
645
646impl fmt::Display for AliasId<'_> {
647 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
648 match self {
649 AliasId::Symbol(name) => write!(f, "{name}"),
650 AliasId::Function(name, params) => {
651 write!(f, "{name}({params})", params = params.join(", "))
652 }
653 AliasId::Parameter(name) => write!(f, "{name}"),
654 }
655 }
656}
657
658/// Parsed declaration part of alias rule.
659#[derive(Clone, Debug)]
660pub enum AliasDeclaration {
661 /// Symbol name.
662 Symbol(String),
663 /// Function name and parameters.
664 Function(String, Vec<String>),
665}
666
667// AliasDeclarationParser and AliasDefinitionParser can be merged into a single
668// trait, but it's unclear whether doing that would simplify the abstraction.
669
670/// Parser for symbol and function alias declaration.
671pub trait AliasDeclarationParser {
672 /// Parse error type.
673 type Error;
674
675 /// Parses symbol or function name and parameters.
676 fn parse_declaration(&self, source: &str) -> Result<AliasDeclaration, Self::Error>;
677}
678
679/// Parser for symbol and function alias definition.
680pub trait AliasDefinitionParser {
681 /// Expression item type.
682 type Output<'i>;
683 /// Parse error type.
684 type Error;
685
686 /// Parses alias body.
687 fn parse_definition<'i>(
688 &self,
689 source: &'i str,
690 ) -> Result<ExpressionNode<'i, Self::Output<'i>>, Self::Error>;
691}
692
693/// Expression item that supports alias substitution.
694pub trait AliasExpandableExpression<'i>: FoldableExpression<'i> {
695 /// Wraps identifier.
696 fn identifier(name: &'i str) -> Self;
697 /// Wraps function call.
698 fn function_call(function: Box<FunctionCallNode<'i, Self>>) -> Self;
699 /// Wraps substituted expression.
700 fn alias_expanded(id: AliasId<'i>, subst: Box<ExpressionNode<'i, Self>>) -> Self;
701}
702
703/// Error that may occur during alias substitution.
704pub trait AliasExpandError: Sized {
705 /// Unexpected number of arguments, or invalid combination of arguments.
706 fn invalid_arguments(err: InvalidArguments<'_>) -> Self;
707 /// Recursion detected during alias substitution.
708 fn recursive_expansion(id: AliasId<'_>, span: pest::Span<'_>) -> Self;
709 /// Attaches alias trace to the current error.
710 fn within_alias_expansion(self, id: AliasId<'_>, span: pest::Span<'_>) -> Self;
711}
712
713/// Expands aliases recursively in tree of `T`.
714#[derive(Debug)]
715struct AliasExpander<'i, T, P> {
716 /// Alias symbols and functions that are globally available.
717 aliases_map: &'i AliasesMap<P, String>,
718 /// Stack of aliases and local parameters currently expanding.
719 states: Vec<AliasExpandingState<'i, T>>,
720}
721
722#[derive(Debug)]
723struct AliasExpandingState<'i, T> {
724 id: AliasId<'i>,
725 locals: HashMap<&'i str, ExpressionNode<'i, T>>,
726}
727
728impl<'i, T, P, E> AliasExpander<'i, T, P>
729where
730 T: AliasExpandableExpression<'i> + Clone,
731 P: AliasDefinitionParser<Output<'i> = T, Error = E>,
732 E: AliasExpandError,
733{
734 fn expand_defn(
735 &mut self,
736 id: AliasId<'i>,
737 defn: &'i str,
738 locals: HashMap<&'i str, ExpressionNode<'i, T>>,
739 span: pest::Span<'i>,
740 ) -> Result<T, E> {
741 // The stack should be short, so let's simply do linear search.
742 if self.states.iter().any(|s| s.id == id) {
743 return Err(E::recursive_expansion(id, span));
744 }
745 self.states.push(AliasExpandingState { id, locals });
746 // Parsed defn could be cached if needed.
747 let result = self
748 .aliases_map
749 .parser
750 .parse_definition(defn)
751 .and_then(|node| self.fold_expression(node))
752 .map(|node| T::alias_expanded(id, Box::new(node)))
753 .map_err(|e| e.within_alias_expansion(id, span));
754 self.states.pop();
755 result
756 }
757}
758
759impl<'i, T, P, E> ExpressionFolder<'i, T> for AliasExpander<'i, T, P>
760where
761 T: AliasExpandableExpression<'i> + Clone,
762 P: AliasDefinitionParser<Output<'i> = T, Error = E>,
763 E: AliasExpandError,
764{
765 type Error = E;
766
767 fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error> {
768 if let Some(subst) = self.states.last().and_then(|s| s.locals.get(name)) {
769 let id = AliasId::Parameter(name);
770 Ok(T::alias_expanded(id, Box::new(subst.clone())))
771 } else if let Some((id, defn)) = self.aliases_map.get_symbol(name) {
772 let locals = HashMap::new(); // Don't spill out the current scope
773 self.expand_defn(id, defn, locals, span)
774 } else {
775 Ok(T::identifier(name))
776 }
777 }
778
779 fn fold_function_call(
780 &mut self,
781 function: Box<FunctionCallNode<'i, T>>,
782 span: pest::Span<'i>,
783 ) -> Result<T, Self::Error> {
784 // For better error indication, builtin functions are shadowed by name,
785 // not by (name, arity).
786 if let Some(overloads) = self.aliases_map.get_function_overloads(function.name) {
787 // TODO: add support for keyword arguments
788 function
789 .ensure_no_keyword_arguments()
790 .map_err(E::invalid_arguments)?;
791 let Some((id, params, defn)) = overloads.find_by_arity(function.arity()) else {
792 let min = overloads.min_arity();
793 let max = overloads.max_arity();
794 let err = if max - min + 1 == overloads.arities().len() {
795 function.invalid_arguments_count(min, Some(max))
796 } else {
797 function.invalid_arguments_count_with_arities(overloads.arities())
798 };
799 return Err(E::invalid_arguments(err));
800 };
801 // Resolve arguments in the current scope, and pass them in to the alias
802 // expansion scope.
803 let args = fold_expression_nodes(self, function.args)?;
804 let locals = params.iter().map(|s| s.as_str()).zip(args).collect();
805 self.expand_defn(id, defn, locals, span)
806 } else {
807 let function = Box::new(fold_function_call_args(self, *function)?);
808 Ok(T::function_call(function))
809 }
810 }
811}
812
813/// Expands aliases recursively.
814pub fn expand_aliases<'i, T, P>(
815 node: ExpressionNode<'i, T>,
816 aliases_map: &'i AliasesMap<P, String>,
817) -> Result<ExpressionNode<'i, T>, P::Error>
818where
819 T: AliasExpandableExpression<'i> + Clone,
820 P: AliasDefinitionParser<Output<'i> = T>,
821 P::Error: AliasExpandError,
822{
823 let mut expander = AliasExpander {
824 aliases_map,
825 states: Vec::new(),
826 };
827 expander.fold_expression(node)
828}
829
830/// Collects similar names from the `candidates` list.
831pub fn collect_similar<I>(name: &str, candidates: I) -> Vec<String>
832where
833 I: IntoIterator,
834 I::Item: AsRef<str>,
835{
836 candidates
837 .into_iter()
838 .filter(|cand| {
839 // The parameter is borrowed from clap f5540d26
840 strsim::jaro(name, cand.as_ref()) > 0.7
841 })
842 .map(|s| s.as_ref().to_owned())
843 .sorted_unstable()
844 .collect()
845}
846
847#[cfg(test)]
848mod tests {
849 use super::*;
850
851 #[test]
852 fn test_expect_arguments() {
853 fn empty_span() -> pest::Span<'static> {
854 pest::Span::new("", 0, 0).unwrap()
855 }
856
857 fn function(
858 name: &'static str,
859 args: impl Into<Vec<ExpressionNode<'static, u32>>>,
860 keyword_args: impl Into<Vec<KeywordArgument<'static, u32>>>,
861 ) -> FunctionCallNode<'static, u32> {
862 FunctionCallNode {
863 name,
864 name_span: empty_span(),
865 args: args.into(),
866 keyword_args: keyword_args.into(),
867 args_span: empty_span(),
868 }
869 }
870
871 fn value(v: u32) -> ExpressionNode<'static, u32> {
872 ExpressionNode::new(v, empty_span())
873 }
874
875 fn keyword(name: &'static str, v: u32) -> KeywordArgument<'static, u32> {
876 KeywordArgument {
877 name,
878 name_span: empty_span(),
879 value: value(v),
880 }
881 }
882
883 let f = function("foo", [], []);
884 assert!(f.expect_no_arguments().is_ok());
885 assert!(f.expect_some_arguments::<0>().is_ok());
886 assert!(f.expect_arguments::<0, 0>().is_ok());
887 assert!(f.expect_named_arguments::<0, 0>(&[]).is_ok());
888
889 let f = function("foo", [value(0)], []);
890 assert!(f.expect_no_arguments().is_err());
891 assert_eq!(
892 f.expect_some_arguments::<0>().unwrap(),
893 (&[], [value(0)].as_slice())
894 );
895 assert_eq!(
896 f.expect_some_arguments::<1>().unwrap(),
897 (&[value(0)], [].as_slice())
898 );
899 assert!(f.expect_arguments::<0, 0>().is_err());
900 assert_eq!(
901 f.expect_arguments::<0, 1>().unwrap(),
902 (&[], [Some(&value(0))])
903 );
904 assert_eq!(f.expect_arguments::<1, 1>().unwrap(), (&[value(0)], [None]));
905 assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
906 assert_eq!(
907 f.expect_named_arguments::<0, 1>(&["a"]).unwrap(),
908 ([], [Some(&value(0))])
909 );
910 assert_eq!(
911 f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
912 ([&value(0)], [])
913 );
914
915 let f = function("foo", [], [keyword("a", 0)]);
916 assert!(f.expect_no_arguments().is_err());
917 assert!(f.expect_some_arguments::<1>().is_err());
918 assert!(f.expect_arguments::<0, 1>().is_err());
919 assert!(f.expect_arguments::<1, 0>().is_err());
920 assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
921 assert!(f.expect_named_arguments::<0, 1>(&[]).is_err());
922 assert!(f.expect_named_arguments::<1, 0>(&[]).is_err());
923 assert_eq!(
924 f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
925 ([&value(0)], [])
926 );
927 assert_eq!(
928 f.expect_named_arguments::<1, 1>(&["a", "b"]).unwrap(),
929 ([&value(0)], [None])
930 );
931 assert!(f.expect_named_arguments::<1, 1>(&["b", "a"]).is_err());
932
933 let f = function("foo", [value(0)], [keyword("a", 1), keyword("b", 2)]);
934 assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
935 assert!(f.expect_named_arguments::<1, 1>(&["a", "b"]).is_err());
936 assert_eq!(
937 f.expect_named_arguments::<1, 2>(&["c", "a", "b"]).unwrap(),
938 ([&value(0)], [Some(&value(1)), Some(&value(2))])
939 );
940 assert_eq!(
941 f.expect_named_arguments::<2, 1>(&["c", "b", "a"]).unwrap(),
942 ([&value(0), &value(2)], [Some(&value(1))])
943 );
944 assert_eq!(
945 f.expect_named_arguments::<0, 3>(&["c", "b", "a"]).unwrap(),
946 ([], [Some(&value(0)), Some(&value(2)), Some(&value(1))])
947 );
948
949 let f = function("foo", [], [keyword("a", 0), keyword("a", 1)]);
950 assert!(f.expect_named_arguments::<1, 1>(&["", "a"]).is_err());
951 }
952}