Source code for crosstl.backend.Rust.RustParser

"""Parser for Rust source AST construction."""

from .RustAst import (
    AssignmentNode,
    BinaryOpNode,
    ForNode,
    WhileNode,
    LoopNode,
    MatchNode,
    MatchArmNode,
    FunctionCallNode,
    FunctionNode,
    IfNode,
    MemberAccessNode,
    ReturnNode,
    BreakNode,
    ContinueNode,
    ShaderNode,
    StructNode,
    StructInitializationNode,
    EnumNode,
    EnumVariantNode,
    AssociatedTypeNode,
    TypeAliasNode,
    ImplNode,
    TraitNode,
    UnaryOpNode,
    VariableNode,
    LetNode,
    VectorConstructorNode,
    TernaryOpNode,
    UseNode,
    AttributeNode,
    ConstNode,
    StaticNode,
    ArrayAccessNode,
    RangeNode,
    TupleNode,
    ArrayNode,
    ReferenceNode,
    DereferenceNode,
    CastNode,
    BlockNode,
)
from .RustLexer import RustLexer, Lexer, TokenType


[docs] class RustParser: """Parse Rust tokens into the Rust backend shader AST.""" def __init__(self, tokens): """Initialize the parser with a token stream from ``RustLexer``.""" self.tokens = tokens self.current_index = 0 self.current_token = tokens[0] if tokens else None
[docs] def parse(self): """Parse the complete Rust token stream into a shader AST.""" structs = [] functions = [] global_variables = [] impl_blocks = [] use_statements = [] traits = [] enums = [] type_aliases = [] while self.current_token[0] != "EOF": if self.current_token[0] == "USE": u = self.parse_use_statement() use_statements.append(u) elif self.current_token[0] == "TYPE": type_aliases.append(self.parse_type_alias()) elif self.current_token[0] == "STRUCT": s = self.parse_struct() structs.append(s) elif self.current_token[0] == "ENUM": e = self.parse_enum() enums.append(e) elif self.current_token[0] == "IMPL": i = self.parse_impl_block() impl_blocks.append(i) elif self.current_token[0] == "TRAIT": t = self.parse_trait() traits.append(t) elif self.current_token[0] == "FN": f = self.parse_function() functions.append(f) elif self.current_token[0] == "CONST": c = self.parse_const() global_variables.append(c) elif self.current_token[0] == "STATIC": s = self.parse_static() global_variables.append(s) elif self.current_token[0] == "PUB": visibility = "pub" self.eat("PUB") if self.current_token[0] == "STRUCT": s = self.parse_struct(visibility=visibility) structs.append(s) elif self.current_token[0] == "FN": f = self.parse_function(visibility=visibility) functions.append(f) elif self.current_token[0] == "CONST": c = self.parse_const(visibility=visibility) global_variables.append(c) elif self.current_token[0] == "STATIC": s = self.parse_static(visibility=visibility) global_variables.append(s) elif self.current_token[0] == "TRAIT": t = self.parse_trait(visibility=visibility) traits.append(t) elif self.current_token[0] == "ENUM": e = self.parse_enum(visibility=visibility) enums.append(e) elif self.current_token[0] == "TYPE": type_aliases.append(self.parse_type_alias(visibility=visibility)) elif self.current_token[0] == "USE": u = self.parse_use_statement(visibility=visibility) use_statements.append(u) else: self.eat(self.current_token[0]) elif self.current_token[0] == "POUND": attrs = self.parse_attributes() # The next item should use these attributes if self.current_token[0] == "STRUCT": s = self.parse_struct(attributes=attrs) structs.append(s) elif self.current_token[0] == "ENUM": e = self.parse_enum(attributes=attrs) enums.append(e) elif self.current_token[0] == "TYPE": type_aliases.append(self.parse_type_alias(attributes=attrs)) elif self.current_token[0] == "FN": f = self.parse_function(attributes=attrs) functions.append(f) elif self.current_token[0] == "PUB": visibility = "pub" self.eat("PUB") if self.current_token[0] == "STRUCT": s = self.parse_struct(attributes=attrs, visibility=visibility) structs.append(s) elif self.current_token[0] == "ENUM": e = self.parse_enum(attributes=attrs, visibility=visibility) enums.append(e) elif self.current_token[0] == "TYPE": type_aliases.append( self.parse_type_alias( attributes=attrs, visibility=visibility ) ) elif self.current_token[0] == "FN": f = self.parse_function(attributes=attrs, visibility=visibility) functions.append(f) else: self.eat(self.current_token[0]) else: self.eat(self.current_token[0]) return ShaderNode( structs, functions, global_variables, impl_blocks, use_statements, traits, enums, type_aliases, )
[docs] def eat(self, expected_type): """Consume and return the current token when it matches ``expected_type``.""" if self.current_token[0] == expected_type: token = self.current_token self.current_index += 1 if self.current_index < len(self.tokens): self.current_token = self.tokens[self.current_index] else: self.current_token = ("EOF", "") return token else: raise SyntaxError(f"Expected {expected_type}, got {self.current_token[0]}")
def skip_until(self, token_type): while self.current_token[0] != token_type and self.current_token[0] != "EOF": self.current_index += 1 if self.current_index < len(self.tokens): self.current_token = self.tokens[self.current_index] else: self.current_token = ("EOF", "") def parse_use_statement(self, visibility=None): self.eat("USE") path = [] path.append(self.current_token[1]) self.eat("IDENTIFIER") while self.current_token[0] == "DOUBLE_COLON": self.eat("DOUBLE_COLON") if self.current_token[0] == "MULTIPLY": path.append("*") self.eat("MULTIPLY") elif self.current_token[0] == "LBRACE": self.eat("LBRACE") items = [] while ( self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF" ): items.append(self.current_token[1]) self.eat("IDENTIFIER") if self.current_token[0] == "COMMA": self.eat("COMMA") else: break self.eat("RBRACE") path.append("{" + ", ".join(items) + "}") else: path.append(self.current_token[1]) self.eat("IDENTIFIER") alias = None if self.current_token[0] == "AS": self.eat("AS") alias = self.current_token[1] self.eat("IDENTIFIER") self.eat("SEMICOLON") return UseNode("::".join(path), alias, visibility) def parse_attributes(self): attrs = [] while self.current_token[0] == "POUND": self.eat("POUND") self.eat("LBRACKET") attr_name = self.current_token[1] self.eat("IDENTIFIER") attr_args = [] if self.current_token[0] == "LPAREN": self.eat("LPAREN") while self.current_token[0] != "RPAREN": attr_args.append(self.current_token[1]) self.eat(self.current_token[0]) if self.current_token[0] == "COMMA": self.eat("COMMA") self.eat("RPAREN") self.eat("RBRACKET") attrs.append(AttributeNode(attr_name, attr_args)) return attrs def parse_struct(self, attributes=None, visibility=None): self.eat("STRUCT") name = self.current_token[1] self.eat("IDENTIFIER") generics = [] if self.current_token[0] == "LESS_THAN": generics = self.parse_generics() self.eat("LBRACE") members = [] while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF": member_attrs = [] if self.current_token[0] == "POUND": member_attrs = self.parse_attributes() if self.current_token[0] == "PUB": self.eat("PUB") member_name = self.current_token[1] self.eat("IDENTIFIER") self.eat("COLON") member_type = self.parse_type() var = VariableNode(member_type, member_name, attributes=member_attrs) members.append(var) if self.current_token[0] == "COMMA": self.eat("COMMA") if self.current_token[0] == "PUB": continue elif self.current_token[0] == "IDENTIFIER": continue self.eat("RBRACE") return StructNode(name, members, attributes, visibility, generics) def parse_enum(self, attributes=None, visibility=None): self.eat("ENUM") name = self.current_token[1] self.eat("IDENTIFIER") generics = [] if self.current_token[0] == "LESS_THAN": generics = self.parse_generics() where_clauses = [] if self.current_token[0] == "WHERE": where_clauses = self.parse_where_clause({"LBRACE"}) self.eat("LBRACE") variants = [] while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF": variant_attrs = [] if self.current_token[0] == "POUND": variant_attrs = self.parse_attributes() variant_name = self.current_token[1] self.eat("IDENTIFIER") kind = "unit" fields = [] if self.current_token[0] == "LPAREN": kind = "tuple" fields = self.parse_tuple_variant_fields() elif self.current_token[0] == "LBRACE": kind = "struct" fields = self.parse_struct_variant_fields() value = None if self.current_token[0] == "EQUALS": self.eat("EQUALS") value = self.parse_expression() variants.append( EnumVariantNode(variant_name, kind, fields, value, variant_attrs) ) if self.current_token[0] == "COMMA": self.eat("COMMA") continue break self.eat("RBRACE") return EnumNode(name, variants, attributes, visibility, generics, where_clauses) def parse_tuple_variant_fields(self): self.eat("LPAREN") fields = [] while self.current_token[0] != "RPAREN" and self.current_token[0] != "EOF": fields.append(self.parse_type()) if self.current_token[0] == "COMMA": self.eat("COMMA") continue break self.eat("RPAREN") return fields def parse_struct_variant_fields(self): self.eat("LBRACE") fields = [] while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF": field_attrs = [] if self.current_token[0] == "POUND": field_attrs = self.parse_attributes() if self.current_token[0] == "PUB": self.eat("PUB") field_name = self.current_token[1] self.eat("IDENTIFIER") self.eat("COLON") field_type = self.parse_type() fields.append(VariableNode(field_type, field_name, attributes=field_attrs)) if self.current_token[0] == "COMMA": self.eat("COMMA") continue break self.eat("RBRACE") return fields def parse_impl_block(self): self.eat("IMPL") generics = [] if self.current_token[0] == "LESS_THAN": generics = self.parse_generics() trait_name = None struct_name = self.parse_type() if self.current_token[0] == "FOR": trait_name = struct_name self.eat("FOR") struct_name = self.parse_type() where_clauses = [] if self.current_token[0] == "WHERE": where_clauses = self.parse_where_clause({"LBRACE"}) self.eat("LBRACE") functions = [] type_aliases = [] while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF": if self.current_token[0] == "FN": f = self.parse_function() functions.append(f) elif self.current_token[0] == "TYPE": type_aliases.append(self.parse_type_alias()) elif self.current_token[0] == "PUB": visibility = "pub" self.eat("PUB") if self.current_token[0] == "FN": f = self.parse_function(visibility=visibility) functions.append(f) elif self.current_token[0] == "TYPE": type_aliases.append(self.parse_type_alias(visibility=visibility)) else: if self.current_token[0] == "EOF": break self.eat(self.current_token[0]) else: if self.current_token[0] == "EOF": break self.eat(self.current_token[0]) self.eat("RBRACE") return ImplNode( struct_name, functions, trait_name, generics, where_clauses, type_aliases, ) def parse_trait(self, visibility=None): self.eat("TRAIT") name = self.current_token[1] self.eat("IDENTIFIER") generics = [] if self.current_token[0] == "LESS_THAN": generics = self.parse_generics() where_clauses = [] if self.current_token[0] == "WHERE": where_clauses = self.parse_where_clause({"LBRACE"}) self.eat("LBRACE") functions = [] associated_types = [] while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF": if self.current_token[0] == "FN": f = self.parse_function_signature() # Traits only have signatures functions.append(f) elif self.current_token[0] == "TYPE": associated_types.append(self.parse_associated_type()) else: if self.current_token[0] == "EOF": break self.eat(self.current_token[0]) self.eat("RBRACE") return TraitNode( name, functions, generics, visibility, where_clauses, associated_types, ) def parse_generics(self): self.eat("LESS_THAN") generics = [] while ( self.current_token[0] != "GREATER_THAN" and self.current_token[0] != "EOF" ): parameter = self.collect_token_text_until({"COMMA", "GREATER_THAN"}) if parameter: generics.append(parameter) if self.current_token[0] == "COMMA": self.eat("COMMA") continue break self.eat("GREATER_THAN") return generics def parse_type(self): if self.current_token[0] == "AMPERSAND": self.eat("AMPERSAND") if self.current_token[0] == "MUT": self.eat("MUT") return f"&mut {self.parse_type()}" return f"&{self.parse_type()}" if self.current_token[0] == "LBRACKET": self.eat("LBRACKET") element_type = self.parse_type() size = None if self.current_token[0] == "SEMICOLON": self.eat("SEMICOLON") size = self.parse_array_type_size() self.eat("RBRACKET") return self.format_array_type(element_type, size) type_parts = [] type_parts.append(self.current_token[1]) self.eat(self.current_token[0]) while True: if self.current_token[0] == "DOUBLE_COLON": type_parts.append("::") self.eat("DOUBLE_COLON") type_parts.append(self.current_token[1]) self.eat(self.current_token[0]) elif self.current_token[0] == "LESS_THAN": type_parts.append(self.parse_generic_argument_suffix()) else: break if self.current_token[0] == "LBRACKET": type_parts.append("[") self.eat("LBRACKET") if self.current_token[0] == "NUMBER": type_parts.append(self.current_token[1]) self.eat("NUMBER") type_parts.append("]") self.eat("RBRACKET") return "".join(type_parts) def parse_generic_argument_suffix(self): self.eat("LESS_THAN") arguments = self.collect_token_text_until({"GREATER_THAN"}) self.eat("GREATER_THAN") return f"<{arguments}>" def parse_array_type_size(self): parts = [] depth = 0 while self.current_token[0] != "EOF": token_type, token_value = self.current_token if token_type == "RBRACKET" and depth == 0: break if token_type in ["LPAREN", "LBRACKET", "LBRACE"]: depth += 1 elif token_type in ["RPAREN", "RBRACKET", "RBRACE"]: depth -= 1 parts.append(str(token_value)) self.eat(token_type) return "".join(parts).strip() or None def format_array_type(self, element_type, size): suffix = f"[{size}]" if size is not None else "[]" if "[" not in element_type: return f"{element_type}{suffix}" base_type, existing_suffix = element_type.split("[", 1) return f"{base_type}{suffix}[{existing_suffix}" def parse_where_clause(self, terminators=None): terminators = set(terminators or {"LBRACE"}) self.eat("WHERE") predicates = [] while ( self.current_token[0] not in terminators and self.current_token[0] != "EOF" ): if self.current_token[0] == "COMMA": self.eat("COMMA") continue type_param = self.collect_token_text_until({"COLON", *terminators}) if not type_param or self.current_token[0] in terminators: break self.eat("COLON") bounds = [] bound_terminators = {"PLUS", "COMMA", *terminators} while ( self.current_token[0] not in {"COMMA", *terminators} and self.current_token[0] != "EOF" ): bound = self.collect_token_text_until(bound_terminators) if bound: bounds.append(bound) if self.current_token[0] == "PLUS": self.eat("PLUS") continue break predicates.append((type_param, bounds)) if self.current_token[0] == "COMMA": self.eat("COMMA") else: break return predicates def parse_associated_type(self): self.eat("TYPE") name = self.current_token[1] self.eat("IDENTIFIER") bounds = [] if self.current_token[0] == "COLON": self.eat("COLON") bound_terminators = {"PLUS", "EQUALS", "WHERE", "SEMICOLON"} while self.current_token[0] not in {"EQUALS", "WHERE", "SEMICOLON"}: bound = self.collect_token_text_until(bound_terminators) if bound: bounds.append(bound) if self.current_token[0] == "PLUS": self.eat("PLUS") continue break where_clauses = [] if self.current_token[0] == "WHERE": where_clauses = self.parse_where_clause({"EQUALS", "SEMICOLON"}) default_type = None if self.current_token[0] == "EQUALS": self.eat("EQUALS") default_type = self.collect_token_text_until({"SEMICOLON"}) self.eat("SEMICOLON") return AssociatedTypeNode(name, bounds, default_type, where_clauses) def parse_type_alias(self, visibility=None, attributes=None): self.eat("TYPE") name = self.current_token[1] self.eat("IDENTIFIER") generics = [] if self.current_token[0] == "LESS_THAN": generics = self.parse_generics() where_clauses = [] if self.current_token[0] == "WHERE": where_clauses = self.parse_where_clause({"EQUALS", "SEMICOLON"}) alias_type = None if self.current_token[0] == "EQUALS": self.eat("EQUALS") alias_type = self.parse_type() self.eat("SEMICOLON") return TypeAliasNode( name, alias_type, generics, visibility, where_clauses, attributes, ) def collect_token_text_until(self, terminators): parts = [] depth = 0 while self.current_token[0] != "EOF": token_type, token_value = self.current_token if depth == 0 and token_type in terminators: break if token_type in {"LESS_THAN", "LPAREN", "LBRACKET"}: depth += 1 elif token_type == "SHIFT_RIGHT": if depth > 1: depth -= 2 parts.append(str(token_value)) self.eat(token_type) continue if depth == 1: depth -= 1 parts.append(">") self.split_shift_right_token() continue elif token_type in {"GREATER_THAN", "RPAREN", "RBRACKET"}: if depth == 0 and token_type in terminators: break depth = max(0, depth - 1) parts.append(str(token_value)) self.eat(token_type) return self.format_token_parts(parts) def split_shift_right_token(self): self.tokens[self.current_index] = ("GREATER_THAN", ">") self.current_token = self.tokens[self.current_index] def format_token_parts(self, parts): formatted = [] previous = None for part in parts: if part == ",": formatted.append(", ") else: if self.needs_token_part_space(previous, part): formatted.append(" ") formatted.append(part) previous = part return "".join(formatted).strip() def needs_token_part_space(self, previous, current): if previous is None: return False if previous in {"<", "[", "(", "::", ","}: return False if current in {">", ">>", "]", ")", ",", "::", "<", "[", "(", ":"}: return False if previous in {":", "+", "=", "->", "=>"}: return True if current in {"+", "=", "->", "=>"}: return True return self.is_token_word(previous) and self.is_token_word(current) def is_token_word(self, part): return part.replace("_", "").isalnum()
[docs] def parse_struct_initialization(self, struct_name): """Parse struct initialization syntax: Name { field: value, ... }""" self.eat("LBRACE") fields = [] while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF": # Parse field name field_name = self.current_token[1] self.eat("IDENTIFIER") self.eat("COLON") # Parse field value field_value = self.parse_expression() fields.append((field_name, field_value)) if self.current_token[0] == "COMMA": self.eat("COMMA") else: break self.eat("RBRACE") return StructInitializationNode(struct_name, fields)
def parse_function(self, attributes=None, visibility=None): self.eat("FN") name = self.current_token[1] self.eat("IDENTIFIER") generics = [] if self.current_token[0] == "LESS_THAN": generics = self.parse_generics() params = self.parse_parameters() return_type = "()" if self.current_token[0] == "ARROW": self.eat("ARROW") return_type = self.parse_type() where_clauses = [] if self.current_token[0] == "WHERE": where_clauses = self.parse_where_clause({"LBRACE"}) self.eat("LBRACE") body = self.parse_block() self.eat("RBRACE") return FunctionNode( return_type, name, params, body, attributes, visibility, generics, where_clauses, ) def parse_function_signature(self): # For trait function signatures (no body) self.eat("FN") name = self.current_token[1] self.eat("IDENTIFIER") generics = [] if self.current_token[0] == "LESS_THAN": generics = self.parse_generics() params = self.parse_parameters() return_type = "()" if self.current_token[0] == "ARROW": self.eat("ARROW") return_type = self.parse_type() where_clauses = [] if self.current_token[0] == "WHERE": where_clauses = self.parse_where_clause({"SEMICOLON"}) self.eat("SEMICOLON") return FunctionNode( return_type, name, params, [], [], None, generics, where_clauses, ) def parse_parameters(self): self.eat("LPAREN") params = [] if self.current_token[0] != "RPAREN": if self.current_token[0] == "SELF": params.append(VariableNode("Self", "self")) self.eat("SELF") if self.current_token[0] == "COMMA": self.eat("COMMA") elif self.current_token[0] == "AMPERSAND": self.eat("AMPERSAND") if self.current_token[0] == "MUT": self.eat("MUT") params.append(VariableNode("&mut Self", "self")) else: params.append(VariableNode("&Self", "self")) self.eat("SELF") if self.current_token[0] == "COMMA": self.eat("COMMA") while self.current_token[0] != "RPAREN": param_attrs = [] while self.current_token[0] == "POUND": param_attrs.extend(self.parse_attributes()) is_mutable = False if self.current_token[0] == "MUT": is_mutable = True self.eat("MUT") param_name = self.current_token[1] self.eat("IDENTIFIER") self.eat("COLON") param_type = self.parse_type() param = VariableNode(param_type, param_name, is_mutable) if param_attrs: param.attributes = param_attrs params.append(param) if self.current_token[0] == "COMMA": self.eat("COMMA") else: break self.eat("RPAREN") return params def parse_const(self, visibility=None): self.eat("CONST") name = self.current_token[1] self.eat("IDENTIFIER") self.eat("COLON") const_type = self.parse_type() self.eat("EQUALS") value = self.parse_expression() self.eat("SEMICOLON") return ConstNode(name, const_type, value, visibility) def parse_static(self, visibility=None): self.eat("STATIC") is_mutable = False if self.current_token[0] == "MUT": is_mutable = True self.eat("MUT") name = self.current_token[1] self.eat("IDENTIFIER") self.eat("COLON") static_type = self.parse_type() self.eat("EQUALS") value = self.parse_expression() self.eat("SEMICOLON") return StaticNode(name, static_type, value, is_mutable, visibility) def parse_block(self): statements = [] while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF": if self.current_token[0] == "LET": stmt = self.parse_let_statement() statements.append(stmt) elif self.current_token[0] == "IDENTIFIER" or self.current_token[0] in [ "VEC2", "VEC3", "VEC4", "MAT2", "MAT3", "MAT4", ]: left = self.parse_expression() if self.current_token[0] in [ "EQUALS", "PLUS_EQUALS", "MINUS_EQUALS", "MULTIPLY_EQUALS", "DIVIDE_EQUALS", "MOD_EQUALS", ]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_expression() self.eat("SEMICOLON") statements.append(AssignmentNode(left, right, op)) else: if self.current_token[0] == "SEMICOLON": self.eat("SEMICOLON") statements.append(left) elif self.current_token[0] == "RETURN": self.eat("RETURN") value = None if self.current_token[0] != "SEMICOLON": value = self.parse_expression() self.eat("SEMICOLON") statements.append(ReturnNode(value)) elif self.current_token[0] == "BREAK": self.eat("BREAK") label = None value = None if self.current_token[0] != "SEMICOLON": if self.current_token[0] == "IDENTIFIER": label = self.current_token[1] self.eat("IDENTIFIER") self.eat("SEMICOLON") statements.append(BreakNode(label, value)) elif self.current_token[0] == "CONTINUE": self.eat("CONTINUE") label = None if self.current_token[0] == "IDENTIFIER": label = self.current_token[1] self.eat("IDENTIFIER") self.eat("SEMICOLON") statements.append(ContinueNode(label)) elif self.current_token[0] == "IF": statements.append(self.parse_if_statement()) elif self.current_token[0] == "MATCH": statements.append(self.parse_match_statement()) elif self.current_token[0] == "FOR": statements.append(self.parse_for_loop()) elif self.current_token[0] == "WHILE": statements.append(self.parse_while_loop()) elif self.current_token[0] == "LOOP": statements.append(self.parse_loop()) else: if self.current_token[0] == "EOF": break self.eat(self.current_token[0]) return statements def parse_let_statement(self): self.eat("LET") is_mutable = False if self.current_token[0] == "MUT": is_mutable = True self.eat("MUT") name = self.current_token[1] self.eat("IDENTIFIER") var_type = None if self.current_token[0] == "COLON": self.eat("COLON") var_type = self.parse_type() value = None if self.current_token[0] == "EQUALS": self.eat("EQUALS") value = self.parse_expression() self.eat("SEMICOLON") return LetNode(name, value, var_type, is_mutable) def parse_if_statement(self): self.eat("IF") condition = self.parse_expression() self.eat("LBRACE") if_body = self.parse_block() self.eat("RBRACE") else_body = None if self.current_token[0] == "ELSE": self.eat("ELSE") if self.current_token[0] == "IF": else_body = [self.parse_if_statement()] else: # else block self.eat("LBRACE") else_body = self.parse_block() self.eat("RBRACE") return IfNode(condition, if_body, else_body) def parse_match_statement(self): self.eat("MATCH") expression = self.parse_expression() self.eat("LBRACE") arms = [] while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF": if self.current_token[0] == "UNDERSCORE": pattern = "_" self.eat("UNDERSCORE") elif self.current_token[0] == "NUMBER": pattern = self.current_token[1] self.eat("NUMBER") elif self.current_token[0] == "STRING": pattern = self.current_token[1] self.eat("STRING") elif self.current_token[0] == "IDENTIFIER": pattern = self.current_token[1] self.eat("IDENTIFIER") else: # Fall back to full expression parsing for complex patterns pattern = self.parse_expression() guard = None if self.current_token[0] == "IF": self.eat("IF") guard = self.parse_expression() self.eat("FAT_ARROW") if self.current_token[0] == "LBRACE": self.eat("LBRACE") body = self.parse_block() self.eat("RBRACE") else: body = [self.parse_expression()] if self.current_token[0] == "SEMICOLON": self.eat("SEMICOLON") if self.current_token[0] == "COMMA": self.eat("COMMA") arms.append(MatchArmNode(pattern, guard, body)) self.eat("RBRACE") return MatchNode(expression, arms) def parse_for_loop(self): self.eat("FOR") pattern = self.current_token[1] self.eat("IDENTIFIER") self.eat("IN") iterable = self.parse_expression() self.eat("LBRACE") body = self.parse_block() self.eat("RBRACE") return ForNode(pattern, iterable, body) def parse_while_loop(self): self.eat("WHILE") condition = self.parse_expression() self.eat("LBRACE") body = self.parse_block() self.eat("RBRACE") return WhileNode(condition, body) def parse_loop(self): self.eat("LOOP") self.eat("LBRACE") body = self.parse_block() self.eat("RBRACE") return LoopNode(body) def parse_expression(self): return self.parse_conditional_expression() def parse_conditional_expression(self): # Handle ternary-like if expressions: if condition { true_expr } else { false_expr } if self.current_token[0] == "IF": self.eat("IF") condition = self.parse_logical_or_expression() self.eat("LBRACE") true_expr = self.parse_expression() self.eat("RBRACE") self.eat("ELSE") self.eat("LBRACE") false_expr = self.parse_expression() self.eat("RBRACE") return TernaryOpNode(condition, true_expr, false_expr) else: return self.parse_logical_or_expression() def parse_logical_or_expression(self): left = self.parse_logical_and_expression() while self.current_token[0] == "LOGICAL_OR": op = self.current_token[1] self.eat("LOGICAL_OR") right = self.parse_logical_and_expression() left = BinaryOpNode(left, op, right) return left def parse_logical_and_expression(self): left = self.parse_equality_expression() while self.current_token[0] == "LOGICAL_AND": op = self.current_token[1] self.eat("LOGICAL_AND") right = self.parse_equality_expression() left = BinaryOpNode(left, op, right) return left def parse_equality_expression(self): left = self.parse_relational_expression() while self.current_token[0] in ["EQUAL", "NOT_EQUAL"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_relational_expression() left = BinaryOpNode(left, op, right) return left def parse_relational_expression(self): left = self.parse_additive_expression() while self.current_token[0] in [ "LESS_THAN", "GREATER_THAN", "LESS_EQUAL", "GREATER_EQUAL", ]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_additive_expression() left = BinaryOpNode(left, op, right) return left def parse_additive_expression(self): left = self.parse_range_expression() while self.current_token[0] in ["PLUS", "MINUS"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_range_expression() left = BinaryOpNode(left, op, right) return left def parse_range_expression(self): left = self.parse_multiplicative_expression() if self.current_token[0] in ["RANGE", "RANGE_INCLUSIVE"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_multiplicative_expression() return RangeNode(left, right, op == "..=") return left def parse_multiplicative_expression(self): left = self.parse_cast_expression() while self.current_token[0] in ["MULTIPLY", "DIVIDE", "MODULO"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_cast_expression() left = BinaryOpNode(left, op, right) return left def parse_cast_expression(self): left = self.parse_unary_expression() if self.current_token[0] == "AS": self.eat("AS") target_type = self.parse_type() return CastNode(target_type, left) return left def parse_unary_expression(self): if self.current_token[0] in ["MINUS", "EXCLAMATION", "AMPERSAND", "MULTIPLY"]: op = self.current_token[1] self.eat(self.current_token[0]) if op == "&": is_mutable = False if self.current_token[0] == "MUT": is_mutable = True self.eat("MUT") expr = self.parse_unary_expression() return ReferenceNode(expr, is_mutable) elif op == "*": expr = self.parse_unary_expression() return DereferenceNode(expr) else: operand = self.parse_unary_expression() return UnaryOpNode(op, operand) else: return self.parse_postfix_expression() def parse_postfix_expression(self): left = self.parse_primary_expression() while True: if self.current_token[0] == "DOT": self.eat("DOT") member = self.current_token[1] self.eat("IDENTIFIER") left = MemberAccessNode(left, member) elif self.current_token[0] == "LBRACKET": self.eat("LBRACKET") index = self.parse_expression() self.eat("RBRACKET") left = ArrayAccessNode(left, index) elif self.current_token[0] == "EXCLAMATION": self.eat("EXCLAMATION") self.eat("LPAREN") args = [] while self.current_token[0] != "RPAREN": args.append(self.parse_expression()) if self.current_token[0] == "COMMA": self.eat("COMMA") else: break self.eat("RPAREN") # Treat macro calls as function calls for code generation left = FunctionCallNode(left, args) elif self.current_token[0] == "LPAREN": self.eat("LPAREN") args = [] while self.current_token[0] != "RPAREN": args.append(self.parse_expression()) if self.current_token[0] == "COMMA": self.eat("COMMA") else: break self.eat("RPAREN") left = FunctionCallNode(left, args) else: break return left def parse_primary_expression(self): if self.current_token[0] == "IDENTIFIER": name = self.current_token[1] self.eat("IDENTIFIER") if self.current_token[0] == "DOUBLE_COLON": return self.finish_path_or_call(name) # Only if this identifier is likely a struct constructor (starts with uppercase) if self.current_token[0] == "LBRACE" and name[0].isupper(): return self.parse_struct_initialization(name) if ( name in ["Vec2", "Vec3", "Vec4", "Mat2", "Mat3", "Mat4"] and self.current_token[0] == "LPAREN" ): self.eat("LPAREN") args = [] while self.current_token[0] != "RPAREN": args.append(self.parse_expression()) if self.current_token[0] == "COMMA": self.eat("COMMA") else: break self.eat("RPAREN") return VectorConstructorNode(name, args) return name elif self.current_token[0] in ["VEC2", "VEC3", "VEC4", "MAT2", "MAT3", "MAT4"]: name = self.current_token[1] self.eat(self.current_token[0]) if self.current_token[0] == "DOUBLE_COLON": return self.finish_path_or_call(name) return name elif self.current_token[0] == "NUMBER": value = self.current_token[1] self.eat("NUMBER") return value elif self.current_token[0] == "STRING": value = self.current_token[1] self.eat("STRING") return value elif self.current_token[0] == "TRUE": self.eat("TRUE") return "true" elif self.current_token[0] == "FALSE": self.eat("FALSE") return "false" elif self.current_token[0] == "SELF": name = self.current_token[1] self.eat("SELF") # Check for struct initialization: Self { ... } if self.current_token[0] == "LBRACE": return self.parse_struct_initialization(name) if self.current_token[0] == "DOUBLE_COLON": return self.finish_path_or_call(name) return name elif self.current_token[0] in {"CRATE", "SUPER"}: name = self.current_token[1] self.eat(self.current_token[0]) if self.current_token[0] == "DOUBLE_COLON": return self.finish_path_or_call(name) return name elif self.current_token[0] == "LPAREN": self.eat("LPAREN") if self.current_token[0] == "RPAREN": self.eat("RPAREN") return "()" expr = self.parse_expression() if self.current_token[0] == "COMMA": elements = [expr] while self.current_token[0] == "COMMA": self.eat("COMMA") if self.current_token[0] != "RPAREN": # Handle trailing comma elements.append(self.parse_expression()) self.eat("RPAREN") return TupleNode(elements) else: self.eat("RPAREN") return expr elif self.current_token[0] == "LBRACKET": self.eat("LBRACKET") if self.current_token[0] == "RBRACKET": self.eat("RBRACKET") return ArrayNode([]) first_element = self.parse_expression() if self.current_token[0] == "SEMICOLON": self.eat("SEMICOLON") size = self.parse_expression() self.eat("RBRACKET") return ArrayNode([first_element], size) elements = [first_element] while self.current_token[0] != "RBRACKET": if self.current_token[0] == "COMMA": self.eat("COMMA") if self.current_token[0] != "RBRACKET": elements.append(self.parse_expression()) else: break self.eat("RBRACKET") return ArrayNode(elements) elif self.current_token[0] == "LBRACE": self.eat("LBRACE") statements = [] expression = None while self.current_token[0] != "RBRACE": if self.peek_is_statement(): stmt = self.parse_statement() statements.append(stmt) else: # Final expression (no semicolon) expression = self.parse_expression() break self.eat("RBRACE") return BlockNode(statements, expression) else: raise SyntaxError( f"Unexpected token in primary expression: {self.current_token}" ) def parse_path_expression(self, first_segment): segments = [first_segment] while self.current_token[0] == "DOUBLE_COLON": self.eat("DOUBLE_COLON") if self.current_token[0] == "LESS_THAN": segments[-1] += self.parse_generic_argument_suffix() continue segments.append(self.current_token[1]) self.eat(self.current_token[0]) return "::".join(segments) def finish_path_or_call(self, first_segment): path = self.parse_path_expression(first_segment) if self.current_token[0] == "LPAREN": return FunctionCallNode(path, self.parse_call_arguments()) return path def parse_call_arguments(self): self.eat("LPAREN") args = [] while self.current_token[0] != "RPAREN": args.append(self.parse_expression()) if self.current_token[0] == "COMMA": self.eat("COMMA") else: break self.eat("RPAREN") return args def peek_is_statement(self): return self.current_token[0] in [ "LET", "IF", "MATCH", "FOR", "WHILE", "LOOP", "RETURN", "BREAK", "CONTINUE", ] def parse_statement(self): if self.current_token[0] == "LET": return self.parse_let_statement() elif self.current_token[0] == "IF": return self.parse_if_statement() elif self.current_token[0] == "MATCH": return self.parse_match_statement() elif self.current_token[0] == "FOR": return self.parse_for_loop() elif self.current_token[0] == "WHILE": return self.parse_while_loop() elif self.current_token[0] == "LOOP": return self.parse_loop() elif self.current_token[0] == "RETURN": self.eat("RETURN") value = None if self.current_token[0] != "SEMICOLON": value = self.parse_expression() self.eat("SEMICOLON") return ReturnNode(value) elif self.current_token[0] == "BREAK": self.eat("BREAK") label = None value = None if self.current_token[0] != "SEMICOLON": if self.current_token[0] == "IDENTIFIER": label = self.current_token[1] self.eat("IDENTIFIER") self.eat("SEMICOLON") return BreakNode(label, value) elif self.current_token[0] == "CONTINUE": self.eat("CONTINUE") label = None if self.current_token[0] == "IDENTIFIER": label = self.current_token[1] self.eat("IDENTIFIER") self.eat("SEMICOLON") return ContinueNode(label) else: expr = self.parse_expression() self.eat("SEMICOLON") return expr