Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Luau: Added support for parsing user-defined type functions ([#938](https://github.com/JohnnyMorganz/StyLua/issues/938))

### Fixed

- Luau: fixed parentheses incorrectly removed in `(expr :: assertion) < foo` when multilining the expression, leading to a syntax error ([#940](https://github.com/JohnnyMorganz/StyLua/issues/940))
Expand Down
14 changes: 14 additions & 0 deletions src/formatters/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,20 @@ fn stmt_remove_leading_newlines(stmt: Stmt) -> Stmt {
type_declaration.type_token(),
with_type_token
),
#[cfg(feature = "luau")]
Stmt::ExportedTypeFunction(exported_type_function) => update_first_token!(
ExportedTypeFunction,
exported_type_function,
exported_type_function.export_token(),
with_export_token
),
#[cfg(feature = "luau")]
Stmt::TypeFunction(type_function) => update_first_token!(
TypeFunction,
type_function,
type_function.type_token(),
with_type_token
),
#[cfg(feature = "lua52")]
Stmt::Goto(goto) => update_first_token!(Goto, goto, goto.goto_token(), with_goto_token),
#[cfg(feature = "lua52")]
Expand Down
105 changes: 100 additions & 5 deletions src/formatters/luau.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use crate::{
context::{create_indent_trivia, create_newline_trivia, Context},
context::{
create_function_definition_trivia, create_indent_trivia, create_newline_trivia, Context,
},
fmt_op, fmt_symbol,
formatters::{
assignment::hang_equal_token,
expression::{format_expression, format_var},
functions::format_function_body,
general::{
format_contained_punctuated_multiline, format_contained_span, format_punctuated,
format_symbol, format_token_reference,
Expand All @@ -23,10 +26,10 @@ use crate::{
};
use full_moon::ast::{
luau::{
CompoundAssignment, CompoundOp, ExportedTypeDeclaration, GenericDeclaration,
GenericDeclarationParameter, GenericParameterInfo, IndexedTypeInfo, TypeArgument,
TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeInfo, TypeIntersection,
TypeSpecifier, TypeUnion,
CompoundAssignment, CompoundOp, ExportedTypeDeclaration, ExportedTypeFunction,
GenericDeclaration, GenericDeclarationParameter, GenericParameterInfo, IndexedTypeInfo,
TypeArgument, TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeFunction,
TypeInfo, TypeIntersection, TypeSpecifier, TypeUnion,
},
punctuated::Pair,
};
Expand Down Expand Up @@ -1351,6 +1354,63 @@ pub fn format_type_declaration_stmt(
format_type_declaration(ctx, type_declaration, true, shape)
}

fn format_type_function(
ctx: &Context,
type_function: &TypeFunction,
add_leading_trivia: bool,
shape: Shape,
) -> TypeFunction {
const TYPE_TOKEN_LENGTH: usize = "type ".len();
const FUNCTION_TOKEN_LENGTH: usize = "function ".len();

// Calculate trivia
let trailing_trivia = vec![create_newline_trivia(ctx)];
let function_definition_trivia = vec![create_function_definition_trivia(ctx)];

let mut type_token = format_symbol(
ctx,
type_function.type_token(),
&TokenReference::new(
vec![],
Token::new(TokenType::Identifier {
identifier: "type".into(),
}),
vec![Token::new(TokenType::spaces(1))],
),
shape,
);

if add_leading_trivia {
let leading_trivia = vec![create_indent_trivia(ctx, shape)];
type_token = type_token.update_leading_trivia(FormatTriviaType::Append(leading_trivia))
}

let function_token = fmt_symbol!(ctx, type_function.function_token(), "function ", shape);
let function_name = format_token_reference(ctx, type_function.function_name(), shape)
.update_trailing_trivia(FormatTriviaType::Append(function_definition_trivia));

let shape = shape
+ (TYPE_TOKEN_LENGTH
+ FUNCTION_TOKEN_LENGTH
+ strip_trivia(&function_name).to_string().len());
let function_body = format_function_body(ctx, type_function.function_body(), shape)
.update_trailing_trivia(FormatTriviaType::Append(trailing_trivia));

TypeFunction::new(function_name, function_body)
.with_type_token(type_token)
.with_function_token(function_token)
}

/// Wrapper around `format_type_function` for statements
/// This is required as `format_type_function` is also used for ExportedTypeFunction, and we don't want leading trivia there
pub fn format_type_function_stmt(
ctx: &Context,
type_function: &TypeFunction,
shape: Shape,
) -> TypeFunction {
format_type_function(ctx, type_function, true, shape)
}

fn format_generic_parameter(
ctx: &Context,
generic_parameter: &GenericDeclarationParameter,
Expand Down Expand Up @@ -1478,3 +1538,38 @@ pub fn format_exported_type_declaration(
.with_export_token(export_token)
.with_type_declaration(type_declaration)
}

pub fn format_exported_type_function(
ctx: &Context,
exported_type_function: &ExportedTypeFunction,
shape: Shape,
) -> ExportedTypeFunction {
// Calculate trivia
let shape = shape.reset();
let leading_trivia = vec![create_indent_trivia(ctx, shape)];

let export_token = format_symbol(
ctx,
exported_type_function.export_token(),
&TokenReference::new(
vec![],
Token::new(TokenType::Identifier {
identifier: "export".into(),
}),
vec![Token::new(TokenType::spaces(1))],
),
shape,
)
.update_leading_trivia(FormatTriviaType::Append(leading_trivia));
let type_function = format_type_function(
ctx,
exported_type_function.type_function(),
false,
shape + 7, // 7 = "export "
);

exported_type_function
.to_owned()
.with_export_token(export_token)
.with_type_function(type_function)
}
36 changes: 34 additions & 2 deletions src/formatters/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
use crate::formatters::lua52::{format_goto, format_goto_no_trivia, format_label};
#[cfg(feature = "luau")]
use crate::formatters::luau::{
format_compound_assignment, format_exported_type_declaration, format_type_declaration_stmt,
format_type_specifier,
format_compound_assignment, format_exported_type_declaration, format_exported_type_function,
format_type_declaration_stmt, format_type_function_stmt, format_type_specifier,
};
use crate::{
context::{create_indent_trivia, create_newline_trivia, Context, FormatNode},
Expand Down Expand Up @@ -793,6 +793,8 @@ pub fn format_function_call_stmt(
/// These are used for range formatting
pub(crate) mod stmt_block {
use crate::{context::Context, formatters::block::format_block, shape::Shape};
#[cfg(feature = "luau")]
use full_moon::ast::luau::TypeFunction;
use full_moon::ast::{
Call, Expression, Field, FunctionArgs, FunctionCall, Index, Prefix, Stmt, Suffix,
TableConstructor,
Expand Down Expand Up @@ -907,6 +909,17 @@ pub(crate) mod stmt_block {
.with_suffixes(suffixes)
}

#[cfg(feature = "luau")]
fn format_type_function_block(
ctx: &Context,
type_function: &TypeFunction,
shape: Shape,
) -> TypeFunction {
let block = format_block(ctx, type_function.function_body().block(), shape);
let body = type_function.function_body().to_owned().with_block(block);
type_function.to_owned().with_function_body(body)
}

/// Only formats a block within an expression
pub fn format_expression_block(
ctx: &Context,
Expand Down Expand Up @@ -1057,6 +1070,23 @@ pub(crate) mod stmt_block {
Stmt::ExportedTypeDeclaration(node) => Stmt::ExportedTypeDeclaration(node.to_owned()),
#[cfg(feature = "luau")]
Stmt::TypeDeclaration(node) => Stmt::TypeDeclaration(node.to_owned()),
#[cfg(feature = "luau")]
Stmt::ExportedTypeFunction(exported_type_function) => {
let type_function = format_type_function_block(
ctx,
exported_type_function.type_function(),
block_shape,
);
Stmt::ExportedTypeFunction(
exported_type_function
.to_owned()
.with_type_function(type_function),
)
}
#[cfg(feature = "luau")]
Stmt::TypeFunction(type_function) => {
Stmt::TypeFunction(format_type_function_block(ctx, type_function, block_shape))
}
#[cfg(feature = "lua52")]
Stmt::Goto(node) => Stmt::Goto(node.to_owned()),
#[cfg(feature = "lua52")]
Expand Down Expand Up @@ -1090,6 +1120,8 @@ pub fn format_stmt(ctx: &Context, stmt: &Stmt, shape: Shape) -> Stmt {
#[cfg(feature = "luau")] CompoundAssignment = format_compound_assignment,
#[cfg(feature = "luau")] ExportedTypeDeclaration = format_exported_type_declaration,
#[cfg(feature = "luau")] TypeDeclaration = format_type_declaration_stmt,
#[cfg(feature = "luau")] ExportedTypeFunction = format_exported_type_function,
#[cfg(feature = "luau")] TypeFunction = format_type_function_stmt,
#[cfg(feature = "lua52")] Goto = format_goto,
#[cfg(feature = "lua52")] Label = format_label,
})
Expand Down
23 changes: 21 additions & 2 deletions src/formatters/trivia.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use full_moon::ast::lua54::Attribute;
use full_moon::ast::luau::{
ElseIfExpression, GenericDeclaration, GenericDeclarationParameter, GenericParameterInfo,
IfExpression, IndexedTypeInfo, InterpolatedString, InterpolatedStringSegment, TypeArgument,
TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeInfo, TypeIntersection,
TypeSpecifier, TypeUnion,
TypeAssertion, TypeDeclaration, TypeField, TypeFieldKey, TypeFunction, TypeInfo,
TypeIntersection, TypeSpecifier, TypeUnion,
};
use full_moon::ast::{
punctuated::Punctuated, span::ContainedSpan, BinOp, Call, Expression, FunctionArgs,
Expand Down Expand Up @@ -680,6 +680,18 @@ define_update_trivia!(Stmt, |this, leading, trailing| {
}
#[cfg(feature = "luau")]
Stmt::TypeDeclaration(stmt) => Stmt::TypeDeclaration(stmt.update_trivia(leading, trailing)),
#[cfg(feature = "luau")]
Stmt::ExportedTypeFunction(stmt) => {
let export_token = stmt.export_token().update_leading_trivia(leading);
let type_function = stmt.type_function().update_trailing_trivia(trailing);
Stmt::ExportedTypeFunction(
stmt.to_owned()
.with_export_token(export_token)
.with_type_function(type_function),
)
}
#[cfg(feature = "luau")]
Stmt::TypeFunction(stmt) => Stmt::TypeFunction(stmt.update_trivia(leading, trailing)),
#[cfg(feature = "lua52")]
Stmt::Goto(stmt) => Stmt::Goto(
stmt.to_owned()
Expand Down Expand Up @@ -920,6 +932,13 @@ define_update_trivia!(TypeDeclaration, |this, leading, trailing| {
.with_type_definition(this.type_definition().update_trailing_trivia(trailing))
});

#[cfg(feature = "luau")]
define_update_trivia!(TypeFunction, |this, leading, trailing| {
this.to_owned()
.with_type_token(this.type_token().update_leading_trivia(leading))
.with_function_body(this.function_body().update_trailing_trivia(trailing))
});

#[cfg(feature = "luau")]
define_update_trailing_trivia!(IndexedTypeInfo, |this, trailing| {
match this {
Expand Down
46 changes: 32 additions & 14 deletions src/formatters/trivia_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ use crate::{
#[cfg(feature = "luau")]
use full_moon::ast::luau::{
GenericDeclarationParameter, GenericParameterInfo, IndexedTypeInfo, TypeArgument,
TypeDeclaration, TypeInfo, TypeIntersection, TypeSpecifier, TypeUnion,
TypeDeclaration, TypeFunction, TypeInfo, TypeIntersection, TypeSpecifier, TypeUnion,
};
use full_moon::{
ast::{
punctuated::{Pair, Punctuated},
BinOp, Block, Call, Expression, Field, FunctionArgs, Index, LastStmt, LocalAssignment,
Parameter, Prefix, Stmt, Suffix, TableConstructor, UnOp, Var, VarExpression,
BinOp, Block, Call, Expression, Field, FunctionArgs, FunctionBody, Index, LastStmt,
LocalAssignment, Parameter, Prefix, Stmt, Suffix, TableConstructor, UnOp, Var,
VarExpression,
},
node::Node,
tokenizer::{Token, TokenKind, TokenReference, TokenType},
Expand Down Expand Up @@ -263,6 +264,12 @@ impl GetTrailingTrivia for FunctionArgs {
}
}

impl GetTrailingTrivia for FunctionBody {
fn trailing_trivia(&self) -> Vec<Token> {
GetTrailingTrivia::trailing_trivia(self.end_token())
}
}

impl GetTrailingTrivia for Prefix {
fn trailing_trivia(&self) -> Vec<Token> {
match self {
Expand Down Expand Up @@ -522,7 +529,6 @@ pub fn take_leading_comments<T: GetLeadingTrivia + UpdateLeadingTrivia>(
)
}

#[cfg(feature = "luau")]
pub fn take_trailing_trivia<T: GetTrailingTrivia + UpdateTrailingTrivia>(
node: &T,
) -> (T, Vec<Token>) {
Expand Down Expand Up @@ -714,6 +720,13 @@ impl GetTrailingTrivia for TypeDeclaration {
}
}

#[cfg(feature = "luau")]
impl GetTrailingTrivia for TypeFunction {
fn trailing_trivia(&self) -> Vec<Token> {
self.function_body().trailing_trivia()
}
}

#[cfg(feature = "luau")]
impl GetTrailingTrivia for TypeSpecifier {
fn trailing_trivia(&self) -> Vec<Token> {
Expand Down Expand Up @@ -873,22 +886,14 @@ pub fn get_stmt_trailing_trivia(stmt: Stmt) -> (Stmt, Vec<Token>) {
end_stmt_trailing_trivia!(If, stmt)
}
Stmt::FunctionDeclaration(stmt) => {
let end_token = stmt.body().end_token();
let trailing_trivia = end_token.trailing_trivia().map(|x| x.to_owned()).collect();
let new_end_token = end_token.update_trailing_trivia(FormatTriviaType::Replace(vec![]));

let body = stmt.body().to_owned().with_end_token(new_end_token);
let (body, trailing_trivia) = take_trailing_trivia(stmt.body());
(
Stmt::FunctionDeclaration(stmt.with_body(body)),
trailing_trivia,
)
}
Stmt::LocalFunction(stmt) => {
let end_token = stmt.body().end_token();
let trailing_trivia = end_token.trailing_trivia().map(|x| x.to_owned()).collect();
let new_end_token = end_token.update_trailing_trivia(FormatTriviaType::Replace(vec![]));

let body = stmt.body().to_owned().with_end_token(new_end_token);
let (body, trailing_trivia) = take_trailing_trivia(stmt.body());
(Stmt::LocalFunction(stmt.with_body(body)), trailing_trivia)
}
Stmt::NumericFor(stmt) => {
Expand Down Expand Up @@ -922,6 +927,19 @@ pub fn get_stmt_trailing_trivia(stmt: Stmt) -> (Stmt, Vec<Token>) {
let (type_declaration, trailing_trivia) = take_trailing_trivia(&stmt);
(Stmt::TypeDeclaration(type_declaration), trailing_trivia)
}
#[cfg(feature = "luau")]
Stmt::ExportedTypeFunction(stmt) => {
let (type_function, trailing_trivia) = take_trailing_trivia(stmt.type_function());
(
Stmt::ExportedTypeFunction(stmt.with_type_function(type_function)),
trailing_trivia,
)
}
#[cfg(feature = "luau")]
Stmt::TypeFunction(stmt) => {
let (type_declaration, trailing_trivia) = take_trailing_trivia(&stmt);
(Stmt::TypeFunction(type_declaration), trailing_trivia)
}
#[cfg(feature = "lua52")]
Stmt::Goto(stmt) => {
let trailing_trivia = stmt
Expand Down
5 changes: 5 additions & 0 deletions tests/inputs-luau/type-functions-1.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
type function Foo(x)
end

export type function Foo(x)
end
9 changes: 9 additions & 0 deletions tests/snapshots/tests__luau@type-functions-1.lua.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: tests/tests.rs
expression: "format(&contents, LuaVersion::Luau)"
input_file: tests/inputs-luau/type-functions-1.lua
snapshot_kind: text
---
type function Foo(x) end

export type function Foo(x) end