this repo has no description

Fix extracting tail-position statements

authored by gearsco.de and committed by

Louis Pilfold 9ef66cf0 dc632c22

+230 -67
+181 -65
compiler-core/src/language_server/code_action.rs
··· 8450 8450 edits: TextEdits<'a>, 8451 8451 function: Option<ExtractedFunction<'a>>, 8452 8452 function_end_position: Option<u32>, 8453 + /// Since the `visit_typed_statement` visitor function doesn't tell us when 8454 + /// a statement is the last in a block or function, we need to track that 8455 + /// manually. 8456 + last_statement_location: Option<SrcSpan>, 8453 8457 } 8454 8458 8455 8459 /// Information about a section of code we are extracting as a function. ··· 8479 8483 fn location(&self) -> SrcSpan { 8480 8484 match &self.value { 8481 8485 ExtractedValue::Expression(expression) => expression.location(), 8482 - ExtractedValue::Statements(location) => *location, 8486 + ExtractedValue::Statements { location, .. } => *location, 8483 8487 } 8484 8488 } 8485 8489 } ··· 8487 8491 #[derive(Debug)] 8488 8492 enum ExtractedValue<'a> { 8489 8493 Expression(&'a TypedExpr), 8490 - Statements(SrcSpan), 8494 + Statements { 8495 + location: SrcSpan, 8496 + position: StatementPosition, 8497 + }, 8498 + } 8499 + 8500 + /// When we are extracting multiple statements, there are two possible cases: 8501 + /// The first is if we are extracting statements in the middle of a function. 8502 + /// In this case, we will need to return some number of arguments, or `Nil`. 8503 + /// For example: 8504 + /// 8505 + /// ```gleam 8506 + /// pub fn main() { 8507 + /// let message = "Hello!" 8508 + /// let log_message = "[INFO] " <> message 8509 + /// //^ Select from here 8510 + /// io.println(log_message) 8511 + /// // ^ Until here 8512 + /// 8513 + /// do_some_more_things() 8514 + /// } 8515 + /// ``` 8516 + /// 8517 + /// Here, the extracted function doesn't bind any variables which we need 8518 + /// afterwards, it purely performs side effects. In this case we can just return 8519 + /// `Nil` from the new function. 8520 + /// 8521 + /// However, consider the following: 8522 + /// 8523 + /// ```gleam 8524 + /// pub fn main() { 8525 + /// let a = 1 8526 + /// let b = 2 8527 + /// //^ Select from here 8528 + /// a + b 8529 + /// // ^ Until here 8530 + /// } 8531 + /// ``` 8532 + /// 8533 + /// Here, despite us not needing any variables from the extracted code, there 8534 + /// is one key difference: the `a + b` expression is at the end of the function, 8535 + /// and so its value is returned from the entire function. This is known as the 8536 + /// "tail" position. In that case, we can't return `Nil` as that would make the 8537 + /// `main` function return `Nil` instead of the result of the addition. If we 8538 + /// extract the tail-position statement, we need to return that last value rather 8539 + /// than `Nil`. 8540 + /// 8541 + #[derive(Debug)] 8542 + enum StatementPosition { 8543 + Tail { type_: Arc<Type> }, 8544 + NotTail, 8491 8545 } 8492 8546 8493 8547 impl<'a> ExtractFunction<'a> { ··· 8502 8556 edits: TextEdits::new(line_numbers), 8503 8557 function: None, 8504 8558 function_end_position: None, 8559 + last_statement_location: None, 8505 8560 } 8506 8561 } 8507 8562 ··· 8524 8579 }; 8525 8580 8526 8581 match extracted.value { 8527 - ExtractedValue::Expression(expression) => { 8528 - self.extract_expression(expression, extracted.parameters, end) 8582 + // If we extract a block, it isn't very helpful to have the body of the 8583 + // extracted function just be a single block expression, so instead we 8584 + // extract the statements inside the block. For example, the following 8585 + // code: 8586 + // 8587 + // ```gleam 8588 + // pub fn main() { 8589 + // let x = { 8590 + // // ^ Select from here 8591 + // let a = 1 8592 + // let b = 2 8593 + // a + b 8594 + // } 8595 + // //^ Until here 8596 + // x 8597 + // } 8598 + // ``` 8599 + // 8600 + // Would produce the following extracted function: 8601 + // 8602 + // ```gleam 8603 + // fn function() { 8604 + // let a = 1 8605 + // let b = 2 8606 + // a + b 8607 + // } 8608 + // ``` 8609 + // 8610 + // Rather than: 8611 + // 8612 + // ```gleam 8613 + // fn function() { 8614 + // { 8615 + // let a = 1 8616 + // let b = 2 8617 + // a + b 8618 + // } 8619 + // } 8620 + // ``` 8621 + // 8622 + ExtractedValue::Expression(TypedExpr::Block { 8623 + statements, 8624 + location: full_location, 8625 + }) => { 8626 + let location = SrcSpan::new( 8627 + statements.first().location().start, 8628 + statements.last().location().end, 8629 + ); 8630 + self.extract_code_in_tail_position( 8631 + *full_location, 8632 + location, 8633 + statements.last().type_(), 8634 + extracted.parameters, 8635 + end, 8636 + ) 8529 8637 } 8530 - ExtractedValue::Statements(location) => self.extract_statements( 8638 + ExtractedValue::Expression(expression) => self.extract_code_in_tail_position( 8639 + expression.location(), 8640 + expression.location(), 8641 + expression.type_(), 8642 + extracted.parameters, 8643 + end, 8644 + ), 8645 + ExtractedValue::Statements { 8646 + location, 8647 + position: StatementPosition::NotTail, 8648 + } => self.extract_statements( 8531 8649 location, 8532 8650 extracted.parameters, 8533 8651 extracted.returned_variables, 8652 + end, 8653 + ), 8654 + ExtractedValue::Statements { 8655 + location, 8656 + position: StatementPosition::Tail { type_ }, 8657 + } => self.extract_code_in_tail_position( 8658 + location, 8659 + location, 8660 + type_, 8661 + extracted.parameters, 8534 8662 end, 8535 8663 ), 8536 8664 } ··· 8561 8689 } 8562 8690 } 8563 8691 8564 - fn extract_expression( 8692 + /// Extracts code from the end of a function or block. This could either be 8693 + /// a single expression, or multiple statements followed by a final expression. 8694 + fn extract_code_in_tail_position( 8565 8695 &mut self, 8566 - expression: &TypedExpr, 8696 + location: SrcSpan, 8697 + code_location: SrcSpan, 8698 + type_: Arc<Type>, 8567 8699 parameters: Vec<(EcoString, Arc<Type>)>, 8568 8700 function_end: u32, 8569 8701 ) { 8570 - // If we extract a block, it isn't very helpful to have the body of the 8571 - // extracted function just be a single block expression, so instead we 8572 - // extract the statements inside the block. For example, the following 8573 - // code: 8574 - // 8575 - // ```gleam 8576 - // pub fn main() { 8577 - // let x = { 8578 - // // ^ Select from here 8579 - // let a = 1 8580 - // let b = 2 8581 - // a + b 8582 - // } 8583 - // //^ Until here 8584 - // x 8585 - // } 8586 - // ``` 8587 - // 8588 - // Would produce the following extracted function: 8589 - // 8590 - // ```gleam 8591 - // fn function() { 8592 - // let a = 1 8593 - // let b = 2 8594 - // a + b 8595 - // } 8596 - // ``` 8597 - // 8598 - // Rather than: 8599 - // 8600 - // ```gleam 8601 - // fn function() { 8602 - // { 8603 - // let a = 1 8604 - // let b = 2 8605 - // a + b 8606 - // } 8607 - // } 8608 - // ``` 8609 - // 8610 - let extracted_code_location = if let TypedExpr::Block { statements, .. } = expression { 8611 - SrcSpan::new( 8612 - statements.first().location().start, 8613 - statements.last().location().end, 8614 - ) 8615 - } else { 8616 - expression.location() 8617 - }; 8618 - 8619 - let expression_code = code_at(self.module, extracted_code_location); 8702 + let expression_code = code_at(self.module, code_location); 8620 8703 8621 8704 let name = self.function_name(); 8622 8705 let arguments = parameters.iter().map(|(name, _)| name).join(", "); ··· 8626 8709 // it with the call and preserve all other semantics; only one value can 8627 8710 // be returned from the expression, unlike when extracting multiple 8628 8711 // statements. 8629 - self.edits.replace(expression.location(), call); 8712 + self.edits.replace(location, call); 8630 8713 8631 8714 let mut printer = Printer::new(&self.module.ast.names); 8632 8715 ··· 8634 8717 .iter() 8635 8718 .map(|(name, type_)| eco_format!("{name}: {}", printer.print_type(type_))) 8636 8719 .join(", "); 8637 - let return_type = printer.print_type(&expression.type_()); 8720 + let return_type = printer.print_type(&type_); 8638 8721 8639 8722 let function = format!( 8640 8723 "\n\nfn {name}({parameters}) -> {return_type} {{ ··· 8861 8944 8862 8945 if within(self.params.range, range) { 8863 8946 self.function_end_position = Some(function.end_position); 8947 + self.last_statement_location = function.body.last().map(|last| last.location()); 8864 8948 8865 8949 ast::visit::visit_typed_function(self, function); 8866 8950 } 8867 8951 } 8868 8952 8953 + fn visit_typed_expr_block( 8954 + &mut self, 8955 + location: &'ast SrcSpan, 8956 + statements: &'ast [TypedStatement], 8957 + ) { 8958 + let last_statement_location = self.last_statement_location; 8959 + self.last_statement_location = statements.last().map(|last| last.location()); 8960 + 8961 + ast::visit::visit_typed_expr_block(self, location, statements); 8962 + 8963 + self.last_statement_location = last_statement_location; 8964 + } 8965 + 8869 8966 fn visit_typed_expr(&mut self, expression: &'ast TypedExpr) { 8870 8967 // If we have already determined what code we want to extract, we don't 8871 8968 // want to extract this instead. This expression would be inside the ··· 8884 8981 } 8885 8982 8886 8983 fn visit_typed_statement(&mut self, statement: &'ast TypedStatement) { 8887 - if self.can_extract(statement.location()) { 8984 + let location = statement.location(); 8985 + if self.can_extract(location) { 8986 + let position = if let Some(last_statement_location) = self.last_statement_location 8987 + && location == last_statement_location 8988 + { 8989 + StatementPosition::Tail { 8990 + type_: statement.type_(), 8991 + } 8992 + } else { 8993 + StatementPosition::NotTail 8994 + }; 8995 + 8888 8996 match &mut self.function { 8889 8997 None => { 8890 - self.function = Some(ExtractedFunction::new(ExtractedValue::Statements( 8891 - statement.location(), 8892 - ))); 8998 + self.function = Some(ExtractedFunction::new(ExtractedValue::Statements { 8999 + location, 9000 + position, 9001 + })); 8893 9002 } 8894 9003 // If we have already chosen an expression to extract, that means 8895 9004 // that this statement is within the already extracted expression, ··· 8902 9011 // be included within list, so we merge th spans to ensure it is 8903 9012 // included. 8904 9013 Some(ExtractedFunction { 8905 - value: ExtractedValue::Statements(location), 9014 + value: 9015 + ExtractedValue::Statements { 9016 + location, 9017 + position: extracted_position, 9018 + }, 8906 9019 .. 8907 - }) => *location = location.merge(&statement.location()), 9020 + }) => { 9021 + *location = location.merge(&statement.location()); 9022 + *extracted_position = position; 9023 + } 8908 9024 } 8909 9025 } 8910 9026 ast::visit::visit_typed_statement(self, statement);
+17
compiler-core/src/language_server/tests/action.rs
··· 10537 10537 find_position_of("let a").select_until(find_position_of("let b")) 10538 10538 ); 10539 10539 } 10540 + 10541 + #[test] 10542 + fn extract_statements_in_tail_position() { 10543 + assert_code_action!( 10544 + EXTRACT_FUNCTION, 10545 + r#" 10546 + pub fn main() { 10547 + let a = 1 10548 + let b = 2 10549 + let c = 3 10550 + let d = 4 10551 + a * b + c * d 10552 + } 10553 + "#, 10554 + find_position_of("let c").select_until(find_position_of("* d")) 10555 + ); 10556 + }
+31
compiler-core/src/language_server/tests/snapshots/gleam_core__language_server__tests__action__extract_statements_in_tail_position.snap
··· 1 + --- 2 + source: compiler-core/src/language_server/tests/action.rs 3 + expression: "\npub fn main() {\n let a = 1\n let b = 2\n let c = 3\n let d = 4\n a * b + c * d\n}\n" 4 + --- 5 + ----- BEFORE ACTION 6 + 7 + pub fn main() { 8 + let a = 1 9 + let b = 2 10 + let c = 3 11 + ▔▔▔▔▔▔▔▔▔ 12 + let d = 4 13 + ▔▔▔▔▔▔▔▔▔▔▔ 14 + a * b + c * d 15 + ▔▔▔▔▔▔▔▔▔▔▔▔↑ 16 + } 17 + 18 + 19 + ----- AFTER ACTION 20 + 21 + pub fn main() { 22 + let a = 1 23 + let b = 2 24 + function(a, b) 25 + } 26 + 27 + fn function(a: Int, b: Int) -> Int { 28 + let c = 3 29 + let d = 4 30 + a * b + c * d 31 + }
+1 -2
compiler-core/src/language_server/tests/snapshots/gleam_core__language_server__tests__action__selected_statements_do_not_select_outer_block.snap
··· 28 28 echo c 29 29 } 30 30 31 - fn function() -> Nil { 31 + fn function() -> Int { 32 32 let a = 10 33 33 let b = 20 34 34 a + b 35 - Nil 36 35 }