···84508450 edits: TextEdits<'a>,
84518451 function: Option<ExtractedFunction<'a>>,
84528452 function_end_position: Option<u32>,
84538453+ /// Since the `visit_typed_statement` visitor function doesn't tell us when
84548454+ /// a statement is the last in a block or function, we need to track that
84558455+ /// manually.
84568456+ last_statement_location: Option<SrcSpan>,
84538457}
8454845884558459/// Information about a section of code we are extracting as a function.
···84798483 fn location(&self) -> SrcSpan {
84808484 match &self.value {
84818485 ExtractedValue::Expression(expression) => expression.location(),
84828482- ExtractedValue::Statements(location) => *location,
84868486+ ExtractedValue::Statements { location, .. } => *location,
84838487 }
84848488 }
84858489}
···84878491#[derive(Debug)]
84888492enum ExtractedValue<'a> {
84898493 Expression(&'a TypedExpr),
84908490- Statements(SrcSpan),
84948494+ Statements {
84958495+ location: SrcSpan,
84968496+ position: StatementPosition,
84978497+ },
84988498+}
84998499+85008500+/// When we are extracting multiple statements, there are two possible cases:
85018501+/// The first is if we are extracting statements in the middle of a function.
85028502+/// In this case, we will need to return some number of arguments, or `Nil`.
85038503+/// For example:
85048504+///
85058505+/// ```gleam
85068506+/// pub fn main() {
85078507+/// let message = "Hello!"
85088508+/// let log_message = "[INFO] " <> message
85098509+/// //^ Select from here
85108510+/// io.println(log_message)
85118511+/// // ^ Until here
85128512+///
85138513+/// do_some_more_things()
85148514+/// }
85158515+/// ```
85168516+///
85178517+/// Here, the extracted function doesn't bind any variables which we need
85188518+/// afterwards, it purely performs side effects. In this case we can just return
85198519+/// `Nil` from the new function.
85208520+///
85218521+/// However, consider the following:
85228522+///
85238523+/// ```gleam
85248524+/// pub fn main() {
85258525+/// let a = 1
85268526+/// let b = 2
85278527+/// //^ Select from here
85288528+/// a + b
85298529+/// // ^ Until here
85308530+/// }
85318531+/// ```
85328532+///
85338533+/// Here, despite us not needing any variables from the extracted code, there
85348534+/// is one key difference: the `a + b` expression is at the end of the function,
85358535+/// and so its value is returned from the entire function. This is known as the
85368536+/// "tail" position. In that case, we can't return `Nil` as that would make the
85378537+/// `main` function return `Nil` instead of the result of the addition. If we
85388538+/// extract the tail-position statement, we need to return that last value rather
85398539+/// than `Nil`.
85408540+///
85418541+#[derive(Debug)]
85428542+enum StatementPosition {
85438543+ Tail { type_: Arc<Type> },
85448544+ NotTail,
84918545}
8492854684938547impl<'a> ExtractFunction<'a> {
···85028556 edits: TextEdits::new(line_numbers),
85038557 function: None,
85048558 function_end_position: None,
85598559+ last_statement_location: None,
85058560 }
85068561 }
85078562···85248579 };
8525858085268581 match extracted.value {
85278527- ExtractedValue::Expression(expression) => {
85288528- self.extract_expression(expression, extracted.parameters, end)
85828582+ // If we extract a block, it isn't very helpful to have the body of the
85838583+ // extracted function just be a single block expression, so instead we
85848584+ // extract the statements inside the block. For example, the following
85858585+ // code:
85868586+ //
85878587+ // ```gleam
85888588+ // pub fn main() {
85898589+ // let x = {
85908590+ // // ^ Select from here
85918591+ // let a = 1
85928592+ // let b = 2
85938593+ // a + b
85948594+ // }
85958595+ // //^ Until here
85968596+ // x
85978597+ // }
85988598+ // ```
85998599+ //
86008600+ // Would produce the following extracted function:
86018601+ //
86028602+ // ```gleam
86038603+ // fn function() {
86048604+ // let a = 1
86058605+ // let b = 2
86068606+ // a + b
86078607+ // }
86088608+ // ```
86098609+ //
86108610+ // Rather than:
86118611+ //
86128612+ // ```gleam
86138613+ // fn function() {
86148614+ // {
86158615+ // let a = 1
86168616+ // let b = 2
86178617+ // a + b
86188618+ // }
86198619+ // }
86208620+ // ```
86218621+ //
86228622+ ExtractedValue::Expression(TypedExpr::Block {
86238623+ statements,
86248624+ location: full_location,
86258625+ }) => {
86268626+ let location = SrcSpan::new(
86278627+ statements.first().location().start,
86288628+ statements.last().location().end,
86298629+ );
86308630+ self.extract_code_in_tail_position(
86318631+ *full_location,
86328632+ location,
86338633+ statements.last().type_(),
86348634+ extracted.parameters,
86358635+ end,
86368636+ )
85298637 }
85308530- ExtractedValue::Statements(location) => self.extract_statements(
86388638+ ExtractedValue::Expression(expression) => self.extract_code_in_tail_position(
86398639+ expression.location(),
86408640+ expression.location(),
86418641+ expression.type_(),
86428642+ extracted.parameters,
86438643+ end,
86448644+ ),
86458645+ ExtractedValue::Statements {
86468646+ location,
86478647+ position: StatementPosition::NotTail,
86488648+ } => self.extract_statements(
85318649 location,
85328650 extracted.parameters,
85338651 extracted.returned_variables,
86528652+ end,
86538653+ ),
86548654+ ExtractedValue::Statements {
86558655+ location,
86568656+ position: StatementPosition::Tail { type_ },
86578657+ } => self.extract_code_in_tail_position(
86588658+ location,
86598659+ location,
86608660+ type_,
86618661+ extracted.parameters,
85348662 end,
85358663 ),
85368664 }
···85618689 }
85628690 }
8563869185648564- fn extract_expression(
86928692+ /// Extracts code from the end of a function or block. This could either be
86938693+ /// a single expression, or multiple statements followed by a final expression.
86948694+ fn extract_code_in_tail_position(
85658695 &mut self,
85668566- expression: &TypedExpr,
86968696+ location: SrcSpan,
86978697+ code_location: SrcSpan,
86988698+ type_: Arc<Type>,
85678699 parameters: Vec<(EcoString, Arc<Type>)>,
85688700 function_end: u32,
85698701 ) {
85708570- // If we extract a block, it isn't very helpful to have the body of the
85718571- // extracted function just be a single block expression, so instead we
85728572- // extract the statements inside the block. For example, the following
85738573- // code:
85748574- //
85758575- // ```gleam
85768576- // pub fn main() {
85778577- // let x = {
85788578- // // ^ Select from here
85798579- // let a = 1
85808580- // let b = 2
85818581- // a + b
85828582- // }
85838583- // //^ Until here
85848584- // x
85858585- // }
85868586- // ```
85878587- //
85888588- // Would produce the following extracted function:
85898589- //
85908590- // ```gleam
85918591- // fn function() {
85928592- // let a = 1
85938593- // let b = 2
85948594- // a + b
85958595- // }
85968596- // ```
85978597- //
85988598- // Rather than:
85998599- //
86008600- // ```gleam
86018601- // fn function() {
86028602- // {
86038603- // let a = 1
86048604- // let b = 2
86058605- // a + b
86068606- // }
86078607- // }
86088608- // ```
86098609- //
86108610- let extracted_code_location = if let TypedExpr::Block { statements, .. } = expression {
86118611- SrcSpan::new(
86128612- statements.first().location().start,
86138613- statements.last().location().end,
86148614- )
86158615- } else {
86168616- expression.location()
86178617- };
86188618-86198619- let expression_code = code_at(self.module, extracted_code_location);
87028702+ let expression_code = code_at(self.module, code_location);
8620870386218704 let name = self.function_name();
86228705 let arguments = parameters.iter().map(|(name, _)| name).join(", ");
···86268709 // it with the call and preserve all other semantics; only one value can
86278710 // be returned from the expression, unlike when extracting multiple
86288711 // statements.
86298629- self.edits.replace(expression.location(), call);
87128712+ self.edits.replace(location, call);
8630871386318714 let mut printer = Printer::new(&self.module.ast.names);
86328715···86348717 .iter()
86358718 .map(|(name, type_)| eco_format!("{name}: {}", printer.print_type(type_)))
86368719 .join(", ");
86378637- let return_type = printer.print_type(&expression.type_());
87208720+ let return_type = printer.print_type(&type_);
8638872186398722 let function = format!(
86408723 "\n\nfn {name}({parameters}) -> {return_type} {{
···8861894488628945 if within(self.params.range, range) {
88638946 self.function_end_position = Some(function.end_position);
89478947+ self.last_statement_location = function.body.last().map(|last| last.location());
8864894888658949 ast::visit::visit_typed_function(self, function);
88668950 }
88678951 }
8868895289538953+ fn visit_typed_expr_block(
89548954+ &mut self,
89558955+ location: &'ast SrcSpan,
89568956+ statements: &'ast [TypedStatement],
89578957+ ) {
89588958+ let last_statement_location = self.last_statement_location;
89598959+ self.last_statement_location = statements.last().map(|last| last.location());
89608960+89618961+ ast::visit::visit_typed_expr_block(self, location, statements);
89628962+89638963+ self.last_statement_location = last_statement_location;
89648964+ }
89658965+88698966 fn visit_typed_expr(&mut self, expression: &'ast TypedExpr) {
88708967 // If we have already determined what code we want to extract, we don't
88718968 // want to extract this instead. This expression would be inside the
···88848981 }
8885898288868983 fn visit_typed_statement(&mut self, statement: &'ast TypedStatement) {
88878887- if self.can_extract(statement.location()) {
89848984+ let location = statement.location();
89858985+ if self.can_extract(location) {
89868986+ let position = if let Some(last_statement_location) = self.last_statement_location
89878987+ && location == last_statement_location
89888988+ {
89898989+ StatementPosition::Tail {
89908990+ type_: statement.type_(),
89918991+ }
89928992+ } else {
89938993+ StatementPosition::NotTail
89948994+ };
89958995+88888996 match &mut self.function {
88898997 None => {
88908890- self.function = Some(ExtractedFunction::new(ExtractedValue::Statements(
88918891- statement.location(),
88928892- )));
89988998+ self.function = Some(ExtractedFunction::new(ExtractedValue::Statements {
89998999+ location,
90009000+ position,
90019001+ }));
88939002 }
88949003 // If we have already chosen an expression to extract, that means
88959004 // that this statement is within the already extracted expression,
···89029011 // be included within list, so we merge th spans to ensure it is
89039012 // included.
89049013 Some(ExtractedFunction {
89058905- value: ExtractedValue::Statements(location),
90149014+ value:
90159015+ ExtractedValue::Statements {
90169016+ location,
90179017+ position: extracted_position,
90189018+ },
89069019 ..
89078907- }) => *location = location.merge(&statement.location()),
90209020+ }) => {
90219021+ *location = location.merge(&statement.location());
90229022+ *extracted_position = position;
90239023+ }
89089024 }
89099025 }
89109026 ast::visit::visit_typed_statement(self, statement);
+17
compiler-core/src/language_server/tests/action.rs
···1053710537 find_position_of("let a").select_until(find_position_of("let b"))
1053810538 );
1053910539}
1054010540+1054110541+#[test]
1054210542+fn extract_statements_in_tail_position() {
1054310543+ assert_code_action!(
1054410544+ EXTRACT_FUNCTION,
1054510545+ r#"
1054610546+pub fn main() {
1054710547+ let a = 1
1054810548+ let b = 2
1054910549+ let c = 3
1055010550+ let d = 4
1055110551+ a * b + c * d
1055210552+}
1055310553+"#,
1055410554+ find_position_of("let c").select_until(find_position_of("* d"))
1055510555+ );
1055610556+}
···11+---
22+source: compiler-core/src/language_server/tests/action.rs
33+expression: "\npub fn main() {\n let a = 1\n let b = 2\n let c = 3\n let d = 4\n a * b + c * d\n}\n"
44+---
55+----- BEFORE ACTION
66+77+pub fn main() {
88+ let a = 1
99+ let b = 2
1010+ let c = 3
1111+ ▔▔▔▔▔▔▔▔▔
1212+ let d = 4
1313+▔▔▔▔▔▔▔▔▔▔▔
1414+ a * b + c * d
1515+▔▔▔▔▔▔▔▔▔▔▔▔↑
1616+}
1717+1818+1919+----- AFTER ACTION
2020+2121+pub fn main() {
2222+ let a = 1
2323+ let b = 2
2424+ function(a, b)
2525+}
2626+2727+fn function(a: Int, b: Int) -> Int {
2828+ let c = 3
2929+ let d = 4
3030+ a * b + c * d
3131+}