diff --git a/src/template_parser.rs b/src/template_parser.rs index e6e3816c0..6810a960e 100644 --- a/src/template_parser.rs +++ b/src/template_parser.rs @@ -13,7 +13,7 @@ // limitations under the License. use std::ops::{RangeFrom, RangeInclusive}; -use std::{error, fmt}; +use std::{error, fmt, iter}; use itertools::Itertools as _; use jujutsu_lib::backend::{Signature, Timestamp}; @@ -265,6 +265,72 @@ impl<'a, C: 'a> Expression<'a, C> { } } +type OptionalArg<'i> = Option>; + +/// Extracts exactly N required arguments. +fn expect_exact_arguments( + pair: Pair, +) -> TemplateParseResult<[Pair; N]> { + let span = pair.as_span(); + let make_error = || TemplateParseError::invalid_argument_count_exact(N, span); + let mut pairs = pair.into_inner(); + let required: [Pair; N] = pairs + .by_ref() + .take(N) + .collect_vec() + .try_into() + .map_err(|_| make_error())?; + if pairs.next().is_none() { + Ok(required) + } else { + Err(make_error()) + } +} + +/// Extracts N required arguments and remainders. +fn expect_some_arguments( + pair: Pair, +) -> TemplateParseResult<([Pair; N], Pairs)> { + let span = pair.as_span(); + let make_error = || TemplateParseError::invalid_argument_count_range_from(N.., span); + let mut pairs = pair.into_inner(); + let required: [Pair; N] = pairs + .by_ref() + .take(N) + .collect_vec() + .try_into() + .map_err(|_| make_error())?; + Ok((required, pairs)) +} + +/// Extracts N required arguments and M optional arguments. +fn expect_arguments( + pair: Pair, +) -> TemplateParseResult<([Pair; N], [OptionalArg; M])> { + let span = pair.as_span(); + let make_error = || TemplateParseError::invalid_argument_count_range(N..=(N + M), span); + let mut pairs = pair.into_inner().fuse(); + let required: [Pair; N] = pairs + .by_ref() + .take(N) + .collect_vec() + .try_into() + .map_err(|_| make_error())?; + let optional: [OptionalArg; M] = pairs + .by_ref() + .map(Some) + .chain(iter::repeat(None)) + .take(M) + .collect_vec() + .try_into() + .unwrap(); + if pairs.next().is_none() { + Ok((required, optional)) + } else { + Err(make_error()) + } +} + fn parse_method_chain<'a, I: 'a>( pair: Pair, input_property: PropertyAndLabels<'a, I>, @@ -487,26 +553,20 @@ fn parse_term<'a, C: 'a>( Ok(Expression::Property(property)) } Rule::function => { - let (name, args_span, mut args) = { + let (name, args_pair) = { let mut inner = expr.into_inner(); let name = inner.next().unwrap(); let args_pair = inner.next().unwrap(); assert_eq!(name.as_rule(), Rule::identifier); assert_eq!(args_pair.as_rule(), Rule::function_arguments); - (name, args_pair.as_span(), args_pair.into_inner()) + (name, args_pair) }; let expression = match name.as_str() { "label" => { - let arg_count_error = - || TemplateParseError::invalid_argument_count_exact(2, args_span); - let label_pair = args.next().ok_or_else(arg_count_error)?; + let [label_pair, content_pair] = expect_exact_arguments(args_pair)?; let label_property = parse_template_rule(label_pair, parse_keyword)?.into_plain_text(); - let arg_template = args.next().ok_or_else(arg_count_error)?; - if args.next().is_some() { - return Err(arg_count_error()); - } - let content = parse_template_rule(arg_template, parse_keyword)?.into_template(); + let content = parse_template_rule(content_pair, parse_keyword)?.into_template(); let labels = TemplateFunction::new(label_property, |s| { s.split_whitespace().map(ToString::to_string).collect() }); @@ -514,29 +574,19 @@ fn parse_term<'a, C: 'a>( Expression::Template(template) } "if" => { - let arg_count_error = - || TemplateParseError::invalid_argument_count_range(2..=3, args_span); - let condition_pair = args.next().ok_or_else(arg_count_error)?; + let ([condition_pair, true_pair], [false_pair]) = expect_arguments(args_pair)?; let condition_span = condition_pair.as_span(); let condition = parse_template_rule(condition_pair, parse_keyword)? .try_into_boolean() .ok_or_else(|| { TemplateParseError::invalid_argument_type("Boolean", condition_span) })?; - - let true_template = args - .next() - .ok_or_else(arg_count_error) - .and_then(|pair| parse_template_rule(pair, parse_keyword))? - .into_template(); - let false_template = args - .next() + let true_template = + parse_template_rule(true_pair, parse_keyword)?.into_template(); + let false_template = false_pair .map(|pair| parse_template_rule(pair, parse_keyword)) .transpose()? .map(|x| x.into_template()); - if args.next().is_some() { - return Err(arg_count_error()); - } let template = Box::new(ConditionalTemplate::new( condition, true_template, @@ -545,12 +595,10 @@ fn parse_term<'a, C: 'a>( Expression::Template(template) } "separate" => { - let arg_count_error = - || TemplateParseError::invalid_argument_count_range_from(1.., args_span); - let separator_pair = args.next().ok_or_else(arg_count_error)?; + let ([separator_pair], content_pairs) = expect_some_arguments(args_pair)?; let separator = parse_template_rule(separator_pair, parse_keyword)?.into_template(); - let contents = args + let contents = content_pairs .map(|pair| { parse_template_rule(pair, parse_keyword).map(|x| x.into_template()) })