"""Common AST node definitions shared across backends."""
[docs]
class ASTNode:
"""Base class for backend AST nodes produced by native source parsers."""
[docs]
class ShaderNode(ASTNode):
"""Root node representing a complete program."""
def __init__(
self,
includes=None,
functions=None,
structs=None,
global_variables=None,
kernels=None,
*args, # Accept extra positional args
**kwargs, # Accept extra keyword args for compatibility
):
self.includes = includes or []
self.functions = functions or []
self.structs = structs or []
self.global_variables = global_variables or []
self.kernels = kernels or []
self.uniforms = kwargs.get("uniforms", [])
self.in_out = kwargs.get("in_out", [])
self.constant = kwargs.get("constant", [])
self.io_variables = kwargs.get("io_variables", [])
self.processors = kwargs.get("processors", [])
self.cbuffers = kwargs.get("cbuffers", [])
self.imports = kwargs.get("imports", [])
self.exports = kwargs.get("exports", [])
self.typedefs = kwargs.get("typedefs", [])
self.extensions = kwargs.get("extensions", [])
self.global_vars = kwargs.get("global_vars", self.global_variables)
if args:
if len(args) >= 1:
self.uniforms = args[0]
if len(args) >= 2:
self.in_out = args[1]
if len(args) >= 3:
self.constant = args[2]
for key, value in kwargs.items():
if not hasattr(self, key):
setattr(self, key, value)
def __repr__(self):
return f"ShaderNode(includes={self.includes}, functions={self.functions}, structs={self.structs}, global_variables={self.global_variables}, kernels={self.kernels})"
[docs]
class FunctionNode(ASTNode):
"""Node representing a function declaration."""
def __init__(
self,
return_type,
name,
params,
body,
qualifiers=None,
attributes=None,
*args,
**kwargs,
):
self.return_type = return_type
self.name = name
self.params = params
self.body = body
self.qualifiers = qualifiers or []
self.attributes = attributes or []
self.generics = kwargs.get("generics", [])
if args:
if len(args) >= 1:
self.qualifier = args[0]
for key, value in kwargs.items():
if not hasattr(self, key):
setattr(self, key, value)
def __repr__(self):
return f"FunctionNode(return_type={self.return_type}, name={self.name}, params={self.params}, body={self.body}, qualifiers={self.qualifiers})"
[docs]
class StructNode(ASTNode):
"""Node representing a struct declaration."""
def __init__(self, name, members, attributes=None):
self.name = name
self.members = members
self.attributes = attributes or []
def __repr__(self):
return f"StructNode(name={self.name}, members={self.members})"
[docs]
class EnumNode(ASTNode):
"""Node representing an enum declaration"""
def __init__(self, name, members):
self.name = name
self.members = members # list of (name, value_or_None)
def __repr__(self):
return f"EnumNode(name={self.name}, members={self.members})"
[docs]
class TypeAliasNode(ASTNode):
"""Node representing a typedef/alias"""
def __init__(self, alias_type, name):
self.alias_type = alias_type
self.name = name
def __repr__(self):
return f"TypeAliasNode(alias_type={self.alias_type}, name={self.name})"
[docs]
class StaticAssertNode(ASTNode):
"""Node representing a static_assert"""
def __init__(self, condition, message=None):
self.condition = condition
self.message = message
def __repr__(self):
return f"StaticAssertNode(condition={self.condition}, message={self.message})"
[docs]
class VariableNode(ASTNode):
"""Node representing a variable declaration"""
def __init__(
self,
vtype,
name,
value=None,
qualifiers=None,
attributes=None,
is_const=False,
**kwargs,
):
self.vtype = vtype
self.name = name
self.value = value
self.qualifiers = qualifiers or []
self.attributes = attributes or []
self.is_const = is_const
self.semantic = kwargs.get("semantic", None)
for key, val in kwargs.items():
if not hasattr(self, key):
setattr(self, key, val)
def __repr__(self):
return f"VariableNode(vtype={self.vtype}, name={self.name}, value={self.value}, qualifiers={self.qualifiers})"
[docs]
class InitializerListNode(ASTNode):
"""Node representing C-style brace initializer lists."""
def __init__(self, elements=None):
self.elements = elements or []
def __repr__(self):
return f"InitializerListNode(elements={self.elements})"
[docs]
class DesignatedInitializerNode(ASTNode):
"""Node representing C99-style designated initializer entries."""
def __init__(self, designators=None, value=None):
self.designators = designators or []
self.value = value
def __repr__(self):
return (
f"DesignatedInitializerNode(designators={self.designators}, "
f"value={self.value})"
)
[docs]
class AssignmentNode(ASTNode):
"""Node representing an assignment operation"""
def __init__(self, left, right, operator="="):
self.left = left
self.right = right
self.operator = operator
def __repr__(self):
return f"AssignmentNode(left={self.left}, operator={self.operator}, right={self.right})"
[docs]
class BinaryOpNode(ASTNode):
"""Node representing a binary operation"""
def __init__(self, left, op, right):
self.left = left
self.op = op
self.right = right
def __repr__(self):
return f"BinaryOpNode(left={self.left}, op={self.op}, right={self.right})"
[docs]
class UnaryOpNode(ASTNode):
"""Node representing a unary operation"""
def __init__(self, op, operand):
self.op = op
self.operand = operand
def __repr__(self):
return f"UnaryOpNode(op={self.op}, operand={self.operand})"
[docs]
class PostfixOpNode(ASTNode):
"""Node representing a postfix operation (e.g., i++, i--)"""
def __init__(self, operand, op):
self.operand = operand
self.op = op
def __repr__(self):
return f"PostfixOpNode(operand={self.operand}, op={self.op})"
[docs]
class FunctionCallNode(ASTNode):
"""Node representing a function call"""
def __init__(self, name, args):
self.name = name
self.args = args
def __repr__(self):
return f"FunctionCallNode(name={self.name}, args={self.args})"
[docs]
class MethodCallNode(ASTNode):
"""Node representing a method call on an object"""
def __init__(self, object, method, args):
self.object = object
self.method = method
self.args = args
def __repr__(self):
return f"MethodCallNode(object={self.object}, method={self.method}, args={self.args})"
[docs]
class CallNode(ASTNode):
"""Node representing a call on a callee expression"""
def __init__(self, callee, args):
self.callee = callee
self.args = args
def __repr__(self):
return f"CallNode(callee={self.callee}, args={self.args})"
[docs]
class MemberAccessNode(ASTNode):
"""Node representing member access (dot or arrow operator)"""
def __init__(self, object, member, is_pointer=False):
self.object = object
self.member = member
self.is_pointer = is_pointer
def __repr__(self):
op = "->" if self.is_pointer else "."
return f"MemberAccessNode(object={self.object}, member={self.member}, operator={op})"
[docs]
class ArrayAccessNode(ASTNode):
"""Node representing array access"""
def __init__(self, array, index):
self.array = array
self.index = index
def __repr__(self):
return f"ArrayAccessNode(array={self.array}, index={self.index})"
[docs]
class IfNode(ASTNode):
"""Node representing an if statement."""
def __init__(
self,
condition=None,
if_body=None,
else_body=None,
if_chain=None,
else_if_chain=None,
):
if if_chain is not None or else_if_chain is not None:
self.if_chain = if_chain or []
self.else_if_chain = else_if_chain or []
self.else_body = else_body
# Extract condition and if_body from if_chain if available
if self.if_chain:
self.condition = condition or (
self.if_chain[0][0] if self.if_chain else None
)
self.if_body = if_body or (
self.if_chain[0][1] if self.if_chain else None
)
else:
self.condition = condition
self.if_body = if_body
else:
self.condition = condition
self.if_body = if_body
self.else_body = else_body
self.if_chain = []
self.else_if_chain = []
def __repr__(self):
return f"IfNode(condition={self.condition}, if_body={self.if_body}, else_body={self.else_body})"
[docs]
class ForNode(ASTNode):
"""Node representing a for loop"""
def __init__(self, init, condition, update, body):
self.init = init
self.condition = condition
self.update = update
self.body = body
def __repr__(self):
return f"ForNode(init={self.init}, condition={self.condition}, update={self.update}, body={self.body})"
[docs]
class RangeForNode(ASTNode):
"""Node representing a C++ range-based for loop"""
def __init__(self, vtype, name, iterable, body):
self.vtype = vtype
self.name = name
self.iterable = iterable
self.body = body
def __repr__(self):
return f"RangeForNode(vtype={self.vtype}, name={self.name}, iterable={self.iterable}, body={self.body})"
[docs]
class WhileNode(ASTNode):
"""Node representing a while loop"""
def __init__(self, condition, body):
self.condition = condition
self.body = body
def __repr__(self):
return f"WhileNode(condition={self.condition}, body={self.body})"
[docs]
class DoWhileNode(ASTNode):
"""Node representing a do-while loop"""
def __init__(self, body, condition):
self.body = body
self.condition = condition
def __repr__(self):
return f"DoWhileNode(body={self.body}, condition={self.condition})"
[docs]
class SwitchNode(ASTNode):
"""Node representing a switch statement."""
def __init__(self, expression, cases, default_case=None, default=None):
self.expression = expression
self.cases = cases
self.default_case = default_case or default
self.default = self.default_case
def __repr__(self):
return f"SwitchNode(expression={self.expression}, cases={self.cases}, default_case={self.default_case})"
[docs]
class CaseNode(ASTNode):
"""Node representing a case in a switch statement."""
def __init__(self, value, body=None, statements=None):
self.value = value
self.body = body or statements or []
self.statements = self.body
def __repr__(self):
return f"CaseNode(value={self.value}, body={self.body})"
[docs]
class ReturnNode(ASTNode):
"""Node representing a return statement"""
def __init__(self, value=None):
self.value = value
def __repr__(self):
return f"ReturnNode(value={self.value})"
[docs]
class ContinueNode(ASTNode):
"""Node representing a continue statement."""
def __init__(self, *args, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __repr__(self):
return "ContinueNode()"
[docs]
class BreakNode(ASTNode):
"""Node representing a break statement."""
def __init__(self, *args, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __repr__(self):
return "BreakNode()"
[docs]
class DiscardNode(ASTNode):
"""Node representing a discard statement"""
def __init__(self, *args, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __repr__(self):
return "DiscardNode()"
[docs]
class VectorConstructorNode(ASTNode):
"""Node representing a vector constructor."""
def __init__(self, vector_type, args, type_name=None):
self.vector_type = vector_type or type_name
self.type_name = self.vector_type
self.args = args
def __repr__(self):
return (
f"VectorConstructorNode(vector_type={self.vector_type}, args={self.args})"
)
[docs]
class TernaryOpNode(ASTNode):
"""Node representing a ternary conditional operator"""
def __init__(self, condition, true_expr, false_expr):
self.condition = condition
self.true_expr = true_expr
self.false_expr = false_expr
def __repr__(self):
return f"TernaryOpNode(condition={self.condition}, true_expr={self.true_expr}, false_expr={self.false_expr})"
[docs]
class CastNode(ASTNode):
"""Node representing a type cast"""
def __init__(self, target_type, expression):
self.target_type = target_type
self.expression = expression
def __repr__(self):
return f"CastNode(target_type={self.target_type}, expression={self.expression})"
[docs]
class NewNode(ASTNode):
"""Node representing a C++ new expression"""
def __init__(self, target_type, size=None, args=None, is_array=False):
self.target_type = target_type
self.size = size
self.args = args or []
self.is_array = is_array
def __repr__(self):
return f"NewNode(target_type={self.target_type}, size={self.size}, args={self.args}, is_array={self.is_array})"
[docs]
class DeleteNode(ASTNode):
"""Node representing a C++ delete expression statement"""
def __init__(self, expression, is_array=False):
self.expression = expression
self.is_array = is_array
def __repr__(self):
return f"DeleteNode(expression={self.expression}, is_array={self.is_array})"
[docs]
class PreprocessorNode(ASTNode):
"""Node representing preprocessor directives"""
def __init__(self, directive, content):
self.directive = directive
self.content = content
def __repr__(self):
return f"PreprocessorNode(directive={self.directive}, content={self.content})"
[docs]
class AttributeNode(ASTNode):
"""Attributes/annotations."""
def __init__(self, name, args=None, arguments=None):
self.name = name
self.args = args or arguments or []
self.arguments = self.args
def __repr__(self):
return f"AttributeNode(name='{self.name}', args={self.args})"
[docs]
class TextureSampleNode(ASTNode):
"""Node representing texture sampling"""
def __init__(self, texture, sampler, coordinates, lod=None):
self.texture = texture
self.sampler = sampler
self.coordinates = coordinates
self.lod = lod
def __repr__(self):
if self.lod is not None:
return f"TextureSampleNode(texture={self.texture}, sampler={self.sampler}, coordinates={self.coordinates}, lod={self.lod})"
return f"TextureSampleNode(texture={self.texture}, sampler={self.sampler}, coordinates={self.coordinates})"
[docs]
class SyncNode(ASTNode):
"""Node representing synchronization operations."""
def __init__(self, sync_type, args=None, arguments=None):
self.sync_type = sync_type
self.args = args or arguments or []
self.arguments = self.args
def __repr__(self):
return f"SyncNode(sync_type={self.sync_type}, args={self.args})"
[docs]
class ThreadgroupSyncNode(ASTNode):
"""Node representing threadgroup synchronization"""
def __init__(self):
pass
def __repr__(self):
return "ThreadgroupSyncNode()"
[docs]
class ConstantBufferNode(ASTNode):
"""Node representing constant buffer"""
def __init__(self, name, members):
self.name = name
self.members = members
def __repr__(self):
return f"ConstantBufferNode(name={self.name}, members={self.members})"