"""Parser for DirectX HLSL source AST construction."""
import re
from .DirectxAst import (
AssignmentNode,
BinaryOpNode,
ForNode,
WhileNode,
DoWhileNode,
FunctionCallNode,
FunctionNode,
IfNode,
MemberAccessNode,
ReturnNode,
ShaderNode,
StructNode,
EnumNode,
TypeAliasNode,
UnaryOpNode,
VariableNode,
VectorConstructorNode,
PragmaNode,
IncludeNode,
SwitchNode,
CaseNode,
TernaryOpNode,
)
from ..common_ast import (
BreakNode,
ContinueNode,
ArrayAccessNode,
CastNode,
AttributeNode,
PreprocessorNode,
TextureSampleNode,
)
TYPE_TOKENS = {
"FLOAT",
"HALF",
"DOUBLE",
"INT",
"UINT",
"BOOL",
"VOID",
"DWORD",
"MIN16FLOAT",
"MIN10FLOAT",
"MIN16INT",
"MIN12INT",
"MIN16UINT",
"INT64_T",
"UINT64_T",
"FVECTOR",
"IVECTOR",
"UVECTOR",
"BVECTOR",
"MATRIX",
"TEXTURE1D",
"TEXTURE1DARRAY",
"TEXTURE2D",
"TEXTURE2DARRAY",
"TEXTURE2DMS",
"TEXTURE2DMSARRAY",
"TEXTURE3D",
"TEXTURECUBE",
"TEXTURECUBEARRAY",
"FEEDBACKTEXTURE2D",
"FEEDBACKTEXTURE2DARRAY",
"RWTEXTURE1D",
"RWTEXTURE1DARRAY",
"RWTEXTURE2D",
"RWTEXTURE2DARRAY",
"RWTEXTURE2DMS",
"RWTEXTURE2DMSARRAY",
"RWTEXTURE3D",
"RWTEXTURECUBE",
"RWTEXTURECUBEARRAY",
"STRUCTUREDBUFFER",
"RWSTRUCTUREDBUFFER",
"APPENDSTRUCTUREDBUFFER",
"CONSUMESTRUCTUREDBUFFER",
"BYTEADDRESSBUFFER",
"RWBYTEADDRESSBUFFER",
"RAYTRACING_ACCELERATION_STRUCTURE",
"RAYQUERY",
"BUFFER",
"RWBUFFER",
"SAMPLER_STATE",
"SAMPLER_COMPARISON_STATE",
"INPUTPATCH",
"OUTPUTPATCH",
"POINTSTREAM",
"LINESTREAM",
"TRIANGLESTREAM",
"IDENTIFIER",
}
QUALIFIER_TOKENS = {
"STATIC",
"CONST",
"INLINE",
"EXTERN",
"VOLATILE",
"PRECISE",
"ROW_MAJOR",
"COLUMN_MAJOR",
"NOINTERPOLATION",
"LINEAR",
"CENTROID",
"SAMPLE",
"IN",
"OUT",
"INOUT",
"UNIFORM",
"GROUPSHARED",
}
ASSIGNMENT_TOKENS = {
"EQUALS",
"PLUS_EQUALS",
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"MOD_EQUALS",
"ASSIGN_AND",
"ASSIGN_OR",
"ASSIGN_XOR",
"ASSIGN_SHIFT_LEFT",
"ASSIGN_SHIFT_RIGHT",
}
[docs]
class HLSLParser:
"""Parse HLSL tokens into the DirectX backend shader AST."""
def __init__(self, tokens):
"""Initialize the parser with a token stream from ``HLSLLexer``."""
self.tokens = tokens
self.current_index = 0
self.current_token = tokens[0] if tokens else ("EOF", "")
[docs]
def parse(self):
"""Parse the complete token stream into a ``ShaderNode``."""
structs = []
functions = []
global_variables = []
cbuffers = []
enums = []
typedefs = []
while self.current_token[0] != "EOF":
if self.current_token[0] == "PREPROCESSOR":
directive = self.parse_preprocessor_directive()
if directive is not None:
structs.append(directive)
continue
if self.current_token[0] == "STRUCT":
structs.append(self.parse_struct())
continue
if self.current_token[0] == "ENUM":
enums.append(self.parse_enum())
continue
if self.current_token[0] == "TYPEDEF":
typedefs.append(self.parse_typedef())
continue
if self.current_token[0] == "CBUFFER":
cbuffers.append(self.parse_cbuffer())
continue
attributes = self.parse_attribute_list()
qualifiers = self.parse_qualifiers()
if not self.is_type_token(self.current_token[0]):
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
else:
self.eat(self.current_token[0])
continue
return_type = self.parse_type()
if self.current_token[0] != "IDENTIFIER":
raise SyntaxError(
f"Expected identifier after type, got {self.current_token[0]}"
)
name = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] == "LPAREN":
func = self.parse_function(return_type, name, qualifiers, attributes)
functions.append(func)
else:
var = self.parse_variable_declaration_rest(
return_type,
name,
qualifiers=qualifiers,
attributes=attributes,
allow_semantic=True,
consume_semicolon=True,
)
global_variables.append(var)
return ShaderNode(
includes=[],
functions=functions,
structs=structs,
global_variables=global_variables,
cbuffers=cbuffers,
enums=enums,
typedefs=typedefs,
)
[docs]
def eat(self, expected_type):
"""Consume and return the current token when it matches ``expected_type``."""
if self.current_token[0] == expected_type:
token = self.current_token
self.current_index += 1
if self.current_index < len(self.tokens):
self.current_token = self.tokens[self.current_index]
else:
self.current_token = ("EOF", "")
return token
raise SyntaxError(f"Expected {expected_type}, got {self.current_token[0]}")
[docs]
def peek(self, offset=1):
"""Return a lookahead token without advancing the parser."""
idx = self.current_index + offset
if idx < len(self.tokens):
return self.tokens[idx]
return ("EOF", "")
def is_type_token(self, token_type):
return token_type in TYPE_TOKENS
def parse_preprocessor_directive(self):
token = self.eat("PREPROCESSOR")
text = token[1].strip()
if text.startswith("#pragma"):
parts = text.split(None, 2)
directive = parts[1] if len(parts) > 1 else ""
value = parts[2] if len(parts) > 2 else None
return PragmaNode(directive, value)
if text.startswith("#include"):
match = re.search(r"#include\s*([<\"])([^>\"]+)[>\"]", text)
if match:
path = match.group(2)
is_system = match.group(1) == "<"
return IncludeNode(path, is_system)
path = text[len("#include") :].strip()
return IncludeNode(path, False)
directive = text[1:].split(None, 1)[0] if text.startswith("#") else text
content = text[len(directive) + 1 :].strip() if directive else text
return PreprocessorNode(directive, content)
def parse_attribute_list(self):
attributes = []
while self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
if self.current_token[0] != "IDENTIFIER":
while (
self.current_token[0] != "RBRACKET"
and self.current_token[0] != "EOF"
):
self.eat(self.current_token[0])
self.eat("RBRACKET")
continue
name = self.current_token[1]
self.eat("IDENTIFIER")
args = []
if self.current_token[0] == "LPAREN":
self.eat("LPAREN")
if self.current_token[0] != "RPAREN":
args.append(self.parse_expression())
while self.current_token[0] == "COMMA":
self.eat("COMMA")
args.append(self.parse_expression())
self.eat("RPAREN")
while (
self.current_token[0] != "RBRACKET" and self.current_token[0] != "EOF"
):
self.eat(self.current_token[0])
self.eat("RBRACKET")
attributes.append(AttributeNode(name, args))
return attributes
def parse_qualifiers(self):
qualifiers = []
while self.current_token[0] in QUALIFIER_TOKENS:
qualifiers.append(self.current_token[1])
self.eat(self.current_token[0])
return qualifiers
def parse_type(self):
if not self.is_type_token(self.current_token[0]):
raise SyntaxError(f"Expected type, got {self.current_token[0]}")
base = self.current_token[1]
self.eat(self.current_token[0])
type_name = base
if self.current_token[0] == "LESS_THAN":
args = self.parse_generic_arguments()
type_name = f"{base}<{', '.join(args)}>"
return type_name
def parse_generic_arguments(self):
args = []
self.eat("LESS_THAN")
while (
self.current_token[0] != "GREATER_THAN" and self.current_token[0] != "EOF"
):
if self.is_type_token(self.current_token[0]):
args.append(self.parse_type())
else:
args.append(self.current_token[1])
self.eat(self.current_token[0])
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("GREATER_THAN")
return args
def parse_array_suffixes(self):
sizes = []
while self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
if self.current_token[0] != "RBRACKET":
sizes.append(self.parse_expression())
else:
sizes.append(None)
self.eat("RBRACKET")
return sizes
def parse_semantic_or_register(self):
semantic = None
register = None
packoffset = None
if self.current_token[0] != "COLON":
return semantic, register, packoffset
self.eat("COLON")
if self.current_token[0] == "REGISTER":
register = self.parse_register_binding("REGISTER")
return semantic, register, packoffset
if self.current_token[0] == "PACKOFFSET":
packoffset = self.parse_register_binding("PACKOFFSET")
return semantic, register, packoffset
semantic = self.current_token[1]
self.eat("IDENTIFIER")
return semantic, register, packoffset
def parse_register_binding(self, token_type):
self.eat(token_type)
self.eat("LPAREN")
parts = []
while self.current_token[0] != "RPAREN" and self.current_token[0] != "EOF":
if self.current_token[0] == "COMMA":
parts.append(", ")
self.eat("COMMA")
continue
parts.append(self.current_token[1])
self.eat(self.current_token[0])
self.eat("RPAREN")
return "".join(str(part) for part in parts).strip()
def parse_struct(self):
self.eat("STRUCT")
name = self.current_token[1]
self.eat("IDENTIFIER")
semantic = None
if self.current_token[0] == "COLON":
semantic, _, _ = self.parse_semantic_or_register()
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
attributes = self.parse_attribute_list()
qualifiers = self.parse_qualifiers()
if not self.is_type_token(self.current_token[0]):
raise SyntaxError(
f"Expected type in struct member, got {self.current_token[0]}"
)
member_type = self.parse_type()
member_name = self.current_token[1]
self.eat("IDENTIFIER")
array_sizes = self.parse_array_suffixes()
member_semantic, _, _ = self.parse_semantic_or_register()
self.eat("SEMICOLON")
member = VariableNode(
member_type,
member_name,
qualifiers=qualifiers,
attributes=attributes,
semantic=member_semantic,
)
member.array_sizes = array_sizes
members.append(member)
self.eat("RBRACE")
variables = []
if self.current_token[0] == "IDENTIFIER":
variables.append(self.current_token[1])
self.eat("IDENTIFIER")
while self.current_token[0] == "COMMA":
self.eat("COMMA")
variables.append(self.current_token[1])
self.eat("IDENTIFIER")
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
struct_node = StructNode(name, members)
struct_node.variables = variables
struct_node.semantic = semantic
return struct_node
def parse_enum(self):
self.eat("ENUM")
if self.current_token[0] == "IDENTIFIER" and self.current_token[1] == "class":
self.eat("IDENTIFIER")
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE":
member_name = self.current_token[1]
self.eat("IDENTIFIER")
member_value = None
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
member_value = self.parse_expression()
members.append((member_name, member_value))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
elif self.current_token[0] == "RBRACE":
break
else:
raise SyntaxError(
f"Expected comma or closing brace in enum, got {self.current_token[0]}"
)
self.eat("RBRACE")
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return EnumNode(name, members)
def parse_typedef(self):
self.eat("TYPEDEF")
alias_type = self.parse_type()
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
return TypeAliasNode(alias_type, name)
def parse_cbuffer(self):
self.eat("CBUFFER")
name = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] == "COLON":
_, register, packoffset = self.parse_semantic_or_register()
else:
register = None
packoffset = None
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
qualifiers = self.parse_qualifiers()
member_type = self.parse_type()
member_name = self.current_token[1]
self.eat("IDENTIFIER")
array_sizes = self.parse_array_suffixes()
semantic = None
register = None
packoffset = None
if self.current_token[0] == "COLON":
semantic, register, packoffset = self.parse_semantic_or_register()
self.eat("SEMICOLON")
member = VariableNode(
member_type,
member_name,
qualifiers=qualifiers,
semantic=semantic,
)
member.register = register
member.packoffset = packoffset
member.array_sizes = array_sizes
members.append(member)
self.eat("RBRACE")
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
cbuffer_node = StructNode(name, members)
cbuffer_node.is_cbuffer = True
cbuffer_node.register = register
cbuffer_node.packoffset = packoffset
return cbuffer_node
def parse_function(self, return_type, name, qualifiers, attributes):
params = self.parse_parameters()
semantic = None
if self.current_token[0] == "COLON":
semantic, _, _ = self.parse_semantic_or_register()
body = self.parse_block()
qualifier = self.infer_function_qualifier(name, attributes, params, semantic)
return FunctionNode(
return_type=return_type,
name=name,
params=params,
body=body,
qualifiers=qualifiers + ([qualifier] if qualifier else []),
attributes=attributes,
qualifier=qualifier,
semantic=semantic,
)
def infer_function_qualifier(self, name, attributes, params, semantic):
for attr in attributes:
if attr.name.lower() == "shader" and attr.args:
raw = attr.args[0]
stage_name = raw
if isinstance(raw, str):
stage_name = raw.strip().strip("\"'").lower()
else:
stage_name = str(raw).lower()
stage_map = {
"vertex": "vertex",
"pixel": "fragment",
"fragment": "fragment",
"compute": "compute",
"geometry": "geometry",
"hull": "tessellation_control",
"domain": "tessellation_evaluation",
"mesh": "mesh",
"amplification": "task",
"task": "task",
"raygeneration": "ray_generation",
"intersection": "ray_intersection",
"closesthit": "ray_closest_hit",
"anyhit": "ray_any_hit",
"miss": "ray_miss",
"callable": "ray_callable",
}
if stage_name in stage_map:
return stage_map[stage_name]
name_lower = name.lower()
if name_lower.startswith("vs"):
return "vertex"
if name_lower.startswith("ps"):
return "fragment"
if name_lower.startswith("cs"):
return "compute"
if name_lower.startswith("gs"):
return "geometry"
if name_lower.startswith("hs"):
return "tessellation_control"
if name_lower.startswith("ds"):
return "tessellation_evaluation"
if name_lower.startswith("ms"):
return "mesh"
if name_lower.startswith("as"):
return "task"
if any(attr.name == "numthreads" for attr in attributes):
return "compute"
if semantic:
semantic_upper = semantic.upper()
if semantic_upper.startswith("SV_TARGET"):
return "fragment"
if semantic_upper == "SV_POSITION":
return "vertex"
for param in params:
if getattr(param, "semantic", None) == "SV_DispatchThreadID":
return "compute"
return None
def parse_parameters(self):
self.eat("LPAREN")
params = []
primitive_qualifiers = {
"point",
"line",
"triangle",
"lineadj",
"triangleadj",
}
if self.current_token[0] != "RPAREN":
while True:
attributes = self.parse_attribute_list()
qualifiers = self.parse_qualifiers()
if (
self.current_token[0] == "IDENTIFIER"
and self.current_token[1] in primitive_qualifiers
):
attributes.append(
AttributeNode("primitive", [self.current_token[1]])
)
self.eat("IDENTIFIER")
if not self.is_type_token(self.current_token[0]):
raise SyntaxError(
f"Unexpected token in parameter list: {self.current_token[0]}"
)
param_type = self.parse_type()
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
else:
name = ""
array_sizes = self.parse_array_suffixes()
semantic, _, _ = self.parse_semantic_or_register()
param = VariableNode(
param_type,
name,
qualifiers=qualifiers,
attributes=attributes,
is_const="const" in qualifiers,
semantic=semantic,
)
param.array_sizes = array_sizes
params.append(param)
if self.current_token[0] == "COMMA":
self.eat("COMMA")
continue
break
self.eat("RPAREN")
return params
def parse_block(self):
self.eat("LBRACE")
statements = []
while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF":
stmt = self.parse_statement()
if stmt is None:
continue
if isinstance(stmt, list):
statements.extend(stmt)
else:
statements.append(stmt)
self.eat("RBRACE")
return statements
def parse_statement(self):
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return None
if self.current_token[0] == "RETURN":
self.eat("RETURN")
value = None
if self.current_token[0] != "SEMICOLON":
value = self.parse_expression()
self.eat("SEMICOLON")
return ReturnNode(value)
if self.current_token[0] == "IF":
return self.parse_if_statement()
if self.current_token[0] == "FOR":
return self.parse_for_loop()
if self.current_token[0] == "WHILE":
return self.parse_while_loop()
if self.current_token[0] == "DO":
return self.parse_do_while_loop()
if self.current_token[0] == "SWITCH":
return self.parse_switch_statement()
if self.current_token[0] == "BREAK":
self.eat("BREAK")
self.eat("SEMICOLON")
return BreakNode()
if self.current_token[0] == "CONTINUE":
self.eat("CONTINUE")
self.eat("SEMICOLON")
return ContinueNode()
if self.current_token[0] == "DISCARD":
self.eat("DISCARD")
self.eat("SEMICOLON")
return "discard"
if self.current_token[0] == "PREPROCESSOR":
self.parse_preprocessor_directive()
return None
if self.looks_like_declaration():
attributes = self.parse_attribute_list()
qualifiers = self.parse_qualifiers()
var = self.parse_variable_declaration(
qualifiers=qualifiers,
attributes=attributes,
allow_semantic=False,
consume_semicolon=True,
)
return var
expr = self.parse_expression()
self.eat("SEMICOLON")
return expr
def looks_like_declaration(self):
idx = self.current_index
while idx < len(self.tokens) and self.tokens[idx][0] in QUALIFIER_TOKENS:
idx += 1
if idx >= len(self.tokens) or self.tokens[idx][0] not in TYPE_TOKENS:
return False
idx += 1
if idx < len(self.tokens) and self.tokens[idx][0] == "LESS_THAN":
depth = 0
while idx < len(self.tokens):
if self.tokens[idx][0] == "LESS_THAN":
depth += 1
elif self.tokens[idx][0] == "GREATER_THAN":
depth -= 1
if depth == 0:
idx += 1
break
idx += 1
if idx >= len(self.tokens) or self.tokens[idx][0] != "IDENTIFIER":
return False
return True
def parse_variable_declaration(
self,
qualifiers=None,
attributes=None,
allow_semantic=True,
consume_semicolon=True,
):
qualifiers = qualifiers or []
attributes = attributes or []
vtype = self.parse_type()
name = self.current_token[1]
self.eat("IDENTIFIER")
return self.parse_variable_declaration_rest(
vtype,
name,
qualifiers=qualifiers,
attributes=attributes,
allow_semantic=allow_semantic,
consume_semicolon=consume_semicolon,
)
def parse_variable_declaration_rest(
self,
vtype,
name,
qualifiers=None,
attributes=None,
allow_semantic=True,
consume_semicolon=True,
):
qualifiers = qualifiers or []
attributes = attributes or []
array_sizes = self.parse_array_suffixes()
semantic = None
register = None
packoffset = None
if allow_semantic:
semantic, register, packoffset = self.parse_semantic_or_register()
value = None
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
value = self.parse_expression()
if consume_semicolon:
self.eat("SEMICOLON")
var = VariableNode(
vtype,
name,
value=value,
qualifiers=qualifiers,
attributes=attributes,
is_const="const" in qualifiers,
semantic=semantic,
)
var.array_sizes = array_sizes
var.register = register
var.packoffset = packoffset
return var
def parse_if_statement(self):
self.eat("IF")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
if_body = self.parse_statement_or_block()
else_body = None
if self.current_token[0] in ["ELSE", "ELSE_IF"]:
if self.current_token[0] == "ELSE":
self.eat("ELSE")
else:
self.eat("ELSE_IF")
self.eat("LPAREN")
else_condition = self.parse_expression()
self.eat("RPAREN")
else_body = IfNode(
else_condition,
self.parse_statement_or_block(),
None,
)
return IfNode(condition, if_body, else_body)
if self.current_token[0] == "IF":
else_body = self.parse_if_statement()
else:
else_body = self.parse_statement_or_block()
return IfNode(condition, if_body, else_body)
def parse_statement_or_block(self):
if self.current_token[0] == "LBRACE":
return self.parse_block()
stmt = self.parse_statement()
return [stmt] if stmt is not None else []
def parse_for_loop(self):
self.eat("FOR")
self.eat("LPAREN")
init = None
if self.current_token[0] != "SEMICOLON":
if self.looks_like_declaration():
qualifiers = self.parse_qualifiers()
init = self.parse_variable_declaration(
qualifiers=qualifiers,
attributes=[],
allow_semantic=False,
consume_semicolon=False,
)
else:
init = self.parse_expression()
self.eat("SEMICOLON")
condition = None
if self.current_token[0] != "SEMICOLON":
condition = self.parse_expression()
self.eat("SEMICOLON")
update = None
if self.current_token[0] != "RPAREN":
update = self.parse_expression()
self.eat("RPAREN")
body = self.parse_statement_or_block()
return ForNode(init, condition, update, body)
def parse_while_loop(self):
self.eat("WHILE")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
body = self.parse_statement_or_block()
return WhileNode(condition, body)
def parse_do_while_loop(self):
self.eat("DO")
body = self.parse_statement_or_block()
self.eat("WHILE")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
self.eat("SEMICOLON")
return DoWhileNode(body, condition)
def parse_switch_statement(self):
self.eat("SWITCH")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
self.eat("LBRACE")
cases = []
default_body = None
while self.current_token[0] in ("CASE", "DEFAULT"):
if self.current_token[0] == "CASE":
cases.append(self.parse_switch_case())
else:
self.eat("DEFAULT")
self.eat("COLON")
default_body = []
while self.current_token[0] not in [
"CASE",
"DEFAULT",
"RBRACE",
"EOF",
]:
stmt = self.parse_statement()
if stmt is not None:
if isinstance(stmt, list):
default_body.extend(stmt)
else:
default_body.append(stmt)
self.eat("RBRACE")
return SwitchNode(condition, cases, default_body)
def parse_switch_case(self):
self.eat("CASE")
value = self.parse_expression()
self.eat("COLON")
body = []
while self.current_token[0] not in ["CASE", "DEFAULT", "RBRACE", "EOF"]:
stmt = self.parse_statement()
if stmt is None:
continue
if isinstance(stmt, list):
body.extend(stmt)
else:
body.append(stmt)
return CaseNode(value, body)
def parse_expression(self):
return self.parse_assignment_expression()
def parse_assignment_expression(self):
left = self.parse_conditional_expression()
if self.current_token[0] in ASSIGNMENT_TOKENS:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_assignment_expression()
return AssignmentNode(left, right, op)
return left
def parse_conditional_expression(self):
expr = self.parse_logical_or_expression()
if self.current_token[0] == "QUESTION":
self.eat("QUESTION")
true_expr = self.parse_expression()
self.eat("COLON")
false_expr = self.parse_conditional_expression()
return TernaryOpNode(expr, true_expr, false_expr)
return expr
def parse_logical_or_expression(self):
expr = self.parse_logical_and_expression()
while self.current_token[0] == "LOGICAL_OR":
op = self.current_token[1]
self.eat("LOGICAL_OR")
right = self.parse_logical_and_expression()
expr = BinaryOpNode(expr, op, right)
return expr
def parse_logical_and_expression(self):
expr = self.parse_bitwise_or_expression()
while self.current_token[0] == "LOGICAL_AND":
op = self.current_token[1]
self.eat("LOGICAL_AND")
right = self.parse_bitwise_or_expression()
expr = BinaryOpNode(expr, op, right)
return expr
def parse_bitwise_or_expression(self):
expr = self.parse_bitwise_xor_expression()
while self.current_token[0] == "BITWISE_OR":
op = self.current_token[1]
self.eat("BITWISE_OR")
right = self.parse_bitwise_xor_expression()
expr = BinaryOpNode(expr, op, right)
return expr
def parse_bitwise_xor_expression(self):
expr = self.parse_bitwise_and_expression()
while self.current_token[0] == "BITWISE_XOR":
op = self.current_token[1]
self.eat("BITWISE_XOR")
right = self.parse_bitwise_and_expression()
expr = BinaryOpNode(expr, op, right)
return expr
def parse_bitwise_and_expression(self):
expr = self.parse_equality_expression()
while self.current_token[0] == "BITWISE_AND":
op = self.current_token[1]
self.eat("BITWISE_AND")
right = self.parse_equality_expression()
expr = BinaryOpNode(expr, op, right)
return expr
def parse_equality_expression(self):
expr = self.parse_relational_expression()
while self.current_token[0] in ["EQUAL", "NOT_EQUAL"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_relational_expression()
expr = BinaryOpNode(expr, op, right)
return expr
def parse_relational_expression(self):
expr = self.parse_shift_expression()
while self.current_token[0] in [
"LESS_THAN",
"GREATER_THAN",
"LESS_EQUAL",
"GREATER_EQUAL",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_shift_expression()
expr = BinaryOpNode(expr, op, right)
return expr
def parse_shift_expression(self):
expr = self.parse_additive_expression()
while self.current_token[0] in ["SHIFT_LEFT", "SHIFT_RIGHT"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_additive_expression()
expr = BinaryOpNode(expr, op, right)
return expr
def parse_additive_expression(self):
expr = self.parse_multiplicative_expression()
while self.current_token[0] in ["PLUS", "MINUS"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_multiplicative_expression()
expr = BinaryOpNode(expr, op, right)
return expr
def parse_multiplicative_expression(self):
expr = self.parse_unary_expression()
while self.current_token[0] in ["MULTIPLY", "DIVIDE", "MOD"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_unary_expression()
expr = BinaryOpNode(expr, op, right)
return expr
def parse_unary_expression(self):
if self.current_token[0] == "LPAREN" and self.looks_like_cast():
self.eat("LPAREN")
target_type = self.parse_type()
self.eat("RPAREN")
operand = self.parse_unary_expression()
return CastNode(target_type, operand)
if self.current_token[0] in [
"PLUS",
"MINUS",
"LOGICAL_NOT",
"BITWISE_NOT",
"INCREMENT",
"DECREMENT",
]:
op = self.current_token[1]
self.eat(self.current_token[0])
operand = self.parse_unary_expression()
return UnaryOpNode(op, operand)
return self.parse_postfix_expression()
def looks_like_cast(self):
if self.current_token[0] != "LPAREN":
return False
if not self.is_type_token(self.peek()[0]):
return False
idx = self.current_index + 1
if idx < len(self.tokens) and self.tokens[idx][0] in TYPE_TOKENS:
idx += 1
if idx < len(self.tokens) and self.tokens[idx][0] == "LESS_THAN":
depth = 0
while idx < len(self.tokens):
if self.tokens[idx][0] == "LESS_THAN":
depth += 1
elif self.tokens[idx][0] == "GREATER_THAN":
depth -= 1
if depth == 0:
idx += 1
break
idx += 1
if idx < len(self.tokens) and self.tokens[idx][0] == "RPAREN":
return True
return False
def parse_postfix_expression(self):
expr = self.parse_primary_expression()
while True:
if self.current_token[0] == "LBRACKET":
self.eat("LBRACKET")
index = self.parse_expression()
self.eat("RBRACKET")
expr = ArrayAccessNode(expr, index)
elif self.current_token[0] == "DOT":
self.eat("DOT")
member = self.current_token[1]
self.eat("IDENTIFIER")
expr = MemberAccessNode(expr, member)
elif self.current_token[0] == "LPAREN":
self.eat("LPAREN")
args = []
if self.current_token[0] != "RPAREN":
args.append(self.parse_expression())
while self.current_token[0] == "COMMA":
self.eat("COMMA")
args.append(self.parse_expression())
self.eat("RPAREN")
if isinstance(expr, MemberAccessNode) and isinstance(expr.member, str):
if expr.member in ["Sample", "SampleLevel"] and len(args) >= 2:
lod = (
args[2]
if expr.member == "SampleLevel" and len(args) > 2
else None
)
expr = TextureSampleNode(expr.object, args[0], args[1], lod)
continue
expr = FunctionCallNode(expr, args)
elif self.current_token[0] in ["INCREMENT", "DECREMENT"]:
op = self.current_token[1]
self.eat(self.current_token[0])
node = UnaryOpNode(op, expr)
node.is_postfix = True
expr = node
else:
break
return expr
def parse_primary_expression(self):
token_type, value = self.current_token
if token_type == "IDENTIFIER":
self.eat("IDENTIFIER")
return value
if token_type in ["NUMBER", "HEX_NUMBER", "BINARY_NUMBER", "OCT_NUMBER"]:
self.eat(token_type)
return self.parse_numeric_literal(token_type, value)
if token_type in ["TRUE", "FALSE"]:
self.eat(token_type)
return token_type == "TRUE"
if token_type in ["STRING", "CHAR_LITERAL"]:
self.eat(token_type)
return value
if token_type == "LPAREN":
self.eat("LPAREN")
expr = self.parse_expression()
self.eat("RPAREN")
return expr
if token_type in [
"FLOAT",
"HALF",
"DOUBLE",
"INT",
"UINT",
"BOOL",
"FVECTOR",
"IVECTOR",
"UVECTOR",
"BVECTOR",
"MATRIX",
]:
type_name = value
self.eat(token_type)
if self.current_token[0] == "LPAREN":
self.eat("LPAREN")
args = []
if self.current_token[0] != "RPAREN":
args.append(self.parse_expression())
while self.current_token[0] == "COMMA":
self.eat("COMMA")
args.append(self.parse_expression())
self.eat("RPAREN")
return VectorConstructorNode(type_name, args)
return type_name
raise SyntaxError(
f"Unexpected token in primary expression: {self.current_token}"
)
def parse_numeric_literal(self, token_type, value):
if token_type == "HEX_NUMBER":
stripped = re.sub(r"[uUlL]+$", "", value)
return int(stripped, 16)
if token_type == "BINARY_NUMBER":
stripped = re.sub(r"[uUlL]+$", "", value)
return int(stripped, 2)
if token_type == "OCT_NUMBER":
stripped = re.sub(r"[uUlL]+$", "", value)
return int(stripped, 8)
stripped = re.sub(r"[fFhHuUlL]+$", "", value)
if "." in stripped or "e" in stripped or "E" in stripped:
return float(stripped)
return int(stripped)