"""Canonical CrossGL abstract syntax tree node definitions."""
from typing import List, Optional, Any, Union, Dict
from enum import Enum
[docs]
class ASTNode:
"""Base class for all AST nodes with common functionality."""
def __init__(self, source_location=None, annotations=None):
"""Initialize source metadata shared by all AST nodes."""
self.source_location = source_location
self.annotations = annotations or {}
self.parent = None
[docs]
def accept(self, visitor):
"""Dispatch this node to a visitor method named for its class."""
method_name = f"visit_{self.__class__.__name__}"
method = getattr(visitor, method_name, visitor.generic_visit)
return method(self)
[docs]
def add_annotation(self, key: str, value: Any):
"""Attach backend or analysis metadata to this node."""
self.annotations[key] = value
[docs]
def get_annotation(self, key: str, default=None):
"""Return an annotation value, or a default when it is absent."""
return self.annotations.get(key, default)
# ============================================================================
# TYPE SYSTEM
# ============================================================================
[docs]
class TypeNode(ASTNode):
"""Base class for all type representations."""
[docs]
class PrimitiveType(TypeNode):
"""Primitive types (int, float, bool, etc.)."""
def __init__(self, name: str, size_bits: Optional[int] = None, **kwargs):
super().__init__(**kwargs)
self.name = name
self.size_bits = size_bits
def __repr__(self):
return f"PrimitiveType(name={self.name}, size_bits={self.size_bits})"
[docs]
class VectorType(TypeNode):
"""Vector types (vec2, vec3, vec4, float3, etc.)."""
def __init__(self, element_type: TypeNode, size: int, **kwargs):
super().__init__(**kwargs)
self.element_type = element_type
self.size = size
def __repr__(self):
return f"VectorType(element_type={self.element_type}, size={self.size})"
[docs]
class MatrixType(TypeNode):
"""Matrix types (mat4, float4x4, etc.)."""
def __init__(self, element_type: TypeNode, rows: int, cols: int, **kwargs):
super().__init__(**kwargs)
self.element_type = element_type
self.rows = rows
self.cols = cols
def __repr__(self):
return f"MatrixType(element_type={self.element_type}, rows={self.rows}, cols={self.cols})"
[docs]
class ArrayType(TypeNode):
"""Array types with static or dynamic sizing."""
def __init__(
self,
element_type: TypeNode,
size: Optional[Union[int, ASTNode]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.element_type = element_type
self.size = size # None for dynamic arrays, int or expression for static
def __repr__(self):
return f"ArrayType(element_type={self.element_type}, size={self.size})"
[docs]
class PointerType(TypeNode):
"""Pointer types for languages that support them."""
def __init__(self, pointee_type: TypeNode, is_mutable: bool = True, **kwargs):
super().__init__(**kwargs)
self.pointee_type = pointee_type
self.is_mutable = is_mutable
def __repr__(self):
return f"PointerType(pointee_type={self.pointee_type}, is_mutable={self.is_mutable})"
[docs]
class ReferenceType(TypeNode):
"""Reference types for languages like Rust."""
def __init__(self, referenced_type: TypeNode, is_mutable: bool = False, **kwargs):
super().__init__(**kwargs)
self.referenced_type = referenced_type
self.is_mutable = is_mutable
def __repr__(self):
return f"ReferenceType(referenced_type={self.referenced_type}, is_mutable={self.is_mutable})"
[docs]
class FunctionType(TypeNode):
"""Function pointer/reference types."""
def __init__(self, return_type: TypeNode, param_types: List[TypeNode], **kwargs):
super().__init__(**kwargs)
self.return_type = return_type
self.param_types = param_types
def __repr__(self):
return f"FunctionType(return_type={self.return_type}, param_types={self.param_types})"
[docs]
class GenericType(TypeNode):
"""Generic/template type parameters."""
def __init__(self, name: str, constraints: List[TypeNode] = None, **kwargs):
super().__init__(**kwargs)
self.name = name
self.constraints = constraints or []
def __repr__(self):
return f"GenericType(name={self.name}, constraints={self.constraints})"
[docs]
class NamedType(TypeNode):
"""User-defined types (structs, enums, etc.)."""
def __init__(self, name: str, generic_args: List[TypeNode] = None, **kwargs):
super().__init__(**kwargs)
self.name = name
self.generic_args = generic_args or []
def __repr__(self):
return f"NamedType(name={self.name}, generic_args={self.generic_args})"
# ============================================================================
# SHADER/PROGRAM STRUCTURE
# ============================================================================
[docs]
class ShaderStage(Enum):
"""Shader pipeline stages."""
VERTEX = "vertex"
FRAGMENT = "fragment"
GEOMETRY = "geometry"
TASK = "task"
AMPLIFICATION = "amplification"
OBJECT = "object"
MESH = "mesh"
TESSELLATION_CONTROL = "tessellation_control"
TESSELLATION_EVALUATION = "tessellation_evaluation"
COMPUTE = "compute"
RAY_GENERATION = "ray_generation"
RAY_INTERSECTION = "ray_intersection"
RAY_CLOSEST_HIT = "ray_closest_hit"
RAY_MISS = "ray_miss"
RAY_ANY_HIT = "ray_any_hit"
RAY_CALLABLE = "ray_callable"
[docs]
class ExecutionModel(Enum):
"""Different execution models supported."""
GRAPHICS_PIPELINE = "graphics_pipeline"
COMPUTE_KERNEL = "compute_kernel"
RAY_TRACING = "ray_tracing"
GENERAL_PURPOSE = "general_purpose"
[docs]
class ShaderNode(ASTNode):
"""Root node representing a complete shader program."""
def __init__(
self,
name: str,
execution_model: ExecutionModel,
stages: Dict[ShaderStage, "StageNode"] = None,
structs: List["StructNode"] = None,
functions: List["FunctionNode"] = None,
global_variables: List["VariableNode"] = None,
constants: List["ConstantNode"] = None,
cbuffers: List["StructNode"] = None,
imports: List["ImportNode"] = None,
preprocessors: List["PreprocessorNode"] = None,
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.execution_model = execution_model
self.stages = stages or {}
self.structs = structs or []
self.functions = functions or []
self.global_variables = global_variables or []
self.constants = constants or []
if cbuffers is not None:
self.cbuffers = cbuffers
self.imports = imports or []
self.preprocessors = preprocessors or []
def __repr__(self):
return f"ShaderNode(name={self.name}, execution_model={self.execution_model})"
[docs]
class StageNode(ASTNode):
"""Individual shader stage (vertex, fragment, compute, etc.)."""
def __init__(
self,
stage: ShaderStage,
entry_point: "FunctionNode",
local_variables: List["VariableNode"] = None,
local_functions: List["FunctionNode"] = None,
execution_config: Dict[str, Any] = None,
**kwargs,
):
super().__init__(**kwargs)
self.stage = stage
self.entry_point = entry_point
self.local_variables = local_variables or []
self.local_functions = local_functions or []
self.execution_config = execution_config or {}
def __repr__(self):
return f"StageNode(stage={self.stage}, entry_point={self.entry_point.name})"
[docs]
class ImportNode(ASTNode):
"""Import/include statements."""
def __init__(
self, path: str, alias: Optional[str] = None, items: List[str] = None, **kwargs
):
super().__init__(**kwargs)
self.path = path
self.alias = alias
self.items = items
def __repr__(self):
return f"ImportNode(path={self.path}, alias={self.alias}, items={self.items})"
[docs]
class PreprocessorNode(ASTNode):
"""Preprocessor directives (e.g. #version, #include)."""
def __init__(self, directive: str, content: str = "", **kwargs):
super().__init__(**kwargs)
self.directive = directive
self.content = content
def __repr__(self):
return f"PreprocessorNode(directive={self.directive}, content={self.content})"
# ============================================================================
# DECLARATIONS
# ============================================================================
[docs]
class StructNode(ASTNode):
"""Struct/class declarations."""
def __init__(
self,
name: str,
members: List["StructMemberNode"],
generic_params: List["GenericParameterNode"] = None,
attributes: List["AttributeNode"] = None,
inheritance: List[NamedType] = None,
visibility: str = "public",
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.members = members
self.generic_params = generic_params or []
self.attributes = attributes or []
self.inheritance = inheritance or []
self.visibility = visibility
def __repr__(self):
return f"StructNode(name={self.name}, members={len(self.members)})"
[docs]
class StructMemberNode(ASTNode):
"""Individual struct member."""
def __init__(
self,
name: str,
member_type: TypeNode,
default_value: Optional["ExpressionNode"] = None,
attributes: List["AttributeNode"] = None,
visibility: str = "public",
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.member_type = member_type
self.default_value = default_value
self.attributes = attributes or []
self.visibility = visibility
def __repr__(self):
return f"StructMemberNode(name={self.name}, member_type={self.member_type})"
[docs]
class EnumNode(ASTNode):
"""Enumeration declarations."""
def __init__(
self,
name: str,
variants: List["EnumVariantNode"],
underlying_type: Optional[TypeNode] = None,
attributes: List["AttributeNode"] = None,
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.variants = variants
self.underlying_type = underlying_type
self.attributes = attributes or []
def __repr__(self):
return f"EnumNode(name={self.name}, variants={len(self.variants)})"
[docs]
class EnumVariantNode(ASTNode):
"""Individual enum variant."""
def __init__(
self,
name: str,
value: Optional["ExpressionNode"] = None,
fields: List[TypeNode] = None,
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.value = value
self.fields = fields or []
def __repr__(self):
return f"EnumVariantNode(name={self.name})"
[docs]
class FunctionNode(ASTNode):
"""Function declarations."""
def __init__(
self,
name: str,
return_type: TypeNode,
parameters: List["ParameterNode"],
body: Optional["BlockNode"] = None,
generic_params: List["GenericParameterNode"] = None,
attributes: List["AttributeNode"] = None,
visibility: str = "public",
qualifiers: List[str] = None,
is_unsafe: bool = False,
is_async: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.return_type = return_type
self.parameters = parameters
self.body = body
self.generic_params = generic_params or []
self.attributes = attributes or []
self.visibility = visibility
self.qualifiers = qualifiers or []
self.is_unsafe = is_unsafe
self.is_async = is_async
def __repr__(self):
return f"FunctionNode(name={self.name}, return_type={self.return_type})"
[docs]
class ParameterNode(ASTNode):
"""Function parameter."""
def __init__(
self,
name: str,
param_type: TypeNode,
default_value: Optional["ExpressionNode"] = None,
attributes: List["AttributeNode"] = None,
is_mutable: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.param_type = param_type
self.default_value = default_value
self.attributes = attributes or []
self.is_mutable = is_mutable
def __repr__(self):
return f"ParameterNode(name={self.name}, param_type={self.param_type})"
[docs]
class VariableNode(ASTNode):
"""Variable declarations."""
def __init__(
self,
name: str,
var_type: TypeNode,
initial_value: Optional["ExpressionNode"] = None,
attributes: List["AttributeNode"] = None,
qualifiers: List[str] = None,
is_mutable: bool = True,
visibility: str = "private",
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.var_type = var_type
self.initial_value = initial_value
self.attributes = attributes or []
self.qualifiers = qualifiers or []
self.is_mutable = is_mutable
self.visibility = visibility
# Legacy aliases
self.vtype = var_type
self.semantic = self.get_semantic_from_attributes()
[docs]
def get_semantic_from_attributes(self):
"""Return the legacy semantic name derived from variable attributes."""
for attr in self.attributes:
if attr.name in ["position", "color", "texcoord", "normal"]:
return attr.name
return None
def __repr__(self):
return f"VariableNode(name={self.name}, var_type={self.var_type})"
[docs]
class ConstantNode(ASTNode):
"""Compile-time constants."""
def __init__(
self,
name: str,
const_type: TypeNode,
value: "ExpressionNode",
visibility: str = "public",
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.const_type = const_type
self.value = value
self.visibility = visibility
def __repr__(self):
return f"ConstantNode(name={self.name}, const_type={self.const_type})"
[docs]
class GenericParameterNode(ASTNode):
"""Generic/template parameter."""
def __init__(
self,
name: str,
constraints: List[TypeNode] = None,
default_type: Optional[TypeNode] = None,
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.constraints = constraints or []
self.default_type = default_type
def __repr__(self):
return f"GenericParameterNode(name={self.name})"
[docs]
class AttributeNode(ASTNode):
"""Attributes/annotations/decorators."""
def __init__(self, name: str, arguments: List["ExpressionNode"] = None, **kwargs):
super().__init__(**kwargs)
self.name = name
self.arguments = arguments or []
def __repr__(self):
return f"AttributeNode(name={self.name})"
# ============================================================================
# STATEMENTS
# ============================================================================
[docs]
class StatementNode(ASTNode):
"""Base class for all statements."""
[docs]
class BlockNode(StatementNode):
"""Block of statements."""
def __init__(self, statements: List[StatementNode], **kwargs):
super().__init__(**kwargs)
self.statements = statements
def __repr__(self):
return f"BlockNode(statements={len(self.statements)})"
[docs]
class ExpressionStatementNode(StatementNode):
"""Expression used as a statement."""
def __init__(self, expression: "ExpressionNode", **kwargs):
super().__init__(**kwargs)
self.expression = expression
def __repr__(self):
return f"ExpressionStatementNode(expression={self.expression})"
[docs]
class AssignmentNode(StatementNode):
"""Assignment operations."""
def __init__(
self,
target: "ExpressionNode",
value: "ExpressionNode",
operator: str = "=",
**kwargs,
):
super().__init__(**kwargs)
self.target = target
self.value = value
self.operator = operator
# Legacy aliases
self.left = target
self.right = value
def __repr__(self):
return f"AssignmentNode(target={self.target}, operator={self.operator}, value={self.value})"
[docs]
class IfNode(StatementNode):
"""Conditional statements."""
def __init__(
self,
condition: "ExpressionNode",
then_branch: StatementNode,
else_branch: Optional[StatementNode] = None,
**kwargs,
):
super().__init__(**kwargs)
self.condition = condition
self.then_branch = then_branch
self.else_branch = else_branch
# Legacy aliases
self.if_condition = condition
self.if_body = then_branch
self.else_if_conditions = []
self.else_if_bodies = []
self.else_body = else_branch
def __repr__(self):
return f"IfNode(condition={self.condition})"
[docs]
class ForNode(StatementNode):
"""For loop statements."""
def __init__(
self,
init: Optional[StatementNode],
condition: Optional["ExpressionNode"],
update: Optional["ExpressionNode"],
body: StatementNode,
**kwargs,
):
super().__init__(**kwargs)
self.init = init
self.condition = condition
self.update = update
self.body = body
def __repr__(self):
return f"ForNode(condition={self.condition})"
[docs]
class ForInNode(StatementNode):
"""For-in loop (Rust, Python style)."""
def __init__(
self, pattern: str, iterable: "ExpressionNode", body: StatementNode, **kwargs
):
super().__init__(**kwargs)
self.pattern = pattern
self.iterable = iterable
self.body = body
def __repr__(self):
return f"ForInNode(pattern={self.pattern})"
[docs]
class WhileNode(StatementNode):
"""While loop statements."""
def __init__(self, condition: "ExpressionNode", body: StatementNode, **kwargs):
super().__init__(**kwargs)
self.condition = condition
self.body = body
def __repr__(self):
return f"WhileNode(condition={self.condition})"
[docs]
class LoopNode(StatementNode):
"""Infinite loop (Rust style)."""
def __init__(self, body: StatementNode, label: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.body = body
self.label = label
def __repr__(self):
return f"LoopNode(label={self.label})"
[docs]
class MatchNode(StatementNode):
"""Pattern matching (Rust, functional languages)."""
def __init__(
self, expression: "ExpressionNode", arms: List["MatchArmNode"], **kwargs
):
super().__init__(**kwargs)
self.expression = expression
self.arms = arms
def __repr__(self):
return f"MatchNode(arms={len(self.arms)})"
[docs]
class MatchArmNode(ASTNode):
"""Pattern matching arm."""
def __init__(
self,
pattern: "PatternNode",
guard: Optional["ExpressionNode"],
body: StatementNode,
**kwargs,
):
super().__init__(**kwargs)
self.pattern = pattern
self.guard = guard
self.body = body
def __repr__(self):
return f"MatchArmNode(pattern={self.pattern})"
[docs]
class SwitchNode(StatementNode):
"""Switch statements."""
def __init__(
self,
expression: "ExpressionNode",
cases: List["CaseNode"],
default_case: Optional[StatementNode] = None,
**kwargs,
):
super().__init__(**kwargs)
self.expression = expression
self.cases = cases
self.default_case = default_case
def __repr__(self):
return f"SwitchNode(cases={len(self.cases)})"
[docs]
class CaseNode(ASTNode):
"""Switch case."""
def __init__(
self, value: "ExpressionNode", statements: List[StatementNode], **kwargs
):
super().__init__(**kwargs)
self.value = value
self.statements = statements
def __repr__(self):
return f"CaseNode(value={self.value})"
[docs]
class ReturnNode(StatementNode):
"""Return statements."""
def __init__(self, value: Optional["ExpressionNode"] = None, **kwargs):
super().__init__(**kwargs)
self.value = value
def __repr__(self):
return f"ReturnNode(value={self.value})"
[docs]
class BreakNode(StatementNode):
"""Break statements."""
def __init__(self, label: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.label = label
def __repr__(self):
return f"BreakNode(label={self.label})"
[docs]
class ContinueNode(StatementNode):
"""Continue statements."""
def __init__(self, label: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.label = label
def __repr__(self):
return f"ContinueNode(label={self.label})"
# ============================================================================
# EXPRESSIONS
# ============================================================================
[docs]
class ExpressionNode(ASTNode):
"""Base class for all expressions."""
def __init__(self, expression_type: Optional[TypeNode] = None, **kwargs):
super().__init__(**kwargs)
self.expression_type = expression_type
# These legacy aliases are read by old code generators.
self.vtype = expression_type
self.name = getattr(self, "identifier", None)
self.semantic = None
[docs]
class LiteralNode(ExpressionNode):
"""Literal values."""
def __init__(self, value: Any, literal_type: TypeNode, **kwargs):
super().__init__(literal_type, **kwargs)
self.value = value
self.literal_type = literal_type
def __repr__(self):
return f"LiteralNode(value={self.value}, literal_type={self.literal_type})"
[docs]
class IdentifierNode(ExpressionNode):
"""Variable/function identifiers."""
def __init__(self, name: str, **kwargs):
super().__init__(**kwargs)
self.identifier = name
self.name = name
def __repr__(self):
return f"IdentifierNode(name={self.identifier})"
[docs]
class RangeNode(ExpressionNode):
"""Integer range expression."""
def __init__(
self,
start: ExpressionNode,
end: ExpressionNode,
inclusive: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.start = start
self.end = end
self.inclusive = inclusive
def __repr__(self):
operator = "..=" if self.inclusive else ".."
return f"RangeNode(start={self.start}, operator={operator}, end={self.end})"
[docs]
class BinaryOpNode(ExpressionNode):
"""Binary operations."""
def __init__(
self, left: ExpressionNode, operator: str, right: ExpressionNode, **kwargs
):
super().__init__(**kwargs)
self.left = left
self.operator = operator
self.right = right
self.op = operator
def __repr__(self):
return f"BinaryOpNode(left={self.left}, operator={self.operator}, right={self.right})"
[docs]
class UnaryOpNode(ExpressionNode):
"""Unary operations."""
def __init__(
self, operator: str, operand: ExpressionNode, is_postfix: bool = False, **kwargs
):
super().__init__(**kwargs)
self.operator = operator
self.operand = operand
self.is_postfix = is_postfix
self.op = operator
def __repr__(self):
return f"UnaryOpNode(operator={self.operator}, operand={self.operand}, is_postfix={self.is_postfix})"
[docs]
class TernaryOpNode(ExpressionNode):
"""Ternary conditional operator."""
def __init__(
self,
condition: ExpressionNode,
true_expr: ExpressionNode,
false_expr: ExpressionNode,
**kwargs,
):
super().__init__(**kwargs)
self.condition = condition
self.true_expr = true_expr
self.false_expr = false_expr
def __repr__(self):
return f"TernaryOpNode(condition={self.condition}, true_expr={self.true_expr}, false_expr={self.false_expr})"
[docs]
class FunctionCallNode(ExpressionNode):
"""Function calls."""
def __init__(
self,
function: ExpressionNode,
arguments: List[ExpressionNode],
generic_args: List[TypeNode] = None,
**kwargs,
):
super().__init__(**kwargs)
self.function = function
self.arguments = arguments
self.generic_args = generic_args or []
# These legacy aliases are read by old code generators.
self.name = function
self.args = arguments
def __repr__(self):
return f"FunctionCallNode(function={self.function}, arguments={len(self.arguments)})"
[docs]
class MemberAccessNode(ExpressionNode):
"""Member access (dot operator)."""
def __init__(self, object_expr: ExpressionNode, member: str, **kwargs):
super().__init__(**kwargs)
self.object_expr = object_expr
self.member = member
self.object = object_expr
def __repr__(self):
return f"MemberAccessNode(object={self.object_expr}, member={self.member})"
[docs]
class PointerAccessNode(ExpressionNode):
"""Pointer member access (arrow operator)."""
def __init__(self, pointer_expr: ExpressionNode, member: str, **kwargs):
super().__init__(**kwargs)
self.pointer_expr = pointer_expr
self.member = member
def __repr__(self):
return f"PointerAccessNode(pointer={self.pointer_expr}, member={self.member})"
[docs]
class ArrayAccessNode(ExpressionNode):
"""Array indexing."""
def __init__(
self, array_expr: ExpressionNode, index_expr: ExpressionNode, **kwargs
):
super().__init__(**kwargs)
self.array_expr = array_expr
self.index_expr = index_expr
self.array = array_expr
self.index = index_expr
def __repr__(self):
return f"ArrayAccessNode(array={self.array_expr}, index={self.index_expr})"
[docs]
class ArrayLiteralNode(ExpressionNode):
"""Array literal expression such as {1, 2, 3}."""
def __init__(self, elements: List[ExpressionNode], **kwargs):
super().__init__(**kwargs)
self.elements = elements
def __repr__(self):
return f"ArrayLiteralNode(elements={len(self.elements)})"
[docs]
class SwizzleNode(ExpressionNode):
"""Vector swizzling (vec.xyz, vec.xxyy, etc.)."""
def __init__(self, vector_expr: ExpressionNode, components: str, **kwargs):
super().__init__(**kwargs)
self.vector_expr = vector_expr
self.components = components
def __repr__(self):
return f"SwizzleNode(vector={self.vector_expr}, components={self.components})"
[docs]
class CastNode(ExpressionNode):
"""Type casting."""
def __init__(self, expression: ExpressionNode, target_type: TypeNode, **kwargs):
super().__init__(target_type, **kwargs)
self.expression = expression
self.target_type = target_type
def __repr__(self):
return f"CastNode(expression={self.expression}, target_type={self.target_type})"
[docs]
class ConstructorNode(ExpressionNode):
"""Type constructors (vec3(1,2,3), MyStruct{field: value})."""
def __init__(
self,
constructor_type: TypeNode,
arguments: List[ExpressionNode],
named_arguments: Dict[str, ExpressionNode] = None,
**kwargs,
):
super().__init__(constructor_type, **kwargs)
self.constructor_type = constructor_type
self.arguments = arguments
self.named_arguments = named_arguments or {}
def __repr__(self):
return f"ConstructorNode(constructor_type={self.constructor_type}, arguments={len(self.arguments)})"
[docs]
class LambdaNode(ExpressionNode):
"""Lambda/closure expressions."""
def __init__(
self,
parameters: List[ParameterNode],
body: Union[ExpressionNode, BlockNode],
captures: List[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.parameters = parameters
self.body = body
self.captures = captures or []
def __repr__(self):
return f"LambdaNode(parameters={len(self.parameters)})"
# ============================================================================
# PATTERN MATCHING
# ============================================================================
[docs]
class PatternNode(ASTNode):
"""Base class for patterns in pattern matching."""
[docs]
class WildcardPatternNode(PatternNode):
"""Wildcard pattern (_)."""
def __repr__(self):
return "WildcardPatternNode()"
[docs]
class IdentifierPatternNode(PatternNode):
"""Identifier pattern (variable binding)."""
def __init__(self, name: str, **kwargs):
super().__init__(**kwargs)
self.name = name
def __repr__(self):
return f"IdentifierPatternNode(name={self.name})"
[docs]
class LiteralPatternNode(PatternNode):
"""Literal pattern."""
def __init__(self, literal: LiteralNode, **kwargs):
super().__init__(**kwargs)
self.literal = literal
def __repr__(self):
return f"LiteralPatternNode(literal={self.literal})"
[docs]
class StructPatternNode(PatternNode):
"""Struct destructuring pattern."""
def __init__(
self, type_name: str, field_patterns: Dict[str, PatternNode], **kwargs
):
super().__init__(**kwargs)
self.type_name = type_name
self.field_patterns = field_patterns
def __repr__(self):
return f"StructPatternNode(type_name={self.type_name})"
# ============================================================================
# GPU/GRAPHICS SPECIFIC NODES
# ============================================================================
[docs]
class TextureNode(ExpressionNode):
"""Texture sampling operations."""
def __init__(
self,
texture_expr: ExpressionNode,
sampler_expr: ExpressionNode,
coordinates: ExpressionNode,
level: Optional[ExpressionNode] = None,
offset: Optional[ExpressionNode] = None,
**kwargs,
):
super().__init__(**kwargs)
self.texture_expr = texture_expr
self.sampler_expr = sampler_expr
self.coordinates = coordinates
self.level = level
self.offset = offset
def __repr__(self):
return (
f"TextureNode(texture={self.texture_expr}, coordinates={self.coordinates})"
)
[docs]
class TextureOpNode(ExpressionNode):
"""Extended texture operations (Sample, Load, Gather, etc.)."""
def __init__(
self,
operation: str,
texture_expr: ExpressionNode,
arguments: List[ExpressionNode],
sampler_expr: Optional[ExpressionNode] = None,
**kwargs,
):
super().__init__(**kwargs)
self.operation = operation
self.texture_expr = texture_expr
self.sampler_expr = sampler_expr
self.arguments = arguments
def __repr__(self):
return f"TextureOpNode(operation={self.operation}, texture={self.texture_expr})"
[docs]
class AtomicOpNode(ExpressionNode):
"""Atomic operations for GPU computing."""
def __init__(
self,
operation: str,
target: ExpressionNode,
arguments: List[ExpressionNode],
**kwargs,
):
super().__init__(**kwargs)
self.operation = operation
self.target = target
self.arguments = arguments
def __repr__(self):
return f"AtomicOpNode(operation={self.operation}, target={self.target})"
[docs]
class SyncNode(StatementNode):
"""Synchronization operations."""
def __init__(
self, sync_type: str, arguments: List[ExpressionNode] = None, **kwargs
):
super().__init__(**kwargs)
self.sync_type = sync_type
self.arguments = arguments or []
def __repr__(self):
return f"SyncNode(sync_type={self.sync_type})"
[docs]
class BuiltinVariableNode(ExpressionNode):
"""Built-in variables (gl_Position, threadIdx, etc.)."""
def __init__(self, builtin_name: str, component: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.builtin_name = builtin_name
self.component = component
def __repr__(self):
return f"BuiltinVariableNode(builtin_name={self.builtin_name}, component={self.component})"
# ============================================================================
# MEMORY AND RESOURCE MANAGEMENT
# ============================================================================
[docs]
class BufferNode(ASTNode):
"""Buffer resource declarations."""
def __init__(
self,
name: str,
buffer_type: TypeNode,
binding: Optional[int] = None,
set_: Optional[int] = None,
access: str = "read_write",
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.buffer_type = buffer_type
self.binding = binding
self.set = set_
self.access = access
def __repr__(self):
return f"BufferNode(name={self.name}, buffer_type={self.buffer_type})"
[docs]
class TextureResourceNode(ASTNode):
"""Texture resource declarations."""
def __init__(
self,
name: str,
texture_type: str,
format: Optional[str] = None,
binding: Optional[int] = None,
set_: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.texture_type = texture_type
self.format = format
self.binding = binding
self.set = set_
def __repr__(self):
return (
f"TextureResourceNode(name={self.name}, texture_type={self.texture_type})"
)
[docs]
class SamplerNode(ASTNode):
"""Sampler resource declarations."""
def __init__(
self,
name: str,
filter_mode: str = "linear",
address_mode: str = "clamp",
binding: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.name = name
self.filter_mode = filter_mode
self.address_mode = address_mode
self.binding = binding
def __repr__(self):
return f"SamplerNode(name={self.name})"
[docs]
class BufferOpNode(ExpressionNode):
"""Buffer operations like Load/Store/Append/Consume."""
def __init__(
self,
operation: str,
buffer_expr: ExpressionNode,
arguments: List[ExpressionNode],
**kwargs,
):
super().__init__(**kwargs)
self.operation = operation
self.buffer_expr = buffer_expr
self.arguments = arguments
def __repr__(self):
return f"BufferOpNode(operation={self.operation}, buffer={self.buffer_expr})"
[docs]
class WaveOpNode(ExpressionNode):
"""Wave/subgroup operations."""
def __init__(self, operation: str, arguments: List[ExpressionNode], **kwargs):
super().__init__(**kwargs)
self.operation = operation
self.arguments = arguments
def __repr__(self):
return f"WaveOpNode(operation={self.operation})"
[docs]
class RayTracingOpNode(ExpressionNode):
"""Raytracing intrinsics like TraceRay, ReportHit, etc."""
def __init__(self, operation: str, arguments: List[ExpressionNode], **kwargs):
super().__init__(**kwargs)
self.operation = operation
self.arguments = arguments
def __repr__(self):
return f"RayTracingOpNode(operation={self.operation})"
[docs]
class RayQueryOpNode(ExpressionNode):
"""RayQuery method calls."""
def __init__(
self,
operation: str,
query_expr: ExpressionNode,
arguments: List[ExpressionNode],
**kwargs,
):
super().__init__(**kwargs)
self.operation = operation
self.query_expr = query_expr
self.arguments = arguments
def __repr__(self):
return f"RayQueryOpNode(operation={self.operation}, query={self.query_expr})"
[docs]
class MeshOpNode(ExpressionNode):
"""Mesh/task shader intrinsics (SetMeshOutputCounts, DispatchMesh)."""
def __init__(self, operation: str, arguments: List[ExpressionNode], **kwargs):
super().__init__(**kwargs)
self.operation = operation
self.arguments = arguments
def __repr__(self):
return f"MeshOpNode(operation={self.operation})"
# ============================================================================
# LEGACY COMPATIBILITY HELPERS
# ============================================================================
CbufferNode = StructNode
VectorConstructorNode = ConstructorNode
[docs]
def create_legacy_shader_node(structs, functions, global_variables, cbuffers):
"""Create a root shader node for legacy tests and backend adapters."""
return ShaderNode(
name="LegacyShader",
execution_model=ExecutionModel.GRAPHICS_PIPELINE,
structs=structs or [],
functions=functions or [],
global_variables=global_variables or [],
constants=cbuffers or [],
)
[docs]
class ArrayNode(VariableNode):
"""Legacy array node for backward compatibility."""
def __init__(self, element_type, name, size=None, semantic=None, **kwargs):
"""Initialize an array variable using the older ArrayNode shape."""
array_type = ArrayType(element_type, size)
super().__init__(name, array_type, **kwargs)
self.element_type = element_type
self.size = size
if semantic:
self.attributes.append(AttributeNode(semantic))
self.vtype = f"{element_type}[]" if size is None else f"{element_type}[{size}]"