"""Parser that builds CrossGL AST nodes from lexer tokens."""
from .ast import (
ASTNode,
TypeNode,
PrimitiveType,
VectorType,
MatrixType,
ArrayType,
PointerType,
ReferenceType,
FunctionType,
GenericType,
NamedType,
ShaderNode,
StageNode,
ImportNode,
PreprocessorNode,
StructNode,
StructMemberNode,
EnumNode,
EnumVariantNode,
FunctionNode,
ParameterNode,
VariableNode,
ConstantNode,
GenericParameterNode,
AttributeNode,
StatementNode,
BlockNode,
ExpressionStatementNode,
AssignmentNode,
IfNode,
ForNode,
ForInNode,
WhileNode,
LoopNode,
MatchNode,
MatchArmNode,
SwitchNode,
CaseNode,
ReturnNode,
BreakNode,
ContinueNode,
ExpressionNode,
LiteralNode,
IdentifierNode,
RangeNode,
BinaryOpNode,
UnaryOpNode,
TernaryOpNode,
FunctionCallNode,
MemberAccessNode,
PointerAccessNode,
ArrayAccessNode,
ArrayLiteralNode,
SwizzleNode,
CastNode,
ConstructorNode,
LambdaNode,
PatternNode,
WildcardPatternNode,
IdentifierPatternNode,
LiteralPatternNode,
StructPatternNode,
TextureNode,
TextureOpNode,
AtomicOpNode,
SyncNode,
BuiltinVariableNode,
BufferNode,
TextureResourceNode,
SamplerNode,
BufferOpNode,
WaveOpNode,
RayTracingOpNode,
RayQueryOpNode,
MeshOpNode,
ArrayNode,
ShaderStage,
ExecutionModel,
create_legacy_shader_node,
)
from .lexer import Lexer
from .validation import validate_shader_cbuffers
import logging
WAVE_INTRINSICS = {
"WaveGetLaneCount",
"WaveGetLaneIndex",
"WaveIsFirstLane",
"WaveActiveSum",
"WaveActiveProduct",
"WaveActiveBitAnd",
"WaveActiveBitOr",
"WaveActiveBitXor",
"WaveActiveMin",
"WaveActiveMax",
"WaveActiveAllTrue",
"WaveActiveAnyTrue",
"WaveActiveBallot",
"WaveReadLaneAt",
"WaveReadLaneFirst",
"WavePrefixSum",
"WavePrefixProduct",
"QuadReadAcrossX",
"QuadReadAcrossY",
"QuadReadAcrossDiagonal",
"QuadReadLaneAt",
"WaveMatch",
"WaveMultiPrefixSum",
"WaveMultiPrefixProduct",
"WaveMultiPrefixBitAnd",
"WaveMultiPrefixBitOr",
"WaveMultiPrefixBitXor",
}
RAYTRACING_INTRINSICS = {
"TraceRay",
"ReportHit",
"CallShader",
"AcceptHitAndEndSearch",
"IgnoreHit",
}
MESH_INTRINSICS = {
"SetMeshOutputCounts",
"DispatchMesh",
}
RAYQUERY_METHODS = {
"Proceed",
"Abort",
"CandidateType",
"CommittedType",
"CandidatePrimitiveIndex",
"CommittedPrimitiveIndex",
"CandidateInstanceID",
"CommittedInstanceID",
"CandidateGeometryIndex",
"CommittedGeometryIndex",
"CandidateObjectRayOrigin",
"CandidateObjectRayDirection",
"CommittedObjectRayOrigin",
"CommittedObjectRayDirection",
"CommittedRayT",
"CandidateRayT",
}
[docs]
class Parser:
"""Recursive-descent parser for CrossGL Universal IR tokens."""
def __init__(self, tokens):
"""Initialize parser state from a token sequence."""
self.tokens = tokens
self.pos = 0
self.current_token = (
self.tokens[self.pos] if self.pos < len(self.tokens) else ("EOF", None)
)
[docs]
def eat(self, token_type):
"""Consume one token of the expected type or raise ``SyntaxError``."""
if self.current_token[0] == token_type:
self.pos += 1
self.current_token = (
self.tokens[self.pos] if self.pos < len(self.tokens) else ("EOF", None)
)
self.skip_comments()
else:
raise SyntaxError(
f"Expected {token_type}, got {self.current_token[0]} '{self.current_token[1]}'"
)
[docs]
def peek(self, offset=1):
"""Return a lookahead token without advancing the parser."""
peek_pos = self.pos + offset
if peek_pos < len(self.tokens):
return self.tokens[peek_pos]
return ("EOF", None)
[docs]
def finalize_shader(self, shader):
"""Run final validation hooks before returning a shader AST."""
return validate_shader_cbuffers(shader)
[docs]
def parse(self):
"""Parse a complete CrossGL translation unit into a ``ShaderNode``."""
structs = []
functions = []
global_variables = []
constants = []
cbuffers = []
stages = {}
imports = []
preprocessors = []
loop_count = 0
max_loops = 10000
while self.current_token[0] != "EOF" and loop_count < max_loops:
loop_count += 1
parsed_element = self.parse_global()
if parsed_element:
if isinstance(parsed_element, StructNode):
if getattr(parsed_element, "is_cbuffer", False):
cbuffers.append(parsed_element)
else:
structs.append(parsed_element)
elif isinstance(parsed_element, FunctionNode):
functions.append(parsed_element)
elif isinstance(parsed_element, VariableNode):
global_variables.append(parsed_element)
elif isinstance(parsed_element, ConstantNode):
constants.append(parsed_element)
elif isinstance(parsed_element, StageNode):
stages[parsed_element.stage] = parsed_element
elif isinstance(parsed_element, ImportNode):
imports.append(parsed_element)
elif isinstance(parsed_element, PreprocessorNode):
preprocessors.append(parsed_element)
elif isinstance(parsed_element, EnumNode):
structs.append(parsed_element)
# Safety check: if we're not advancing, break
if loop_count > 100:
current_token_info = (
f"{self.current_token[0]}:{self.current_token[1]}"
if self.current_token[1]
else self.current_token[0]
)
print(
f"Warning: Parser may be stuck at token {current_token_info}, breaking..."
)
break
if loop_count >= max_loops:
print(f"Warning: Parser hit maximum loop limit ({max_loops}), stopping...")
shader = ShaderNode(
name="main",
execution_model=ExecutionModel.GRAPHICS_PIPELINE,
stages=stages,
structs=structs,
functions=functions,
global_variables=global_variables,
constants=constants,
imports=imports,
preprocessors=preprocessors,
)
if cbuffers:
shader.cbuffers = cbuffers
return self.finalize_shader(shader)
[docs]
def parse_program(self):
"""Parse the explicit program form used by legacy callers."""
imports = []
structs = []
enums = []
functions = []
constants = []
global_variables = []
cbuffers = []
shader_node = None
while self.current_token[0] != "EOF":
if self.current_token[0] in ["IMPORT", "USE"]:
imports.append(self.parse_import())
elif self.current_token[0] == "SHADER":
shader_node = self.parse_shader_declaration()
elif self.current_token[0] == "STRUCT":
structs.append(self.parse_struct())
elif self.current_token[0] == "ENUM":
enums.append(self.parse_enum())
elif self.is_cbuffer_declaration():
cbuffers.append(self.parse_cbuffer_as_struct())
elif self.current_token[0] == "CONST":
constants.append(self.parse_constant())
elif self.is_function_declaration():
functions.append(self.parse_function())
elif self.is_variable_declaration():
global_variables.append(self.parse_variable_declaration())
elif self.current_token[0] in [
"VERTEX",
"FRAGMENT",
"COMPUTE",
"GEOMETRY",
"TESSELLATION_CONTROL",
"TESSELLATION_EVALUATION",
"TASK",
"MESH",
"RAY_GENERATION",
"RAY_INTERSECTION",
"RAY_CLOSEST_HIT",
"RAY_MISS",
"RAY_ANY_HIT",
"RAY_CALLABLE",
]:
stage_func = self.parse_shader_stage()
functions.append(stage_func)
else:
self.skip_unknown_token()
if shader_node:
shader_node.structs.extend(structs)
shader_node.functions.extend(functions)
shader_node.global_variables.extend(global_variables)
shader_node.constants.extend(constants)
if cbuffers:
shader_node.cbuffers = getattr(shader_node, "cbuffers", []) + cbuffers
shader_node.imports.extend(imports)
return self.finalize_shader(shader_node)
else:
shader = ShaderNode(
name="main",
execution_model=ExecutionModel.GRAPHICS_PIPELINE,
stages={},
structs=structs,
functions=functions,
global_variables=global_variables,
constants=constants,
imports=imports,
)
if cbuffers:
shader.cbuffers = cbuffers
return self.finalize_shader(shader)
[docs]
def parse_shader_declaration(self):
"""Parse a named ``shader`` block and its contained declarations."""
self.eat("SHADER")
name = self.current_token[1]
self.eat("IDENTIFIER")
execution_model = ExecutionModel.GRAPHICS_PIPELINE
stages = {}
functions = []
structs = []
global_variables = []
constants = []
cbuffers = []
self.eat("LBRACE")
while self.current_token[0] != "RBRACE":
if self.current_token[0] in [
"VERTEX",
"FRAGMENT",
"COMPUTE",
"GEOMETRY",
"TESSELLATION_CONTROL",
"TESSELLATION_EVALUATION",
"TASK",
"MESH",
"RAY_GENERATION",
"RAY_INTERSECTION",
"RAY_CLOSEST_HIT",
"RAY_MISS",
"RAY_ANY_HIT",
"RAY_CALLABLE",
]:
stage_node = self.parse_shader_stage_block()
stages[stage_node.stage] = stage_node
elif self.current_token[0] == "STRUCT":
structs.append(self.parse_struct())
elif self.is_cbuffer_declaration():
cbuffers.append(self.parse_cbuffer_as_struct())
elif self.current_token[0] == "CONST":
constants.append(self.parse_constant())
elif self.is_function_declaration():
func = self.parse_function()
functions.append(func)
elif self.is_variable_declaration():
global_variables.append(self.parse_variable_declaration())
else:
self.skip_unknown_token()
self.eat("RBRACE")
shader = ShaderNode(
name=name,
execution_model=execution_model,
stages=stages,
structs=structs,
functions=functions,
global_variables=global_variables,
constants=constants,
)
if cbuffers:
shader.cbuffers = cbuffers
return self.finalize_shader(shader)
[docs]
def parse_shader_stage_block(self):
"""Parse a stage-qualified block into a ``StageNode``."""
stage_type = self.current_token[1]
stage_enum = {
"vertex": ShaderStage.VERTEX,
"fragment": ShaderStage.FRAGMENT,
"compute": ShaderStage.COMPUTE,
"geometry": ShaderStage.GEOMETRY,
"tessellation_control": ShaderStage.TESSELLATION_CONTROL,
"tessellation_evaluation": ShaderStage.TESSELLATION_EVALUATION,
"task": ShaderStage.TASK,
"amplification": ShaderStage.AMPLIFICATION,
"object": ShaderStage.OBJECT,
"mesh": ShaderStage.MESH,
"ray_generation": ShaderStage.RAY_GENERATION,
"ray_intersection": ShaderStage.RAY_INTERSECTION,
"ray_closest_hit": ShaderStage.RAY_CLOSEST_HIT,
"ray_miss": ShaderStage.RAY_MISS,
"ray_any_hit": ShaderStage.RAY_ANY_HIT,
"ray_callable": ShaderStage.RAY_CALLABLE,
"intersection": ShaderStage.RAY_INTERSECTION,
"closesthit": ShaderStage.RAY_CLOSEST_HIT,
"anyhit": ShaderStage.RAY_ANY_HIT,
"miss": ShaderStage.RAY_MISS,
"callable": ShaderStage.RAY_CALLABLE,
}.get(stage_type, ShaderStage.VERTEX)
self.eat(self.current_token[0])
stage_name = None
if self.current_token[0] == "IDENTIFIER":
stage_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("LBRACE")
local_variables = []
local_functions = []
main_function = None
while self.current_token[0] != "RBRACE":
if self.is_function_declaration():
func = self.parse_function()
if func.name == "main":
main_function = func
else:
local_functions.append(func)
elif self.is_variable_declaration():
local_variables.append(self.parse_variable_declaration())
else:
self.skip_unknown_token()
self.eat("RBRACE")
if not main_function:
main_function = FunctionNode(
name="main",
return_type=PrimitiveType("void"),
parameters=[],
body=BlockNode([]),
)
if stage_name:
main_function.name = stage_name
return StageNode(
stage=stage_enum,
entry_point=main_function,
local_variables=local_variables,
local_functions=local_functions,
)
[docs]
def parse_import(self):
"""Parse an ``import`` or ``use`` declaration."""
if self.current_token[0] == "IMPORT":
self.eat("IMPORT")
else:
self.eat("USE")
path = self.current_token[1]
self.eat("IDENTIFIER")
alias = None
items = None
if self.current_token[0] == "AS":
self.eat("AS")
alias = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
return ImportNode(path=path, alias=alias, items=items)
[docs]
def parse_preprocessor_directive(self):
"""Parse a preprocessor token into a structured directive node."""
text = self.current_token[1] or ""
self.eat("PREPROCESSOR")
stripped = text.lstrip("#").strip()
if not stripped:
return PreprocessorNode("", "")
parts = stripped.split(None, 1)
directive = parts[0]
content = parts[1] if len(parts) > 1 else ""
return PreprocessorNode(directive, content)
[docs]
def parse_precision_statement(self):
"""Parse a GLSL-style precision statement as a preprocessor node."""
self.eat("PRECISION")
parts = []
while self.current_token[0] not in ["SEMICOLON", "EOF"]:
parts.append(str(self.current_token[1]))
self.eat(self.current_token[0])
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
content = " ".join(parts).strip()
return PreprocessorNode("precision", content)
[docs]
def parse_struct(self):
"""Parse a struct declaration and its member list."""
if self.current_token[0] == "EOF":
return None
self.eat("STRUCT")
if self.current_token[0] != "IDENTIFIER":
return None
name = self.current_token[1]
self.eat("IDENTIFIER")
generic_params = []
if self.current_token[0] == "LESS_THAN":
generic_params = self.parse_generic_parameters()
if self.current_token[0] != "LBRACE":
# Skip malformed struct
while self.current_token[0] not in ["SEMICOLON", "EOF"]:
self.skip_unknown_token()
if self.current_token[0] == "SEMICOLON":
self.skip_unknown_token()
return None
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
member = self.parse_struct_member()
if member:
members.append(member)
if self.current_token[0] == "EOF":
return None
self.eat("RBRACE")
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return StructNode(name=name, members=members, generic_params=generic_params)
[docs]
def parse_struct_member(self):
"""Parse one struct member declaration."""
if self.current_token[0] == "EOF":
return None
if self.current_token[0] == "PREPROCESSOR":
return self.parse_preprocessor_directive()
if self.current_token[0] == "PRECISION":
return self.parse_precision_statement()
if self.current_token[0] == "ENUM":
return self.parse_enum()
if self.current_token[0] == "IDENTIFIER" and self.current_token[1] in [
"generic",
"trait",
]:
# Skip to semicolon or brace
while self.current_token[0] not in ["SEMICOLON", "RBRACE", "EOF"]:
self.skip_unknown_token()
if self.current_token[0] == "SEMICOLON":
self.skip_unknown_token()
return None
member_type = self.parse_type()
if self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
size = None
if self.current_token[0] != "RBRACKET":
size = self.parse_expression()
self.eat("RBRACKET")
member_type = ArrayType(member_type, size)
if self.current_token[0] != "IDENTIFIER":
# Skip malformed member
while self.current_token[0] not in ["SEMICOLON", "RBRACE", "EOF"]:
self.skip_unknown_token()
if self.current_token[0] == "SEMICOLON":
self.skip_unknown_token()
return None
name = self.current_token[1]
self.eat("IDENTIFIER")
while self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
size = None
if self.current_token[0] != "RBRACKET":
size = self.parse_expression()
self.eat("RBRACKET")
member_type = ArrayType(member_type, size)
attributes = []
if self.current_token[0] in ["AT", "ATTRIBUTE"]:
attributes = self.parse_attributes()
if self.current_token[0] != "SEMICOLON":
# Skip malformed member
while self.current_token[0] not in ["SEMICOLON", "RBRACE", "EOF"]:
self.skip_unknown_token()
if self.current_token[0] == "SEMICOLON":
self.skip_unknown_token()
return None
self.eat("SEMICOLON")
return StructMemberNode(
name=name, member_type=member_type, attributes=attributes
)
[docs]
def parse_enum(self):
"""Parse an enum declaration and its variants."""
self.eat("ENUM")
name = self.current_token[1]
self.eat("IDENTIFIER")
underlying_type = None
if self.current_token[0] == "COLON":
self.eat("COLON")
underlying_type = self.parse_type()
self.eat("LBRACE")
variants = []
while self.current_token[0] != "RBRACE":
variant = self.parse_enum_variant()
variants.append(variant)
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("RBRACE")
return EnumNode(name=name, variants=variants, underlying_type=underlying_type)
[docs]
def parse_enum_variant(self):
"""Parse one enum variant, including tuple or struct payloads."""
name = self.current_token[1]
self.eat("IDENTIFIER")
value = None
variant_data = None
if self.current_token[0] == "LPAREN":
self.eat("LPAREN")
# Parse variant data/parameters
variant_params = []
while self.current_token[0] != "RPAREN":
if self.current_token[0] == "IDENTIFIER":
# Type parameter
param_type = self.parse_type()
variant_params.append(param_type)
else:
# Parse as expression
param_expr = self.parse_expression()
variant_params.append(param_expr)
if self.current_token[0] == "COMMA":
self.eat("COMMA")
elif self.current_token[0] != "RPAREN":
break
self.eat("RPAREN")
variant_data = variant_params
elif self.current_token[0] == "EQUALS":
self.eat("EQUALS")
value = self.parse_expression()
elif self.current_token[0] == "LBRACE":
self.eat("LBRACE")
struct_members = []
while self.current_token[0] != "RBRACE":
if self.current_token[0] == "IDENTIFIER":
member_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("COLON")
member_type = self.parse_type()
struct_members.append((member_name, member_type))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
elif self.current_token[0] != "RBRACE":
self.skip_unknown_token()
else:
break
self.eat("RBRACE")
variant_data = struct_members
variant_node = EnumVariantNode(name=name, value=value)
if variant_data:
variant_node.data = variant_data
return variant_node
[docs]
def parse_function(self):
"""Parse a function declaration or definition."""
qualifiers = []
attributes = []
while self.current_token[0] in [
"ASYNC",
"UNSAFE",
"GLOBAL",
"KERNEL",
"ATTRIBUTE",
]:
if self.current_token[0] == "ATTRIBUTE":
attributes = self.parse_attributes()
else:
qualifiers.append(self.current_token[1])
self.eat(self.current_token[0])
if self.current_token[0] == "FUNCTION":
self.eat("FUNCTION")
return_type = self.parse_type()
name = self.current_token[1]
self.eat("IDENTIFIER")
generic_params = []
if self.current_token[0] == "LESS_THAN":
generic_params = self.parse_generic_parameters()
self.eat("LPAREN")
parameters = self.parse_parameter_list()
self.eat("RPAREN")
post_attributes = []
if self.current_token[0] in ["AT", "ATTRIBUTE"]:
post_attributes = self.parse_attributes()
body = None
if self.current_token[0] == "LBRACE":
body = self.parse_block()
else:
self.eat("SEMICOLON")
return FunctionNode(
name=name,
return_type=return_type,
parameters=parameters,
body=body,
generic_params=generic_params,
attributes=attributes + post_attributes,
qualifiers=qualifiers,
is_async="async" in qualifiers,
is_unsafe="unsafe" in qualifiers,
)
[docs]
def parse_parameter_list(self):
"""Parse a comma-separated function parameter list."""
parameters = []
while self.current_token[0] != "RPAREN":
param = self.parse_parameter()
parameters.append(param)
if self.current_token[0] == "COMMA":
self.eat("COMMA")
elif self.current_token[0] != "RPAREN":
break
return parameters
[docs]
def parse_parameter(self):
"""Parse one function parameter declaration."""
attributes = []
if self.current_token[0] == "ATTRIBUTE":
attributes = self.parse_attributes()
is_mutable = False
if self.current_token[0] == "MUT":
is_mutable = True
self.eat("MUT")
param_type = self.parse_type()
name = self.current_token[1]
self.eat("IDENTIFIER")
while self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
size = None
if self.current_token[0] != "RBRACKET":
size = self.parse_expression()
self.eat("RBRACKET")
param_type = ArrayType(param_type, size)
if self.current_token[0] in ["AT", "ATTRIBUTE"]:
attributes.extend(self.parse_attributes())
default_value = None
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
default_value = self.parse_expression()
return ParameterNode(
name=name,
param_type=param_type,
default_value=default_value,
attributes=attributes,
is_mutable=is_mutable,
)
[docs]
def parse_variable_declaration(self):
"""Parse a variable declaration, including qualifiers and attributes."""
attributes = []
if self.current_token[0] in ["AT", "ATTRIBUTE"]:
attributes = self.parse_attributes()
qualifiers = []
while self.current_token[0] in [
"CONST",
"STATIC",
"MUT",
"SHARED",
"UNIFORM",
"BUFFER",
]:
qualifiers.append(self.current_token[1])
self.eat(self.current_token[0])
var_type = self.parse_type()
name = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] in ["AT", "ATTRIBUTE"]:
attributes.extend(self.parse_attributes())
while self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
size = None
if self.current_token[0] != "RBRACKET":
size = self.parse_expression()
self.eat("RBRACKET")
var_type = ArrayType(var_type, size)
initial_value = None
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
initial_value = self.parse_expression()
self.eat("SEMICOLON")
return VariableNode(
name=name,
var_type=var_type,
initial_value=initial_value,
qualifiers=qualifiers,
attributes=attributes,
is_mutable="const" not in qualifiers,
)
[docs]
def is_variable_declaration(self):
"""
Lookahead check for variable declarations.
Handles complex cases and distinguishes from function calls and member access.
"""
saved_pos = self.pos
saved_token = self.current_token
try:
if self.current_token[0] in ["AT", "ATTRIBUTE"]:
self.parse_attributes()
while self.current_token[0] in [
"CONST",
"STATIC",
"MUT",
"SHARED",
"UNIFORM",
"BUFFER",
]:
self.eat(self.current_token[0])
if not self.is_type_token():
return False
self.advance_over_type()
if self.current_token[0] != "IDENTIFIER":
return False
self.current_token[1]
self.eat("IDENTIFIER")
next_token = self.current_token[0]
if next_token == "LBRACKET":
return True
if next_token == "EQUALS":
return True
if next_token == "SEMICOLON":
return True
if next_token == "COMMA":
return True
if next_token in ["RPAREN", "COMMA"] and self.in_parameter_context():
return True
if next_token in ["DOT", "LPAREN"]:
return False
if next_token in [
"PLUS",
"MINUS",
"MULTIPLY",
"DIVIDE",
"MOD",
"EQUAL",
"NOT_EQUAL",
"LESS_THAN",
"GREATER_THAN",
"LOGICAL_AND",
"LOGICAL_OR",
"BITWISE_AND",
"BITWISE_OR",
"ASSIGN_ADD",
"ASSIGN_SUB",
"ASSIGN_MUL",
"ASSIGN_DIV",
"ASSIGN_MOD",
"ASSIGN_XOR",
"ASSIGN_OR",
"ASSIGN_AND",
"ASSIGN_SHIFT_LEFT",
"ASSIGN_SHIFT_RIGHT",
]:
return False
return True
except Exception:
return False
finally:
self.pos = saved_pos
self.current_token = saved_token
[docs]
def in_parameter_context(self):
"""Check if we're currently parsing function parameters."""
for i in range(max(0, self.pos - 10), self.pos):
if i < len(self.tokens) and self.tokens[i][0] == "LPAREN":
return True
return False
[docs]
def parse_constant(self):
"""Parse a constant declaration."""
self.eat("CONST")
const_type = self.parse_type()
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("EQUALS")
value = self.parse_expression()
self.eat("SEMICOLON")
return ConstantNode(name=name, const_type=const_type, value=value)
[docs]
def parse_type(self):
"""Parse a CrossGL type expression into a ``TypeNode``."""
is_buffer = False
if self.current_token[0] == "BUFFER":
is_buffer = True
self.eat("BUFFER")
base_type = None
if self.current_token[0] in [
"BOOL",
"I8",
"I16",
"I32",
"I64",
"U8",
"U16",
"U32",
"U64",
"F16",
"F32",
"F64",
"INT",
"UINT",
"FLOAT",
"DOUBLE",
"HALF",
"CHAR",
"STRING",
"VOID",
]:
type_name = self.current_token[1]
self.eat(self.current_token[0])
base_type = PrimitiveType(type_name)
elif self.current_token[0] in [
"VEC2",
"VEC3",
"VEC4",
"IVEC2",
"IVEC3",
"IVEC4",
"UVEC2",
"UVEC3",
"UVEC4",
"DVEC2",
"DVEC3",
"DVEC4",
"BVEC2",
"BVEC3",
"BVEC4",
]:
vec_type = self.current_token[1]
self.eat(self.current_token[0])
generic_args = []
if self.current_token[0] == "LESS_THAN":
generic_args = self.parse_generic_arguments()
if generic_args:
element_type = self.vector_element_type_from_generic(generic_args[0])
size = int(vec_type[-1])
elif vec_type.startswith("ivec"):
element_type = PrimitiveType("int")
size = int(vec_type[-1])
elif vec_type.startswith("uvec"):
element_type = PrimitiveType("uint")
size = int(vec_type[-1])
elif vec_type.startswith("dvec"):
element_type = PrimitiveType("double")
size = int(vec_type[-1])
elif vec_type.startswith("bvec"):
element_type = PrimitiveType("bool")
size = int(vec_type[-1])
else: # vec
element_type = PrimitiveType("float")
size = int(vec_type[-1])
base_type = VectorType(element_type, size)
elif self.current_token[0] in [
"MAT2",
"MAT3",
"MAT4",
"MAT2X2",
"MAT2X3",
"MAT2X4",
"MAT3X2",
"MAT3X3",
"MAT3X4",
"MAT4X2",
"MAT4X3",
"MAT4X4",
"DMAT2",
"DMAT3",
"DMAT4",
"DMAT2X2",
"DMAT2X3",
"DMAT2X4",
"DMAT3X2",
"DMAT3X3",
"DMAT3X4",
"DMAT4X2",
"DMAT4X3",
"DMAT4X4",
]:
mat_type = self.current_token[1]
self.eat(self.current_token[0])
is_double_matrix = mat_type.startswith("dmat")
dimensions = mat_type[4:] if is_double_matrix else mat_type[3:]
if "x" in dimensions:
rows, cols = map(int, dimensions.split("x"))
else:
size = int(dimensions)
rows = cols = size
element_type = "double" if is_double_matrix else "float"
base_type = MatrixType(PrimitiveType(element_type), rows, cols)
elif self.current_token[0] in [
"SAMPLER",
"SAMPLER1D",
"SAMPLER2D",
"SAMPLER3D",
"SAMPLERCUBE",
"SAMPLER2DARRAY",
"SAMPLER2DSHADOW",
"SAMPLER2DARRAYSHADOW",
"SAMPLERCUBESHADOW",
"SAMPLERCUBEARRAY",
"SAMPLERCUBEARRAYSHADOW",
"SAMPLER2DMS",
"SAMPLER2DMSARRAY",
"IIMAGE2D",
"IIMAGE3D",
"IIMAGE2DARRAY",
"IIMAGE2DMS",
"IIMAGE2DMSARRAY",
"UIMAGE2D",
"UIMAGE3D",
"UIMAGE2DARRAY",
"UIMAGE2DMS",
"UIMAGE2DMSARRAY",
"IMAGE2D",
"IMAGE3D",
"IMAGECUBE",
"IMAGE2DARRAY",
"IMAGE2DMS",
"IMAGE2DMSARRAY",
]:
sampler_types = {
"SAMPLER": "sampler",
"SAMPLER1D": "sampler1D",
"SAMPLER2D": "sampler2D",
"SAMPLER3D": "sampler3D",
"SAMPLERCUBE": "samplerCube",
"SAMPLER2DARRAY": "sampler2DArray",
"SAMPLER2DSHADOW": "sampler2DShadow",
"SAMPLER2DARRAYSHADOW": "sampler2DArrayShadow",
"SAMPLERCUBESHADOW": "samplerCubeShadow",
"SAMPLERCUBEARRAY": "samplerCubeArray",
"SAMPLERCUBEARRAYSHADOW": "samplerCubeArrayShadow",
"SAMPLER2DMS": "sampler2DMS",
"SAMPLER2DMSARRAY": "sampler2DMSArray",
"IIMAGE2D": "iimage2D",
"IIMAGE3D": "iimage3D",
"IIMAGE2DARRAY": "iimage2DArray",
"IIMAGE2DMS": "iimage2DMS",
"IIMAGE2DMSARRAY": "iimage2DMSArray",
"UIMAGE2D": "uimage2D",
"UIMAGE3D": "uimage3D",
"UIMAGE2DARRAY": "uimage2DArray",
"UIMAGE2DMS": "uimage2DMS",
"UIMAGE2DMSARRAY": "uimage2DMSArray",
"IMAGE2D": "image2D",
"IMAGE3D": "image3D",
"IMAGECUBE": "imageCube",
"IMAGE2DARRAY": "image2DArray",
"IMAGE2DMS": "image2DMS",
"IMAGE2DMSARRAY": "image2DMSArray",
}
token_type = self.current_token[0]
self.eat(token_type)
base_type = NamedType(sampler_types[token_type])
elif self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
generic_args = []
if self.current_token[0] == "LESS_THAN":
generic_args = self.parse_generic_arguments()
base_type = NamedType(name, generic_args)
elif self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
element_type = self.parse_type()
size = None
if self.current_token[0] != "RBRACKET":
size = self.parse_expression()
self.eat("RBRACKET")
base_type = ArrayType(element_type, size)
else:
base_type = PrimitiveType("float")
while self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
size = None
if self.current_token[0] != "RBRACKET":
size = self.parse_expression()
self.eat("RBRACKET")
base_type = ArrayType(base_type, size)
if self.current_token[0] == "MULTIPLY":
self.eat("MULTIPLY")
# pointer handling deferred: PointerType not yet used by generators
if is_buffer:
if hasattr(base_type, "qualifiers"):
base_type.qualifiers = getattr(base_type, "qualifiers", []) + ["buffer"]
else:
base_type = NamedType(f"buffer_{base_type}", [])
return base_type
[docs]
def parse_generic_parameters(self):
"""Parse generic parameter declarations after ``<``."""
self.eat("LESS_THAN")
params = []
while self.current_token[0] != "GREATER_THAN":
name = self.current_token[1]
self.eat("IDENTIFIER")
constraints = []
if self.current_token[0] == "COLON":
self.eat("COLON")
constraints.append(self.parse_type())
while self.current_token[0] == "PLUS":
self.eat("PLUS")
constraints.append(self.parse_type())
params.append(GenericParameterNode(name=name, constraints=constraints))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("GREATER_THAN")
return params
[docs]
def parse_generic_arguments(self):
"""Parse generic type arguments after ``<``."""
self.eat("LESS_THAN")
args = []
while self.current_token[0] != "GREATER_THAN":
arg_type = self.parse_type()
args.append(arg_type)
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("GREATER_THAN")
return args
[docs]
def vector_element_type_from_generic(self, type_node):
"""Resolve a vector generic argument to a primitive element type."""
type_name = self.format_type_argument(type_node)
aliases = {
"f32": "float",
"float": "float",
"f64": "double",
"double": "double",
"i32": "int",
"int": "int",
"u32": "uint",
"uint": "uint",
"bool": "bool",
}
return PrimitiveType(aliases.get(type_name, type_name))
[docs]
def parse_attributes(self):
"""Parse one or more ``@`` attribute annotations."""
attributes = []
while self.current_token[0] in ["AT", "ATTRIBUTE"]:
if self.current_token[0] == "ATTRIBUTE":
name = self.current_token[1][1:]
self.eat("ATTRIBUTE")
else:
self.eat("AT")
name = self.current_token[1]
self.eat("IDENTIFIER")
arguments = []
if self.current_token[0] == "LPAREN":
self.eat("LPAREN")
while self.current_token[0] != "RPAREN":
arguments.append(self.parse_expression())
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("RPAREN")
attributes.append(AttributeNode(name=name, arguments=arguments))
return attributes
[docs]
def parse_block(self):
"""Parse a braced statement block."""
self.eat("LBRACE")
statements = []
while self.current_token[0] != "RBRACE":
stmt = self.parse_statement()
if stmt:
statements.append(stmt)
self.eat("RBRACE")
return BlockNode(statements)
[docs]
def parse_statement(self):
"""Parse any statement form supported by CrossGL."""
if self.current_token[0] == "IF":
return self.parse_if_statement()
elif self.current_token[0] == "FOR":
return self.parse_for_statement()
elif self.current_token[0] == "WHILE":
return self.parse_while_statement()
elif self.current_token[0] == "LOOP":
return self.parse_loop_statement()
elif self.current_token[0] == "MATCH":
return self.parse_match_statement()
elif self.current_token[0] == "SWITCH":
return self.parse_switch_statement()
elif self.current_token[0] == "RETURN":
return self.parse_return_statement()
elif self.current_token[0] == "LET":
return self.parse_let_declaration()
elif self.current_token[0] == "BREAK":
self.eat("BREAK")
self.eat("SEMICOLON")
return BreakNode()
elif self.current_token[0] == "CONTINUE":
self.eat("CONTINUE")
self.eat("SEMICOLON")
return ContinueNode()
elif self.current_token[0] == "LBRACE":
return self.parse_block()
elif self.is_variable_declaration():
return self.parse_variable_declaration()
else:
expr = self.parse_expression()
self.eat("SEMICOLON")
return ExpressionStatementNode(expr)
[docs]
def parse_let_declaration(self):
"""Parse Rust-style ``let [mut] name [: type] = expr;`` declarations."""
self.eat("LET")
is_mutable = False
if self.current_token[0] == "MUT":
is_mutable = True
self.eat("MUT")
name = self.current_token[1]
self.eat("IDENTIFIER")
var_type = None
if self.current_token[0] == "COLON":
self.eat("COLON")
var_type = self.parse_type()
self.eat("EQUALS")
initial_value = self.parse_expression()
self.eat("SEMICOLON")
var_node = VariableNode(
name=name, var_type=var_type, initial_value=initial_value
)
if is_mutable:
var_node.qualifiers = getattr(var_node, "qualifiers", []) + ["mut"]
var_node.is_let_declaration = True
return var_node
[docs]
def parse_if_statement(self):
"""Parse an if/else statement chain."""
self.eat("IF")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
then_branch = self.parse_statement()
else_branch = None
if self.current_token[0] == "ELSE":
self.eat("ELSE")
if self.current_token[0] == "IF":
else_branch = self.parse_if_statement()
else:
else_branch = self.parse_statement()
return IfNode(
condition=condition, then_branch=then_branch, else_branch=else_branch
)
[docs]
def parse_for_statement(self):
"""Parse a C-style ``for`` loop or dispatch to ``for-in`` parsing."""
self.eat("FOR")
if self.current_token[0] != "LPAREN":
return self.parse_for_in_statement_after_for()
self.eat("LPAREN")
init = None
if self.current_token[0] != "SEMICOLON":
if self.is_variable_declaration():
init = self.parse_for_loop_variable_declaration()
else:
init = ExpressionStatementNode(self.parse_expression())
self.eat("SEMICOLON")
else:
self.eat("SEMICOLON")
condition = None
if self.current_token[0] != "SEMICOLON":
condition = self.parse_expression()
self.eat("SEMICOLON")
update = None
if self.current_token[0] != "RPAREN":
update = self.parse_expression()
self.eat("RPAREN")
body = self.parse_statement()
return ForNode(init=init, condition=condition, update=update, body=body)
[docs]
def parse_for_in_statement_after_for(self):
"""Parse a ``for pattern in iterable`` loop after ``for`` is consumed."""
pattern = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("IN")
iterable = self.parse_expression()
body = self.parse_statement()
return ForInNode(pattern=pattern, iterable=iterable, body=body)
[docs]
def parse_for_loop_variable_declaration(self):
"""Parse variable declarations in for loops (without consuming semicolon)."""
qualifiers = []
while self.current_token[0] in [
"CONST",
"STATIC",
"MUT",
"SHARED",
"UNIFORM",
"BUFFER",
]:
qualifiers.append(self.current_token[1])
self.eat(self.current_token[0])
var_type = self.parse_type()
name = self.current_token[1]
self.eat("IDENTIFIER")
while self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
size = None
if self.current_token[0] != "RBRACKET":
size = self.parse_expression()
self.eat("RBRACKET")
var_type = ArrayType(var_type, size)
initial_value = None
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
initial_value = self.parse_expression()
# Don't consume semicolon - that's handled by the for loop parser
return VariableNode(
name=name,
var_type=var_type,
initial_value=initial_value,
qualifiers=qualifiers,
is_mutable="const" not in qualifiers,
)
[docs]
def parse_while_statement(self):
"""Parse a while loop statement."""
self.eat("WHILE")
condition = self.parse_expression()
body = self.parse_statement()
return WhileNode(condition=condition, body=body)
[docs]
def parse_loop_statement(self):
"""Parse an unconditional loop statement."""
self.eat("LOOP")
body = self.parse_statement()
return LoopNode(body=body)
[docs]
def parse_match_statement(self):
"""Parse a match statement and all of its arms."""
self.eat("MATCH")
expression = self.parse_expression()
self.eat("LBRACE")
arms = []
while self.current_token[0] != "RBRACE":
arm = self.parse_match_arm()
arms.append(arm)
self.eat("RBRACE")
return MatchNode(expression=expression, arms=arms)
[docs]
def parse_match_arm(self):
"""Parse a single match arm."""
pattern = self.parse_pattern()
guard = None
if self.current_token[0] == "IF":
self.eat("IF")
guard = self.parse_expression()
self.eat("FAT_ARROW")
body = self.parse_statement()
if self.current_token[0] == "COMMA":
self.eat("COMMA")
return MatchArmNode(pattern=pattern, guard=guard, body=body)
[docs]
def parse_pattern(self):
"""Parse a match pattern."""
if self.current_token[0] == "IDENTIFIER" and self.current_token[1] == "_":
self.eat("IDENTIFIER")
return WildcardPatternNode()
elif self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
return IdentifierPatternNode(name)
else:
literal = self.parse_literal()
return LiteralPatternNode(literal)
[docs]
def parse_return_statement(self):
"""Parse a return statement with an optional value."""
self.eat("RETURN")
value = None
if self.current_token[0] != "SEMICOLON":
value = self.parse_expression()
self.eat("SEMICOLON")
return ReturnNode(value=value)
[docs]
def parse_expression(self):
"""Parse an expression using the highest-precedence entry point."""
return self.parse_assignment_expression()
[docs]
def parse_range_expression(self):
"""Parse range and inclusive-range expressions."""
left = self.parse_ternary_expression()
if self.current_token[0] in ["RANGE", "RANGE_INCLUSIVE"]:
inclusive = self.current_token[0] == "RANGE_INCLUSIVE"
self.eat(self.current_token[0])
right = self.parse_ternary_expression()
return RangeNode(left, right, inclusive=inclusive)
return left
[docs]
def parse_assignment_expression(self):
"""Parse assignment and compound-assignment expressions."""
left = self.parse_range_expression()
if self.current_token[0] in [
"EQUALS",
"ASSIGN_ADD",
"ASSIGN_SUB",
"ASSIGN_MUL",
"ASSIGN_DIV",
"ASSIGN_MOD",
"ASSIGN_XOR",
"ASSIGN_OR",
"ASSIGN_AND",
"ASSIGN_SHIFT_LEFT",
"ASSIGN_SHIFT_RIGHT",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_assignment_expression()
return AssignmentNode(left, right, op)
return left
[docs]
def parse_ternary_expression(self):
"""Parse a ternary conditional expression."""
condition = self.parse_logical_or_expression()
if self.current_token[0] == "QUESTION":
self.eat("QUESTION")
true_expr = self.parse_expression()
self.eat("COLON")
false_expr = self.parse_ternary_expression()
return TernaryOpNode(condition, true_expr, false_expr)
return condition
[docs]
def parse_logical_or_expression(self):
"""Parse logical OR expressions."""
left = self.parse_logical_and_expression()
while self.current_token[0] == "LOGICAL_OR":
op = self.current_token[1]
self.eat("LOGICAL_OR")
right = self.parse_logical_and_expression()
left = BinaryOpNode(left, op, right)
return left
[docs]
def parse_logical_and_expression(self):
"""Parse logical AND expressions."""
left = self.parse_bitwise_or_expression()
while self.current_token[0] == "LOGICAL_AND":
op = self.current_token[1]
self.eat("LOGICAL_AND")
right = self.parse_bitwise_or_expression()
left = BinaryOpNode(left, op, right)
return left
[docs]
def parse_bitwise_or_expression(self):
"""Parse bitwise OR expressions."""
left = self.parse_bitwise_xor_expression()
while self.current_token[0] == "BITWISE_OR":
op = self.current_token[1]
self.eat("BITWISE_OR")
right = self.parse_bitwise_xor_expression()
left = BinaryOpNode(left, op, right)
return left
[docs]
def parse_bitwise_xor_expression(self):
"""Parse bitwise XOR expressions."""
left = self.parse_bitwise_and_expression()
while self.current_token[0] == "BITWISE_XOR":
op = self.current_token[1]
self.eat("BITWISE_XOR")
right = self.parse_bitwise_and_expression()
left = BinaryOpNode(left, op, right)
return left
[docs]
def parse_bitwise_and_expression(self):
"""Parse bitwise AND expressions."""
left = self.parse_equality_expression()
while self.current_token[0] == "BITWISE_AND":
op = self.current_token[1]
self.eat("BITWISE_AND")
right = self.parse_equality_expression()
left = BinaryOpNode(left, op, right)
return left
[docs]
def parse_equality_expression(self):
"""Parse equality and inequality expressions."""
left = self.parse_relational_expression()
while self.current_token[0] in ["EQUAL", "NOT_EQUAL"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_relational_expression()
left = BinaryOpNode(left, op, right)
return left
[docs]
def parse_relational_expression(self):
"""Parse relational comparison expressions."""
left = self.parse_shift_expression()
while self.current_token[0] in [
"LESS_THAN",
"GREATER_THAN",
"LESS_EQUAL",
"GREATER_EQUAL",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_shift_expression()
left = BinaryOpNode(left, op, right)
return left
[docs]
def parse_shift_expression(self):
"""Parse bit-shift expressions."""
left = self.parse_additive_expression()
while self.current_token[0] in ["BITWISE_SHIFT_LEFT", "BITWISE_SHIFT_RIGHT"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_additive_expression()
left = BinaryOpNode(left, op, right)
return left
[docs]
def parse_additive_expression(self):
"""Parse addition and subtraction expressions."""
left = self.parse_multiplicative_expression()
while self.current_token[0] in ["PLUS", "MINUS"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_multiplicative_expression()
left = BinaryOpNode(left, op, right)
return left
[docs]
def parse_multiplicative_expression(self):
"""Parse multiplication, division, and modulo expressions."""
left = self.parse_unary_expression()
while self.current_token[0] in ["MULTIPLY", "DIVIDE", "MOD"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_unary_expression()
left = BinaryOpNode(left, op, right)
return left
[docs]
def parse_unary_expression(self):
"""Parse prefix unary expressions."""
if self.current_token[0] in [
"NOT",
"MINUS",
"PLUS",
"BITWISE_NOT",
"INCREMENT",
"DECREMENT",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
operand = self.parse_unary_expression()
return UnaryOpNode(op, operand)
return self.parse_postfix_expression()
[docs]
def parse_postfix_expression(self):
"""Parse member, call, index, and postfix unary expressions."""
left = self.parse_primary_expression()
while True:
if self.current_token[0] == "DOT":
self.eat("DOT")
member = self.current_token[1]
self.eat("IDENTIFIER")
left = MemberAccessNode(left, member)
elif self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
index = self.parse_expression()
self.eat("RBRACKET")
left = ArrayAccessNode(left, index)
elif (
self.current_token[0] == "LESS_THAN"
and isinstance(left, IdentifierNode)
and left.name in {"vec2", "vec3", "vec4"}
):
generic_args = self.parse_generic_arguments()
args = ", ".join(self.format_type_argument(arg) for arg in generic_args)
left = IdentifierNode(f"{left.name}<{args}>")
elif self.current_token[0] == "LPAREN":
self.eat("LPAREN")
arguments = []
while self.current_token[0] != "RPAREN":
arguments.append(self.parse_expression())
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("RPAREN")
if isinstance(left, IdentifierNode):
if left.name in WAVE_INTRINSICS:
left = WaveOpNode(left.name, arguments)
elif left.name in RAYTRACING_INTRINSICS:
left = RayTracingOpNode(left.name, arguments)
elif left.name in MESH_INTRINSICS:
left = MeshOpNode(left.name, arguments)
else:
left = FunctionCallNode(left, arguments)
elif (
isinstance(left, MemberAccessNode)
and left.member in RAYQUERY_METHODS
):
left = RayQueryOpNode(left.member, left.object, arguments)
else:
left = FunctionCallNode(left, arguments)
elif self.current_token[0] in ["INCREMENT", "DECREMENT"]:
op = self.current_token[1]
self.eat(self.current_token[0])
left = UnaryOpNode(op, left, is_postfix=True)
else:
break
return left
[docs]
def parse_primary_expression(self):
"""Parse identifiers, literals, parenthesized expressions, and arrays."""
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
return IdentifierNode(name)
elif self.current_token[0] == "BOOLEAN_LITERAL":
return self.parse_literal()
elif self.current_token[0] in [
"NUMBER",
"FLOAT_NUMBER",
"HEX_NUMBER",
"BIN_NUMBER",
"OCT_NUMBER",
]:
return self.parse_literal()
elif self.current_token[0] in ["STRING_LITERAL", "CHAR_LITERAL"]:
return self.parse_literal()
elif self.current_token[0] == "LPAREN":
self.eat("LPAREN")
expr = self.parse_expression()
self.eat("RPAREN")
return expr
elif self.current_token[0] == "LBRACE":
return self.parse_array_literal()
else:
name = str(self.current_token[1]) if self.current_token[1] else "unknown"
self.eat(self.current_token[0])
return IdentifierNode(name)
[docs]
def parse_array_literal(self):
"""Parse a braced array literal expression."""
self.eat("LBRACE")
elements = []
while self.current_token[0] != "RBRACE":
elements.append(self.parse_expression())
if self.current_token[0] == "COMMA":
self.eat("COMMA")
if self.current_token[0] == "RBRACE":
break
elif self.current_token[0] != "RBRACE":
break
self.eat("RBRACE")
return ArrayLiteralNode(elements)
[docs]
def parse_literal(self):
"""Parse a scalar literal token into a typed literal node."""
token_type, value = self.current_token
if token_type == "NUMBER":
self.eat("NUMBER")
digits, literal_type = self.parse_integer_literal_parts(value)
return LiteralNode(int(digits), literal_type)
elif token_type == "FLOAT_NUMBER":
self.eat("FLOAT_NUMBER")
return LiteralNode(float(value.rstrip("fF")), PrimitiveType("float"))
elif token_type == "BOOLEAN_LITERAL":
self.eat("BOOLEAN_LITERAL")
return LiteralNode(value == "true", PrimitiveType("bool"))
elif token_type == "HEX_NUMBER":
self.eat("HEX_NUMBER")
digits, literal_type = self.parse_integer_literal_parts(value)
return LiteralNode(int(digits, 16), literal_type)
elif token_type == "BIN_NUMBER":
self.eat("BIN_NUMBER")
digits, literal_type = self.parse_integer_literal_parts(value)
return LiteralNode(int(digits, 2), literal_type)
elif token_type == "OCT_NUMBER":
self.eat("OCT_NUMBER")
digits, literal_type = self.parse_integer_literal_parts(value)
return LiteralNode(int(digits, 8), literal_type)
elif token_type == "STRING_LITERAL":
self.eat("STRING_LITERAL")
return LiteralNode(value[1:-1], PrimitiveType("string"))
elif token_type == "CHAR_LITERAL":
self.eat("CHAR_LITERAL")
return LiteralNode(value[1:-1], PrimitiveType("char"))
else:
self.eat(token_type)
return LiteralNode(value, PrimitiveType("unknown"))
[docs]
def parse_integer_literal_parts(self, value):
"""Return integer digits and signedness for an integer literal."""
value = str(value)
if value.endswith(("u", "U")):
return value[:-1], PrimitiveType("uint")
return value, PrimitiveType("int")
# Legacy compatibility methods
[docs]
def parse_legacy_shader(self):
"""Return an empty legacy shader root."""
return create_legacy_shader_node([], [], [], [])
[docs]
def parse_shader_stage(self):
"""Parse a legacy stage block into a function node."""
stage_type = self.current_token[1]
self.eat(self.current_token[0])
self.eat("LBRACE")
body = []
while self.current_token[0] != "RBRACE":
if self.is_function_declaration():
func = self.parse_function()
body.append(func)
else:
self.skip_unknown_token()
self.eat("RBRACE")
return FunctionNode(
name=f"{stage_type}_main",
return_type=PrimitiveType("void"),
parameters=[],
body=BlockNode(body),
qualifiers=[stage_type],
)
[docs]
def parse_cbuffer_as_struct(self):
"""Parse a constant/uniform buffer declaration as a struct node."""
if self.current_token[0] == "CBUFFER":
self.eat("CBUFFER")
else:
self.eat("UNIFORM")
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE":
member_type = self.parse_type()
member_name = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
size = None
if self.current_token[0] != "RBRACKET":
size = self.parse_expression()
self.eat("RBRACKET")
member_type = ArrayType(member_type, size)
self.eat("SEMICOLON")
members.append(StructMemberNode(name=member_name, member_type=member_type))
self.eat("RBRACE")
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
node = StructNode(name=name, members=members)
node.is_cbuffer = True
return node
# Helper methods
[docs]
def is_cbuffer_declaration(self):
"""Return whether the current token begins a cbuffer declaration."""
if self.current_token[0] == "CBUFFER":
return True
return (
self.current_token[0] == "UNIFORM"
and self.peek()[0] == "IDENTIFIER"
and self.peek(2)[0] == "LBRACE"
)
[docs]
def is_function_declaration(self):
"""Return whether the current token sequence looks like a function."""
saved_pos = self.pos
saved_token = self.current_token
try:
while self.current_token[0] in ["ASYNC", "UNSAFE", "GLOBAL", "KERNEL"]:
self.eat(self.current_token[0])
if self.current_token[0] == "FUNCTION":
self.eat("FUNCTION")
if self.is_type_token():
self.advance_over_type()
if self.current_token[0] == "IDENTIFIER":
self.eat("IDENTIFIER")
# Skip generic parameters if present
if self.current_token[0] == "LESS_THAN":
depth = 1
self.eat("LESS_THAN")
while depth > 0 and self.current_token[0] != "EOF":
if self.current_token[0] == "LESS_THAN":
depth += 1
elif self.current_token[0] == "GREATER_THAN":
depth -= 1
self.eat(self.current_token[0])
if self.current_token[0] == "LPAREN":
return True
except:
pass
finally:
self.pos = saved_pos
self.current_token = saved_token
return False
[docs]
def is_type_token(self):
"""Return whether the current token can start a type expression."""
return self.current_token[0] in [
"BOOL",
"I8",
"I16",
"I32",
"I64",
"U8",
"U16",
"U32",
"U64",
"F16",
"F32",
"F64",
"INT",
"UINT",
"FLOAT",
"DOUBLE",
"HALF",
"CHAR",
"STRING",
"VOID",
"VEC2",
"VEC3",
"VEC4",
"IVEC2",
"IVEC3",
"IVEC4",
"UVEC2",
"UVEC3",
"UVEC4",
"DVEC2",
"DVEC3",
"DVEC4",
"BVEC2",
"BVEC3",
"BVEC4",
"MAT2",
"MAT3",
"MAT4",
"MAT2X2",
"MAT2X3",
"MAT2X4",
"MAT3X2",
"MAT3X3",
"MAT3X4",
"MAT4X2",
"MAT4X3",
"MAT4X4",
"DMAT2",
"DMAT3",
"DMAT4",
"DMAT2X2",
"DMAT2X3",
"DMAT2X4",
"DMAT3X2",
"DMAT3X3",
"DMAT3X4",
"DMAT4X2",
"DMAT4X3",
"DMAT4X4",
"SAMPLER",
"SAMPLER1D",
"SAMPLER2D",
"SAMPLER3D",
"SAMPLERCUBE",
"SAMPLER2DARRAY",
"SAMPLER2DSHADOW",
"SAMPLER2DARRAYSHADOW",
"SAMPLERCUBESHADOW",
"SAMPLERCUBEARRAY",
"SAMPLERCUBEARRAYSHADOW",
"SAMPLER2DMS",
"SAMPLER2DMSARRAY",
"IIMAGE2D",
"IIMAGE3D",
"IIMAGE2DARRAY",
"IIMAGE2DMS",
"IIMAGE2DMSARRAY",
"UIMAGE2D",
"UIMAGE3D",
"UIMAGE2DARRAY",
"UIMAGE2DMS",
"UIMAGE2DMSARRAY",
"IMAGE2D",
"IMAGE3D",
"IMAGECUBE",
"IMAGE2DARRAY",
"IMAGE2DMS",
"IMAGE2DMSARRAY",
"IDENTIFIER",
]
[docs]
def advance_over_type(self):
"""Advance over a type expression during lookahead checks."""
if self.is_type_token():
self.eat(self.current_token[0])
if self.current_token[0] == "LESS_THAN":
depth = 1
self.eat("LESS_THAN")
while depth > 0 and self.current_token[0] != "EOF":
if self.current_token[0] == "LESS_THAN":
depth += 1
elif self.current_token[0] == "GREATER_THAN":
depth -= 1
self.eat(self.current_token[0])
[docs]
def skip_unknown_token(self):
"""Consume one token when recovering from unsupported syntax."""
if self.current_token[0] != "EOF":
self.eat(self.current_token[0])
[docs]
def create_empty_shader(self):
"""Create an empty shader used by parser error recovery."""
return ShaderNode(
name="error",
execution_model=ExecutionModel.GRAPHICS_PIPELINE,
stages={},
structs=[],
functions=[],
global_variables=[],
constants=[],
imports=[],
)
[docs]
def report_error(self, message):
"""Emit a parser warning."""
logging.warning(f"Parser: {message}")
[docs]
def next_token(self):
"""Advance to the next token without validating token type."""
if self.pos < len(self.tokens) - 1:
self.pos += 1
self.current_token = self.tokens[self.pos]
[docs]
def parse_global(self):
"""Parse one top-level declaration or recover past unsupported input."""
if self.current_token[0] == "EOF":
return None
if self.current_token[0] == "PREPROCESSOR":
return self.parse_preprocessor_directive()
if self.current_token[0] == "PRECISION":
return self.parse_precision_statement()
if self.current_token[0] == "IDENTIFIER" and self.current_token[1] == "generic":
self.skip_unknown_token()
while self.current_token[0] not in ["RBRACE", "SEMICOLON", "EOF"]:
self.skip_unknown_token()
if self.current_token[0] in ["RBRACE", "SEMICOLON"]:
self.skip_unknown_token()
return None
if self.current_token[0] == "TRAIT":
self.skip_unknown_token()
while self.current_token[0] not in ["RBRACE", "EOF"]:
self.skip_unknown_token()
if self.current_token[0] == "RBRACE":
self.skip_unknown_token()
return None
if self.current_token[0] == "ENUM":
return self.parse_enum()
if self.current_token[0] == "STRUCT":
return self.parse_struct()
if self.current_token[0] == "CONST":
return self.parse_constant()
if self.is_cbuffer_declaration():
return self.parse_cbuffer_as_struct()
if self.current_token[0] in [
"VERTEX",
"FRAGMENT",
"COMPUTE",
"GEOMETRY",
"TESSELLATION_CONTROL",
"TESSELLATION_EVALUATION",
"TASK",
"MESH",
"OBJECT",
"AMPLIFICATION",
"RAY_GENERATION",
"RAY_INTERSECTION",
"RAY_CLOSEST_HIT",
"RAY_MISS",
"RAY_ANY_HIT",
"RAY_CALLABLE",
]:
return self.parse_shader_stage_block()
# Functions take priority over variables to avoid misidentification.
if self.is_function_declaration():
return self.parse_function()
if self.is_variable_declaration():
return self.parse_variable_declaration()
# is_function_declaration can fail on some edge-case type tokens; try anyway.
if self.is_type_token():
try:
return self.parse_function()
except:
self.skip_unknown_token()
return None
self.skip_unknown_token()
return None
[docs]
def parse_generic_declaration(self):
"""Parse legacy generic declarations that wrap structs/enums/functions."""
if self.current_token[0] == "EOF":
return None
self.eat("IDENTIFIER")
if self.current_token[0] != "LESS_THAN":
return None
self.eat("LESS_THAN")
generic_params = []
while (
self.current_token[0] != "GREATER_THAN" and self.current_token[0] != "EOF"
):
if self.current_token[0] == "IDENTIFIER":
param_name = self.current_token[1]
self.eat("IDENTIFIER")
constraints = []
if self.current_token[0] == "COLON":
self.eat("COLON")
while (
self.current_token[0] == "IDENTIFIER"
and self.current_token[0] != "EOF"
):
constraints.append(self.current_token[1])
self.eat("IDENTIFIER")
if self.current_token[0] == "PLUS":
self.eat("PLUS")
else:
break
generic_params.append((param_name, constraints))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
elif self.current_token[0] != "GREATER_THAN":
break
if self.current_token[0] == "EOF":
return None
self.eat("GREATER_THAN")
if self.current_token[0] == "EOF":
return None
elif self.current_token[0] == "STRUCT":
struct_node = self.parse_struct()
if struct_node:
struct_node.generic_params = generic_params
return struct_node
elif self.current_token[0] == "ENUM":
enum_node = self.parse_enum()
if enum_node:
enum_node.generic_params = generic_params
return enum_node
elif self.current_token[0] == "FUNCTION" or self.is_function_declaration():
func_node = self.parse_function()
if func_node:
func_node.generic_params = generic_params
return func_node
else:
self.skip_unknown_token()
return None
[docs]
def parse_trait(self):
"""Parse a trait-like declaration into the current AST shape."""
if self.current_token[0] == "EOF":
return None
self.eat("TRAIT")
if self.current_token[0] != "IDENTIFIER":
return None
name = self.current_token[1]
self.eat("IDENTIFIER")
generic_params = []
if self.current_token[0] == "LESS_THAN":
generic_params = self.parse_generic_arguments()
if self.current_token[0] != "LBRACE":
return None
self.eat("LBRACE")
methods = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
if self.is_function_declaration():
func = self.parse_function()
if func:
methods.append(func)
else:
self.skip_unknown_token()
if self.current_token[0] == "EOF":
return None
self.eat("RBRACE")
# Traits are represented as structs (simplified).
return StructNode(
name=name, members=methods, attributes=[], generic_params=generic_params
)
[docs]
def parse_switch_statement(self):
"""Parse a switch statement and its case clauses."""
self.eat("SWITCH")
self.eat("LPAREN")
expression = self.parse_expression()
self.eat("RPAREN")
self.eat("LBRACE")
cases = []
while self.current_token[0] != "RBRACE":
if self.current_token[0] == "CASE":
case = self.parse_case()
cases.append(case)
elif self.current_token[0] == "DEFAULT":
case = self.parse_default_case()
cases.append(case)
else:
self.skip_unknown_token()
self.eat("RBRACE")
return SwitchNode(expression=expression, cases=cases)
[docs]
def parse_case(self):
"""Parse one explicit switch case."""
self.eat("CASE")
value = self.parse_expression()
self.eat("COLON")
statements = []
while self.current_token[0] not in ["CASE", "DEFAULT", "RBRACE", "EOF"]:
stmt = self.parse_statement()
statements.append(stmt)
return CaseNode(value=value, statements=statements)
[docs]
def parse_default_case(self):
"""Parse the default switch case."""
self.eat("DEFAULT")
self.eat("COLON")
statements = []
while self.current_token[0] not in ["CASE", "DEFAULT", "RBRACE", "EOF"]:
stmt = self.parse_statement()
statements.append(stmt)
return CaseNode(value=None, statements=statements)