Source code for crosstl.backend.Metal.MetalParser
"""Parser for Metal source AST construction."""
from .MetalLexer import *
from .MetalAst import *
# Token groups for parsing
QUALIFIER_TOKENS = {
"CONSTANT",
"DEVICE",
"THREADGROUP",
"THREADGROUP_IMAGEBLOCK",
"THREAD",
"CONST",
"CONSTEXPR",
"STATIC",
"INLINE",
"VOLATILE",
"RESTRICT",
"READ",
"WRITE",
"READ_WRITE",
}
TYPE_TOKENS = {
"VOID",
"FLOAT",
"HALF",
"DOUBLE",
"INT",
"UINT",
"LONG",
"ULONG",
"SHORT",
"USHORT",
"CHAR",
"UCHAR",
"BOOL",
"SIZE_T",
"PTRDIFF_T",
"INT64_T",
"UINT64_T",
"INT8_T",
"UINT8_T",
"INT16_T",
"UINT16_T",
"INT32_T",
"UINT32_T",
"VECTOR",
"PACKED_VECTOR",
"SIMD_VECTOR",
"MATRIX",
"SIMD_MATRIX",
"ATOMIC_INT",
"ATOMIC_UINT",
"ATOMIC_BOOL",
"TEXTURE1D",
"TEXTURE1D_ARRAY",
"TEXTURE2D",
"TEXTURE2D_MS",
"TEXTURE2D_MS_ARRAY",
"TEXTURE3D",
"TEXTURECUBE",
"TEXTURECUBE_ARRAY",
"TEXTURE2D_ARRAY",
"TEXTUREBUFFER",
"DEPTH2D",
"DEPTH2D_ARRAY",
"DEPTHCUBE",
"DEPTHCUBE_ARRAY",
"DEPTH2D_MS",
"DEPTH2D_MS_ARRAY",
"ACCELERATION_STRUCTURE",
"INTERSECTION_FUNCTION_TABLE",
"VISIBLE_FUNCTION_TABLE",
"INDIRECT_COMMAND_BUFFER",
"SAMPLER",
"IDENTIFIER",
"METAL",
"ENUM",
"TYPEDEF",
}
STAGE_TOKENS = {
"VERTEX",
"FRAGMENT",
"KERNEL",
"INTERSECTION",
"ANYHIT",
"CLOSESTHIT",
"MISS",
"CALLABLE",
"MESH",
"OBJECT",
"AMPLIFICATION",
}
UNARY_KEYWORDS = {"SIZEOF", "ALIGNOF"}
[docs]
class MetalParser:
"""Parse Metal tokens into the Metal backend shader AST."""
def __init__(self, tokens):
"""Initialize the parser with a token stream from ``MetalLexer``."""
self.tokens = tokens
self.pos = 0
self.current_token = (
self.tokens[self.pos] if self.pos < len(self.tokens) else ("EOF", None)
)
self.skip_comments()
self.known_types = {
"void",
"bool",
"char",
"uchar",
"short",
"ushort",
"int",
"uint",
"long",
"ulong",
"float",
"half",
"double",
"size_t",
"ptrdiff_t",
"int64_t",
"uint64_t",
"int8_t",
"uint8_t",
"int16_t",
"uint16_t",
"int32_t",
"uint32_t",
"sampler",
"texture1d",
"texture1d_array",
"texture2d",
"texture2d_array",
"texture2d_ms",
"texture2d_ms_array",
"texture3d",
"texturecube",
"texturecube_array",
"texture_buffer",
"depth2d",
"depth2d_array",
"depth2d_ms",
"depth2d_ms_array",
"depthcube",
"depthcube_array",
"acceleration_structure",
"intersection_function_table",
"visible_function_table",
"indirect_command_buffer",
"atomic_int",
"atomic_uint",
"atomic_bool",
"enum",
"ray",
"ray_data",
"intersection_result",
"intersection_params",
"triangle_intersection_params",
"intersector",
"packed_float2",
"packed_float3",
"packed_float4",
"packed_half2",
"packed_half3",
"packed_half4",
"packed_int2",
"packed_int3",
"packed_int4",
"packed_uint2",
"packed_uint3",
"packed_uint4",
"simd_float2",
"simd_float3",
"simd_float4",
"simd_float2x2",
"simd_float3x3",
"simd_float4x4",
"simd_int2",
"simd_int3",
"simd_int4",
"simd_uint2",
"simd_uint3",
"simd_uint4",
}
[docs]
def skip_comments(self):
"""Advance past comment tokens before parsing syntax."""
while self.pos < len(self.tokens) and self.current_token[0] in [
"COMMENT_SINGLE",
"COMMENT_MULTI",
]:
self.pos += 1
self.current_token = (
self.tokens[self.pos] if self.pos < len(self.tokens) else ("EOF", None)
)
[docs]
def eat(self, token_type):
"""Consume the current token when it matches ``token_type``."""
if self.current_token[0] == token_type:
self.pos += 1
self.current_token = (
self.tokens[self.pos] if self.pos < len(self.tokens) else ("EOF", None)
)
self.skip_comments()
else:
raise SyntaxError(f"Expected {token_type}, got {self.current_token[0]}")
[docs]
def peek(self, offset=1):
"""Return a lookahead token without advancing the parser."""
idx = self.pos + offset
if idx < len(self.tokens):
return self.tokens[idx]
return ("EOF", None)
[docs]
def parse(self):
"""Parse the complete token stream into a shader node."""
shader = self.parse_shader()
self.eat("EOF")
return shader
[docs]
def parse_shader(self):
"""Parse top-level Metal declarations, functions, and preprocessor nodes."""
functions = []
preprocessors = []
structs = []
enums = []
typedefs = []
constants = []
global_variables = []
while self.current_token[0] != "EOF":
if self.current_token[0] == "PREPROCESSOR":
directive = self.parse_preprocessor_directive()
if directive is not None:
preprocessors.append(directive)
elif self.current_token[0] == "USING":
alias = self.parse_using_statement()
if alias is not None:
typedefs.append(alias)
elif self.current_token[0] == "STRUCT":
structs.append(self.parse_struct())
elif self.current_token[0] == "ALIGNAS":
alignas_specs = self.parse_alignas_specifiers()
if self.current_token[0] == "STRUCT":
structs.append(self.parse_struct(alignas_specs))
else:
# Treat as global variable with alignas
global_variables.append(
self.parse_global_variable(pre_alignas=alignas_specs)
)
elif self.current_token[0] == "ENUM":
enums.append(self.parse_enum())
elif self.current_token[0] == "TYPEDEF":
typedefs.append(self.parse_typedef())
elif self.current_token[0] == "STATIC_ASSERT":
global_variables.append(self.parse_static_assert())
elif self.current_token[0] == "CONSTANT":
if self.is_constant_buffer():
constants.append(self.parse_constant_buffer())
else:
global_variables.append(self.parse_global_variable())
elif self.current_token[0] in STAGE_TOKENS or (
self.current_token[0] in TYPE_TOKENS
or self.current_token[0] in QUALIFIER_TOKENS
):
if self.is_function_definition():
functions.append(self.parse_function())
else:
global_variables.append(self.parse_global_variable())
else:
self.eat(self.current_token[0]) # Skip unknown tokens
return ShaderNode(
includes=preprocessors,
functions=functions,
structs=structs,
global_variables=global_variables,
constant=constants,
enums=enums,
typedefs=typedefs,
)
def is_constant_buffer(self):
if self.current_token[0] != "CONSTANT":
return False
next_tok = self.peek(1)
next_next = self.peek(2)
return next_tok[0] == "IDENTIFIER" and next_next[0] == "LBRACE"
def is_function_definition(self):
idx = self.pos
while idx < len(self.tokens) and self.tokens[idx][0] == "ATTRIBUTE":
idx += 1
while idx < len(self.tokens) and self.tokens[idx][0] in QUALIFIER_TOKENS:
idx += 1
if idx >= len(self.tokens):
return False
tok_type = self.tokens[idx][0]
if tok_type in STAGE_TOKENS:
idx += 1
while idx < len(self.tokens) and self.tokens[idx][0] == "ATTRIBUTE":
idx += 1
while idx < len(self.tokens) and self.tokens[idx][0] in QUALIFIER_TOKENS:
idx += 1
if idx >= len(self.tokens):
return False
tok_type = self.tokens[idx][0]
if tok_type not in TYPE_TOKENS:
return False
idx += 1
if idx < len(self.tokens) and self.tokens[idx][0] == "LESS_THAN":
depth = 0
while idx < len(self.tokens):
if self.tokens[idx][0] == "LESS_THAN":
depth += 1
elif self.tokens[idx][0] == "GREATER_THAN":
depth -= 1
if depth == 0:
idx += 1
break
idx += 1
while idx < len(self.tokens) and self.tokens[idx][0] in [
"MULTIPLY",
"BITWISE_AND",
]:
idx += 1
if idx >= len(self.tokens) or self.tokens[idx][0] != "IDENTIFIER":
return False
idx += 1
return idx < len(self.tokens) and self.tokens[idx][0] == "LPAREN"
def parse_preprocessor_directive(self):
text = self.current_token[1] or ""
self.eat("PREPROCESSOR")
stripped = text.lstrip("#").strip()
if stripped:
parts = stripped.split(None, 1)
directive = f"#{parts[0]}"
content = parts[1] if len(parts) > 1 else ""
return PreprocessorNode(directive, content)
directive = text
content = ""
if self.current_token[0] == "LESS_THAN":
self.eat("LESS_THAN")
parts = []
while self.current_token[0] != "GREATER_THAN":
parts.append(self.current_token[1])
self.eat(self.current_token[0])
self.eat("GREATER_THAN")
content = "<" + "".join(parts) + ">"
elif self.current_token[0] == "STRING":
content = self.current_token[1]
self.eat("STRING")
elif self.current_token[0] not in ["EOF", "PREPROCESSOR"]:
content = str(self.current_token[1])
self.eat(self.current_token[0])
return PreprocessorNode(directive, content)
def parse_using_statement(self):
self.eat("USING")
if self.current_token[0] == "NAMESPACE":
self.eat("NAMESPACE")
self.eat("METAL")
self.eat("SEMICOLON")
return None
alias_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("EQUALS")
alias_type, _qualifiers = self.parse_type_specifier()
self.eat("SEMICOLON")
self.known_types.add(alias_name)
return TypeAliasNode(alias_type, alias_name)
def parse_enum(self):
self.eat("ENUM")
if self.current_token[0] == "CLASS":
self.eat("CLASS")
name = self.current_token[1]
self.eat("IDENTIFIER")
self.known_types.add(name)
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE":
member_name = self.current_token[1]
self.eat("IDENTIFIER")
member_value = None
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
member_value = self.parse_expression()
members.append((member_name, member_value))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
elif self.current_token[0] == "RBRACE":
break
else:
raise SyntaxError(
f"Expected comma or closing brace in enum, got {self.current_token[0]}"
)
self.eat("RBRACE")
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return EnumNode(name, members)
def parse_typedef(self):
self.eat("TYPEDEF")
alias_type, _qualifiers = self.parse_type_specifier()
alias_name, _array = self.parse_declarator()
self.eat("SEMICOLON")
self.known_types.add(alias_name)
return TypeAliasNode(alias_type, alias_name)
def parse_static_assert(self):
self.eat("STATIC_ASSERT")
self.eat("LPAREN")
condition = self.parse_expression()
message = None
if self.current_token[0] == "COMMA":
self.eat("COMMA")
if self.current_token[0] == "STRING":
message = self.current_token[1]
self.eat("STRING")
else:
message = self.parse_expression()
self.eat("RPAREN")
self.eat("SEMICOLON")
return StaticAssertNode(condition, message)
def parse_global_variable(self, pre_alignas=None):
attributes = self.parse_attributes()
alignas_specs = pre_alignas or self.parse_alignas_specifiers()
vtype, qualifiers = self.parse_type_specifier()
name, array_sizes = self.parse_declarator()
var_attributes = self.parse_attributes()
attributes.extend(var_attributes)
var_node = VariableNode(
vtype, name, qualifiers=qualifiers, attributes=attributes
)
var_node.array_sizes = array_sizes
var_node.alignas = alignas_specs
if "const" in qualifiers or "constexpr" in qualifiers:
var_node.is_const = True
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
value = self.parse_expression()
self.eat("SEMICOLON")
return AssignmentNode(var_node, value)
self.eat("SEMICOLON")
return var_node
def parse_type_specifier(self):
qualifiers = []
while self.current_token[0] in QUALIFIER_TOKENS:
qualifiers.append(self.current_token[1])
self.eat(self.current_token[0])
if self.current_token[0] not in TYPE_TOKENS:
raise SyntaxError(f"Expected type, got {self.current_token[0]}")
if self.current_token[0] == "METAL" or (
self.current_token[0] == "IDENTIFIER" and self.peek(1)[0] == "SCOPE"
):
base_type = self.parse_scoped_identifier()
else:
base_type = self.current_token[1]
self.eat(self.current_token[0])
if self.current_token[0] == "LESS_THAN":
depth = 0
inner = []
self.eat("LESS_THAN")
depth += 1
while depth > 0 and self.current_token[0] != "EOF":
if self.current_token[0] == "LESS_THAN":
depth += 1
inner.append(self.current_token[1])
self.eat("LESS_THAN")
elif self.current_token[0] == "GREATER_THAN":
depth -= 1
if depth == 0:
self.eat("GREATER_THAN")
break
inner.append(self.current_token[1])
self.eat("GREATER_THAN")
else:
inner.append(self.current_token[1])
self.eat(self.current_token[0])
base_type = f"{base_type}<{''.join(inner)}>"
pointer_suffix = ""
while self.current_token[0] in ["MULTIPLY", "BITWISE_AND"]:
pointer_suffix += "*" if self.current_token[0] == "MULTIPLY" else "&"
self.eat(self.current_token[0])
return base_type + pointer_suffix, qualifiers
def parse_alignas_specifiers(self):
specs = []
while self.current_token[0] == "ALIGNAS":
self.eat("ALIGNAS")
self.eat("LPAREN")
if self.is_type_start():
type_name, _quals = self.parse_type_specifier()
specs.append(("type", type_name))
else:
expr = self.parse_expression()
specs.append(expr)
self.eat("RPAREN")
return specs
def is_type_start(self):
if self.current_token[0] in QUALIFIER_TOKENS:
return True
if self.current_token[0] in TYPE_TOKENS:
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
if name in self.known_types:
return True
next_tok = self.peek(1)[0]
if next_tok == "SCOPE":
return True
return False
return True
return False
def parse_declarator(self):
name = ""
array_sizes = []
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
while self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
size = None
if self.current_token[0] != "RBRACKET":
size = self.parse_expression()
self.eat("RBRACKET")
array_sizes.append(size)
return name, array_sizes
def parse_struct(self, pre_alignas=None):
alignas_specs = pre_alignas or self.parse_alignas_specifiers()
self.eat("STRUCT")
name = self.current_token[1]
self.eat("IDENTIFIER")
self.known_types.add(name)
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE":
member_alignas = self.parse_alignas_specifiers()
vtype, qualifiers = self.parse_type_specifier()
var_name, array_sizes = self.parse_declarator()
attributes = self.parse_attributes()
self.eat("SEMICOLON")
var_node = VariableNode(
vtype, var_name, qualifiers=qualifiers, attributes=attributes
)
var_node.array_sizes = array_sizes
var_node.alignas = member_alignas
members.append(var_node)
self.eat("RBRACE")
self.eat("SEMICOLON")
struct_node = StructNode(name, members)
struct_node.alignas = alignas_specs
return struct_node
def parse_constant_buffer(self):
self.eat("CONSTANT")
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE":
member_alignas = self.parse_alignas_specifiers()
vtype, qualifiers = self.parse_type_specifier()
var_name, array_sizes = self.parse_declarator()
self.eat("SEMICOLON")
var_node = VariableNode(vtype, var_name, qualifiers=qualifiers)
var_node.array_sizes = array_sizes
var_node.alignas = member_alignas
members.append(var_node)
self.eat("RBRACE")
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return ConstantBufferNode(name, members)
def parse_function(self):
attributes = self.parse_attributes()
qualifier = None
if self.current_token[0] in STAGE_TOKENS:
qualifier = self.current_token[1]
self.eat(self.current_token[0])
return_type, _return_qualifiers = self.parse_type_specifier()
if self.current_token[0] in STAGE_TOKENS:
if qualifier is None:
qualifier = self.current_token[1]
self.eat(self.current_token[0])
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("LPAREN")
params = self.parse_parameters()
self.eat("RPAREN")
post_attributes = self.parse_attributes()
attributes.extend(post_attributes)
body = self.parse_block()
return FunctionNode(
return_type=return_type,
name=name,
params=params,
body=body,
qualifiers=[qualifier] if qualifier else [],
attributes=attributes,
qualifier=qualifier, # Also store as single qualifier for backward compatibility
)
def parse_parameters(self):
params = []
while self.current_token[0] != "RPAREN":
attributes = self.parse_attributes()
vtype, qualifiers = self.parse_type_specifier()
name, array_sizes = self.parse_declarator()
param_attributes = self.parse_attributes()
attributes.extend(param_attributes)
var_node = VariableNode(
vtype, name, qualifiers=qualifiers, attributes=attributes
)
var_node.array_sizes = array_sizes
params.append(var_node)
if self.current_token[0] == "COMMA":
self.eat("COMMA")
elif self.current_token[0] == "RPAREN":
break
else:
raise SyntaxError(
f"Expected comma or closing parenthesis, got {self.current_token[0]}"
)
return params
def parse_attributes(self):
attributes = []
while self.current_token[0] == "ATTRIBUTE":
attr_content = self.current_token[1][2:-2].strip() # Remove [[ and ]]
def split_top_level(text):
parts = []
buf = ""
depth = 0
for ch in text:
if ch == "(":
depth += 1
elif ch == ")":
depth = max(0, depth - 1)
if ch == "," and depth == 0:
if buf.strip():
parts.append(buf.strip())
buf = ""
continue
buf += ch
if buf.strip():
parts.append(buf.strip())
return parts
for part in split_top_level(attr_content):
name = part
args = []
if "(" in part and part.endswith(")"):
name, arg_str = part.split("(", 1)
arg_str = arg_str[:-1] # remove trailing )
args = [arg.strip() for arg in split_top_level(arg_str)]
name = name.strip()
attributes.append(AttributeNode(name.strip(), args))
self.eat("ATTRIBUTE")
return attributes
def parse_block(self):
statements = []
self.eat("LBRACE")
while self.current_token[0] != "RBRACE":
statements.append(self.parse_statement())
self.eat("RBRACE")
return statements
def is_declaration_start(self):
if self.current_token[0] == "ALIGNAS":
return True
if self.current_token[0] in QUALIFIER_TOKENS:
return True
if self.current_token[0] in TYPE_TOKENS:
if self.current_token[0] == "IDENTIFIER":
next_tok = self.peek(1)[0]
if next_tok in [
"IDENTIFIER",
"SCOPE",
"LESS_THAN",
"MULTIPLY",
"BITWISE_AND",
]:
return True
return self.current_token[1] in self.known_types
return True
return False
def parse_statement(self):
if self.is_declaration_start():
return self.parse_variable_declaration_or_assignment()
elif self.current_token[0] == "IF":
return self.parse_if_statement()
elif self.current_token[0] == "FOR":
return self.parse_for_statement()
elif self.current_token[0] == "WHILE":
return self.parse_while_statement()
elif self.current_token[0] == "DO":
return self.parse_do_while_statement()
elif self.current_token[0] == "SWITCH":
return self.parse_switch_statement()
elif self.current_token[0] == "RETURN":
return self.parse_return_statement()
elif self.current_token[0] == "BREAK":
self.eat("BREAK")
self.eat("SEMICOLON")
return BreakNode()
elif self.current_token[0] == "CONTINUE":
self.eat("CONTINUE")
self.eat("SEMICOLON")
return ContinueNode()
elif self.current_token[0] == "DISCARD":
self.eat("DISCARD")
self.eat("SEMICOLON")
return DiscardNode()
elif self.current_token[0] == "STATIC_ASSERT":
return self.parse_static_assert()
else:
return self.parse_expression_statement()
def parse_variable_declaration_or_assignment(self):
alignas_specs = self.parse_alignas_specifiers()
vtype, qualifiers = self.parse_type_specifier()
name, array_sizes = self.parse_declarator()
attributes = self.parse_attributes()
var_node = VariableNode(
vtype, name, qualifiers=qualifiers, attributes=attributes
)
var_node.array_sizes = array_sizes
var_node.alignas = alignas_specs
if "const" in qualifiers or "constexpr" in qualifiers:
var_node.is_const = True
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return var_node
if self.current_token[0] in [
"EQUALS",
"PLUS_EQUALS",
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"ASSIGN_MOD",
"ASSIGN_AND",
"ASSIGN_OR",
"ASSIGN_XOR",
"ASSIGN_SHIFT_LEFT",
"ASSIGN_SHIFT_RIGHT",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
value = self.parse_expression()
self.eat("SEMICOLON")
return AssignmentNode(var_node, value, op)
expr = self.parse_expression()
self.eat("SEMICOLON")
return expr
def parse_if_statement(self):
if_chain = []
else_if_chain = []
else_body = None
while self.current_token[0] == "IF":
self.eat("IF")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
body = self.parse_block()
if_chain.append((condition, body))
while self.current_token[0] == "ELSE_IF":
self.eat("ELSE_IF")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
body = self.parse_block()
else_if_chain.append((condition, body))
if self.current_token[0] == "ELSE":
self.eat("ELSE")
else_body = self.parse_block()
return IfNode(
if_chain=if_chain, else_if_chain=else_if_chain, else_body=else_body
)
def parse_for_statement(self):
self.eat("FOR")
self.eat("LPAREN")
init = None
if self.current_token[0] != "SEMICOLON":
init = self.parse_for_init()
self.eat("SEMICOLON")
condition = None
if self.current_token[0] != "SEMICOLON":
condition = self.parse_expression()
self.eat("SEMICOLON")
update = None
if self.current_token[0] != "RPAREN":
update = self.parse_expression()
self.eat("RPAREN")
body = self.parse_block()
return ForNode(init, condition, update, body)
def parse_for_init(self):
if self.is_declaration_start():
vtype, qualifiers = self.parse_type_specifier()
name, array_sizes = self.parse_declarator()
var_node = VariableNode(vtype, name, qualifiers=qualifiers)
var_node.array_sizes = array_sizes
if "const" in qualifiers or "constexpr" in qualifiers:
var_node.is_const = True
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
init_value = self.parse_expression()
return AssignmentNode(var_node, init_value)
return var_node
return self.parse_expression()
def parse_return_statement(self):
self.eat("RETURN")
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return ReturnNode(None)
value = self.parse_expression()
self.eat("SEMICOLON")
return ReturnNode(value)
def parse_while_statement(self):
self.eat("WHILE")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
body = self.parse_block()
return WhileNode(condition, body)
def parse_do_while_statement(self):
self.eat("DO")
body = self.parse_block()
self.eat("WHILE")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
self.eat("SEMICOLON")
return DoWhileNode(body, condition)
def parse_expression_statement(self):
expr = self.parse_expression()
self.eat("SEMICOLON")
return expr
def parse_expression(self):
return self.parse_assignment()
def parse_assignment(self):
left = self.parse_conditional()
if self.current_token[0] in [
"EQUALS",
"PLUS_EQUALS",
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"ASSIGN_MOD",
"ASSIGN_AND",
"ASSIGN_OR",
"ASSIGN_XOR",
"ASSIGN_SHIFT_LEFT",
"ASSIGN_SHIFT_RIGHT",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_assignment()
return AssignmentNode(left, right, op)
return left
def parse_conditional(self):
left = self.parse_logical_or()
if self.current_token[0] == "QUESTION":
self.eat("QUESTION")
true_expr = self.parse_expression()
self.eat("COLON")
false_expr = self.parse_expression()
return TernaryOpNode(left, true_expr, false_expr)
return left
def parse_logical_or(self):
left = self.parse_logical_and()
while self.current_token[0] == "OR":
op = self.current_token[1]
self.eat("OR")
right = self.parse_logical_and()
left = BinaryOpNode(left, op, right)
return left
def parse_logical_and(self):
left = self.parse_bitwise_or()
while self.current_token[0] == "AND":
op = self.current_token[1]
self.eat("AND")
right = self.parse_bitwise_or()
left = BinaryOpNode(left, op, right)
return left
def parse_bitwise_or(self):
left = self.parse_bitwise_xor()
while self.current_token[0] == "BITWISE_OR":
op = self.current_token[1]
self.eat("BITWISE_OR")
right = self.parse_bitwise_xor()
left = BinaryOpNode(left, op, right)
return left
def parse_bitwise_xor(self):
left = self.parse_bitwise_and()
while self.current_token[0] == "BITWISE_XOR":
op = self.current_token[1]
self.eat("BITWISE_XOR")
right = self.parse_bitwise_and()
left = BinaryOpNode(left, op, right)
return left
def parse_bitwise_and(self):
left = self.parse_equality()
while self.current_token[0] == "BITWISE_AND":
op = self.current_token[1]
self.eat("BITWISE_AND")
right = self.parse_equality()
left = BinaryOpNode(left, op, right)
return left
def parse_equality(self):
left = self.parse_relational()
while self.current_token[0] in ["EQUAL", "NOT_EQUAL"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_relational()
left = BinaryOpNode(left, op, right)
return left
def parse_relational(self):
left = self.parse_shift()
while self.current_token[0] in [
"LESS_THAN",
"GREATER_THAN",
"LESS_EQUAL",
"GREATER_EQUAL",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_shift()
left = BinaryOpNode(left, op, right)
return left
def parse_shift(self):
left = self.parse_additive()
while self.current_token[0] in ["SHIFT_LEFT", "SHIFT_RIGHT"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_additive()
left = BinaryOpNode(left, op, right)
return left
def parse_additive(self):
left = self.parse_multiplicative()
while self.current_token[0] in ["PLUS", "MINUS"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_multiplicative()
left = BinaryOpNode(left, op, right)
return left
def parse_multiplicative(self):
left = self.parse_unary()
while self.current_token[0] in ["MULTIPLY", "DIVIDE", "MOD"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_unary()
left = BinaryOpNode(left, op, right)
return left
def parse_unary(self):
if self.current_token[0] in UNARY_KEYWORDS:
op = self.current_token[1]
self.eat(self.current_token[0])
if self.current_token[0] == "LPAREN" and self.is_type_in_parens():
self.eat("LPAREN")
type_name, _quals = self.parse_type_specifier()
self.eat("RPAREN")
return FunctionCallNode(op, [type_name])
operand = self.parse_unary()
return FunctionCallNode(op, [operand])
if self.current_token[0] == "LPAREN" and self.is_type_in_parens():
self.eat("LPAREN")
type_name, _quals = self.parse_type_specifier()
self.eat("RPAREN")
operand = self.parse_unary()
return CastNode(type_name, operand)
if self.current_token[0] in [
"PLUS",
"MINUS",
"NOT",
"BITWISE_NOT",
"INCREMENT",
"DECREMENT",
"MULTIPLY",
"BITWISE_AND",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
operand = self.parse_unary()
return UnaryOpNode(op, operand)
return self.parse_postfix()
def is_type_in_parens(self):
if self.current_token[0] != "LPAREN":
return False
idx = self.pos + 1
while idx < len(self.tokens) and self.tokens[idx][0] in QUALIFIER_TOKENS:
idx += 1
if idx >= len(self.tokens):
return False
tok_type = self.tokens[idx][0]
if tok_type not in TYPE_TOKENS:
return False
idx += 1
if idx < len(self.tokens) and self.tokens[idx][0] == "LESS_THAN":
depth = 0
while idx < len(self.tokens):
if self.tokens[idx][0] == "LESS_THAN":
depth += 1
elif self.tokens[idx][0] == "GREATER_THAN":
depth -= 1
if depth == 0:
idx += 1
break
idx += 1
while idx < len(self.tokens) and self.tokens[idx][0] in [
"MULTIPLY",
"BITWISE_AND",
]:
idx += 1
return idx < len(self.tokens) and self.tokens[idx][0] == "RPAREN"
def parse_postfix(self):
node = self.parse_primary()
while True:
if self.current_token[0] == "LPAREN":
node = self.parse_call(node)
continue
if self.current_token[0] == "DOT":
self.eat("DOT")
if self.current_token[0] not in [
"IDENTIFIER",
"READ",
"WRITE",
"READ_WRITE",
]:
raise SyntaxError(
f"Expected identifier after dot, got {self.current_token[0]}"
)
member = self.current_token[1]
self.eat(self.current_token[0])
node = MemberAccessNode(node, member)
continue
if self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
index = None
if self.current_token[0] != "RBRACKET":
index = self.parse_expression()
self.eat("RBRACKET")
node = ArrayAccessNode(node, index)
continue
if self.current_token[0] in ["INCREMENT", "DECREMENT"]:
op = self.current_token[1]
self.eat(self.current_token[0])
node = PostfixOpNode(node, op)
continue
break
return node
def parse_primary(self):
if self.current_token[0] == "NUMBER":
value = self.current_token[1]
self.eat("NUMBER")
return value
if self.current_token[0] in ["TRUE", "FALSE"]:
value = self.current_token[1]
self.eat(self.current_token[0])
return value
if self.current_token[0] == "LPAREN":
self.eat("LPAREN")
expr = self.parse_expression()
self.eat("RPAREN")
return expr
if self.current_token[0] in [
"VECTOR",
"MATRIX",
"SIMD_MATRIX",
"PACKED_VECTOR",
"SIMD_VECTOR",
"FLOAT",
"HALF",
"DOUBLE",
"INT",
"UINT",
"BOOL",
]:
type_name = self.current_token[1]
self.eat(self.current_token[0])
if self.current_token[0] == "LPAREN":
return self.parse_vector_constructor(type_name)
raise SyntaxError(f"Unexpected type in expression: {type_name}")
if self.current_token[0] in ["IDENTIFIER", "METAL"]:
name = self.parse_scoped_identifier()
return VariableNode("", name)
raise SyntaxError(f"Unexpected token in expression: {self.current_token[0]}")
def parse_scoped_identifier(self):
parts = []
if self.current_token[0] == "METAL":
parts.append("metal")
self.eat("METAL")
else:
parts.append(self.current_token[1])
self.eat("IDENTIFIER")
while self.current_token[0] == "SCOPE":
self.eat("SCOPE")
if (
self.current_token[0] not in TYPE_TOKENS
and self.current_token[0] != "METAL"
):
raise SyntaxError(
f"Expected identifier after '::', got {self.current_token[0]}"
)
if self.current_token[0] == "METAL":
parts.append("metal")
self.eat("METAL")
else:
parts.append(self.current_token[1])
self.eat(self.current_token[0])
return "::".join(parts)
def parse_vector_constructor(self, type_name):
self.eat("LPAREN")
args = []
while self.current_token[0] != "RPAREN":
args.append(self.parse_expression())
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("RPAREN")
return VectorConstructorNode(type_name, args)
def parse_call(self, callee):
self.eat("LPAREN")
args = []
while self.current_token[0] != "RPAREN":
args.append(self.parse_expression())
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("RPAREN")
if isinstance(callee, MemberAccessNode):
if callee.member == "sample":
return self.build_texture_sample(callee.object, args)
return MethodCallNode(callee.object, callee.member, args)
if isinstance(callee, VariableNode):
return FunctionCallNode(callee.name, args)
return CallNode(callee, args)
def build_texture_sample(self, texture, args):
sampler = args[0] if len(args) > 0 else None
coords = args[1] if len(args) > 1 else None
lod = args[2] if len(args) > 2 else None
if lod is not None:
return TextureSampleNode(texture, sampler, coords, lod)
return TextureSampleNode(texture, sampler, coords)
def parse_texture_sample_args(self, texture):
self.eat("LPAREN")
sampler = self.parse_expression()
self.eat("COMMA")
coordinates = self.parse_expression()
# Support for optional LOD parameter
lod = None
if self.current_token[0] == "COMMA":
self.eat("COMMA")
lod = self.parse_expression()
self.eat("RPAREN")
if lod is not None:
return TextureSampleNode(texture, sampler, coordinates, lod)
return TextureSampleNode(texture, sampler, coordinates)
def parse_texture_sample(self):
texture = self.parse_expression()
self.eat("DOT")
self.eat("IDENTIFIER") # 'sample' method
self.eat("LPAREN")
sampler = self.parse_expression()
self.eat("COMMA")
coordinates = self.parse_expression()
self.eat("RPAREN")
return TextureSampleNode(texture, sampler, coordinates)
def parse_switch_statement(self):
self.eat("SWITCH")
self.eat("LPAREN")
expression = self.parse_expression()
self.eat("RPAREN")
self.eat("LBRACE")
cases = []
default = None
while self.current_token[0] not in ["RBRACE", "EOF"]:
if self.current_token[0] == "CASE":
cases.append(self.parse_case_statement())
elif self.current_token[0] == "DEFAULT":
self.eat("DEFAULT")
self.eat("COLON")
default_statements = []
# Parse statements until next case/default or end of switch
while self.current_token[0] not in ["CASE", "DEFAULT", "RBRACE", "EOF"]:
if self.current_token[0] == "BREAK":
self.eat("BREAK")
self.eat("SEMICOLON")
break
else:
default_statements.append(self.parse_statement())
default = default_statements
else:
raise SyntaxError(
f"Unexpected token in switch statement: {self.current_token[0]}"
)
self.eat("RBRACE")
return SwitchNode(expression, cases, default)
def parse_case_statement(self):
self.eat("CASE")
value = self.parse_expression()
self.eat("COLON")
statements = []
# Parse statements until next case/default/break or end of switch
while self.current_token[0] not in ["CASE", "DEFAULT", "RBRACE", "EOF"]:
if self.current_token[0] == "BREAK":
self.eat("BREAK")
self.eat("SEMICOLON")
break
else:
statements.append(self.parse_statement())
return CaseNode(value, statements)