"""Parser for Rust source AST construction."""
from .RustAst import (
AssignmentNode,
BinaryOpNode,
ForNode,
WhileNode,
LoopNode,
MatchNode,
MatchArmNode,
FunctionCallNode,
FunctionNode,
IfNode,
MemberAccessNode,
ReturnNode,
BreakNode,
ContinueNode,
ShaderNode,
StructNode,
StructInitializationNode,
EnumNode,
EnumVariantNode,
AssociatedTypeNode,
TypeAliasNode,
ImplNode,
TraitNode,
UnaryOpNode,
VariableNode,
LetNode,
VectorConstructorNode,
TernaryOpNode,
UseNode,
AttributeNode,
ConstNode,
StaticNode,
ArrayAccessNode,
RangeNode,
TupleNode,
ArrayNode,
ReferenceNode,
DereferenceNode,
CastNode,
BlockNode,
)
from .RustLexer import RustLexer, Lexer, TokenType
[docs]
class RustParser:
"""Parse Rust tokens into the Rust backend shader AST."""
def __init__(self, tokens):
"""Initialize the parser with a token stream from ``RustLexer``."""
self.tokens = tokens
self.current_index = 0
self.current_token = tokens[0] if tokens else None
[docs]
def parse(self):
"""Parse the complete Rust token stream into a shader AST."""
structs = []
functions = []
global_variables = []
impl_blocks = []
use_statements = []
traits = []
enums = []
type_aliases = []
while self.current_token[0] != "EOF":
if self.current_token[0] == "USE":
u = self.parse_use_statement()
use_statements.append(u)
elif self.current_token[0] == "TYPE":
type_aliases.append(self.parse_type_alias())
elif self.current_token[0] == "STRUCT":
s = self.parse_struct()
structs.append(s)
elif self.current_token[0] == "ENUM":
e = self.parse_enum()
enums.append(e)
elif self.current_token[0] == "IMPL":
i = self.parse_impl_block()
impl_blocks.append(i)
elif self.current_token[0] == "TRAIT":
t = self.parse_trait()
traits.append(t)
elif self.current_token[0] == "FN":
f = self.parse_function()
functions.append(f)
elif self.current_token[0] == "CONST":
c = self.parse_const()
global_variables.append(c)
elif self.current_token[0] == "STATIC":
s = self.parse_static()
global_variables.append(s)
elif self.current_token[0] == "PUB":
visibility = "pub"
self.eat("PUB")
if self.current_token[0] == "STRUCT":
s = self.parse_struct(visibility=visibility)
structs.append(s)
elif self.current_token[0] == "FN":
f = self.parse_function(visibility=visibility)
functions.append(f)
elif self.current_token[0] == "CONST":
c = self.parse_const(visibility=visibility)
global_variables.append(c)
elif self.current_token[0] == "STATIC":
s = self.parse_static(visibility=visibility)
global_variables.append(s)
elif self.current_token[0] == "TRAIT":
t = self.parse_trait(visibility=visibility)
traits.append(t)
elif self.current_token[0] == "ENUM":
e = self.parse_enum(visibility=visibility)
enums.append(e)
elif self.current_token[0] == "TYPE":
type_aliases.append(self.parse_type_alias(visibility=visibility))
elif self.current_token[0] == "USE":
u = self.parse_use_statement(visibility=visibility)
use_statements.append(u)
else:
self.eat(self.current_token[0])
elif self.current_token[0] == "POUND":
attrs = self.parse_attributes()
# The next item should use these attributes
if self.current_token[0] == "STRUCT":
s = self.parse_struct(attributes=attrs)
structs.append(s)
elif self.current_token[0] == "ENUM":
e = self.parse_enum(attributes=attrs)
enums.append(e)
elif self.current_token[0] == "TYPE":
type_aliases.append(self.parse_type_alias(attributes=attrs))
elif self.current_token[0] == "FN":
f = self.parse_function(attributes=attrs)
functions.append(f)
elif self.current_token[0] == "PUB":
visibility = "pub"
self.eat("PUB")
if self.current_token[0] == "STRUCT":
s = self.parse_struct(attributes=attrs, visibility=visibility)
structs.append(s)
elif self.current_token[0] == "ENUM":
e = self.parse_enum(attributes=attrs, visibility=visibility)
enums.append(e)
elif self.current_token[0] == "TYPE":
type_aliases.append(
self.parse_type_alias(
attributes=attrs, visibility=visibility
)
)
elif self.current_token[0] == "FN":
f = self.parse_function(attributes=attrs, visibility=visibility)
functions.append(f)
else:
self.eat(self.current_token[0])
else:
self.eat(self.current_token[0])
return ShaderNode(
structs,
functions,
global_variables,
impl_blocks,
use_statements,
traits,
enums,
type_aliases,
)
[docs]
def eat(self, expected_type):
"""Consume and return the current token when it matches ``expected_type``."""
if self.current_token[0] == expected_type:
token = self.current_token
self.current_index += 1
if self.current_index < len(self.tokens):
self.current_token = self.tokens[self.current_index]
else:
self.current_token = ("EOF", "")
return token
else:
raise SyntaxError(f"Expected {expected_type}, got {self.current_token[0]}")
def skip_until(self, token_type):
while self.current_token[0] != token_type and self.current_token[0] != "EOF":
self.current_index += 1
if self.current_index < len(self.tokens):
self.current_token = self.tokens[self.current_index]
else:
self.current_token = ("EOF", "")
def parse_use_statement(self, visibility=None):
self.eat("USE")
path = []
path.append(self.current_token[1])
self.eat("IDENTIFIER")
while self.current_token[0] == "DOUBLE_COLON":
self.eat("DOUBLE_COLON")
if self.current_token[0] == "MULTIPLY":
path.append("*")
self.eat("MULTIPLY")
elif self.current_token[0] == "LBRACE":
self.eat("LBRACE")
items = []
while (
self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF"
):
items.append(self.current_token[1])
self.eat("IDENTIFIER")
if self.current_token[0] == "COMMA":
self.eat("COMMA")
else:
break
self.eat("RBRACE")
path.append("{" + ", ".join(items) + "}")
else:
path.append(self.current_token[1])
self.eat("IDENTIFIER")
alias = None
if self.current_token[0] == "AS":
self.eat("AS")
alias = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
return UseNode("::".join(path), alias, visibility)
def parse_attributes(self):
attrs = []
while self.current_token[0] == "POUND":
self.eat("POUND")
self.eat("LBRACKET")
attr_name = self.current_token[1]
self.eat("IDENTIFIER")
attr_args = []
if self.current_token[0] == "LPAREN":
self.eat("LPAREN")
while self.current_token[0] != "RPAREN":
attr_args.append(self.current_token[1])
self.eat(self.current_token[0])
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("RPAREN")
self.eat("RBRACKET")
attrs.append(AttributeNode(attr_name, attr_args))
return attrs
def parse_struct(self, attributes=None, visibility=None):
self.eat("STRUCT")
name = self.current_token[1]
self.eat("IDENTIFIER")
generics = []
if self.current_token[0] == "LESS_THAN":
generics = self.parse_generics()
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
member_attrs = []
if self.current_token[0] == "POUND":
member_attrs = self.parse_attributes()
if self.current_token[0] == "PUB":
self.eat("PUB")
member_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("COLON")
member_type = self.parse_type()
var = VariableNode(member_type, member_name, attributes=member_attrs)
members.append(var)
if self.current_token[0] == "COMMA":
self.eat("COMMA")
if self.current_token[0] == "PUB":
continue
elif self.current_token[0] == "IDENTIFIER":
continue
self.eat("RBRACE")
return StructNode(name, members, attributes, visibility, generics)
def parse_enum(self, attributes=None, visibility=None):
self.eat("ENUM")
name = self.current_token[1]
self.eat("IDENTIFIER")
generics = []
if self.current_token[0] == "LESS_THAN":
generics = self.parse_generics()
where_clauses = []
if self.current_token[0] == "WHERE":
where_clauses = self.parse_where_clause({"LBRACE"})
self.eat("LBRACE")
variants = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
variant_attrs = []
if self.current_token[0] == "POUND":
variant_attrs = self.parse_attributes()
variant_name = self.current_token[1]
self.eat("IDENTIFIER")
kind = "unit"
fields = []
if self.current_token[0] == "LPAREN":
kind = "tuple"
fields = self.parse_tuple_variant_fields()
elif self.current_token[0] == "LBRACE":
kind = "struct"
fields = self.parse_struct_variant_fields()
value = None
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
value = self.parse_expression()
variants.append(
EnumVariantNode(variant_name, kind, fields, value, variant_attrs)
)
if self.current_token[0] == "COMMA":
self.eat("COMMA")
continue
break
self.eat("RBRACE")
return EnumNode(name, variants, attributes, visibility, generics, where_clauses)
def parse_tuple_variant_fields(self):
self.eat("LPAREN")
fields = []
while self.current_token[0] != "RPAREN" and self.current_token[0] != "EOF":
fields.append(self.parse_type())
if self.current_token[0] == "COMMA":
self.eat("COMMA")
continue
break
self.eat("RPAREN")
return fields
def parse_struct_variant_fields(self):
self.eat("LBRACE")
fields = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
field_attrs = []
if self.current_token[0] == "POUND":
field_attrs = self.parse_attributes()
if self.current_token[0] == "PUB":
self.eat("PUB")
field_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("COLON")
field_type = self.parse_type()
fields.append(VariableNode(field_type, field_name, attributes=field_attrs))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
continue
break
self.eat("RBRACE")
return fields
def parse_impl_block(self):
self.eat("IMPL")
generics = []
if self.current_token[0] == "LESS_THAN":
generics = self.parse_generics()
trait_name = None
struct_name = self.parse_type()
if self.current_token[0] == "FOR":
trait_name = struct_name
self.eat("FOR")
struct_name = self.parse_type()
where_clauses = []
if self.current_token[0] == "WHERE":
where_clauses = self.parse_where_clause({"LBRACE"})
self.eat("LBRACE")
functions = []
type_aliases = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
if self.current_token[0] == "FN":
f = self.parse_function()
functions.append(f)
elif self.current_token[0] == "TYPE":
type_aliases.append(self.parse_type_alias())
elif self.current_token[0] == "PUB":
visibility = "pub"
self.eat("PUB")
if self.current_token[0] == "FN":
f = self.parse_function(visibility=visibility)
functions.append(f)
elif self.current_token[0] == "TYPE":
type_aliases.append(self.parse_type_alias(visibility=visibility))
else:
if self.current_token[0] == "EOF":
break
self.eat(self.current_token[0])
else:
if self.current_token[0] == "EOF":
break
self.eat(self.current_token[0])
self.eat("RBRACE")
return ImplNode(
struct_name,
functions,
trait_name,
generics,
where_clauses,
type_aliases,
)
def parse_trait(self, visibility=None):
self.eat("TRAIT")
name = self.current_token[1]
self.eat("IDENTIFIER")
generics = []
if self.current_token[0] == "LESS_THAN":
generics = self.parse_generics()
where_clauses = []
if self.current_token[0] == "WHERE":
where_clauses = self.parse_where_clause({"LBRACE"})
self.eat("LBRACE")
functions = []
associated_types = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
if self.current_token[0] == "FN":
f = self.parse_function_signature() # Traits only have signatures
functions.append(f)
elif self.current_token[0] == "TYPE":
associated_types.append(self.parse_associated_type())
else:
if self.current_token[0] == "EOF":
break
self.eat(self.current_token[0])
self.eat("RBRACE")
return TraitNode(
name,
functions,
generics,
visibility,
where_clauses,
associated_types,
)
def parse_generics(self):
self.eat("LESS_THAN")
generics = []
while (
self.current_token[0] != "GREATER_THAN" and self.current_token[0] != "EOF"
):
parameter = self.collect_token_text_until({"COMMA", "GREATER_THAN"})
if parameter:
generics.append(parameter)
if self.current_token[0] == "COMMA":
self.eat("COMMA")
continue
break
self.eat("GREATER_THAN")
return generics
def parse_type(self):
if self.current_token[0] == "AMPERSAND":
self.eat("AMPERSAND")
if self.current_token[0] == "MUT":
self.eat("MUT")
return f"&mut {self.parse_type()}"
return f"&{self.parse_type()}"
if self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
element_type = self.parse_type()
size = None
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
size = self.parse_array_type_size()
self.eat("RBRACKET")
return self.format_array_type(element_type, size)
type_parts = []
type_parts.append(self.current_token[1])
self.eat(self.current_token[0])
while True:
if self.current_token[0] == "DOUBLE_COLON":
type_parts.append("::")
self.eat("DOUBLE_COLON")
type_parts.append(self.current_token[1])
self.eat(self.current_token[0])
elif self.current_token[0] == "LESS_THAN":
type_parts.append(self.parse_generic_argument_suffix())
else:
break
if self.current_token[0] == "LBRACKET":
type_parts.append("[")
self.eat("LBRACKET")
if self.current_token[0] == "NUMBER":
type_parts.append(self.current_token[1])
self.eat("NUMBER")
type_parts.append("]")
self.eat("RBRACKET")
return "".join(type_parts)
def parse_generic_argument_suffix(self):
self.eat("LESS_THAN")
arguments = self.collect_token_text_until({"GREATER_THAN"})
self.eat("GREATER_THAN")
return f"<{arguments}>"
def parse_array_type_size(self):
parts = []
depth = 0
while self.current_token[0] != "EOF":
token_type, token_value = self.current_token
if token_type == "RBRACKET" and depth == 0:
break
if token_type in ["LPAREN", "LBRACKET", "LBRACE"]:
depth += 1
elif token_type in ["RPAREN", "RBRACKET", "RBRACE"]:
depth -= 1
parts.append(str(token_value))
self.eat(token_type)
return "".join(parts).strip() or None
def format_array_type(self, element_type, size):
suffix = f"[{size}]" if size is not None else "[]"
if "[" not in element_type:
return f"{element_type}{suffix}"
base_type, existing_suffix = element_type.split("[", 1)
return f"{base_type}{suffix}[{existing_suffix}"
def parse_where_clause(self, terminators=None):
terminators = set(terminators or {"LBRACE"})
self.eat("WHERE")
predicates = []
while (
self.current_token[0] not in terminators and self.current_token[0] != "EOF"
):
if self.current_token[0] == "COMMA":
self.eat("COMMA")
continue
type_param = self.collect_token_text_until({"COLON", *terminators})
if not type_param or self.current_token[0] in terminators:
break
self.eat("COLON")
bounds = []
bound_terminators = {"PLUS", "COMMA", *terminators}
while (
self.current_token[0] not in {"COMMA", *terminators}
and self.current_token[0] != "EOF"
):
bound = self.collect_token_text_until(bound_terminators)
if bound:
bounds.append(bound)
if self.current_token[0] == "PLUS":
self.eat("PLUS")
continue
break
predicates.append((type_param, bounds))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
else:
break
return predicates
def parse_associated_type(self):
self.eat("TYPE")
name = self.current_token[1]
self.eat("IDENTIFIER")
bounds = []
if self.current_token[0] == "COLON":
self.eat("COLON")
bound_terminators = {"PLUS", "EQUALS", "WHERE", "SEMICOLON"}
while self.current_token[0] not in {"EQUALS", "WHERE", "SEMICOLON"}:
bound = self.collect_token_text_until(bound_terminators)
if bound:
bounds.append(bound)
if self.current_token[0] == "PLUS":
self.eat("PLUS")
continue
break
where_clauses = []
if self.current_token[0] == "WHERE":
where_clauses = self.parse_where_clause({"EQUALS", "SEMICOLON"})
default_type = None
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
default_type = self.collect_token_text_until({"SEMICOLON"})
self.eat("SEMICOLON")
return AssociatedTypeNode(name, bounds, default_type, where_clauses)
def parse_type_alias(self, visibility=None, attributes=None):
self.eat("TYPE")
name = self.current_token[1]
self.eat("IDENTIFIER")
generics = []
if self.current_token[0] == "LESS_THAN":
generics = self.parse_generics()
where_clauses = []
if self.current_token[0] == "WHERE":
where_clauses = self.parse_where_clause({"EQUALS", "SEMICOLON"})
alias_type = None
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
alias_type = self.parse_type()
self.eat("SEMICOLON")
return TypeAliasNode(
name,
alias_type,
generics,
visibility,
where_clauses,
attributes,
)
def collect_token_text_until(self, terminators):
parts = []
depth = 0
while self.current_token[0] != "EOF":
token_type, token_value = self.current_token
if depth == 0 and token_type in terminators:
break
if token_type in {"LESS_THAN", "LPAREN", "LBRACKET"}:
depth += 1
elif token_type == "SHIFT_RIGHT":
if depth > 1:
depth -= 2
parts.append(str(token_value))
self.eat(token_type)
continue
if depth == 1:
depth -= 1
parts.append(">")
self.split_shift_right_token()
continue
elif token_type in {"GREATER_THAN", "RPAREN", "RBRACKET"}:
if depth == 0 and token_type in terminators:
break
depth = max(0, depth - 1)
parts.append(str(token_value))
self.eat(token_type)
return self.format_token_parts(parts)
def split_shift_right_token(self):
self.tokens[self.current_index] = ("GREATER_THAN", ">")
self.current_token = self.tokens[self.current_index]
def format_token_parts(self, parts):
formatted = []
previous = None
for part in parts:
if part == ",":
formatted.append(", ")
else:
if self.needs_token_part_space(previous, part):
formatted.append(" ")
formatted.append(part)
previous = part
return "".join(formatted).strip()
def needs_token_part_space(self, previous, current):
if previous is None:
return False
if previous in {"<", "[", "(", "::", ","}:
return False
if current in {">", ">>", "]", ")", ",", "::", "<", "[", "(", ":"}:
return False
if previous in {":", "+", "=", "->", "=>"}:
return True
if current in {"+", "=", "->", "=>"}:
return True
return self.is_token_word(previous) and self.is_token_word(current)
def is_token_word(self, part):
return part.replace("_", "").isalnum()
[docs]
def parse_struct_initialization(self, struct_name):
"""Parse struct initialization syntax: Name { field: value, ... }"""
self.eat("LBRACE")
fields = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
# Parse field name
field_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("COLON")
# Parse field value
field_value = self.parse_expression()
fields.append((field_name, field_value))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
else:
break
self.eat("RBRACE")
return StructInitializationNode(struct_name, fields)
def parse_function(self, attributes=None, visibility=None):
self.eat("FN")
name = self.current_token[1]
self.eat("IDENTIFIER")
generics = []
if self.current_token[0] == "LESS_THAN":
generics = self.parse_generics()
params = self.parse_parameters()
return_type = "()"
if self.current_token[0] == "ARROW":
self.eat("ARROW")
return_type = self.parse_type()
where_clauses = []
if self.current_token[0] == "WHERE":
where_clauses = self.parse_where_clause({"LBRACE"})
self.eat("LBRACE")
body = self.parse_block()
self.eat("RBRACE")
return FunctionNode(
return_type,
name,
params,
body,
attributes,
visibility,
generics,
where_clauses,
)
def parse_function_signature(self):
# For trait function signatures (no body)
self.eat("FN")
name = self.current_token[1]
self.eat("IDENTIFIER")
generics = []
if self.current_token[0] == "LESS_THAN":
generics = self.parse_generics()
params = self.parse_parameters()
return_type = "()"
if self.current_token[0] == "ARROW":
self.eat("ARROW")
return_type = self.parse_type()
where_clauses = []
if self.current_token[0] == "WHERE":
where_clauses = self.parse_where_clause({"SEMICOLON"})
self.eat("SEMICOLON")
return FunctionNode(
return_type,
name,
params,
[],
[],
None,
generics,
where_clauses,
)
def parse_parameters(self):
self.eat("LPAREN")
params = []
if self.current_token[0] != "RPAREN":
if self.current_token[0] == "SELF":
params.append(VariableNode("Self", "self"))
self.eat("SELF")
if self.current_token[0] == "COMMA":
self.eat("COMMA")
elif self.current_token[0] == "AMPERSAND":
self.eat("AMPERSAND")
if self.current_token[0] == "MUT":
self.eat("MUT")
params.append(VariableNode("&mut Self", "self"))
else:
params.append(VariableNode("&Self", "self"))
self.eat("SELF")
if self.current_token[0] == "COMMA":
self.eat("COMMA")
while self.current_token[0] != "RPAREN":
param_attrs = []
while self.current_token[0] == "POUND":
param_attrs.extend(self.parse_attributes())
is_mutable = False
if self.current_token[0] == "MUT":
is_mutable = True
self.eat("MUT")
param_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("COLON")
param_type = self.parse_type()
param = VariableNode(param_type, param_name, is_mutable)
if param_attrs:
param.attributes = param_attrs
params.append(param)
if self.current_token[0] == "COMMA":
self.eat("COMMA")
else:
break
self.eat("RPAREN")
return params
def parse_const(self, visibility=None):
self.eat("CONST")
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("COLON")
const_type = self.parse_type()
self.eat("EQUALS")
value = self.parse_expression()
self.eat("SEMICOLON")
return ConstNode(name, const_type, value, visibility)
def parse_static(self, visibility=None):
self.eat("STATIC")
is_mutable = False
if self.current_token[0] == "MUT":
is_mutable = True
self.eat("MUT")
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("COLON")
static_type = self.parse_type()
self.eat("EQUALS")
value = self.parse_expression()
self.eat("SEMICOLON")
return StaticNode(name, static_type, value, is_mutable, visibility)
def parse_block(self):
statements = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
if self.current_token[0] == "LET":
stmt = self.parse_let_statement()
statements.append(stmt)
elif self.current_token[0] == "IDENTIFIER" or self.current_token[0] in [
"VEC2",
"VEC3",
"VEC4",
"MAT2",
"MAT3",
"MAT4",
]:
left = self.parse_expression()
if self.current_token[0] in [
"EQUALS",
"PLUS_EQUALS",
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"MOD_EQUALS",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_expression()
self.eat("SEMICOLON")
statements.append(AssignmentNode(left, right, op))
else:
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
statements.append(left)
elif self.current_token[0] == "RETURN":
self.eat("RETURN")
value = None
if self.current_token[0] != "SEMICOLON":
value = self.parse_expression()
self.eat("SEMICOLON")
statements.append(ReturnNode(value))
elif self.current_token[0] == "BREAK":
self.eat("BREAK")
label = None
value = None
if self.current_token[0] != "SEMICOLON":
if self.current_token[0] == "IDENTIFIER":
label = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
statements.append(BreakNode(label, value))
elif self.current_token[0] == "CONTINUE":
self.eat("CONTINUE")
label = None
if self.current_token[0] == "IDENTIFIER":
label = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
statements.append(ContinueNode(label))
elif self.current_token[0] == "IF":
statements.append(self.parse_if_statement())
elif self.current_token[0] == "MATCH":
statements.append(self.parse_match_statement())
elif self.current_token[0] == "FOR":
statements.append(self.parse_for_loop())
elif self.current_token[0] == "WHILE":
statements.append(self.parse_while_loop())
elif self.current_token[0] == "LOOP":
statements.append(self.parse_loop())
else:
if self.current_token[0] == "EOF":
break
self.eat(self.current_token[0])
return statements
def parse_let_statement(self):
self.eat("LET")
is_mutable = False
if self.current_token[0] == "MUT":
is_mutable = True
self.eat("MUT")
name = self.current_token[1]
self.eat("IDENTIFIER")
var_type = None
if self.current_token[0] == "COLON":
self.eat("COLON")
var_type = self.parse_type()
value = None
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
value = self.parse_expression()
self.eat("SEMICOLON")
return LetNode(name, value, var_type, is_mutable)
def parse_if_statement(self):
self.eat("IF")
condition = self.parse_expression()
self.eat("LBRACE")
if_body = self.parse_block()
self.eat("RBRACE")
else_body = None
if self.current_token[0] == "ELSE":
self.eat("ELSE")
if self.current_token[0] == "IF":
else_body = [self.parse_if_statement()]
else:
# else block
self.eat("LBRACE")
else_body = self.parse_block()
self.eat("RBRACE")
return IfNode(condition, if_body, else_body)
def parse_match_statement(self):
self.eat("MATCH")
expression = self.parse_expression()
self.eat("LBRACE")
arms = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
if self.current_token[0] == "UNDERSCORE":
pattern = "_"
self.eat("UNDERSCORE")
elif self.current_token[0] == "NUMBER":
pattern = self.current_token[1]
self.eat("NUMBER")
elif self.current_token[0] == "STRING":
pattern = self.current_token[1]
self.eat("STRING")
elif self.current_token[0] == "IDENTIFIER":
pattern = self.current_token[1]
self.eat("IDENTIFIER")
else:
# Fall back to full expression parsing for complex patterns
pattern = self.parse_expression()
guard = None
if self.current_token[0] == "IF":
self.eat("IF")
guard = self.parse_expression()
self.eat("FAT_ARROW")
if self.current_token[0] == "LBRACE":
self.eat("LBRACE")
body = self.parse_block()
self.eat("RBRACE")
else:
body = [self.parse_expression()]
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
if self.current_token[0] == "COMMA":
self.eat("COMMA")
arms.append(MatchArmNode(pattern, guard, body))
self.eat("RBRACE")
return MatchNode(expression, arms)
def parse_for_loop(self):
self.eat("FOR")
pattern = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("IN")
iterable = self.parse_expression()
self.eat("LBRACE")
body = self.parse_block()
self.eat("RBRACE")
return ForNode(pattern, iterable, body)
def parse_while_loop(self):
self.eat("WHILE")
condition = self.parse_expression()
self.eat("LBRACE")
body = self.parse_block()
self.eat("RBRACE")
return WhileNode(condition, body)
def parse_loop(self):
self.eat("LOOP")
self.eat("LBRACE")
body = self.parse_block()
self.eat("RBRACE")
return LoopNode(body)
def parse_expression(self):
return self.parse_conditional_expression()
def parse_conditional_expression(self):
# Handle ternary-like if expressions: if condition { true_expr } else { false_expr }
if self.current_token[0] == "IF":
self.eat("IF")
condition = self.parse_logical_or_expression()
self.eat("LBRACE")
true_expr = self.parse_expression()
self.eat("RBRACE")
self.eat("ELSE")
self.eat("LBRACE")
false_expr = self.parse_expression()
self.eat("RBRACE")
return TernaryOpNode(condition, true_expr, false_expr)
else:
return self.parse_logical_or_expression()
def parse_logical_or_expression(self):
left = self.parse_logical_and_expression()
while self.current_token[0] == "LOGICAL_OR":
op = self.current_token[1]
self.eat("LOGICAL_OR")
right = self.parse_logical_and_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_logical_and_expression(self):
left = self.parse_equality_expression()
while self.current_token[0] == "LOGICAL_AND":
op = self.current_token[1]
self.eat("LOGICAL_AND")
right = self.parse_equality_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_equality_expression(self):
left = self.parse_relational_expression()
while self.current_token[0] in ["EQUAL", "NOT_EQUAL"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_relational_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_relational_expression(self):
left = self.parse_additive_expression()
while self.current_token[0] in [
"LESS_THAN",
"GREATER_THAN",
"LESS_EQUAL",
"GREATER_EQUAL",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_additive_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_additive_expression(self):
left = self.parse_range_expression()
while self.current_token[0] in ["PLUS", "MINUS"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_range_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_range_expression(self):
left = self.parse_multiplicative_expression()
if self.current_token[0] in ["RANGE", "RANGE_INCLUSIVE"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_multiplicative_expression()
return RangeNode(left, right, op == "..=")
return left
def parse_multiplicative_expression(self):
left = self.parse_cast_expression()
while self.current_token[0] in ["MULTIPLY", "DIVIDE", "MODULO"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_cast_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_cast_expression(self):
left = self.parse_unary_expression()
if self.current_token[0] == "AS":
self.eat("AS")
target_type = self.parse_type()
return CastNode(target_type, left)
return left
def parse_unary_expression(self):
if self.current_token[0] in ["MINUS", "EXCLAMATION", "AMPERSAND", "MULTIPLY"]:
op = self.current_token[1]
self.eat(self.current_token[0])
if op == "&":
is_mutable = False
if self.current_token[0] == "MUT":
is_mutable = True
self.eat("MUT")
expr = self.parse_unary_expression()
return ReferenceNode(expr, is_mutable)
elif op == "*":
expr = self.parse_unary_expression()
return DereferenceNode(expr)
else:
operand = self.parse_unary_expression()
return UnaryOpNode(op, operand)
else:
return self.parse_postfix_expression()
def parse_postfix_expression(self):
left = self.parse_primary_expression()
while True:
if self.current_token[0] == "DOT":
self.eat("DOT")
member = self.current_token[1]
self.eat("IDENTIFIER")
left = MemberAccessNode(left, member)
elif self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
index = self.parse_expression()
self.eat("RBRACKET")
left = ArrayAccessNode(left, index)
elif self.current_token[0] == "EXCLAMATION":
self.eat("EXCLAMATION")
self.eat("LPAREN")
args = []
while self.current_token[0] != "RPAREN":
args.append(self.parse_expression())
if self.current_token[0] == "COMMA":
self.eat("COMMA")
else:
break
self.eat("RPAREN")
# Treat macro calls as function calls for code generation
left = FunctionCallNode(left, args)
elif self.current_token[0] == "LPAREN":
self.eat("LPAREN")
args = []
while self.current_token[0] != "RPAREN":
args.append(self.parse_expression())
if self.current_token[0] == "COMMA":
self.eat("COMMA")
else:
break
self.eat("RPAREN")
left = FunctionCallNode(left, args)
else:
break
return left
def parse_primary_expression(self):
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] == "DOUBLE_COLON":
return self.finish_path_or_call(name)
# Only if this identifier is likely a struct constructor (starts with uppercase)
if self.current_token[0] == "LBRACE" and name[0].isupper():
return self.parse_struct_initialization(name)
if (
name in ["Vec2", "Vec3", "Vec4", "Mat2", "Mat3", "Mat4"]
and self.current_token[0] == "LPAREN"
):
self.eat("LPAREN")
args = []
while self.current_token[0] != "RPAREN":
args.append(self.parse_expression())
if self.current_token[0] == "COMMA":
self.eat("COMMA")
else:
break
self.eat("RPAREN")
return VectorConstructorNode(name, args)
return name
elif self.current_token[0] in ["VEC2", "VEC3", "VEC4", "MAT2", "MAT3", "MAT4"]:
name = self.current_token[1]
self.eat(self.current_token[0])
if self.current_token[0] == "DOUBLE_COLON":
return self.finish_path_or_call(name)
return name
elif self.current_token[0] == "NUMBER":
value = self.current_token[1]
self.eat("NUMBER")
return value
elif self.current_token[0] == "STRING":
value = self.current_token[1]
self.eat("STRING")
return value
elif self.current_token[0] == "TRUE":
self.eat("TRUE")
return "true"
elif self.current_token[0] == "FALSE":
self.eat("FALSE")
return "false"
elif self.current_token[0] == "SELF":
name = self.current_token[1]
self.eat("SELF")
# Check for struct initialization: Self { ... }
if self.current_token[0] == "LBRACE":
return self.parse_struct_initialization(name)
if self.current_token[0] == "DOUBLE_COLON":
return self.finish_path_or_call(name)
return name
elif self.current_token[0] in {"CRATE", "SUPER"}:
name = self.current_token[1]
self.eat(self.current_token[0])
if self.current_token[0] == "DOUBLE_COLON":
return self.finish_path_or_call(name)
return name
elif self.current_token[0] == "LPAREN":
self.eat("LPAREN")
if self.current_token[0] == "RPAREN":
self.eat("RPAREN")
return "()"
expr = self.parse_expression()
if self.current_token[0] == "COMMA":
elements = [expr]
while self.current_token[0] == "COMMA":
self.eat("COMMA")
if self.current_token[0] != "RPAREN": # Handle trailing comma
elements.append(self.parse_expression())
self.eat("RPAREN")
return TupleNode(elements)
else:
self.eat("RPAREN")
return expr
elif self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
if self.current_token[0] == "RBRACKET":
self.eat("RBRACKET")
return ArrayNode([])
first_element = self.parse_expression()
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
size = self.parse_expression()
self.eat("RBRACKET")
return ArrayNode([first_element], size)
elements = [first_element]
while self.current_token[0] != "RBRACKET":
if self.current_token[0] == "COMMA":
self.eat("COMMA")
if self.current_token[0] != "RBRACKET":
elements.append(self.parse_expression())
else:
break
self.eat("RBRACKET")
return ArrayNode(elements)
elif self.current_token[0] == "LBRACE":
self.eat("LBRACE")
statements = []
expression = None
while self.current_token[0] != "RBRACE":
if self.peek_is_statement():
stmt = self.parse_statement()
statements.append(stmt)
else:
# Final expression (no semicolon)
expression = self.parse_expression()
break
self.eat("RBRACE")
return BlockNode(statements, expression)
else:
raise SyntaxError(
f"Unexpected token in primary expression: {self.current_token}"
)
def parse_path_expression(self, first_segment):
segments = [first_segment]
while self.current_token[0] == "DOUBLE_COLON":
self.eat("DOUBLE_COLON")
if self.current_token[0] == "LESS_THAN":
segments[-1] += self.parse_generic_argument_suffix()
continue
segments.append(self.current_token[1])
self.eat(self.current_token[0])
return "::".join(segments)
def finish_path_or_call(self, first_segment):
path = self.parse_path_expression(first_segment)
if self.current_token[0] == "LPAREN":
return FunctionCallNode(path, self.parse_call_arguments())
return path
def parse_call_arguments(self):
self.eat("LPAREN")
args = []
while self.current_token[0] != "RPAREN":
args.append(self.parse_expression())
if self.current_token[0] == "COMMA":
self.eat("COMMA")
else:
break
self.eat("RPAREN")
return args
def peek_is_statement(self):
return self.current_token[0] in [
"LET",
"IF",
"MATCH",
"FOR",
"WHILE",
"LOOP",
"RETURN",
"BREAK",
"CONTINUE",
]
def parse_statement(self):
if self.current_token[0] == "LET":
return self.parse_let_statement()
elif self.current_token[0] == "IF":
return self.parse_if_statement()
elif self.current_token[0] == "MATCH":
return self.parse_match_statement()
elif self.current_token[0] == "FOR":
return self.parse_for_loop()
elif self.current_token[0] == "WHILE":
return self.parse_while_loop()
elif self.current_token[0] == "LOOP":
return self.parse_loop()
elif self.current_token[0] == "RETURN":
self.eat("RETURN")
value = None
if self.current_token[0] != "SEMICOLON":
value = self.parse_expression()
self.eat("SEMICOLON")
return ReturnNode(value)
elif self.current_token[0] == "BREAK":
self.eat("BREAK")
label = None
value = None
if self.current_token[0] != "SEMICOLON":
if self.current_token[0] == "IDENTIFIER":
label = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
return BreakNode(label, value)
elif self.current_token[0] == "CONTINUE":
self.eat("CONTINUE")
label = None
if self.current_token[0] == "IDENTIFIER":
label = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
return ContinueNode(label)
else:
expr = self.parse_expression()
self.eat("SEMICOLON")
return expr