Source code for crosstl.backend.slang.SlangParser

"""Parser for Slang source AST construction."""

from .SlangLexer import *
from .SlangAst import *


[docs] class SlangParser: """Parse Slang tokens into the Slang backend AST.""" def __init__(self, tokens): """Initialize the parser with a token stream from ``SlangLexer``.""" self.tokens = tokens self.pos = 0 self.current_token = self.tokens[self.pos] self.skip_comments()
[docs] def skip_comments(self): """Advance past comment tokens before parsing syntax.""" while self.current_token[0] in ["COMMENT_SINGLE", "COMMENT_MULTI"]: self.eat(self.current_token[0])
[docs] def eat(self, token_type): """Consume the current token when it matches ``token_type``.""" if self.current_token[0] == token_type: self.pos += 1 self.current_token = ( self.tokens[self.pos] if self.pos < len(self.tokens) else ("EOF", None) ) self.skip_comments() else: raise SyntaxError(f"Expected {token_type}, got {self.current_token[0]}")
[docs] def parse(self): """Parse the complete Slang token stream into a shader AST.""" shader = self.parse_shader() self.eat("EOF") return shader
[docs] def parse_shader(self): """Parse top-level Slang declarations, functions, and cbuffers.""" imports = [] exports = [] functions = [] structs = [] typedefs = [] cbuffers = [] global_variables = [] while self.current_token[0] != "EOF": if self.current_token[0] == "IMPORT": imports.append(self.parse_import()) elif self.current_token[0] == "EXPORT": exports.append(self.parse_export()) elif self.current_token[0] == "STRUCT": structs.append(self.parse_struct()) elif self.current_token[0] == "CBUFFER": cbuffers.append(self.parse_cbuffer()) elif self.current_token[0] == "TYPEDEF": typedefs.append(self.parse_typedef()) elif self.current_token[0] == "TYPE_SHADER": type_shader = self.current_token[1].split('"')[1] self.eat("TYPE_SHADER") functions.append(self.parse_function(type_shader)) elif self.current_token[0] in [ "VOID", "FLOAT", "FVECTOR", "IDENTIFIER", "GENERIC", "TEXTURE2D", "SAMPLER_STATE", ]: if self.is_function(): functions.append(self.parse_function()) else: global_variables.append(self.parse_global_variable()) else: self.eat(self.current_token[0]) # Skip unknown tokens return ShaderNode( imports, exports, structs, typedefs, functions, global_variables, cbuffers, )
def is_function(self): current_pos = self.pos while self.tokens[current_pos][0] != "EOF": if self.tokens[current_pos][0] == "LPAREN": return True if self.tokens[current_pos][0] == "SEMICOLON": return False current_pos += 1 return False def parse_cbuffer(self): self.eat("CBUFFER") name = self.current_token[1] self.eat("IDENTIFIER") self.eat("LBRACE") members = [] while self.current_token[0] != "RBRACE": vtype = self.current_token[1] self.eat(self.current_token[0]) var_name = self.current_token[1] self.eat("IDENTIFIER") if self.current_token[0] == "LBRACKET": self.eat("LBRACKET") size = self.current_token[1] self.eat("NUMBER") self.eat("RBRACKET") var_name += f"[{size}]" self.eat("SEMICOLON") members.append(VariableNode(vtype, var_name)) self.eat("RBRACE") return StructNode(name, members) def parse_global_variable(self): var_type = self.current_token[1] self.eat(self.current_token[0]) var_name = self.current_token[1] self.eat("IDENTIFIER") self.eat("SEMICOLON") return VariableNode(var_type, var_name) def parse_import(self): self.eat("IMPORT") module_name = self.current_token[1] self.eat("IDENTIFIER") self.eat("SEMICOLON") return ImportNode(module_name) def parse_export(self): self.eat("EXPORT") exported_item = ( self.parse_function() if self.current_token[0] in ["VOID", "FLOAT", "FVECTOR", "IDENTIFIER"] else self.parse_struct() ) return ExportNode(exported_item) def parse_struct(self): self.eat("STRUCT") name = self.current_token[1] self.eat("IDENTIFIER") self.eat("LBRACE") members = [] while self.current_token[0] != "RBRACE": vtype = self.current_token[1] self.eat(self.current_token[0]) # Eat the type (FVECTOR, FLOAT, etc.) var_name = self.current_token[1] self.eat("IDENTIFIER") semantic = None if self.current_token[0] == "COLON": self.eat("COLON") semantic = self.current_token[1] self.eat(self.current_token[0]) self.eat("SEMICOLON") members.append(VariableNode(vtype, var_name, semantic)) self.eat("RBRACE") return StructNode(name, members) def parse_typedef(self): self.eat("TYPEDEF") original_type = self.current_token[1] self.eat(self.current_token[0]) new_type = self.current_token[1] self.eat("IDENTIFIER") self.eat("SEMICOLON") return TypedefNode(original_type, new_type) def parse_function(self, shader_type=None): is_generic = False if self.current_token[0] == "GENERIC": is_generic = True self.eat("GENERIC") return_type = self.current_token[1] self.eat(self.current_token[0]) name = self.current_token[1] self.eat("IDENTIFIER") self.eat("LPAREN") params = self.parse_parameters() self.eat("RPAREN") semantic = None if self.current_token[0] == "COLON": self.eat("COLON") semantic = self.current_token[1] self.eat(self.current_token[0]) body = self.parse_block() return FunctionNode( return_type, name, params, body, is_generic, shader_type, semantic ) def parse_parameters(self): params = [] while self.current_token[0] != "RPAREN": struct_def = " " vtype = self.current_token[1] self.eat(self.current_token[0]) name = self.current_token[1] self.eat("IDENTIFIER") semantic = None if self.current_token[0] == "IDENTIFIER": struct_def = self.current_token[1] self.eat("IDENTIFIER") if self.current_token[0] == "COLON": self.eat("COLON") semantic = self.current_token[1] self.eat(self.current_token[0]) params.append(VariableNode(vtype + struct_def, name, semantic)) if self.current_token[0] == "COMMA": self.eat("COMMA") return params def parse_block(self): statements = [] self.eat("LBRACE") while self.current_token[0] != "RBRACE": statements.append(self.parse_statement()) self.eat("RBRACE") return statements def parse_statement(self): if self.current_token[0] in [ "FLOAT", "FVECTOR", "INT", "UINT", "BOOL", "IDENTIFIER", ]: return self.parse_variable_declaration_or_assignment() elif self.current_token[0] == "IF": return self.parse_if_statement() elif self.current_token[0] == "FOR": return self.parse_for_statement() elif self.current_token[0] == "RETURN": return self.parse_return_statement() else: return self.parse_expression_statement() def parse_variable_declaration_or_assignment(self): if self.current_token[0] in [ "FLOAT", "FVECTOR", "INT", "UINT", "BOOL", "IDENTIFIER", ]: first_token = self.current_token self.eat(self.current_token[0]) if self.current_token[0] == "IDENTIFIER": name = self.current_token[1] self.eat("IDENTIFIER") if self.current_token[0] == "SEMICOLON": self.eat("SEMICOLON") return VariableNode(first_token[1], name) elif self.current_token[0] in [ "EQUALS", "PLUS_EQUALS", "MINUS_EQUALS", "MULTIPLY_EQUALS", "DIVIDE_EQUALS", ]: self.eat(self.current_token[0]) value = self.parse_expression() if self.current_token[0] == "LPAREN": # This handles cases like "float3 test = float3(1.0, 1.0, 1.0);" self.eat("LPAREN") args = [] while self.current_token[0] != "RPAREN": args.append(self.parse_expression()) if self.current_token[0] == "COMMA": self.eat("COMMA") self.eat("RPAREN") self.eat("SEMICOLON") return AssignmentNode( VariableNode(first_token[1], name), VectorConstructorNode(first_token[1], args), ) elif self.current_token[0] == "LBRACKET": self.eat("LBRACKET") index = self.current_token[1] self.eat("IDENTIFIER") self.eat("RBRACKET") if self.current_token[0] == "SEMICOLON": self.eat("SEMICOLON") return AssignmentNode( VariableNode(first_token[1], f"{name}[{index}]"), value ) self.eat("SEMICOLON") return AssignmentNode(VariableNode(first_token[1], name), value) elif self.current_token[0] in [ "EQUALS", "PLUS_EQUALS", "MINUS_EQUALS", "MULTIPLY_EQUALS", "DIVIDE_EQUALS", ]: # This handles cases like "test = float3(1.0, 1.0, 1.0);" self.eat(self.current_token[0]) value = self.parse_expression() self.eat("SEMICOLON") return AssignmentNode(VariableNode("", first_token[1]), value) elif self.current_token[0] == "DOT": left = self.parse_member_access(first_token[1]) if self.current_token[0] in [ "EQUALS", "PLUS_EQUALS", "MINUS_EQUALS", "MULTIPLY_EQUALS", "DIVIDE_EQUALS", ]: self.eat(self.current_token[0]) right = self.parse_expression() self.eat("SEMICOLON") return AssignmentNode(left, right) else: self.eat("SEMICOLON") return left else: if self.current_token[0] in [ "PLUS_EQUALS", "MINUS_EQUALS", "MULTIPLY_EQUALS", "DIVIDE_EQUALS", "EQUAL", ]: op = self.current_token[1] self.eat(self.current_token[0]) expr = self.parse_expression() self.eat("SEMICOLON") return BinaryOpNode(VariableNode("", first_token[1]), op, expr) else: expr = self.parse_expression() self.eat("SEMICOLON") return expr else: expr = self.parse_expression() self.eat("SEMICOLON") return expr def parse_if_statement(self): self.eat("IF") self.eat("LPAREN") condition = self.parse_expression() self.eat("RPAREN") if_body = self.parse_block() else_body = None if self.current_token[0] == "ELSE": self.eat("ELSE") else_body = self.parse_block() elif self.current_token[0] == "ELSE_IF": else_body = self.parse_else_if_statement() return IfNode(condition, if_body, else_body) def parse_else_if_statement(self): self.eat("ELSE_IF") self.eat("LPAREN") condition = self.parse_expression() self.eat("RPAREN") if_body = self.parse_block() else_body = None if self.current_token[0] == "ELSE": self.eat("ELSE") else_body = self.parse_block() elif self.current_token[0] == "ELSE_IF": else_body = self.parse_else_if_statement() return IfNode(condition, if_body, else_body) def parse_for_statement(self): self.eat("FOR") self.eat("LPAREN") if self.current_token[0] in ["INT", "FLOAT", "FVECTOR"]: type_name = self.current_token[1] self.eat(self.current_token[0]) var_name = self.current_token[1] self.eat("IDENTIFIER") self.eat("EQUALS") init_value = self.parse_expression() init = VariableNode(type_name, var_name) init = AssignmentNode(init, init_value) else: init = self.parse_expression() self.eat("SEMICOLON") condition = self.parse_expression() self.eat("SEMICOLON") update = self.parse_expression() self.eat("RPAREN") body = self.parse_block() return ForNode(init, condition, update, body) def parse_return_statement(self): self.eat("RETURN") value = self.parse_expression() self.eat("SEMICOLON") return ReturnNode(value) def parse_expression_statement(self): expr = self.parse_expression() self.eat("SEMICOLON") return expr def parse_expression(self): left = self.parse_logical_or() while self.current_token[0] in [ "EQUALS", "PLUS_EQUALS", "MINUS_EQUALS", "MULTIPLY_EQUALS", "DIVIDE_EQUALS", ]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_logical_or() left = AssignmentNode(left, right, op) if self.current_token[0] == "QUESTION": self.eat("QUESTION") true_expr = self.parse_expression() self.eat("COLON") false_expr = self.parse_expression() left = TernaryOpNode(left, true_expr, false_expr) return left def parse_assignment(self): left = self.parse_logical_or() if self.current_token[0] == "EQUALS": self.eat("EQUALS") right = self.parse_assignment() return AssignmentNode(left, right) return left def parse_logical_or(self): left = self.parse_logical_and() while self.current_token[0] == "OR": op = self.current_token[1] self.eat("OR") right = self.parse_logical_and() left = BinaryOpNode(left, op, right) return left def parse_logical_and(self): left = self.parse_equality() while self.current_token[0] == "AND": op = self.current_token[1] self.eat("AND") right = self.parse_equality() left = BinaryOpNode(left, op, right) return left def parse_equality(self): left = self.parse_relational() while self.current_token[0] in ["EQUAL", "NOT_EQUAL"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_relational() left = BinaryOpNode(left, op, right) return left def parse_relational(self): left = self.parse_additive() while self.current_token[0] in [ "LESS_THAN", "GREATER_THAN", "LESS_EQUAL", "GREATER_EQUAL", ]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_additive() left = BinaryOpNode(left, op, right) return left def parse_additive(self): left = self.parse_multiplicative() while self.current_token[0] in ["PLUS", "MINUS"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_multiplicative() left = BinaryOpNode(left, op, right) return left def parse_multiplicative(self): left = self.parse_unary() while self.current_token[0] in ["MULTIPLY", "DIVIDE", "MOD"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_unary() left = BinaryOpNode(left, op, right) return left def parse_unary(self): if self.current_token[0] in ["PLUS", "MINUS", "BITWISE_NOT"]: op = self.current_token[1] self.eat(self.current_token[0]) operand = self.parse_unary() return UnaryOpNode(op, operand) return self.parse_primary() def parse_primary(self): if self.current_token[0] in [ "IDENTIFIER", "INT", "FLOAT", "FVECTOR", "GENERIC", ]: if self.current_token[0] in ["INT", "FLOAT", "FVECTOR", "GENERIC"]: type_name = self.current_token[1] self.eat(self.current_token[0]) if self.current_token[0] == "IDENTIFIER": var_name = self.current_token[1] self.eat("IDENTIFIER") if self.current_token[0] == "LBRAKET": self.eat("LBRAKET") if self.current_token[0] == "NUMBER": index = self.current_token[1] else: index = self.current_token[1] self.eat("IDENTIFIER") self.eat("RBRAKET") return VariableNode(type_name, f"{var_name}[{index}]") return VariableNode(type_name, var_name) elif self.current_token[0] == "LPAREN": return self.parse_vector_constructor(type_name) return self.parse_function_call_or_identifier() if self.current_token[0] == "LBRAKET": self.eat("LBRAKET") expr = self.parse_expression() self.eat("RBRAKET") return expr elif self.current_token[0] == "NUMBER": value = self.current_token[1] self.eat("NUMBER") return value elif self.current_token[0] == "LPAREN": self.eat("LPAREN") expr = self.parse_expression() self.eat("RPAREN") return expr else: raise SyntaxError( f"Unexpected token in expression: {self.current_token[0]}" ) def parse_vector_constructor(self, type_name): self.eat("LPAREN") args = [] while self.current_token[0] != "RPAREN": args.append(self.parse_expression()) if self.current_token[0] == "COMMA": self.eat("COMMA") self.eat("RPAREN") return VectorConstructorNode(type_name, args) def parse_function_call_or_identifier(self): name = self.current_token[1] self.eat("IDENTIFIER") if self.current_token[0] == "LPAREN": return self.parse_function_call(name) elif self.current_token[0] == "DOT": return self.parse_member_access(name) return VariableNode("", name) def parse_function_call(self, name): self.eat("LPAREN") args = [] while self.current_token[0] != "RPAREN": args.append(self.parse_expression()) if self.current_token[0] == "COMMA": self.eat("COMMA") self.eat("RPAREN") return FunctionCallNode(name, args) def parse_member_access(self, object): self.eat("DOT") if self.current_token[0] != "IDENTIFIER": raise SyntaxError( f"Expected identifier after dot, got {self.current_token[0]}" ) member = self.current_token[1] self.eat("IDENTIFIER") if self.current_token[0] == "DOT": return self.parse_member_access(MemberAccessNode(object, member)) return MemberAccessNode(object, member)