"""SPIR-V/Vulkan AST Node definitions"""
from ..common_ast import (
ASTNode,
AssignmentNode,
BinaryOpNode,
BreakNode,
CaseNode,
DoWhileNode,
ForNode,
FunctionCallNode,
FunctionNode,
IfNode,
MemberAccessNode,
ReturnNode,
ShaderNode,
StructNode,
SwitchNode,
TernaryOpNode,
UnaryOpNode,
VariableNode,
WhileNode,
)
# Keep common AST imports used for re-exports (autoflake-safe).
_COMMON_NODES = (
AssignmentNode,
BinaryOpNode,
BreakNode,
CaseNode,
DoWhileNode,
ForNode,
FunctionCallNode,
FunctionNode,
IfNode,
MemberAccessNode,
ReturnNode,
ShaderNode,
StructNode,
SwitchNode,
TernaryOpNode,
UnaryOpNode,
VariableNode,
WhileNode,
)
# SPIR-V/Vulkan-specific nodes
[docs]
class DescriptorSetNode(ASTNode):
"""Node representing a descriptor set binding"""
def __init__(self, set_number, binding, descriptor_type, name):
self.set_number = set_number
self.binding = binding
self.descriptor_type = descriptor_type
self.name = name
def __repr__(self):
return f"DescriptorSetNode(set={self.set_number}, binding={self.binding}, type={self.descriptor_type}, name={self.name})"
[docs]
class LayoutNode(ASTNode):
"""Node representing layout qualifiers"""
def __init__(
self,
qualifiers,
declaration=None,
*,
push_constant=False,
layout_type=None,
data_type=None,
variable_name=None,
struct_fields=None,
block_name=None,
):
self.qualifiers = qualifiers
self.declaration = declaration
self.push_constant = push_constant
self.layout_type = layout_type
self.data_type = data_type
self.variable_name = variable_name
self.struct_fields = struct_fields or []
self.block_name = block_name
def __repr__(self):
return (
"LayoutNode("
f"qualifiers={self.qualifiers}, layout_type={self.layout_type}, "
f"data_type={self.data_type}, variable_name={self.variable_name}, "
f"block_name={self.block_name})"
)
[docs]
class ShaderStageNode(ASTNode):
"""Node representing shader stage annotation"""
def __init__(self, stage):
self.stage = stage
def __repr__(self):
return f"ShaderStageNode(stage={self.stage})"
[docs]
class PushConstantNode(ASTNode):
"""Node representing push constant block"""
def __init__(self, members):
self.members = members
def __repr__(self):
return f"PushConstantNode(members={self.members})"
[docs]
class DefaultNode(ASTNode):
"""Node representing default case in switch"""
def __init__(self, statements):
self.statements = statements
def __repr__(self):
return f"DefaultNode(statements={len(self.statements)})"