Source code for crosstl.backend.common_ast

"""Common AST node definitions shared across backends."""


[docs] class ASTNode: """Base class for backend AST nodes produced by native source parsers."""
[docs] class ShaderNode(ASTNode): """Root node representing a complete program.""" def __init__( self, includes=None, functions=None, structs=None, global_variables=None, kernels=None, *args, # Accept extra positional args **kwargs, # Accept extra keyword args for compatibility ): self.includes = includes or [] self.functions = functions or [] self.structs = structs or [] self.global_variables = global_variables or [] self.kernels = kernels or [] self.uniforms = kwargs.get("uniforms", []) self.in_out = kwargs.get("in_out", []) self.constant = kwargs.get("constant", []) self.io_variables = kwargs.get("io_variables", []) self.processors = kwargs.get("processors", []) self.cbuffers = kwargs.get("cbuffers", []) self.imports = kwargs.get("imports", []) self.exports = kwargs.get("exports", []) self.typedefs = kwargs.get("typedefs", []) self.extensions = kwargs.get("extensions", []) self.global_vars = kwargs.get("global_vars", self.global_variables) if args: if len(args) >= 1: self.uniforms = args[0] if len(args) >= 2: self.in_out = args[1] if len(args) >= 3: self.constant = args[2] for key, value in kwargs.items(): if not hasattr(self, key): setattr(self, key, value) def __repr__(self): return f"ShaderNode(includes={self.includes}, functions={self.functions}, structs={self.structs}, global_variables={self.global_variables}, kernels={self.kernels})"
[docs] class FunctionNode(ASTNode): """Node representing a function declaration.""" def __init__( self, return_type, name, params, body, qualifiers=None, attributes=None, *args, **kwargs, ): self.return_type = return_type self.name = name self.params = params self.body = body self.qualifiers = qualifiers or [] self.attributes = attributes or [] self.generics = kwargs.get("generics", []) if args: if len(args) >= 1: self.qualifier = args[0] for key, value in kwargs.items(): if not hasattr(self, key): setattr(self, key, value) def __repr__(self): return f"FunctionNode(return_type={self.return_type}, name={self.name}, params={self.params}, body={self.body}, qualifiers={self.qualifiers})"
[docs] class StructNode(ASTNode): """Node representing a struct declaration.""" def __init__(self, name, members, attributes=None): self.name = name self.members = members self.attributes = attributes or [] def __repr__(self): return f"StructNode(name={self.name}, members={self.members})"
[docs] class EnumNode(ASTNode): """Node representing an enum declaration""" def __init__(self, name, members): self.name = name self.members = members # list of (name, value_or_None) def __repr__(self): return f"EnumNode(name={self.name}, members={self.members})"
[docs] class TypeAliasNode(ASTNode): """Node representing a typedef/alias""" def __init__(self, alias_type, name): self.alias_type = alias_type self.name = name def __repr__(self): return f"TypeAliasNode(alias_type={self.alias_type}, name={self.name})"
[docs] class StaticAssertNode(ASTNode): """Node representing a static_assert""" def __init__(self, condition, message=None): self.condition = condition self.message = message def __repr__(self): return f"StaticAssertNode(condition={self.condition}, message={self.message})"
[docs] class VariableNode(ASTNode): """Node representing a variable declaration""" def __init__( self, vtype, name, value=None, qualifiers=None, attributes=None, is_const=False, **kwargs, ): self.vtype = vtype self.name = name self.value = value self.qualifiers = qualifiers or [] self.attributes = attributes or [] self.is_const = is_const self.semantic = kwargs.get("semantic", None) for key, val in kwargs.items(): if not hasattr(self, key): setattr(self, key, val) def __repr__(self): return f"VariableNode(vtype={self.vtype}, name={self.name}, value={self.value}, qualifiers={self.qualifiers})"
[docs] class InitializerListNode(ASTNode): """Node representing C-style brace initializer lists.""" def __init__(self, elements=None): self.elements = elements or [] def __repr__(self): return f"InitializerListNode(elements={self.elements})"
[docs] class DesignatedInitializerNode(ASTNode): """Node representing C99-style designated initializer entries.""" def __init__(self, designators=None, value=None): self.designators = designators or [] self.value = value def __repr__(self): return ( f"DesignatedInitializerNode(designators={self.designators}, " f"value={self.value})" )
[docs] class AssignmentNode(ASTNode): """Node representing an assignment operation""" def __init__(self, left, right, operator="="): self.left = left self.right = right self.operator = operator def __repr__(self): return f"AssignmentNode(left={self.left}, operator={self.operator}, right={self.right})"
[docs] class BinaryOpNode(ASTNode): """Node representing a binary operation""" def __init__(self, left, op, right): self.left = left self.op = op self.right = right def __repr__(self): return f"BinaryOpNode(left={self.left}, op={self.op}, right={self.right})"
[docs] class UnaryOpNode(ASTNode): """Node representing a unary operation""" def __init__(self, op, operand): self.op = op self.operand = operand def __repr__(self): return f"UnaryOpNode(op={self.op}, operand={self.operand})"
[docs] class PostfixOpNode(ASTNode): """Node representing a postfix operation (e.g., i++, i--)""" def __init__(self, operand, op): self.operand = operand self.op = op def __repr__(self): return f"PostfixOpNode(operand={self.operand}, op={self.op})"
[docs] class FunctionCallNode(ASTNode): """Node representing a function call""" def __init__(self, name, args): self.name = name self.args = args def __repr__(self): return f"FunctionCallNode(name={self.name}, args={self.args})"
[docs] class MethodCallNode(ASTNode): """Node representing a method call on an object""" def __init__(self, object, method, args): self.object = object self.method = method self.args = args def __repr__(self): return f"MethodCallNode(object={self.object}, method={self.method}, args={self.args})"
[docs] class CallNode(ASTNode): """Node representing a call on a callee expression""" def __init__(self, callee, args): self.callee = callee self.args = args def __repr__(self): return f"CallNode(callee={self.callee}, args={self.args})"
[docs] class MemberAccessNode(ASTNode): """Node representing member access (dot or arrow operator)""" def __init__(self, object, member, is_pointer=False): self.object = object self.member = member self.is_pointer = is_pointer def __repr__(self): op = "->" if self.is_pointer else "." return f"MemberAccessNode(object={self.object}, member={self.member}, operator={op})"
[docs] class ArrayAccessNode(ASTNode): """Node representing array access""" def __init__(self, array, index): self.array = array self.index = index def __repr__(self): return f"ArrayAccessNode(array={self.array}, index={self.index})"
[docs] class IfNode(ASTNode): """Node representing an if statement.""" def __init__( self, condition=None, if_body=None, else_body=None, if_chain=None, else_if_chain=None, ): if if_chain is not None or else_if_chain is not None: self.if_chain = if_chain or [] self.else_if_chain = else_if_chain or [] self.else_body = else_body # Extract condition and if_body from if_chain if available if self.if_chain: self.condition = condition or ( self.if_chain[0][0] if self.if_chain else None ) self.if_body = if_body or ( self.if_chain[0][1] if self.if_chain else None ) else: self.condition = condition self.if_body = if_body else: self.condition = condition self.if_body = if_body self.else_body = else_body self.if_chain = [] self.else_if_chain = [] def __repr__(self): return f"IfNode(condition={self.condition}, if_body={self.if_body}, else_body={self.else_body})"
[docs] class ForNode(ASTNode): """Node representing a for loop""" def __init__(self, init, condition, update, body): self.init = init self.condition = condition self.update = update self.body = body def __repr__(self): return f"ForNode(init={self.init}, condition={self.condition}, update={self.update}, body={self.body})"
[docs] class RangeForNode(ASTNode): """Node representing a C++ range-based for loop""" def __init__(self, vtype, name, iterable, body): self.vtype = vtype self.name = name self.iterable = iterable self.body = body def __repr__(self): return f"RangeForNode(vtype={self.vtype}, name={self.name}, iterable={self.iterable}, body={self.body})"
[docs] class WhileNode(ASTNode): """Node representing a while loop""" def __init__(self, condition, body): self.condition = condition self.body = body def __repr__(self): return f"WhileNode(condition={self.condition}, body={self.body})"
[docs] class DoWhileNode(ASTNode): """Node representing a do-while loop""" def __init__(self, body, condition): self.body = body self.condition = condition def __repr__(self): return f"DoWhileNode(body={self.body}, condition={self.condition})"
[docs] class SwitchNode(ASTNode): """Node representing a switch statement.""" def __init__(self, expression, cases, default_case=None, default=None): self.expression = expression self.cases = cases self.default_case = default_case or default self.default = self.default_case def __repr__(self): return f"SwitchNode(expression={self.expression}, cases={self.cases}, default_case={self.default_case})"
[docs] class CaseNode(ASTNode): """Node representing a case in a switch statement.""" def __init__(self, value, body=None, statements=None): self.value = value self.body = body or statements or [] self.statements = self.body def __repr__(self): return f"CaseNode(value={self.value}, body={self.body})"
[docs] class ReturnNode(ASTNode): """Node representing a return statement""" def __init__(self, value=None): self.value = value def __repr__(self): return f"ReturnNode(value={self.value})"
[docs] class ContinueNode(ASTNode): """Node representing a continue statement.""" def __init__(self, *args, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) def __repr__(self): return "ContinueNode()"
[docs] class BreakNode(ASTNode): """Node representing a break statement.""" def __init__(self, *args, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) def __repr__(self): return "BreakNode()"
[docs] class DiscardNode(ASTNode): """Node representing a discard statement""" def __init__(self, *args, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) def __repr__(self): return "DiscardNode()"
[docs] class VectorConstructorNode(ASTNode): """Node representing a vector constructor.""" def __init__(self, vector_type, args, type_name=None): self.vector_type = vector_type or type_name self.type_name = self.vector_type self.args = args def __repr__(self): return ( f"VectorConstructorNode(vector_type={self.vector_type}, args={self.args})" )
[docs] class TernaryOpNode(ASTNode): """Node representing a ternary conditional operator""" def __init__(self, condition, true_expr, false_expr): self.condition = condition self.true_expr = true_expr self.false_expr = false_expr def __repr__(self): return f"TernaryOpNode(condition={self.condition}, true_expr={self.true_expr}, false_expr={self.false_expr})"
[docs] class CastNode(ASTNode): """Node representing a type cast""" def __init__(self, target_type, expression): self.target_type = target_type self.expression = expression def __repr__(self): return f"CastNode(target_type={self.target_type}, expression={self.expression})"
[docs] class NewNode(ASTNode): """Node representing a C++ new expression""" def __init__(self, target_type, size=None, args=None, is_array=False): self.target_type = target_type self.size = size self.args = args or [] self.is_array = is_array def __repr__(self): return f"NewNode(target_type={self.target_type}, size={self.size}, args={self.args}, is_array={self.is_array})"
[docs] class DeleteNode(ASTNode): """Node representing a C++ delete expression statement""" def __init__(self, expression, is_array=False): self.expression = expression self.is_array = is_array def __repr__(self): return f"DeleteNode(expression={self.expression}, is_array={self.is_array})"
[docs] class PreprocessorNode(ASTNode): """Node representing preprocessor directives""" def __init__(self, directive, content): self.directive = directive self.content = content def __repr__(self): return f"PreprocessorNode(directive={self.directive}, content={self.content})"
[docs] class AttributeNode(ASTNode): """Attributes/annotations.""" def __init__(self, name, args=None, arguments=None): self.name = name self.args = args or arguments or [] self.arguments = self.args def __repr__(self): return f"AttributeNode(name='{self.name}', args={self.args})"
[docs] class TextureSampleNode(ASTNode): """Node representing texture sampling""" def __init__(self, texture, sampler, coordinates, lod=None): self.texture = texture self.sampler = sampler self.coordinates = coordinates self.lod = lod def __repr__(self): if self.lod is not None: return f"TextureSampleNode(texture={self.texture}, sampler={self.sampler}, coordinates={self.coordinates}, lod={self.lod})" return f"TextureSampleNode(texture={self.texture}, sampler={self.sampler}, coordinates={self.coordinates})"
[docs] class SyncNode(ASTNode): """Node representing synchronization operations.""" def __init__(self, sync_type, args=None, arguments=None): self.sync_type = sync_type self.args = args or arguments or [] self.arguments = self.args def __repr__(self): return f"SyncNode(sync_type={self.sync_type}, args={self.args})"
[docs] class ThreadgroupSyncNode(ASTNode): """Node representing threadgroup synchronization""" def __init__(self): pass def __repr__(self): return "ThreadgroupSyncNode()"
[docs] class ConstantBufferNode(ASTNode): """Node representing constant buffer""" def __init__(self, name, members): self.name = name self.members = members def __repr__(self): return f"ConstantBufferNode(name={self.name}, members={self.members})"