"""Parser for Vulkan SPIR-V source AST construction."""
from .VulkanLexer import *
from .VulkanAst import *
[docs]
class VulkanParser:
"""Parse Vulkan/SPIR-V style tokens into the Vulkan backend AST."""
def __init__(self, tokens):
"""Initialize the parser with a token stream from ``VulkanLexer``."""
self.tokens = tokens
self.pos = 0
self.current_token = self.tokens[self.pos]
self.skip_comments()
[docs]
def peek(self, offset):
"""Look ahead by offset tokens without consuming them."""
peek_index = self.pos + offset
if peek_index < len(self.tokens):
return self.tokens[peek_index][
0
] # Return the type of the token at the peeked index
return None
[docs]
def skip_until(self, token_type):
"""Skip tokens until the specified token type is found"""
while self.current_token[0] != token_type and self.current_token[0] != "EOF":
self.pos += 1
if self.pos < len(self.tokens):
self.current_token = self.tokens[self.pos]
else:
self.current_token = ("EOF", None)
return
[docs]
def eat(self, token_type):
"""Consume the current token when it matches ``token_type``."""
if self.current_token[0] == token_type:
self.pos += 1
self.current_token = (
self.tokens[self.pos] if self.pos < len(self.tokens) else ("EOF", None)
)
self.skip_comments()
else:
raise SyntaxError(f"Expected {token_type}, got {self.current_token[0]}")
[docs]
def parse(self):
"""Parse the complete token stream into a module AST."""
module = self.parse_module()
self.eat("EOF")
return module
[docs]
def parse_module(self):
"""Parse top-level Vulkan/SPIR-V declarations and functions."""
functions = []
structs = []
global_variables = []
while self.current_token[0] != "EOF":
if self.current_token[0] == "LAYOUT":
global_variables.append(self.parse_layout())
elif self.current_token[0] == "STRUCT":
structs.append(self.parse_struct())
elif self.current_token[0] == "UNIFORM":
global_variables.append(self.parse_uniform())
elif (
(
self.current_token[0]
in [
"VOID",
"FLOAT",
"INT",
"UINT",
"BOOL",
"VEC2",
"VEC3",
"VEC4",
"MAT2",
"MAT3",
"MAT4",
]
or self.current_token[1] in VALID_DATA_TYPES
)
and self.peek(1) == "IDENTIFIER"
and self.peek(2) == "LPAREN"
):
functions.append(self.parse_function())
elif (
self.current_token[0] == "IDENTIFIER"
or self.current_token[0]
in [
"FLOAT",
"INT",
"UINT",
"BOOL",
"VEC2",
"VEC3",
"VEC4",
"MAT2",
"MAT3",
"MAT4",
]
or self.current_token[1] in VALID_DATA_TYPES
):
global_variables.append(self.parse_variable(self.current_token[1]))
else:
self.eat(self.current_token[0])
return ShaderNode(
functions=functions,
structs=structs,
global_variables=global_variables,
)
def parse_layout(self):
self.eat("LAYOUT")
self.eat("LPAREN")
bindings = []
push_constant = False
if self.current_token[0] == "PUSH_CONSTANT":
push_constant = True
self.eat("PUSH_CONSTANT")
if self.current_token[0] == "COMMA":
self.eat("COMMA")
while self.current_token[0] != "RPAREN":
binding_name = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] == "EQUALS":
self.eat("EQUALS")
binding_value = self.current_token[1]
self.eat("NUMBER")
bindings.append((binding_name, binding_value))
else:
bindings.append((binding_name, None))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
self.eat("RPAREN")
layout_type = None
block_name = None
if self.current_token[0] in ["IN", "OUT", "UNIFORM", "BUFFER"]:
layout_type = self.current_token[0]
self.eat(layout_type)
if self.current_token[0] == "IDENTIFIER":
block_name = self.current_token[1]
self.eat(self.current_token[0])
data_type = None
struct_fields = None
if layout_type in ["UNIFORM", "BUFFER"]:
if self.current_token[0] == "LBRACE":
self.eat("LBRACE")
struct_fields = []
# Parse structured fields within the uniform/push_constant/buffer block
while self.current_token[0] != "RBRACE":
if self.current_token[1] in VALID_DATA_TYPES:
field_type = self.current_token[1]
self.eat(self.current_token[0])
else:
raise SyntaxError(
"Expected some data type before an identifier"
)
field_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
struct_fields.append((field_type, field_name))
self.eat("RBRACE")
data_type = "struct"
else:
raise SyntaxError(
"Expected structured data block after 'uniform' or 'buffer'"
)
else:
if self.current_token[1] in VALID_DATA_TYPES:
data_type = self.current_token[1]
self.eat(self.current_token[0])
else:
raise SyntaxError(f"Unexpected type: {self.current_token[1]}")
variable_name = None
if self.current_token[0] == "IDENTIFIER":
variable_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
return LayoutNode(
bindings,
push_constant=push_constant,
layout_type=layout_type,
data_type=data_type,
variable_name=variable_name,
struct_fields=struct_fields,
block_name=block_name,
)
def parse_push_constant(self):
self.eat("PUSH_CONSTANT")
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE":
members.append(self.parse_variable())
self.eat("RBRACE")
return PushConstantNode(members)
def parse_descriptor_set(self):
self.eat("DESCRIPTOR_SET")
set_number = self.current_token[1]
self.eat("NUMBER")
self.eat("LBRACE")
bindings = []
while self.current_token[0] != "RBRACE":
bindings.append(self.parse_variable())
self.eat("RBRACE")
return DescriptorSetNode(set_number, bindings)
def parse_struct(self):
self.eat("STRUCT")
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("LBRACE")
members = []
while self.current_token[0] != "RBRACE":
if self.current_token[0] in [
"VEC2",
"VEC3",
"VEC4",
"IVEC2",
"IVEC3",
"IVEC4",
"UVEC2",
"UVEC3",
"UVEC4",
"FLOAT",
"INT",
"UINT",
"BOOL",
"MAT2",
"MAT3",
"MAT4",
]:
type_name = self.current_token[1]
self.eat(self.current_token[0])
elif self.current_token[1] in VALID_DATA_TYPES:
type_name = self.current_token[1]
self.eat(self.current_token[0])
elif self.current_token[0] == "IDENTIFIER":
type_name = self.current_token[1]
self.eat("IDENTIFIER")
else:
raise SyntaxError(
f"Unexpected token in struct member: {self.current_token}"
)
member_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
members.append(VariableNode(type_name, member_name))
self.eat("RBRACE")
self.eat("SEMICOLON")
return StructNode(name, members)
def parse_function(self):
return_type = self.current_token[1]
if self.current_token[1] in VALID_DATA_TYPES:
self.eat(self.current_token[0])
else:
raise SyntaxError(f"Unexpected type: {self.current_token[1]}")
func_name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("LPAREN")
params = self.parse_parameters()
self.eat("RPAREN")
body = self.parse_block()
return FunctionNode(return_type, func_name, params, body)
def parse_parameters(self):
params = []
while self.current_token[0] != "RPAREN":
vtype = self.current_token[1]
self.eat(self.current_token[0])
name = self.current_token[1]
self.eat("IDENTIFIER")
params.append(VariableNode(vtype, name))
if self.current_token[0] == "COMMA":
self.eat("COMMA")
return params
def parse_block(self):
self.eat("LBRACE")
statements = []
while self.current_token[0] != "RBRACE":
statements.append(self.parse_body())
self.eat("RBRACE")
return statements
def parse_body(self):
token_type = self.current_token[0]
if token_type == "IDENTIFIER" or self.current_token[1] in VALID_DATA_TYPES:
return self.parse_assignment_or_function_call()
elif token_type == "IF":
return self.parse_if_statement()
elif token_type == "FOR":
return self.parse_for_statement()
elif token_type == "WHILE":
return self.parse_while_statement()
elif token_type == "DO":
return self.parse_do_while_statement()
elif token_type == "SWITCH":
return self.parse_switch_statement()
elif token_type == "BREAK":
self.eat("BREAK")
self.eat("SEMICOLON")
return BreakNode()
else:
return self.parse_expression_statement()
def parse_update(self):
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] == "POST_INCREMENT":
self.eat("POST_INCREMENT")
return UnaryOpNode("POST_INCREMENT", VariableNode("", name))
elif self.current_token[0] == "POST_DECREMENT":
self.eat("POST_DECREMENT")
return UnaryOpNode("POST_DECREMENT", VariableNode("", name))
elif self.current_token[0] in [
"EQUALS",
"ASSIGN_ADD",
"ASSIGN_SUB",
"ASSIGN_MUL",
"ASSIGN_DIV",
]:
op = self.current_token[0]
self.eat(op)
value = self.parse_expression()
if op == "EQUALS":
return AssignmentNode(name, value)
elif op == "ASSIGN_ADD":
return AssignmentNode(
name, BinaryOpNode(VariableNode("", name), "+", value)
)
elif op == "ASSIGN_SUB":
return AssignmentNode(
name, BinaryOpNode(VariableNode("", name), "-", value)
)
elif op == "ASSIGN_MUL":
return AssignmentNode(
name, BinaryOpNode(VariableNode("", name), "*", value)
)
elif op == "ASSIGN_DIV":
return AssignmentNode(
name, BinaryOpNode(VariableNode("", name), "/", value)
)
else:
raise SyntaxError(
f"Expected INCREMENT or DECREMENT, got {self.current_token[0]}"
)
elif self.current_token[0] == "PRE_INCREMENT":
self.eat("PRE_INCREMENT")
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
return UnaryOpNode("PRE_INCREMENT", VariableNode("", name))
else:
raise SyntaxError(
f"Expected IDENTIFIER after PRE_INCREMENT, got {self.current_token[0]}"
)
elif self.current_token[0] == "PRE_DECREMENT":
self.eat("PRE_DECREMENT")
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
return UnaryOpNode("PRE_DECREMENT", VariableNode("", name))
else:
raise SyntaxError(
f"Expected IDENTIFIER after PRE_DECREMENT, got {self.current_token[0]}"
)
else:
raise SyntaxError(f"Unexpected token in update: {self.current_token[0]}")
def parse_if_statement(self):
self.eat("IF")
self.eat("LPAREN")
if_condition = self.parse_expression()
self.eat("RPAREN")
if_body = self.parse_block()
else_body = None
else_if_chain = []
while self.current_token[0] == "ELSE" and self.peek(1) == "IF":
self.eat("ELSE")
self.eat("IF")
self.eat("LPAREN")
else_if_condition = self.parse_expression()
self.eat("RPAREN")
else_if_chain.append((else_if_condition, self.parse_block()))
if self.current_token[0] == "ELSE":
self.eat("ELSE")
else_body = self.parse_block()
return IfNode(
if_condition,
if_body,
else_body,
else_if_chain=else_if_chain,
)
def parse_for_statement(self):
self.eat("FOR")
self.eat("LPAREN")
initialization = self.parse_assignment_or_function_call()
condition = self.parse_expression()
self.eat("SEMICOLON")
increment = self.parse_update()
self.eat("RPAREN")
body = self.parse_block()
return ForNode(initialization, condition, increment, body)
def parse_variable(self, type_name):
name = self.current_token[1]
self.eat("IDENTIFIER")
while self.current_token[0] == "DOT":
self.eat("DOT")
member_name = self.current_token[1]
self.eat("IDENTIFIER")
name += "." + member_name
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return VariableNode(type_name, name)
elif self.current_token[0] == "EQUALS":
self.eat("EQUALS")
value = self.parse_expression()
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return AssignmentNode(VariableNode(type_name, name), value)
else:
self.skip_until("SEMICOLON")
self.eat("SEMICOLON")
return AssignmentNode(VariableNode(type_name, name), value)
elif self.current_token[0] in ("BINARY_AND", "BINARY_OR", "BINARY_XOR"):
op = self.current_token[0]
op_symbol = (
"&" if op == "BINARY_AND" else ("|" if op == "BINARY_OR" else "^")
)
self.eat(op)
right = self.parse_expression()
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return BinaryOpNode(VariableNode(type_name, name), op_symbol, right)
else:
self.skip_until("SEMICOLON")
self.eat("SEMICOLON")
return BinaryOpNode(VariableNode(type_name, name), op_symbol, right)
elif self.current_token[0] in (
"EQUALS",
"PLUS_EQUALS",
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"EQUAL",
"LESS_THAN",
"GREATER_THAN",
"LESS_EQUAL",
"GREATER_EQUAL",
"ASSIGN_AND",
"ASSIGN_OR",
"ASSIGN_XOR",
"ASSIGN_MOD",
"BITWISE_SHIFT_RIGHT",
"BITWISE_SHIFT_LEFT",
"BITWISE_XOR",
"ASSIGN_SHIFT_LEFT",
"ASSIGN_SHIFT_RIGHT",
):
op = self.current_token[0]
op_name = self.current_token[1]
self.eat(op)
value = self.parse_expression()
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return BinaryOpNode(VariableNode(type_name, name), op_name, value)
else:
self.skip_until("SEMICOLON")
self.eat("SEMICOLON")
return BinaryOpNode(VariableNode(type_name, name), op_name, value)
else:
self.skip_until("SEMICOLON")
self.eat("SEMICOLON")
return VariableNode(type_name, name)
def parse_member_access(self, object):
self.eat("DOT")
if self.current_token[0] != "IDENTIFIER":
raise SyntaxError(
f"Expected identifier after dot, got {self.current_token[0]}"
)
member = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] == "DOT":
return self.parse_member_access(MemberAccessNode(object, member))
return MemberAccessNode(object, member)
def parse_function_call(self, name):
self.eat("LPAREN")
args = []
if self.current_token[0] != "RPAREN":
args.append(self.parse_expression())
while self.current_token[0] == "COMMA":
self.eat("COMMA")
args.append(self.parse_expression())
self.eat("RPAREN")
return FunctionCallNode(name, args)
def parse_function_call_or_identifier(self):
func_name = self.current_token[1]
self.eat(self.current_token[0])
if self.current_token[0] == "LPAREN":
return self.parse_function_call(func_name)
elif self.current_token[0] == "DOT":
return self.parse_member_access(func_name)
return VariableNode("", func_name)
def parse_primary(self):
if self.current_token[0] == "MINUS":
self.eat("MINUS")
value = self.parse_primary()
return UnaryOpNode("-", value)
if (
self.current_token[0] == "BITWISE_NOT"
or self.current_token[0] == "BINARY_NOT"
):
self.eat(self.current_token[0])
value = self.parse_primary()
return UnaryOpNode("~", value)
if (
self.current_token[0] == "IDENTIFIER"
or self.current_token[1] in VALID_DATA_TYPES
):
return self.parse_function_call_or_identifier()
elif self.current_token[0] == "NUMBER":
value = self.current_token[1]
self.eat("NUMBER")
if value.endswith("u"):
value = value[:-1]
return value
elif self.current_token[0] == "LPAREN":
self.eat("LPAREN")
expr = self.parse_expression()
self.eat("RPAREN")
return expr
else:
raise SyntaxError(
f"Unexpected token in expression: {self.current_token[0]}"
)
def parse_multiplicative(self):
left = self.parse_primary()
while self.current_token[0] in ["MULTIPLY", "DIVIDE", "MOD"]:
op = self.current_token[1]
self.eat(self.current_token[0])
right = self.parse_primary()
left = BinaryOpNode(left, op, right)
return left
def parse_additive(self):
left = self.parse_multiplicative()
while self.current_token[0] in ["PLUS", "MINUS"]:
token_type = self.current_token[0]
op = self.current_token[1]
self.eat(token_type)
right = self.parse_multiplicative()
left = BinaryOpNode(left, op, right)
return left
def parse_assignment(self, name):
if self.current_token[0] in [
"EQUALS",
"PLUS_EQUALS",
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"LESS_THAN",
"GREATER_THAN",
"LESS_EQUAL",
"GREATER_EQUAL",
"ASSIGN_AND",
"ASSIGN_OR",
"ASSIGN_XOR",
"ASSIGN_MOD",
"BITWISE_SHIFT_RIGHT",
"BITWISE_SHIFT_LEFT",
"BITWISE_XOR",
"ASSIGN_SHIFT_LEFT",
"ASSIGN_SHIFT_RIGHT",
]:
op = self.current_token[0]
op_name = self.current_token[1]
self.eat(op)
value = self.parse_expression()
if self.current_token[0] == "SEMICOLON":
self.eat("SEMICOLON")
return BinaryOpNode(name, op_name, value)
else:
raise SyntaxError(
f"Expected assignment operator, found: {self.current_token[0]}"
)
def parse_assignment_or_function_call(self):
type_name = ""
if self.current_token[0] == "IDENTIFIER" and self.peek(1) in [
"POST_INCREMENT",
"POST_DECREMENT",
]:
name = self.current_token[1]
self.eat("IDENTIFIER")
if self.current_token[0] in [
"EQUALS",
"PLUS_EQUALS",
"MINUS_EQUALS",
"MULTIPLY_EQUALS",
"DIVIDE_EQUALS",
"LESS_THAN",
"GREATER_THAN",
"LESS_EQUAL",
"GREATER_EQUAL",
"ASSIGN_AND",
"ASSIGN_OR",
"ASSIGN_XOR",
"ASSIGN_MOD",
"BITWISE_SHIFT_RIGHT",
"BITWISE_SHIFT_LEFT",
"BITWISE_XOR",
"ASSIGN_SHIFT_LEFT",
"ASSIGN_SHIFT_RIGHT",
]:
return self.parse_assignment(name) # todo
elif self.current_token[0] == "POST_INCREMENT":
self.eat("POST_INCREMENT")
self.eat("SEMICOLON")
return AssignmentNode(
name, UnaryOpNode("POST_INCREMENT", VariableNode("", name))
)
elif self.current_token[0] == "POST_DECREMENT":
self.eat("POST_DECREMENT")
self.eat("SEMICOLON")
return AssignmentNode(
name, UnaryOpNode("POST_DECREMENT", VariableNode("", name))
)
elif self.current_token[0] == "LPAREN":
return self.parse_function_call(name)
else:
raise SyntaxError(
f"Unexpected token after identifier: {self.current_token[0]}"
)
if self.current_token[1] in VALID_DATA_TYPES:
type_name = self.current_token[1]
self.eat(self.current_token[0])
if self.current_token[0] == "IDENTIFIER":
return self.parse_variable(type_name)
def parse_expression(self):
left = self.parse_bitwise_expression()
while self.current_token[0] in [
"LESS_THAN",
"GREATER_THAN",
"LESS_EQUAL",
"GREATER_EQUAL",
"EQUAL",
"NOT_EQUAL",
"AND",
"OR",
]:
op = self.current_token[0]
op_symbol = self.current_token[1] if len(self.current_token) > 1 else op
self.eat(op)
right = self.parse_bitwise_expression()
left = BinaryOpNode(left, op_symbol, right)
if self.current_token[0] == "QUESTION":
self.eat("QUESTION")
true_expr = self.parse_expression()
self.eat("COLON")
false_expr = self.parse_expression()
left = TernaryOpNode(left, true_expr, false_expr)
return left
def parse_bitwise_expression(self):
left = self.parse_additive()
while self.current_token[0] in [
"BINARY_AND",
"BINARY_OR",
"BINARY_XOR",
"BITWISE_SHIFT_LEFT",
"BITWISE_SHIFT_RIGHT",
]:
op = self.current_token[0]
self.eat(op)
right = self.parse_additive()
op_symbol = (
"&"
if op == "BINARY_AND"
else (
"|"
if op == "BINARY_OR"
else (
"^"
if op == "BINARY_XOR"
else "<<" if op == "BITWISE_SHIFT_LEFT" else ">>"
)
)
)
left = BinaryOpNode(left, op_symbol, right)
return left
def parse_expression_statement(self):
expr = self.parse_expression()
# self.eat("SEMICOLON")
return expr
def parse_while_statement(self):
self.eat("WHILE")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
body = self.parse_block()
return WhileNode(condition, body)
def parse_do_while_statement(self):
self.eat("DO")
body = self.parse_block()
self.eat("WHILE")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
self.eat("SEMICOLON")
return DoWhileNode(body, condition)
def parse_switch_statement(self):
self.eat("SWITCH")
self.eat("LPAREN")
expr = self.parse_expression()
self.eat("RPAREN")
self.eat("LBRACE")
cases = []
while self.current_token[0] != "RBRACE":
cases.append(self.parse_case_statement())
self.eat("RBRACE")
return SwitchNode(expr, cases)
def parse_case_statement(self):
if self.current_token[0] == "CASE":
self.eat("CASE")
value = self.parse_expression()
self.eat("COLON")
elif self.current_token[0] == "DEFAULT":
self.eat("DEFAULT")
value = None
self.eat("COLON")
statements = []
while self.current_token[0] not in ["CASE", "DEFAULT", "RBRACE"]:
statements.append(self.parse_body())
return CaseNode(value, statements)
def parse_default_statement(self):
self.eat("DEFAULT")
self.eat("COLON")
statements = []
while self.current_token[0] not in ["CASE", "RBRACE"]:
statements.append(self.parse_body())
return DefaultNode(statements)
def parse_uniform(self):
self.eat("UNIFORM")
var_type = self.current_token[1]
if self.current_token[1] in VALID_DATA_TYPES:
self.eat(self.current_token[0])
else:
raise SyntaxError(f"Unexpected type: {self.current_token[1]}")
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
return UniformNode(name, var_type)
def parse_unary(self):
if self.current_token[0] in ["PLUS", "MINUS", "BITWISE_NOT"]:
op = self.current_token[1]
self.eat(self.current_token[0])
operand = self.parse_unary()
return UnaryOpNode(op, operand)
return self.parse_primary()