Source code for crosstl.translator.codegen.SPIRV_codegen

"""CrossGL-to-Vulkan SPIR-V code generator."""

import re
from typing import List, Optional, Tuple, Union

from .array_utils import parse_array_type, detect_array_element_type
from ..ast import (
    AssignmentNode,
    ArrayAccessNode,
    ArrayLiteralNode,
    ArrayNode,
    BinaryOpNode,
    BreakNode,
    ContinueNode,
    ForNode,
    FunctionCallNode,
    IdentifierNode,
    IfNode,
    LiteralNode,
    MemberAccessNode,
    ReturnNode,
    ShaderNode,
    StructNode,
    TernaryOpNode,
    UnaryOpNode,
    VariableNode,
    WhileNode,
)


class SpirvType:
    """Represents a SPIR-V type with storage class information."""

    def __init__(self, base_type: str, storage_class: Optional[str] = None):
        """Store the base type name and optional storage class."""
        self.base_type = base_type
        self.storage_class = storage_class

    def __str__(self) -> str:
        """Return a readable type label for debug output."""
        if self.storage_class:
            return f"{self.base_type} ({self.storage_class})"
        return self.base_type


class SpirvId:
    """Represents a SPIR-V ID with its associated type."""

    def __init__(
        self, id_value: int, spirv_type: SpirvType, name: Optional[str] = None
    ):
        """Store the numeric result id, type metadata, and optional name."""
        self.id = id_value
        self.type = spirv_type
        self.name = name

    def __str__(self) -> str:
        """Return a readable SPIR-V id label for debug output."""
        if self.name:
            return f"%{self.id} ({self.name}: {self.type})"
        return f"%{self.id} ({self.type})"


[docs] class VulkanSPIRVCodeGen: """Generates SPIR-V code from a CrossGL shader AST.""" def __init__(self): """Initialize an empty SPIR-V module-generation state.""" self.reset_generation_state() def reset_generation_state(self): """Reset per-module SPIR-V ids, declarations, and symbol caches.""" self.next_id = 1 self.code_lines = [] self.decorations = [] self.required_extensions = set() self.primitive_types = {} self.vector_types = {} self.matrix_types = {} self.struct_types = {} self.pointer_types = {} self.function_types = {} self.array_types = {} self.resource_types = {} self.resource_image_types = {} self.required_capabilities = set() self.global_variables = {} self.local_variables = {} self.variable_value_types = {} self.value_types = {} self.constants = {} self.vector_constants = {} self.composite_constants = {} self.resource_type_metadata = {} self.functions = {} self.function_signatures = {} self.function_resource_array_params = {} self.function_resource_array_type_hints = {} self.function_execution_models = {} self.current_execution_model = None self.current_return_type = None self.glsl_std450_id = None self.main_fn_id = None self.requires_compute_derivatives = False self.current_label = None self.loop_merge_labels = [] self.loop_continue_labels = [] self.defined_functions = set() self.current_struct_members = {} self.inputs = [] self.outputs = [] self.uniform_buffers = [] self.next_input_location = 0 self.next_output_location = 0 self.next_resource_binding = 0 self.is_vertex_shader = False self.bound_id = 0 def get_id(self) -> int: """Get the next available SPIR-V ID.""" id_value = self.next_id self.next_id += 1 if id_value > self.bound_id: self.bound_id = id_value return id_value def emit(self, instruction: str): """Add a SPIR-V instruction to the code.""" self.code_lines.append(instruction) def require_capability(self, capability: str): """Request a SPIR-V capability for instructions emitted later.""" if capability != "Shader": self.required_capabilities.add(capability) def require_extension(self, extension: str): """Request a SPIR-V extension for instructions emitted later.""" self.required_extensions.add(extension) def require_compute_derivatives(self): """Enable compute shader derivatives for derivative-dependent image ops.""" self.requires_compute_derivatives = True self.require_capability("ComputeDerivativeGroupQuadsKHR") self.require_extension("SPV_KHR_compute_shader_derivatives") def register_primitive_type(self, name: str) -> SpirvId: """Create and register a primitive type.""" name = self.normalize_primitive_name(name) if name in self.primitive_types: return self.primitive_types[name] id_value = self.get_id() if name == "void": self.emit(f"%{id_value} = OpTypeVoid") elif name == "bool": self.emit(f"%{id_value} = OpTypeBool") elif name == "float": self.emit(f"%{id_value} = OpTypeFloat 32") elif name == "double": self.require_capability("Float64") self.emit(f"%{id_value} = OpTypeFloat 64") elif name == "int": self.emit(f"%{id_value} = OpTypeInt 32 1") elif name == "uint": self.emit(f"%{id_value} = OpTypeInt 32 0") spirv_type = SpirvType(name) spirv_id = SpirvId(id_value, spirv_type, name) self.primitive_types[name] = spirv_id return spirv_id def register_vector_type(self, component_type: SpirvId, count: int) -> SpirvId: """Create and register a vector type.""" key = (component_type.id, count) if key in self.vector_types: return self.vector_types[key] id_value = self.get_id() self.emit(f"%{id_value} = OpTypeVector %{component_type.id} {count}") type_name = f"v{count}{component_type.type.base_type}" spirv_type = SpirvType(type_name) spirv_id = SpirvId(id_value, spirv_type, type_name) self.vector_types[key] = spirv_id return spirv_id def register_matrix_type(self, column_type: SpirvId, count: int) -> SpirvId: """Create and register a matrix type.""" key = (column_type.id, count) if key in self.matrix_types: return self.matrix_types[key] id_value = self.get_id() self.emit(f"%{id_value} = OpTypeMatrix %{column_type.id} {count}") type_name = f"mat{count}x{column_type.type.base_type[1]}" spirv_type = SpirvType(type_name) spirv_id = SpirvId(id_value, spirv_type, type_name) self.matrix_types[key] = spirv_id return spirv_id def register_pointer_type( self, pointed_type: SpirvId, storage_class: str ) -> SpirvId: """Create and register a pointer type.""" key = (pointed_type.id, storage_class) if key in self.pointer_types: return self.pointer_types[key] id_value = self.get_id() self.emit(f"%{id_value} = OpTypePointer {storage_class} %{pointed_type.id}") spirv_type = SpirvType(f"ptr_{pointed_type.type.base_type}", storage_class) spirv_id = SpirvId(id_value, spirv_type) self.pointer_types[key] = spirv_id return spirv_id def register_image_type( self, type_name: str, component_type: SpirvId, dim: str, depth: int, arrayed: int, multisampled: int, sampled: int, image_format: str = "Unknown", ) -> SpirvId: key = ( component_type.id, dim, depth, arrayed, multisampled, sampled, image_format, ) if key in self.resource_image_types: return self.resource_image_types[key] if sampled == 2 and multisampled: self.require_capability("StorageImageMultisample") if arrayed: self.require_capability("ImageMSArray") id_value = self.get_id() self.emit( f"%{id_value} = OpTypeImage %{component_type.id} {dim} " f"{depth} {arrayed} {multisampled} {sampled} {image_format}" ) spirv_type = SpirvType(type_name) spirv_id = SpirvId(id_value, spirv_type, type_name) self.resource_image_types[key] = spirv_id return spirv_id def register_resource_type( self, type_name: str, image_format: Optional[str] = None ) -> SpirvId: if image_format is None and type_name in self.resource_types: return self.resource_types[type_name] info = self.resource_type_info(type_name) if info is None: raise ValueError(f"Unknown SPIR-V resource type {type_name}") info = dict(info) source_format = None if info["kind"] == "storage_image" and image_format: spirv_format = self.spirv_image_format_name(image_format) if spirv_format: source_format = str(image_format).lower() info["format"] = spirv_format info["component_type"] = self.image_format_component_type(image_format) cache_key = ( type_name, info.get("kind"), info.get("component_type"), info.get("format"), ) if cache_key in self.resource_types: return self.resource_types[cache_key] if info["kind"] == "sampler": id_value = self.get_id() self.emit(f"%{id_value} = OpTypeSampler") spirv_id = SpirvId(id_value, SpirvType(type_name), type_name) else: component_type = self.register_primitive_type(info["component_type"]) image_type = self.register_image_type( f"{type_name}_image", component_type, info["dim"], info["depth"], info["arrayed"], info["multisampled"], info["sampled"], info["format"], ) if info["kind"] == "sampled_image": id_value = self.get_id() self.emit(f"%{id_value} = OpTypeSampledImage %{image_type.id}") spirv_id = SpirvId(id_value, SpirvType(type_name), type_name) else: spirv_id = image_type spirv_id.type = SpirvType(type_name) spirv_id.name = type_name metadata = dict(info) metadata["type_name"] = type_name metadata["source_format"] = source_format metadata["image_type_id"] = image_type.id metadata["component_count"] = self.image_format_component_count( source_format ) self.resource_type_metadata[image_type.id] = metadata if info["kind"] == "sampler": self.resource_type_metadata[spirv_id.id] = { "kind": "sampler", "type_name": type_name, "component_type": "float", "component_count": 0, } else: self.resource_type_metadata[spirv_id.id] = metadata self.resource_types[cache_key] = spirv_id if image_format is None: self.resource_types[type_name] = spirv_id return spirv_id def register_struct_type( self, name: str, members: List[Tuple[SpirvId, str]] ) -> SpirvId: """Create and register a struct type.""" if name in self.struct_types: return self.struct_types[name] id_value = self.get_id() member_types = " ".join([f"%{member[0].id}" for member in members]) self.emit(f"%{id_value} = OpTypeStruct {member_types}") self.emit(f'OpName %{id_value} "{name}"') for i, (_, member_name) in enumerate(members): self.emit(f'OpMemberName %{id_value} {i} "{member_name}"') spirv_type = SpirvType(name) spirv_id = SpirvId(id_value, spirv_type, name) self.struct_types[name] = spirv_id self.current_struct_members[name] = members return spirv_id def register_function_type( self, return_type: SpirvId, param_types: List[SpirvId] ) -> SpirvId: """Create and register a function type.""" key = (return_type.id, tuple(p.id for p in param_types)) if key in self.function_types: return self.function_types[key] id_value = self.get_id() params = " ".join([f"%{param.id}" for param in param_types]) if params: self.emit(f"%{id_value} = OpTypeFunction %{return_type.id} {params}") else: self.emit(f"%{id_value} = OpTypeFunction %{return_type.id}") spirv_type = SpirvType(f"fn_{return_type.type.base_type}") spirv_id = SpirvId(id_value, spirv_type) self.function_types[key] = spirv_id return spirv_id def register_constant( self, value: Union[bool, int, float], type_id: SpirvId ) -> SpirvId: """Create and register a constant value.""" key = (value, type_id.id) if key in self.constants: return self.constants[key] id_value = self.get_id() type_name = type_id.type.base_type if type_name == "bool": opcode = "OpConstantTrue" if value else "OpConstantFalse" self.emit(f"%{id_value} = {opcode} %{type_id.id}") else: constant_value = str(value) self.emit(f"%{id_value} = OpConstant %{type_id.id} {constant_value}") spirv_id = SpirvId(id_value, type_id.type, f"{type_name}_{value}") self.value_types[id_value] = type_id self.constants[key] = spirv_id return spirv_id def register_vector_constant( self, vector_type: SpirvId, components: List[SpirvId] ) -> SpirvId: """Create and register a composite vector constant.""" key = (vector_type.id, tuple(c.id for c in components)) if key in self.vector_constants: return self.vector_constants[key] id_value = self.get_id() component_list = " ".join([f"%{component.id}" for component in components]) self.emit( f"%{id_value} = OpConstantComposite %{vector_type.id} {component_list}" ) spirv_id = SpirvId(id_value, vector_type.type) self.value_types[id_value] = vector_type self.vector_constants[key] = spirv_id return spirv_id def register_composite_constant( self, composite_type: SpirvId, components: List[SpirvId] ) -> SpirvId: """Create and register a composite constant.""" key = (composite_type.id, tuple(component.id for component in components)) if key in self.composite_constants: return self.composite_constants[key] id_value = self.get_id() component_list = " ".join(f"%{component.id}" for component in components) self.emit( f"%{id_value} = OpConstantComposite %{composite_type.id} {component_list}" ) spirv_id = SpirvId(id_value, composite_type.type) self.value_types[id_value] = composite_type self.composite_constants[key] = spirv_id return spirv_id def is_constant_instruction(self, value_id: SpirvId) -> bool: return ( any(constant.id == value_id.id for constant in self.constants.values()) or any( constant.id == value_id.id for constant in self.vector_constants.values() ) or any( constant.id == value_id.id for constant in self.composite_constants.values() ) ) def image_offset_operand(self, offset_id: SpirvId) -> str: if self.is_constant_instruction(offset_id): return f"ConstOffset %{offset_id.id}" self.require_capability("ImageGatherExtended") return f"Offset %{offset_id.id}" def image_operands(self, *operands: str) -> str: masks = [] values = [] for operand in operands: if not operand: continue parts = operand.split() masks.append(parts[0]) values.extend(parts[1:]) if not masks: return "" return " ".join(["|".join(masks)] + values) def create_variable( self, type_id: SpirvId, storage_class: str, name: Optional[str] = None, initializer: Optional[SpirvId] = None, ) -> SpirvId: """Create a new variable.""" pointer_type = self.register_pointer_type(type_id, storage_class) id_value = self.get_id() initializer_operand = f" %{initializer.id}" if initializer is not None else "" self.emit( f"%{id_value} = OpVariable %{pointer_type.id} " f"{storage_class}{initializer_operand}" ) if name: self.emit(f'OpName %{id_value} "{name}"') spirv_id = SpirvId(id_value, pointer_type.type, name) self.variable_value_types[id_value] = type_id return spirv_id def store_to_variable(self, variable_id: SpirvId, value_id: SpirvId): """Store a value to a variable.""" self.emit(f"OpStore %{variable_id.id} %{value_id.id}") def load_from_variable(self, variable_id: SpirvId, result_type: SpirvId) -> SpirvId: """Load a value from a variable.""" id_value = self.get_id() self.emit(f"%{id_value} = OpLoad %{result_type.id} %{variable_id.id}") spirv_id = SpirvId(id_value, result_type.type) self.value_types[id_value] = result_type return spirv_id def access_chain( self, base_id: SpirvId, indices: List[SpirvId], result_type: SpirvId ) -> SpirvId: """Create an access chain to a struct or array member.""" id_value = self.get_id() index_list = " ".join([f"%{index.id}" for index in indices]) self.emit( f"%{id_value} = OpAccessChain %{result_type.id} %{base_id.id} {index_list}" ) spirv_id = SpirvId(id_value, result_type.type) return spirv_id def composite_extract( self, composite: SpirvId, member_type: SpirvId, member_index: int ) -> SpirvId: """Extract a member from a composite value.""" id_value = self.get_id() self.emit( f"%{id_value} = OpCompositeExtract %{member_type.id} " f"%{composite.id} {member_index}" ) spirv_id = SpirvId(id_value, member_type.type) self.value_types[id_value] = member_type return spirv_id def struct_member_info(self, struct_type: str, member_name: str): members = self.current_struct_members.get(struct_type) if not members: return None for index, (member_type, name) in enumerate(members): if name == member_name: return index, member_type return None def struct_type_name_from_pointer(self, pointer: SpirvId): struct_type_id = self.variable_value_types.get(pointer.id) return struct_type_id.type.base_type if struct_type_id else None def create_member_access_pointer( self, base_pointer: SpirvId, member_name: str ) -> Optional[SpirvId]: struct_type = self.struct_type_name_from_pointer(base_pointer) member_info = self.struct_member_info(struct_type, member_name) if member_info is None: return None member_index, member_type = member_info int_type = self.primitive_types["int"] index = self.register_constant(member_index, int_type) storage_class = base_pointer.type.storage_class or "Function" ptr_type = self.register_pointer_type(member_type, storage_class) access = self.access_chain(base_pointer, [index], ptr_type) self.variable_value_types[access.id] = member_type return access def create_function( self, name: str, return_type: SpirvId, param_types: List[SpirvId] ) -> SpirvId: """Create a function declaration.""" function_type = self.register_function_type(return_type, param_types) id_value = self.get_id() self.emit( f"%{id_value} = OpFunction %{return_type.id} None %{function_type.id}" ) spirv_id = SpirvId(id_value, return_type.type, name) self.functions[name] = spirv_id self.function_signatures[name] = (return_type, param_types) return spirv_id def create_function_parameter( self, param_type: SpirvId, name: Optional[str] = None ) -> SpirvId: """Create a function parameter.""" id_value = self.get_id() self.emit(f"%{id_value} = OpFunctionParameter %{param_type.id}") if name: self.emit(f'OpName %{id_value} "{name}"') spirv_id = SpirvId(id_value, param_type.type, name) self.value_types[id_value] = param_type return spirv_id def begin_block(self) -> SpirvId: """Begin a new basic block.""" id_value = self.get_id() self.emit(f"%{id_value} = OpLabel") self.current_label = id_value return SpirvId(id_value, SpirvType("label")) def end_function(self): """End the current function.""" self.emit("OpFunctionEnd") self.current_label = None def binary_operation( self, op: str, result_type: SpirvId, left: SpirvId, right: SpirvId ) -> SpirvId: """Create a binary operation.""" arithmetic_ops = { "+": ("OpFAdd", "OpIAdd", "OpIAdd"), "-": ("OpFSub", "OpISub", "OpISub"), "*": ("OpFMul", "OpIMul", "OpIMul"), "MULTIPLY": ("OpFMul", "OpIMul", "OpIMul"), "/": ("OpFDiv", "OpSDiv", "OpUDiv"), "%": ("OpFMod", "OpSMod", "OpUMod"), } comparison_ops = { "==": ("OpFOrdEqual", "OpIEqual", "OpIEqual"), "!=": ("OpFOrdNotEqual", "OpINotEqual", "OpINotEqual"), "<": ("OpFOrdLessThan", "OpSLessThan", "OpULessThan"), ">": ("OpFOrdGreaterThan", "OpSGreaterThan", "OpUGreaterThan"), "<=": ("OpFOrdLessThanEqual", "OpSLessThanEqual", "OpULessThanEqual"), ">=": ( "OpFOrdGreaterThanEqual", "OpSGreaterThanEqual", "OpUGreaterThanEqual", ), } if op in {"&&", "||"}: result_type = self.register_primitive_type("bool") spv_op = "OpLogicalAnd" if op == "&&" else "OpLogicalOr" elif op in comparison_ops: result_type = self.register_primitive_type("bool") float_op, signed_op, unsigned_op = comparison_ops[op] spv_op = ( unsigned_op if self.is_unsigned_type(left.type) else signed_op if self.is_integer_type(left.type) else float_op ) elif op in arithmetic_ops: result_type, left, right = self.align_binary_arithmetic_operands( result_type, left, right ) float_op, signed_op, unsigned_op = arithmetic_ops[op] component_type = self.scalar_or_vector_component_type(left.type) spv_op = ( unsigned_op if component_type == "uint" else signed_op if component_type == "int" else float_op ) else: spv_op = { "&": "OpBitwiseAnd", "|": "OpBitwiseOr", "^": "OpBitwiseXor", "<<": "OpShiftLeftLogical", ">>": "OpShiftRightLogical", }.get(op, f"Op{op}") id_value = self.get_id() self.emit(f"%{id_value} = {spv_op} %{result_type.id} %{left.id} %{right.id}") spirv_id = SpirvId(id_value, result_type.type) self.value_types[id_value] = result_type return spirv_id def align_binary_arithmetic_operands( self, result_type: SpirvId, left: SpirvId, right: SpirvId ) -> Tuple[SpirvId, SpirvId, SpirvId]: left_vector = self.vector_component_type_and_count(left.type.base_type) right_vector = self.vector_component_type_and_count(right.type.base_type) if left_vector is not None and right_vector is None: vector_type = self.ensure_registered_type(left.type) if self.scalar_or_vector_component_type(right.type) == left_vector[0]: return ( vector_type, left, self.splat_scalar_to_vector(right, vector_type), ) if right_vector is not None and left_vector is None: vector_type = self.ensure_registered_type(right.type) if self.scalar_or_vector_component_type(left.type) == right_vector[0]: return ( vector_type, self.splat_scalar_to_vector(left, vector_type), right, ) return result_type, left, right def splat_scalar_to_vector( self, scalar_id: SpirvId, vector_type: SpirvId ) -> SpirvId: vector_info = self.vector_component_type_and_count(vector_type.type.base_type) if vector_info is None: return scalar_id _, component_count = vector_info id_value = self.get_id() component_list = " ".join(f"%{scalar_id.id}" for _ in range(component_count)) self.emit( f"%{id_value} = OpCompositeConstruct %{vector_type.id} {component_list}" ) self.value_types[id_value] = vector_type return SpirvId(id_value, vector_type.type) def is_integer_type(self, spirv_type: SpirvType) -> bool: return spirv_type.base_type in {"int", "uint"} def is_unsigned_type(self, spirv_type: SpirvType) -> bool: return spirv_type.base_type == "uint" def unary_operation( self, op: str, result_type: Union[SpirvId, SpirvType], operand: SpirvId ) -> SpirvId: """Create a unary operation.""" result_type = self.ensure_registered_type(result_type) id_value = self.get_id() if op == "+": spv_op = None elif op == "-": component_type = self.scalar_or_vector_component_type(result_type.type) spv_op = "OpSNegate" if component_type in {"int", "uint"} else "OpFNegate" else: spv_op = { "!": "OpLogicalNot", "~": "OpNot", }.get(op) if spv_op is None: return operand self.emit(f"%{id_value} = {spv_op} %{result_type.id} %{operand.id}") spirv_id = SpirvId(id_value, result_type.type) self.value_types[id_value] = result_type return spirv_id def select_operation( self, result_type: SpirvId, condition: SpirvId, true_value: SpirvId, false_value: SpirvId, ) -> SpirvId: """Create a SPIR-V select operation for ternary expressions.""" if not self.can_use_select_operation(result_type, condition): return self.select_composite_operation( result_type, condition, true_value, false_value ) id_value = self.get_id() self.emit( f"%{id_value} = OpSelect %{result_type.id} %{condition.id} " f"%{true_value.id} %{false_value.id}" ) spirv_id = SpirvId(id_value, result_type.type) self.value_types[id_value] = result_type return spirv_id def is_select_result_type(self, result_type: SpirvId) -> bool: """Return whether OpSelect can directly produce this result type.""" base_type = result_type.type.base_type if base_type in {"bool", "int", "uint", "float", "double"}: return True return any( vector_type.type.base_type == base_type for vector_type in self.vector_types.values() ) def can_use_select_operation( self, result_type: SpirvId, condition: SpirvId ) -> bool: if not self.is_select_result_type(result_type): return False result_vector = self.vector_component_type_and_count(result_type.type.base_type) condition_vector = self.vector_component_type_and_count( condition.type.base_type ) if result_vector is not None: return ( condition_vector is not None and condition_vector[0] == "bool" and condition_vector[1] == result_vector[1] ) return condition.type.base_type == "bool" def select_composite_operation( self, result_type: SpirvId, condition: SpirvId, true_value: SpirvId, false_value: SpirvId, ) -> SpirvId: """Select composite values through control flow instead of OpSelect.""" result_variable = self.create_variable(result_type, "Function") merge_label = SpirvId(self.get_id(), SpirvType("label")) then_label = SpirvId(self.get_id(), SpirvType("label")) else_label = SpirvId(self.get_id(), SpirvType("label")) self.create_selection_merge(merge_label) self.create_conditional_branch(condition, then_label, else_label) self.emit(f"%{then_label.id} = OpLabel") self.current_label = then_label.id self.store_to_variable(result_variable, true_value) if not self.current_block_has_terminator(): self.create_branch(merge_label) self.emit(f"%{else_label.id} = OpLabel") self.current_label = else_label.id self.store_to_variable(result_variable, false_value) if not self.current_block_has_terminator(): self.create_branch(merge_label) self.emit(f"%{merge_label.id} = OpLabel") self.current_label = merge_label.id return self.load_from_variable(result_variable, result_type) def call_function( self, function_name: str, args: List[SpirvId] ) -> Optional[SpirvId]: """Call a function with arguments.""" if function_name not in self.functions: # Handle built-in function return self.call_builtin_function(function_name, args) function_id = self.functions[function_name] return_type, _ = self.function_signatures[function_name] id_value = self.get_id() arg_list = " ".join([f"%{arg.id}" for arg in args]) self.emit( f"%{id_value} = OpFunctionCall %{return_type.id} %{function_id.id} {arg_list}" ) spirv_id = SpirvId(id_value, return_type.type) self.value_types[id_value] = return_type return spirv_id def resource_function_names(self): return { "imageLoad", "imageStore", "texture", "texture2D", "textureCube", "textureCompare", "textureCompareLod", "textureCompareGrad", "textureCompareOffset", "textureGatherCompare", "textureGatherCompareOffset", "textureLod", "textureGrad", "textureOffset", "textureGather", "textureGatherOffset", "textureGatherOffsets", "texelFetch", "textureSize", "imageSize", "textureSamples", "imageSamples", "textureQueryLevels", "textureQueryLod", } def resource_query_size_result_type(self, metadata) -> SpirvId: dim = metadata.get("dim", "2D") if metadata else "2D" component_count = { "1D": 1, "Buffer": 1, "2D": 2, "Rect": 2, "Cube": 2, "3D": 3, }.get(dim, 2) if metadata and metadata.get("arrayed"): component_count += 1 component_count = min(max(component_count, 1), 4) int_type = self.register_primitive_type("int") if component_count <= 1: return int_type return self.register_vector_type(int_type, component_count) def resource_query_lod_coordinate_components(self, metadata) -> int: dim = metadata.get("dim", "2D") if metadata else "2D" return { "1D": 1, "Buffer": 1, "2D": 2, "Rect": 2, "3D": 3, "Cube": 3, }.get(dim, 2) def trim_image_query_lod_coordinate(self, coord_id: SpirvId, metadata) -> SpirvId: required_count = self.resource_query_lod_coordinate_components(metadata) vector_info = self.vector_component_type_and_count(coord_id.type.base_type) if vector_info is None: return coord_id component_type_name, source_count = vector_info if source_count <= required_count: return coord_id component_type = self.register_primitive_type(component_type_name) components = [ self.composite_extract(coord_id, component_type, index) for index in range(required_count) ] if required_count <= 1: return components[0] result_type = self.register_vector_type(component_type, required_count) id_value = self.get_id() component_list = " ".join(f"%{component.id}" for component in components) self.emit( f"%{id_value} = OpCompositeConstruct %{result_type.id} {component_list}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) def texture_gather_offsets_arguments( self, extra_args: List[SpirvId] ) -> Tuple[List[SpirvId], Optional[SpirvId]]: if len(extra_args) >= 4: component_id = extra_args[4] if len(extra_args) >= 5 else None return extra_args[:4], component_id component_id = extra_args[1] if len(extra_args) >= 2 else None if not extra_args: return [], component_id offsets_value = extra_args[0] offsets_type = self.value_types.get(offsets_value.id) element_type = self.array_element_type_from_type(offsets_type) if element_type is not None: return [ self.composite_extract(offsets_value, element_type, index) for index in range(4) ], component_id return [offsets_value] * 4, component_id def emit_image_gather( self, sampled_image_id: SpirvId, coord_id: SpirvId, component_id: SpirvId, result_type: SpirvId, offset_id: Optional[SpirvId] = None, ) -> SpirvId: image_operands = "" if offset_id is not None: image_operands = f" {self.image_offset_operand(offset_id)}" id_value = self.get_id() self.emit( f"%{id_value} = OpImageGather %{result_type.id} " f"%{sampled_image_id.id} %{coord_id.id} %{component_id.id}" f"{image_operands}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) def emit_texture_gather_offsets( self, sampled_image_id: SpirvId, coord_id: SpirvId, extra_args: List[SpirvId], metadata, int_type: SpirvId, ) -> SpirvId: offsets, component_id = self.texture_gather_offsets_arguments(extra_args) if len(offsets) != 4: self.emit("; WARNING: textureGatherOffsets requires four offset operands") return self.register_constant(0.0, self.register_primitive_type("float")) if component_id is None: component_id = self.register_constant(0, int_type) result_type = self.resource_access_result_type(metadata) component_type = self.register_primitive_type( metadata.get("component_type", "float") ) gathered_components = [] for index, offset_id in enumerate(offsets): gathered = self.emit_image_gather( sampled_image_id, coord_id, component_id, result_type, offset_id, ) gathered_components.append( self.composite_extract(gathered, component_type, index) ) id_value = self.get_id() component_list = " ".join( f"%{component.id}" for component in gathered_components ) self.emit( f"%{id_value} = OpCompositeConstruct %{result_type.id} " f"{component_list}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) def extract_image_from_sampled_image( self, sampled_image_id: SpirvId, metadata ) -> Optional[SpirvId]: image_type_id = metadata.get("image_type_id") if metadata else None image_type = ( self.find_registered_type_by_id(image_type_id) if image_type_id is not None else None ) if image_type is None: self.emit("; WARNING: Could not determine image type for sampled image") return None id_value = self.get_id() self.emit(f"%{id_value} = OpImage %{image_type.id} %{sampled_image_id.id}") self.value_types[id_value] = image_type return SpirvId(id_value, image_type.type) def image_operand_for_query( self, resource_id: SpirvId, metadata ) -> Optional[SpirvId]: if not metadata: return None if metadata.get("kind") == "sampled_image": return self.extract_image_from_sampled_image(resource_id, metadata) if metadata.get("kind") == "storage_image": return resource_id return None def shadow_compare_operands( self, function_name: str, args: List[SpirvId], extra_arg_count: int ): coord_index = 1 if len(args) > 1: sampler_metadata = self.resource_metadata_for_value(args[1]) if sampler_metadata and sampler_metadata.get("kind") == "sampler": coord_index = 2 required_arg_count = coord_index + 2 + extra_arg_count if len(args) < required_arg_count: self.emit( f"; WARNING: {function_name} requires a shadow texture, " "coordinate, depth, and operation operands" ) return None sampled_image_id = args[0] coord_id = args[coord_index] depth_id = args[coord_index + 1] extra_args = args[coord_index + 2 : required_arg_count] metadata = self.resource_metadata_for_value(sampled_image_id) if ( not metadata or metadata.get("kind") != "sampled_image" or int(metadata.get("depth", 0)) != 1 ): self.emit( f"; WARNING: {function_name} requires a shadow sampled image operand" ) return None return sampled_image_id, coord_id, depth_id, extra_args def sampled_texture_operands( self, function_name: str, args: List[SpirvId], extra_arg_count: int = 0 ): coord_index = 1 if len(args) > 1: sampler_metadata = self.resource_metadata_for_value(args[1]) if sampler_metadata and sampler_metadata.get("kind") == "sampler": coord_index = 2 required_arg_count = coord_index + 1 + extra_arg_count if len(args) < required_arg_count: self.emit( f"; WARNING: {function_name} requires a texture, coordinate, " "and operation operands" ) return None sampled_image_id = args[0] coord_id = args[coord_index] extra_args = args[coord_index + 1 :] metadata = self.resource_metadata_for_value(sampled_image_id) if not metadata or metadata.get("kind") != "sampled_image": self.emit(f"; WARNING: {function_name} requires a sampled image operand") return None return sampled_image_id, coord_id, extra_args, metadata def requires_explicit_lod_sampling(self) -> bool: return self.current_execution_model in {"GLCompute", "MeshEXT", "TaskEXT"} def default_lod_operand(self) -> str: lod_id = self.register_constant(0.0, self.register_primitive_type("float")) return f"Lod %{lod_id.id}" def call_resource_function( self, function_name: str, args: List[SpirvId] ) -> Optional[SpirvId]: if function_name == "imageLoad": if len(args) < 2: self.emit("; WARNING: imageLoad requires image and coordinate operands") return self.register_constant( 0.0, self.register_primitive_type("float") ) image_id, coord_id = args[0], args[1] metadata = self.resource_metadata_for_value(image_id) if not metadata or metadata.get("kind") != "storage_image": self.emit("; WARNING: imageLoad requires a storage image operand") return self.register_constant( 0.0, self.register_primitive_type("float") ) image_operands = "" if metadata.get("multisampled"): if len(args) < 3: self.emit("; WARNING: imageLoad requires a sample operand") return self.register_constant( 0.0, self.register_primitive_type("float") ) image_operands = f" Sample %{args[2].id}" result_type = self.resource_access_result_type(metadata) id_value = self.get_id() self.emit( f"%{id_value} = OpImageRead %{result_type.id} " f"%{image_id.id} %{coord_id.id}{image_operands}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) if function_name == "imageStore": if len(args) < 3: self.emit( "; WARNING: imageStore requires image, coordinate, and value operands" ) return None image_id, coord_id, texel_id = args[0], args[1], args[2] metadata = self.resource_metadata_for_value(image_id) if not metadata or metadata.get("kind") != "storage_image": self.emit("; WARNING: imageStore requires a storage image operand") return None image_operands = "" if metadata.get("multisampled"): if len(args) < 4: self.emit("; WARNING: imageStore requires a sample operand") return None sample_id, texel_id = args[2], args[3] image_operands = f" Sample %{sample_id.id}" self.emit( f"OpImageWrite %{image_id.id} %{coord_id.id} %{texel_id.id}" f"{image_operands}" ) return None if function_name in {"texture", "texture2D", "textureCube"}: sample_args = self.sampled_texture_operands(function_name, args) if sample_args is None: return self.register_constant( 0.0, self.register_primitive_type("float") ) sampled_image_id, coord_id, _, metadata = sample_args result_type = self.resource_access_result_type(metadata) id_value = self.get_id() if self.requires_explicit_lod_sampling(): self.emit( f"%{id_value} = OpImageSampleExplicitLod %{result_type.id} " f"%{sampled_image_id.id} %{coord_id.id} " f"{self.default_lod_operand()}" ) else: self.emit( f"%{id_value} = OpImageSampleImplicitLod %{result_type.id} " f"%{sampled_image_id.id} %{coord_id.id}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) if function_name in { "textureCompare", "textureCompareLod", "textureCompareGrad", "textureCompareOffset", }: extra_arg_count = { "textureCompare": 0, "textureCompareLod": 1, "textureCompareGrad": 2, "textureCompareOffset": 1, }[function_name] compare_args = self.shadow_compare_operands( function_name, args, extra_arg_count ) if compare_args is None: return self.register_constant( 0.0, self.register_primitive_type("float") ) sampled_image_id, coord_id, depth_id, extra_args = compare_args result_type = self.register_primitive_type("float") id_value = self.get_id() if function_name == "textureCompare": if self.requires_explicit_lod_sampling(): self.emit( f"%{id_value} = OpImageSampleDrefExplicitLod " f"%{result_type.id} %{sampled_image_id.id} " f"%{coord_id.id} %{depth_id.id} {self.default_lod_operand()}" ) else: self.emit( f"%{id_value} = OpImageSampleDrefImplicitLod " f"%{result_type.id} %{sampled_image_id.id} " f"%{coord_id.id} %{depth_id.id}" ) elif function_name == "textureCompareOffset": offset_operand = self.image_offset_operand(extra_args[0]) if self.requires_explicit_lod_sampling(): self.emit( f"%{id_value} = OpImageSampleDrefExplicitLod " f"%{result_type.id} %{sampled_image_id.id} " f"%{coord_id.id} %{depth_id.id} " f"{self.image_operands(self.default_lod_operand(), offset_operand)}" ) else: self.emit( f"%{id_value} = OpImageSampleDrefImplicitLod " f"%{result_type.id} %{sampled_image_id.id} " f"%{coord_id.id} %{depth_id.id} {offset_operand}" ) elif function_name == "textureCompareLod": self.emit( f"%{id_value} = OpImageSampleDrefExplicitLod %{result_type.id} " f"%{sampled_image_id.id} %{coord_id.id} %{depth_id.id} " f"Lod %{extra_args[0].id}" ) else: self.emit( f"%{id_value} = OpImageSampleDrefExplicitLod %{result_type.id} " f"%{sampled_image_id.id} %{coord_id.id} %{depth_id.id} " f"Grad %{extra_args[0].id} %{extra_args[1].id}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) if function_name in {"textureGatherCompare", "textureGatherCompareOffset"}: extra_arg_count = 1 if function_name == "textureGatherCompareOffset" else 0 compare_args = self.shadow_compare_operands( function_name, args, extra_arg_count ) if compare_args is None: return self.register_constant( 0.0, self.register_primitive_type("float") ) sampled_image_id, coord_id, depth_id, extra_args = compare_args float_type = self.register_primitive_type("float") result_type = self.register_vector_type(float_type, 4) id_value = self.get_id() image_operands = ( f" {self.image_offset_operand(extra_args[0])}" if function_name == "textureGatherCompareOffset" else "" ) self.emit( f"%{id_value} = OpImageDrefGather %{result_type.id} " f"%{sampled_image_id.id} %{coord_id.id} %{depth_id.id}" f"{image_operands}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) if function_name == "textureOffset": sample_args = self.sampled_texture_operands(function_name, args, 1) if sample_args is None: return self.register_constant( 0.0, self.register_primitive_type("float") ) sampled_image_id, coord_id, extra_args, metadata = sample_args offset_id = extra_args[0] result_type = self.resource_access_result_type(metadata) id_value = self.get_id() offset_operand = self.image_offset_operand(offset_id) if self.requires_explicit_lod_sampling(): self.emit( f"%{id_value} = OpImageSampleExplicitLod %{result_type.id} " f"%{sampled_image_id.id} %{coord_id.id} " f"{self.image_operands(self.default_lod_operand(), offset_operand)}" ) else: self.emit( f"%{id_value} = OpImageSampleImplicitLod %{result_type.id} " f"%{sampled_image_id.id} %{coord_id.id} {offset_operand}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) if function_name in { "textureGather", "textureGatherOffset", "textureGatherOffsets", }: required_extra_count = 0 if function_name == "textureGather" else 1 sample_args = self.sampled_texture_operands( function_name, args, required_extra_count ) if sample_args is None: return self.register_constant( 0.0, self.register_primitive_type("float") ) sampled_image_id, coord_id, extra_args, metadata = sample_args int_type = self.register_primitive_type("int") if function_name == "textureGather": component_id = ( extra_args[0] if extra_args else self.register_constant(0, int_type) ) offset_id = None elif function_name == "textureGatherOffsets": return self.emit_texture_gather_offsets( sampled_image_id, coord_id, extra_args, metadata, int_type, ) else: offset_id = extra_args[0] component_id = ( extra_args[1] if len(extra_args) >= 2 else self.register_constant(0, int_type) ) result_type = self.resource_access_result_type(metadata) return self.emit_image_gather( sampled_image_id, coord_id, component_id, result_type, offset_id, ) if function_name == "texelFetch": sample_args = self.sampled_texture_operands(function_name, args, 1) if sample_args is None: return self.register_constant( 0.0, self.register_primitive_type("float") ) sampled_image_id, coord_id, extra_args, metadata = sample_args operand_id = extra_args[0] image_id = self.extract_image_from_sampled_image(sampled_image_id, metadata) if image_id is None: return self.register_constant( 0.0, self.register_primitive_type("float") ) result_type = self.resource_access_result_type(metadata) id_value = self.get_id() image_operand = "Sample" if metadata.get("multisampled") else "Lod" self.emit( f"%{id_value} = OpImageFetch %{result_type.id} " f"%{image_id.id} %{coord_id.id} {image_operand} %{operand_id.id}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) if function_name in {"textureSize", "imageSize"}: if not args: self.emit(f"; WARNING: {function_name} requires an image operand") return self.register_constant(0, self.register_primitive_type("int")) resource_id = args[0] metadata = self.resource_metadata_for_value(resource_id) expected_kind = ( "sampled_image" if function_name == "textureSize" else "storage_image" ) if not metadata or metadata.get("kind") != expected_kind: self.emit( f"; WARNING: {function_name} requires a {expected_kind} operand" ) return self.register_constant(0, self.register_primitive_type("int")) image_id = self.image_operand_for_query(resource_id, metadata) if image_id is None: return self.register_constant(0, self.register_primitive_type("int")) result_type = self.resource_query_size_result_type(metadata) id_value = self.get_id() self.require_capability("ImageQuery") if ( function_name == "textureSize" and len(args) >= 2 and not metadata.get("multisampled") ): self.emit( f"%{id_value} = OpImageQuerySizeLod %{result_type.id} " f"%{image_id.id} %{args[1].id}" ) else: self.emit( f"%{id_value} = OpImageQuerySize %{result_type.id} " f"%{image_id.id}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) if function_name in {"textureSamples", "imageSamples"}: if not args: self.emit(f"; WARNING: {function_name} requires an image operand") return self.register_constant(0, self.register_primitive_type("int")) resource_id = args[0] metadata = self.resource_metadata_for_value(resource_id) expected_kind = ( "sampled_image" if function_name == "textureSamples" else "storage_image" ) if not metadata or metadata.get("kind") != expected_kind: self.emit( f"; WARNING: {function_name} requires a {expected_kind} operand" ) return self.register_constant(0, self.register_primitive_type("int")) if metadata.get("dim") != "2D" or not metadata.get("multisampled"): self.emit(f"; WARNING: {function_name} requires a multisample 2D image") return self.register_constant(0, self.register_primitive_type("int")) image_id = self.image_operand_for_query(resource_id, metadata) if image_id is None: return self.register_constant(0, self.register_primitive_type("int")) result_type = self.register_primitive_type("int") id_value = self.get_id() self.require_capability("ImageQuery") self.emit( f"%{id_value} = OpImageQuerySamples %{result_type.id} " f"%{image_id.id}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) if function_name == "textureQueryLevels": if not args: self.emit("; WARNING: textureQueryLevels requires a texture operand") return self.register_constant(0, self.register_primitive_type("int")) sampled_image_id = args[0] metadata = self.resource_metadata_for_value(sampled_image_id) if not metadata or metadata.get("kind") != "sampled_image": self.emit( "; WARNING: textureQueryLevels requires a sampled image operand" ) return self.register_constant(0, self.register_primitive_type("int")) image_id = self.extract_image_from_sampled_image(sampled_image_id, metadata) if image_id is None: return self.register_constant(0, self.register_primitive_type("int")) result_type = self.register_primitive_type("int") id_value = self.get_id() self.require_capability("ImageQuery") self.emit( f"%{id_value} = OpImageQueryLevels %{result_type.id} " f"%{image_id.id}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) if function_name == "textureQueryLod": if len(args) < 2: self.emit( "; WARNING: textureQueryLod requires texture and coordinate operands" ) return self.register_constant( 0.0, self.register_primitive_type("float") ) sampled_image_id = args[0] coord_id = args[1] if len(args) >= 3: sampler_metadata = self.resource_metadata_for_value(args[1]) if sampler_metadata and sampler_metadata.get("kind") == "sampler": coord_id = args[2] metadata = self.resource_metadata_for_value(sampled_image_id) if not metadata or metadata.get("kind") != "sampled_image": self.emit("; WARNING: textureQueryLod requires a sampled image operand") return self.register_constant( 0.0, self.register_primitive_type("float") ) coord_id = self.trim_image_query_lod_coordinate(coord_id, metadata) float_type = self.register_primitive_type("float") result_type = self.register_vector_type(float_type, 2) id_value = self.get_id() self.require_capability("ImageQuery") if self.requires_explicit_lod_sampling(): self.require_compute_derivatives() self.emit( f"%{id_value} = OpImageQueryLod %{result_type.id} " f"%{sampled_image_id.id} %{coord_id.id}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) if function_name in {"textureLod", "textureGrad"}: required_arg_count = 3 if function_name == "textureLod" else 4 extra_arg_count = required_arg_count - 2 sample_args = self.sampled_texture_operands( function_name, args, extra_arg_count ) if sample_args is None: return self.register_constant( 0.0, self.register_primitive_type("float") ) sampled_image_id, coord_id, extra_args, metadata = sample_args if function_name == "textureLod": image_operands = f"Lod %{extra_args[0].id}" else: image_operands = f"Grad %{extra_args[0].id} %{extra_args[1].id}" result_type = self.resource_access_result_type(metadata) id_value = self.get_id() self.emit( f"%{id_value} = OpImageSampleExplicitLod %{result_type.id} " f"%{sampled_image_id.id} %{coord_id.id} {image_operands}" ) self.value_types[id_value] = result_type return SpirvId(id_value, result_type.type) return None def resource_offset_argument_indices(self, function_name: str): return { "textureOffset": {2, 3}, "textureGatherOffset": {2, 3}, "textureGatherOffsets": {2, 3, 4, 5, 6}, "textureCompareOffset": {3, 4}, "textureGatherCompareOffset": {3, 4}, }.get(function_name, set()) def literal_integer_vector_constant(self, expr) -> Optional[SpirvId]: if not isinstance(expr, FunctionCallNode): return None callee_expr = getattr(expr, "function", getattr(expr, "name", None)) if hasattr(callee_expr, "name"): function_name = callee_expr.name elif isinstance(callee_expr, str): function_name = callee_expr else: return None vector_info = self.vector_component_type_and_count(function_name) if vector_info is None: return None component_type_name, component_count = vector_info if component_type_name not in {"int", "uint"}: return None if len(expr.args) != component_count: return None component_values = [ self.literal_integer_value(arg, component_type_name) for arg in expr.args ] if any(value is None for value in component_values): return None component_type = self.register_primitive_type(component_type_name) vector_type = self.register_vector_type(component_type, component_count) components = [ self.register_constant(value, component_type) for value in component_values ] return self.register_vector_constant(vector_type, components) def literal_integer_value(self, expr, component_type_name: str) -> Optional[int]: if isinstance(expr, UnaryOpNode): value = self.literal_integer_value(expr.operand, component_type_name) if value is None: return None if expr.op == "-": value = -value elif expr.op != "+": return None if component_type_name == "uint" and value < 0: return None return value if not isinstance(expr, LiteralNode): return None literal_type = self.normalize_primitive_name( self.convert_type_node_to_string(expr.literal_type) ) if literal_type not in {"int", "uint"}: return None value = int(expr.value) if component_type_name == "uint" and value < 0: return None return value def process_call_argument(self, function_name, arg, arg_index): if arg_index in self.resource_offset_argument_indices(function_name): offset_constant = self.literal_integer_vector_constant(arg) if offset_constant is not None: return offset_constant resource_array_params = self.function_resource_array_params.get( function_name, set() ) if arg_index in resource_array_params: pointer_arg = self.variable_pointer_from_expression(arg) if pointer_arg is not None: return pointer_arg return self.process_expression(arg) def call_builtin_function( self, function_name: str, args: List[SpirvId] ) -> Optional[SpirvId]: """Call a built-in function.""" if self.glsl_std450_id is None: self.glsl_std450_id = self.get_id() self.emit(f'%{self.glsl_std450_id} = OpExtInstImport "GLSL.std.450"') vector_info = self.vector_component_type_and_count(function_name) if vector_info: component_type_name, component_count = vector_info component_type = self.register_primitive_type(component_type_name) vector_type = self.register_vector_type(component_type, component_count) id_value = self.get_id() # If no arguments are provided, construct a default vector if not args: if component_type_name == "bool": zero_value = False one_value = True elif component_type_name in {"int", "uint"}: zero_value = 0 one_value = 1 else: zero_value = 0.0 one_value = 1.0 # Preserve old defaults: first component one, rest zero. component_zero = self.register_constant(zero_value, component_type) component_one = self.register_constant(one_value, component_type) # Create default vector components default_args = [component_zero] * component_count if component_count > 0: default_args[0] = component_one arg_list = " ".join([f"%{arg.id}" for arg in default_args]) elif ( len(args) == 1 and self.vector_component_type_and_count(args[0].type.base_type) is None ): arg_list = " ".join(f"%{args[0].id}" for _ in range(component_count)) else: arg_list = " ".join([f"%{arg.id}" for arg in args]) self.emit( f"%{id_value} = OpCompositeConstruct %{vector_type.id} {arg_list}" ) return SpirvId(id_value, vector_type.type) if function_name in self.struct_types: struct_type = self.struct_types[function_name] id_value = self.get_id() arg_list = " ".join([f"%{arg.id}" for arg in args]) self.emit( f"%{id_value} = OpCompositeConstruct %{struct_type.id} {arg_list}" ) spirv_id = SpirvId(id_value, struct_type.type) self.value_types[id_value] = struct_type return spirv_id # Matrix constructors elif re.match(r"mat(\d)x\d", function_name) or re.match( r"mat\d", function_name ): if "x" in function_name: match = re.match(r"mat(\d)x(\d)", function_name) cols, rows = int(match.group(1)), int(match.group(2)) else: match = re.match(r"mat(\d)", function_name) cols = rows = int(match.group(1)) float_type = self.register_primitive_type("float") vector_type = self.register_vector_type(float_type, rows) matrix_type = self.register_matrix_type(vector_type, cols) id_value = self.get_id() # If no arguments provided, create identity matrix if not args: # Create identity matrix: 1's on diagonal, 0's elsewhere float_zero = self.register_constant(0.0, float_type) float_one = self.register_constant(1.0, float_type) # Create column vectors col_vectors = [] for col in range(cols): col_components = [] for row in range(rows): if col == row: col_components.append(float_one) else: col_components.append(float_zero) col_id = self.get_id() col_args = " ".join([f"%{comp.id}" for comp in col_components]) self.emit( f"%{col_id} = OpCompositeConstruct %{vector_type.id} {col_args}" ) col_vectors.append(SpirvId(col_id, vector_type.type)) arg_list = " ".join([f"%{vec.id}" for vec in col_vectors]) else: arg_list = " ".join([f"%{arg.id}" for arg in args]) self.emit( f"%{id_value} = OpCompositeConstruct %{matrix_type.id} {arg_list}" ) return SpirvId(id_value, matrix_type.type) # Special case for dot product - use OpDot instead of OpExtInst elif function_name == "dot" and len(args) == 2: # Get the result type (always float) float_type = self.register_primitive_type("float") result_type_id = float_type.id # Generate a direct OpDot instruction id_value = self.get_id() self.emit( f"%{id_value} = OpDot %{result_type_id} %{args[0].id} %{args[1].id}" ) return SpirvId(id_value, float_type.type) # GLSL standard library functions else: # Determine result type based on the function name float_type = self.primitive_types["float"] # Default to float if we can't determine or no args provided result_type = float_type.type # Try to infer result type from arguments if available if args: if function_name in [ "sin", "cos", "tan", "asin", "acos", "atan", "sinh", "cosh", "tanh", "exp", "log", "exp2", "log2", "sqrt", "inversesqrt", "abs", "sign", "floor", "ceil", "fract", "trunc", "round", "roundEven", ]: # These functions return the same type as their first argument result_type = args[0].type elif function_name in ["length", "distance"]: # These functions return a float regardless of input result_type = float_type.type elif function_name in ["normalize", "reflect", "refract"]: # These functions return the same vector type as their first argument result_type = args[0].type elif function_name in ["cross"]: # cross product returns a vec3 vector_type = self.register_vector_type(float_type, 3) result_type = vector_type.type id_value = self.get_id() arg_list = " ".join([f"%{arg.id}" for arg in args]) # Use a proper mapping for GLSL.std.450 extended instructions glsl_std450_map = { "sin": "Sin", "cos": "Cos", "tan": "Tan", "asin": "Asin", "acos": "Acos", "atan": "Atan", "sinh": "Sinh", "cosh": "Cosh", "tanh": "Tanh", "exp": "Exp", "log": "Log", "exp2": "Exp2", "log2": "Log2", "sqrt": "Sqrt", "inversesqrt": "InverseSqrt", "abs": "FAbs", "sign": "FSign", "floor": "Floor", "ceil": "Ceil", "fract": "Fract", "trunc": "Trunc", "round": "Round", "roundEven": "RoundEven", "length": "Length", "distance": "Distance", "cross": "Cross", "normalize": "Normalize", "reflect": "Reflect", "refract": "Refract", } glsl_function = glsl_std450_map.get( function_name, function_name[0].upper() + function_name[1:] ) # Find the result type ID result_type_id = None for id_obj in ( [self.primitive_types.get(result_type.base_type)] + list(self.vector_types.values()) + list(self.matrix_types.values()) ): if id_obj and id_obj.type.base_type == result_type.base_type: result_type_id = id_obj.id break if result_type_id is None: result_type_id = float_type.id self.emit( f"%{id_value} = OpExtInst %{result_type_id} %{self.glsl_std450_id} {glsl_function} {arg_list}" ) return SpirvId(id_value, result_type) def create_branch(self, target_label: SpirvId): """Create an unconditional branch.""" self.emit(f"OpBranch %{target_label.id}") def create_conditional_branch( self, condition: SpirvId, true_label: SpirvId, false_label: SpirvId ): """Create a conditional branch.""" self.emit( f"OpBranchConditional %{condition.id} %{true_label.id} %{false_label.id}" ) def create_selection_merge( self, merge_label: SpirvId, selection_control: str = "None" ): """Create a selection merge instruction for if/switch statements.""" self.emit(f"OpSelectionMerge %{merge_label.id} {selection_control}") def create_loop_merge( self, merge_label: SpirvId, continue_label: SpirvId, loop_control: str = "None" ): """Create a loop merge instruction for loops.""" self.emit(f"OpLoopMerge %{merge_label.id} %{continue_label.id} {loop_control}") def create_return(self): """Create a return instruction.""" self.emit("OpReturn") def create_return_value(self, value: SpirvId): """Create a return value instruction.""" self.emit(f"OpReturnValue %{value.id}") def current_block_has_terminator(self) -> bool: """Return whether the current block already ends in a terminator.""" for line in reversed(self.code_lines): stripped = line.strip() if not stripped or stripped.startswith(";"): continue if re.match(r"%\d+ = OpLabel$", stripped): return False return stripped.startswith(("OpBranch", "OpReturn", "OpKill")) return False def normalize_primitive_name(self, type_name: str) -> str: aliases = { "f32": "float", "f64": "double", "i32": "int", "u32": "uint", } return aliases.get(str(type_name), str(type_name)) def normalize_generic_vector_type(self, type_str: str) -> str: compact = re.sub(r"\s+", "", str(type_str)) match = re.fullmatch(r"vec([234])<([^>]+)>", compact) if not match: return compact size, element_type = match.groups() element_type = self.normalize_primitive_name(element_type) prefixes = { "float": "vec", "double": "dvec", "int": "ivec", "uint": "uvec", "bool": "bvec", } return f"{prefixes.get(element_type, 'vec')}{size}" def vector_component_type_and_count( self, type_str: str ) -> Optional[Tuple[str, int]]: type_str = self.normalize_generic_vector_type(type_str) internal_match = re.fullmatch(r"v([234])(float|double|int|uint|bool)", type_str) if internal_match: size, component_type = internal_match.groups() return component_type, int(size) vector_prefixes = ( ("dvec", "double"), ("ivec", "int"), ("uvec", "uint"), ("bvec", "bool"), ("vec", "float"), ) for prefix, component_type in vector_prefixes: if type_str.startswith(prefix) and type_str[len(prefix) :].isdigit(): return component_type, int(type_str[len(prefix) :]) return None def resource_type_info(self, type_str: str): sampler_info = { "sampler": {"kind": "sampler"}, "sampler1D": { "kind": "sampled_image", "component_type": "float", "dim": "1D", "depth": 0, "arrayed": 0, "multisampled": 0, "sampled": 1, "format": "Unknown", }, "sampler2D": { "kind": "sampled_image", "component_type": "float", "dim": "2D", "depth": 0, "arrayed": 0, "multisampled": 0, "sampled": 1, "format": "Unknown", }, "sampler3D": { "kind": "sampled_image", "component_type": "float", "dim": "3D", "depth": 0, "arrayed": 0, "multisampled": 0, "sampled": 1, "format": "Unknown", }, "samplerCube": { "kind": "sampled_image", "component_type": "float", "dim": "Cube", "depth": 0, "arrayed": 0, "multisampled": 0, "sampled": 1, "format": "Unknown", }, "sampler2DArray": { "kind": "sampled_image", "component_type": "float", "dim": "2D", "depth": 0, "arrayed": 1, "multisampled": 0, "sampled": 1, "format": "Unknown", }, "sampler2DShadow": { "kind": "sampled_image", "component_type": "float", "dim": "2D", "depth": 1, "arrayed": 0, "multisampled": 0, "sampled": 1, "format": "Unknown", }, "sampler2DArrayShadow": { "kind": "sampled_image", "component_type": "float", "dim": "2D", "depth": 1, "arrayed": 1, "multisampled": 0, "sampled": 1, "format": "Unknown", }, "samplerCubeShadow": { "kind": "sampled_image", "component_type": "float", "dim": "Cube", "depth": 1, "arrayed": 0, "multisampled": 0, "sampled": 1, "format": "Unknown", }, "samplerCubeArray": { "kind": "sampled_image", "component_type": "float", "dim": "Cube", "depth": 0, "arrayed": 1, "multisampled": 0, "sampled": 1, "format": "Unknown", }, "samplerCubeArrayShadow": { "kind": "sampled_image", "component_type": "float", "dim": "Cube", "depth": 1, "arrayed": 1, "multisampled": 0, "sampled": 1, "format": "Unknown", }, "sampler2DMS": { "kind": "sampled_image", "component_type": "float", "dim": "2D", "depth": 0, "arrayed": 0, "multisampled": 1, "sampled": 1, "format": "Unknown", }, "sampler2DMSArray": { "kind": "sampled_image", "component_type": "float", "dim": "2D", "depth": 0, "arrayed": 1, "multisampled": 1, "sampled": 1, "format": "Unknown", }, } if type_str in sampler_info: return sampler_info[type_str] image_match = re.fullmatch(r"([iu]?image)(2D|3D|Cube)(MS)?(Array)?", type_str) if image_match: prefix, dim, ms_suffix, array_suffix = image_match.groups() if ms_suffix and dim != "2D": return None component_type = { "image": "float", "iimage": "int", "uimage": "uint", }[prefix] return { "kind": "storage_image", "component_type": component_type, "dim": "Cube" if dim == "Cube" else dim, "depth": 0, "arrayed": 1 if array_suffix else 0, "multisampled": 1 if ms_suffix else 0, "sampled": 2, "format": "Unknown", } return None def is_resource_type_name(self, type_str: str) -> bool: return self.resource_type_info(type_str) is not None def spirv_image_format_map(self): return { "r8": "R8", "r8_snorm": "R8Snorm", "r8i": "R8i", "r8ui": "R8ui", "r16": "R16", "r16_snorm": "R16Snorm", "r16f": "R16f", "r16i": "R16i", "r16ui": "R16ui", "r32f": "R32f", "r32i": "R32i", "r32ui": "R32ui", "rg8": "Rg8", "rg8_snorm": "Rg8Snorm", "rg8i": "Rg8i", "rg8ui": "Rg8ui", "rg16": "Rg16", "rg16_snorm": "Rg16Snorm", "rg16f": "Rg16f", "rg16i": "Rg16i", "rg16ui": "Rg16ui", "rg32f": "Rg32f", "rg32i": "Rg32i", "rg32ui": "Rg32ui", "rgba8": "Rgba8", "rgba8_snorm": "Rgba8Snorm", "rgba8i": "Rgba8i", "rgba8ui": "Rgba8ui", "rgba16": "Rgba16", "rgba16_snorm": "Rgba16Snorm", "rgba16f": "Rgba16f", "rgba16i": "Rgba16i", "rgba16ui": "Rgba16ui", "rgba32f": "Rgba32f", "rgba32i": "Rgba32i", "rgba32ui": "Rgba32ui", } def spirv_image_format_name(self, image_format: Optional[str]) -> Optional[str]: if image_format is None: return None return self.spirv_image_format_map().get(str(image_format).lower()) def image_format_component_type(self, image_format: str) -> str: image_format = str(image_format).lower() if image_format.endswith("ui"): return "uint" if image_format.endswith("i") and not image_format.endswith("_snorm"): return "int" return "float" def image_format_component_count(self, image_format: Optional[str]) -> int: if image_format is None: return 4 image_format = str(image_format).lower() if image_format.startswith("rgba"): return 4 if image_format.startswith("rg"): return 2 if image_format.startswith("r"): return 1 return 4 def resource_access_result_type(self, metadata) -> SpirvId: component_type = self.register_primitive_type( metadata.get("component_type", "float") ) component_count = int(metadata.get("component_count", 4)) if component_count <= 1: return component_type return self.register_vector_type(component_type, component_count) def resource_metadata_for_value(self, value_id: SpirvId): result_type = self.value_types.get(value_id.id) if result_type is not None: metadata = self.resource_type_metadata.get(result_type.id) if metadata is not None: return metadata return self.resource_type_metadata.get(value_id.id) def attribute_value_to_string(self, value): if value is None: return None if isinstance(value, str): return value if hasattr(value, "name"): return str(value.name) if hasattr(value, "value"): return str(value.value).strip('"') return str(value) def explicit_image_format(self, node) -> Optional[str]: if not hasattr(node, "attributes"): return None supported_formats = self.spirv_image_format_map() for attr in node.attributes: attr_name = getattr(attr, "name", None) if not attr_name: continue attr_name = str(attr_name).lower() if attr_name in supported_formats: return attr_name if attr_name != "format": continue arguments = getattr(attr, "arguments", []) or [] if not arguments: continue format_name = self.attribute_value_to_string(arguments[0]) if format_name is None: continue format_name = str(format_name).lower() if format_name in supported_formats: return format_name return None def map_resource_type_with_format(self, type_name, node=None) -> SpirvId: if hasattr(type_name, "name") or hasattr(type_name, "element_type"): type_str = self.convert_type_node_to_string(type_name) else: type_str = str(type_name) type_str = self.normalize_generic_vector_type(type_str) explicit_format = self.explicit_image_format(node) if node is not None else None array_type = self.split_outer_array_type(type_str) if array_type is not None: base_type = self.array_base_type_name(type_str) element_type_name, size = array_type if self.is_resource_type_name(base_type): element_type = self.map_resource_type_with_format( element_type_name, node ) return self.register_array_type(element_type, size) if self.is_resource_type_name(type_str): return self.register_resource_type(type_str, explicit_format) return self.map_crossgl_type(type_name) def format_array_size(self, size): if size is None: return None if hasattr(size, "value"): return size.value return size def find_registered_type_by_base(self, base_type: str) -> Optional[SpirvId]: for type_dict in [ self.primitive_types, self.vector_types, self.matrix_types, self.struct_types, self.array_types, self.resource_types, self.resource_image_types, ]: for type_id in type_dict.values(): if type_id.type.base_type == base_type: return type_id return None def find_registered_type_by_id(self, id_value: int) -> Optional[SpirvId]: for type_dict in [ self.primitive_types, self.vector_types, self.matrix_types, self.struct_types, self.array_types, self.resource_types, self.resource_image_types, ]: for type_id in type_dict.values(): if type_id.id == id_value: return type_id return None def ensure_registered_type(self, type_ref: Union[SpirvId, SpirvType]) -> SpirvId: if isinstance(type_ref, SpirvId): return type_ref registered_type = self.find_registered_type_by_base(type_ref.base_type) if registered_type is not None: return registered_type return self.map_crossgl_type(type_ref.base_type) def scalar_or_vector_component_type(self, spirv_type: SpirvType) -> str: vector_info = self.vector_component_type_and_count(spirv_type.base_type) if vector_info is not None: return vector_info[0] return spirv_type.base_type def map_crossgl_type(self, type_name) -> SpirvId: """Map a CrossGL type name to a SPIR-V type ID.""" if hasattr(type_name, "name") or hasattr(type_name, "element_type"): type_str = self.convert_type_node_to_string(type_name) else: type_str = str(type_name) type_str = self.normalize_generic_vector_type(type_str) array_type = self.split_outer_array_type(type_str) if array_type is not None: element_type_name, size = array_type element_type = self.map_crossgl_type(element_type_name) return self.register_array_type(element_type, size) primitive_type = self.normalize_primitive_name(type_str) if primitive_type in {"float", "double", "int", "uint", "bool", "void"}: return self.register_primitive_type(primitive_type) vector_info = self.vector_component_type_and_count(type_str) if vector_info: component_type, size = vector_info component_type_id = self.register_primitive_type(component_type) return self.register_vector_type(component_type_id, size) matrix_match = re.fullmatch(r"(d)?mat([234])(?:x([234]))?", type_str) if matrix_match: is_double, cols, rows = matrix_match.groups() component_type = self.register_primitive_type( "double" if is_double else "float" ) row_count = int(rows or cols) col_count = int(cols) col_type = self.register_vector_type(component_type, row_count) return self.register_matrix_type(col_type, col_count) registered_type = self.find_registered_type_by_base(type_str) if registered_type: return registered_type if self.is_resource_type_name(type_str): return self.register_resource_type(type_str) if type_str in self.struct_types: # Struct type (reference to existing struct) return self.struct_types[type_str] else: # If type is unknown, return a default float type self.emit(f"; WARNING: Unknown type {type_str}, using float as default") return self.register_primitive_type("float") def convert_type_node_to_string(self, type_node) -> str: """Convert new AST TypeNode to string representation.""" if type_node.__class__.__name__ == "ArrayType": element_type = self.convert_type_node_to_string(type_node.element_type) size = self.format_array_size(type_node.size) return ( f"{element_type}[{size}]" if size is not None else f"{element_type}[]" ) if hasattr(type_node, "name"): generic_args = getattr(type_node, "generic_args", []) if generic_args: args = ", ".join( self.convert_type_node_to_string(arg) for arg in generic_args ) return f"{type_node.name}<{args}>" return type_node.name elif hasattr(type_node, "element_type") and hasattr(type_node, "rows"): element_type = self.convert_type_node_to_string(type_node.element_type) prefix = "dmat" if element_type in {"double", "f64"} else "mat" if type_node.rows == type_node.cols: return f"{prefix}{type_node.rows}" return f"{prefix}{type_node.rows}x{type_node.cols}" elif hasattr(type_node, "element_type") and hasattr(type_node, "size"): element_type = self.convert_type_node_to_string(type_node.element_type) size = type_node.size if element_type in {"float", "f32"}: return f"vec{size}" elif element_type in {"int", "i32"}: return f"ivec{size}" elif element_type in {"uint", "u32"}: return f"uvec{size}" elif element_type in {"double", "f64"}: return f"dvec{size}" elif element_type == "bool": return f"bvec{size}" else: return f"{element_type}{size}" else: return str(type_node) def process_crossgl_struct(self, struct_node: StructNode) -> SpirvId: """Process a CrossGL struct definition.""" members = [] for member in struct_node.members: member_type = None member_name = member.name if isinstance(member, ArrayNode): element_type = member.element_type if hasattr(element_type, "name") or hasattr( element_type, "element_type" ): element_type = self.convert_type_node_to_string(element_type) size = self.format_array_size(member.size) member_type = self.map_crossgl_type( f"{element_type}[{size}]" if size is not None else f"{element_type}[]" ) else: member_type_source = getattr( member, "member_type", getattr(member, "var_type", getattr(member, "vtype", None)), ) if member_type_source is not None: member_type = self.map_crossgl_type(member_type_source) if member_type: members.append((member_type, member_name)) return self.register_struct_type(struct_node.name, members) def process_function_node(self, function_node): """Process a CrossGL function definition.""" return_type = self.map_crossgl_type(function_node.return_type) previous_return_type = self.current_return_type self.current_return_type = return_type param_types = [] param_value_types = [] resource_array_param_indices = set() param_type_hints = self.function_resource_array_type_hints.get( function_node.name, {} ) for param in getattr( function_node, "parameters", getattr(function_node, "params", []) ): param_type_source = getattr( param, "param_type", getattr(param, "vtype", None) ) param_name = getattr(param, "name", None) if param_name in param_type_hints: param_type_source = param_type_hints[param_name] if param_type_source is not None: param_type = self.map_resource_type_with_format( param_type_source, param ) else: param_type = self.map_crossgl_type("float") param_value_types.append(param_type) if self.is_resource_array_type(param_type): resource_array_param_indices.add(len(param_types)) param_type = self.register_pointer_type(param_type, "UniformConstant") param_types.append(param_type) function_id = self.create_function(function_node.name, return_type, param_types) self.function_resource_array_params[function_node.name] = ( resource_array_param_indices ) for i, param in enumerate( getattr(function_node, "parameters", getattr(function_node, "params", [])) ): if hasattr(param, "name"): param_name = param.name else: param_name = f"param{i}" param_id = self.create_function_parameter(param_types[i], param_name) self.local_variables[param_name] = param_id if i in resource_array_param_indices: self.variable_value_types[param_id.id] = param_value_types[i] self.begin_block() previous_execution_model = self.current_execution_model if self.current_execution_model is None: execution_models = self.function_execution_models.get( function_node.name, set() ) if "GLCompute" in execution_models: self.current_execution_model = "GLCompute" elif len(execution_models) == 1: self.current_execution_model = next(iter(execution_models)) self.process_statements(function_node.body) if self.convert_type_node_to_string(function_node.return_type) == "void": self.create_return() self.end_function() self.current_execution_model = previous_execution_model self.current_return_type = previous_return_type self.local_variables.clear() return function_id def process_statements(self, statements): """Process a list of CrossGL statements.""" if hasattr(statements, "statements"): stmt_list = statements.statements elif isinstance(statements, list): stmt_list = statements else: stmt_list = [statements] for stmt in stmt_list: if self.current_block_has_terminator(): break self.process_statement(stmt) def process_statement(self, stmt): """Process a single CrossGL statement.""" if isinstance(stmt, AssignmentNode): self.process_assignment(stmt) elif isinstance(stmt, VariableNode): self.process_variable_declaration(stmt) elif isinstance(stmt, ReturnNode): self.process_return(stmt) elif isinstance(stmt, IfNode): self.process_if(stmt) elif isinstance(stmt, ForNode): self.process_for(stmt) elif isinstance(stmt, WhileNode): self.process_while(stmt) elif isinstance(stmt, BreakNode): self.process_break(stmt) elif isinstance(stmt, ContinueNode): self.process_continue(stmt) elif isinstance(stmt, FunctionCallNode): self.process_expression(stmt) # Just evaluate and discard result elif isinstance(stmt, (UnaryOpNode, BinaryOpNode)): self.process_expression(stmt) elif hasattr(stmt, "expression"): expression = stmt.expression if isinstance(expression, AssignmentNode): self.process_assignment(expression) else: self.process_expression(expression) def process_variable_declaration(self, node: VariableNode): """Process a local CrossGL variable declaration.""" var_type_source = getattr(node, "var_type", getattr(node, "vtype", "float")) var_type = self.map_resource_type_with_format(var_type_source, node) var_id = self.create_variable(var_type, "Function", node.name) self.local_variables[node.name] = var_id initial_value = getattr(node, "initial_value", None) if initial_value is not None: if isinstance(initial_value, ArrayLiteralNode): rhs_value = self.process_array_literal(initial_value, var_type) else: rhs_value = self.process_expression(initial_value) if rhs_value is not None: self.store_to_variable(var_id, rhs_value) def process_global_variable_declaration( self, node: VariableNode, default_storage_class: str = "Private" ) -> SpirvId: """Process a module-scope CrossGL variable declaration.""" var_type_source = getattr(node, "var_type", getattr(node, "vtype", "float")) var_type_name = self.type_name_from_value(var_type_source) var_type = self.map_resource_type_with_format(var_type_source, node) storage_class = self.infer_global_storage_class( node, default_storage_class, var_type_name ) initializer = None initial_value = getattr(node, "initial_value", None) if storage_class == "Private" and isinstance(initial_value, ArrayLiteralNode): initializer = self.process_array_literal( initial_value, var_type, constant=True ) if storage_class == "Input": location = self.next_input_location self.next_input_location += 1 var_id = self.register_input(node.name, var_type, location, 0) elif storage_class == "Output": location = self.next_output_location self.next_output_location += 1 var_id = self.register_output(node.name, var_type, location, 0) else: var_id = self.create_variable( var_type, storage_class, node.name, initializer ) if storage_class == "UniformConstant": binding = self.next_resource_binding self.next_resource_binding += 1 self.decorations.append(f"OpDecorate %{var_id.id} DescriptorSet 0") self.decorations.append(f"OpDecorate %{var_id.id} Binding {binding}") self.global_variables[node.name] = var_id return var_id def infer_global_storage_class( self, node: VariableNode, default_storage_class: str, type_name: str = None ) -> str: attribute_names = { getattr(attribute, "name", "").lower() for attribute in getattr(node, "attributes", []) } qualifiers = { str(qualifier).lower() for qualifier in getattr(node, "qualifiers", []) } if attribute_names & {"input", "in"} or qualifiers & {"input", "in"}: return "Input" if attribute_names & {"output", "out"} or qualifiers & {"output", "out"}: return "Output" if type_name: base_type_name, _ = parse_array_type(type_name) if self.is_resource_type_name(base_type_name): return "UniformConstant" return default_storage_class def type_name_from_value(self, type_value) -> str: if hasattr(type_value, "name") or hasattr(type_value, "element_type"): return self.convert_type_node_to_string(type_value) return str(type_value) def collect_ast_functions(self, root): functions = [] visited = set() def walk(value): if value is None or isinstance(value, (str, int, float, bool)): return if isinstance(value, dict): for item in value.values(): walk(item) return if isinstance(value, (list, tuple, set)): for item in value: walk(item) return value_id = id(value) if value_id in visited: return visited.add(value_id) if hasattr(value, "body") and hasattr(value, "parameters"): functions.append(value) if hasattr(value, "__dict__"): for child in vars(value).values(): walk(child) walk(root) return functions def walk_ast_nodes(self, root): visited = set() def walk(value): if value is None or isinstance(value, (str, int, float, bool)): return if isinstance(value, dict): for item in value.values(): yield from walk(item) return if isinstance(value, (list, tuple, set)): for item in value: yield from walk(item) return value_id = id(value) if value_id in visited: return visited.add(value_id) yield value if hasattr(value, "__dict__"): for child in vars(value).values(): yield from walk(child) yield from walk(root) def array_dimensions(self, type_name: str): if not type_name or "[" not in type_name: return None suffix = type_name[type_name.find("[") :] dimensions = [] offset = 0 while offset < len(suffix): if suffix[offset] != "[": return None end = suffix.find("]", offset + 1) if end == -1: return None dimensions.append(suffix[offset + 1 : end]) offset = end + 1 return dimensions def is_unsized_resource_array_type_name(self, type_name: str) -> bool: type_name = self.normalize_generic_vector_type(str(type_name)) array_type = self.split_outer_array_type(type_name) return ( array_type is not None and array_type[1] is None and self.is_resource_type_name(self.array_base_type_name(type_name)) ) def is_fixed_resource_array_type_name(self, type_name: str) -> bool: type_name = self.normalize_generic_vector_type(str(type_name)) array_type = self.split_outer_array_type(type_name) return ( array_type is not None and array_type[1] is not None and self.is_resource_type_name(self.array_base_type_name(type_name)) ) def fixed_type_for_unsized_resource_param(self, declared_type: str, arg_type: str): declared_type = self.normalize_generic_vector_type(str(declared_type)) arg_type = self.normalize_generic_vector_type(str(arg_type)) if not self.is_unsized_resource_array_type_name(declared_type): return None if not self.is_fixed_resource_array_type_name(arg_type): return None if self.array_base_type_name(declared_type) != self.array_base_type_name( arg_type ): return None declared_dimensions = self.array_dimensions(declared_type) arg_dimensions = self.array_dimensions(arg_type) if not declared_dimensions or not arg_dimensions: return None if len(declared_dimensions) != len(arg_dimensions): return None if declared_dimensions[0] != "": return None if declared_dimensions[1:] != arg_dimensions[1:]: return None return arg_type def expression_name(self, expr): if isinstance(expr, str): return expr if hasattr(expr, "name") and isinstance(expr.name, str): return expr.name if isinstance(expr, ArrayAccessNode): array_expr = getattr(expr, "array", getattr(expr, "array_expr", None)) return self.expression_name(array_expr) return None def function_call_name(self, call): callee = getattr(call, "function", getattr(call, "name", None)) if hasattr(callee, "name"): return callee.name if isinstance(callee, str): return callee return None def collect_function_execution_models(self, ast): functions = { getattr(func, "name", None): func for func in self.collect_ast_functions(ast) } functions = {name: func for name, func in functions.items() if name} calls_by_function = {} for function_name, func in functions.items(): calls_by_function[function_name] = { call_name for call_name in ( self.function_call_name(call) for call in self.walk_ast_nodes(getattr(func, "body", [])) if isinstance(call, FunctionCallNode) ) if call_name in functions } execution_models = {function_name: set() for function_name in functions} def mark_callgraph(entry_name: str, execution_model: str): pending = [entry_name] visited = set() while pending: function_name = pending.pop() if function_name in visited or function_name not in functions: continue visited.add(function_name) execution_models[function_name].add(execution_model) pending.extend(calls_by_function.get(function_name, ())) stage_qualifiers = { "vertex", "fragment", "compute", "geometry", "tessellation_control", "tessellation_evaluation", } for func in getattr(ast, "functions", []): qualifier = self.get_function_qualifier(func) if func.name == "main" or qualifier in stage_qualifiers: mark_callgraph(func.name, self.spirv_execution_model(qualifier)) for stage_type, stage in (getattr(ast, "stages", None) or {}).items(): stage_name = self.stage_key(stage_type) execution_model = self.spirv_execution_model(stage_name) entry_function = getattr(stage, "entry_point", None) if entry_function is not None: mark_callgraph(entry_function.name, execution_model) return { function_name: models for function_name, models in execution_models.items() if models } def collect_resource_array_parameter_type_hints(self, ast): functions = { getattr(func, "name", None): func for func in self.collect_ast_functions(ast) } functions = {name: func for name, func in functions.items() if name} global_nodes = list(getattr(ast, "global_variables", []) or []) for stage in (getattr(ast, "stages", None) or {}).values(): global_nodes.extend(getattr(stage, "local_variables", []) or []) global_types = {} for node in self.walk_ast_nodes(global_nodes): if isinstance(node, VariableNode): global_types[node.name] = self.type_name_from_value( getattr(node, "var_type", getattr(node, "vtype", "float")) ) declared_param_types = {} for func_name, func in functions.items(): declared_param_types[func_name] = {} for param in getattr(func, "parameters", getattr(func, "params", [])): param_name = getattr(param, "name", None) param_type = getattr(param, "param_type", getattr(param, "vtype", None)) if param_name and param_type is not None: declared_param_types[func_name][param_name] = ( self.type_name_from_value(param_type) ) hints = {func_name: {} for func_name in functions} def visible_types(func_name): visible = dict(global_types) for param_name, param_type in declared_param_types.get( func_name, {} ).items(): visible[param_name] = hints.get(func_name, {}).get( param_name, param_type ) return visible changed = True while changed: changed = False for caller_name, func in functions.items(): caller_visible_types = visible_types(caller_name) for call in self.walk_ast_nodes(getattr(func, "body", [])): if not isinstance(call, FunctionCallNode): continue callee_name = self.function_call_name(call) callee = functions.get(callee_name) if callee is None: continue callee_params = getattr( callee, "parameters", getattr(callee, "params", []) ) args = getattr(call, "arguments", getattr(call, "args", [])) for index, arg in enumerate(args): if index >= len(callee_params): continue param = callee_params[index] param_name = getattr(param, "name", None) declared_type = declared_param_types.get(callee_name, {}).get( param_name ) if not param_name or declared_type is None: continue arg_name = self.expression_name(arg) arg_type = caller_visible_types.get(arg_name) if arg_type is None: continue fixed_type = self.fixed_type_for_unsized_resource_param( declared_type, arg_type ) if fixed_type is None: continue existing = hints.setdefault(callee_name, {}).get(param_name) if existing is not None and existing != fixed_type: raise ValueError( "Conflicting SPIR-V resource array parameter sizes for " f"'{param_name}': {existing} and {fixed_type}" ) if existing != fixed_type: hints[callee_name][param_name] = fixed_type changed = True return { func_name: param_hints for func_name, param_hints in hints.items() if param_hints } def split_outer_array_type(self, type_name: str): if not type_name or "[" not in type_name or not type_name.endswith("]"): return None open_bracket = type_name.find("[") close_bracket = type_name.find("]", open_bracket) if close_bracket == -1: return None base_type = type_name[:open_bracket] remaining_suffix = type_name[close_bracket + 1 :] element_type = ( f"{base_type}{remaining_suffix}" if remaining_suffix else base_type ) size_text = type_name[open_bracket + 1 : close_bracket].strip() if not size_text: return element_type, None try: return element_type, int(size_text) except ValueError: return element_type, None def array_base_type_name(self, type_name: str): if not type_name or "[" not in type_name: return type_name return type_name[: type_name.find("[")] def get_variable_value(self, variable_id: SpirvId) -> SpirvId: value_type = self.variable_value_types.get(variable_id.id) if value_type: return self.load_from_variable(variable_id, value_type) if variable_id.type.storage_class: base_type = variable_id.type.base_type.replace("ptr_", "", 1) var_type = self.find_registered_type_by_base(base_type) if var_type: return self.load_from_variable(variable_id, var_type) return variable_id def variable_pointer_from_expression(self, expr) -> Optional[SpirvId]: if isinstance(expr, IdentifierNode): name = expr.name elif isinstance(expr, VariableNode): name = expr.name elif isinstance(expr, str): name = expr elif isinstance(expr, ArrayAccessNode): index = self.process_expression(expr.index) if index is None: return None access, _ = self.create_array_element_access(expr.array, index) return access elif isinstance(expr, MemberAccessNode): base_pointer = self.variable_pointer_from_expression(expr.object) if base_pointer is None: return None return self.create_member_access_pointer(base_pointer, expr.member) else: return None return self.local_variables.get(name) or self.global_variables.get(name) def array_element_type_from_type(self, array_type: Optional[SpirvId]): if array_type is None: return None for (element_type_id, _), arr_type_id in self.array_types.items(): if arr_type_id.id == array_type.id: return self.find_registered_type_by_id(element_type_id) return None def array_type_info_from_type(self, array_type: Optional[SpirvId]): if array_type is None: return None for (element_type_id, size), arr_type_id in self.array_types.items(): if arr_type_id.id == array_type.id: return self.find_registered_type_by_id(element_type_id), size return None def vector_type_info_from_type(self, vector_type: Optional[SpirvId]): if vector_type is None: return None for (component_type_id, count), vec_type_id in self.vector_types.items(): if vec_type_id.id == vector_type.id: return self.find_registered_type_by_id(component_type_id), count return None def matrix_type_info_from_type(self, matrix_type: Optional[SpirvId]): if matrix_type is None: return None for (column_type_id, count), mat_type_id in self.matrix_types.items(): if mat_type_id.id == matrix_type.id: return self.find_registered_type_by_id(column_type_id), count return None def default_value_for_type(self, type_id: SpirvId) -> SpirvId: primitive_name = self.normalize_primitive_name(type_id.type.base_type) if primitive_name in {"float", "double"}: return self.register_constant(0.0, type_id) if primitive_name in {"int", "uint"}: return self.register_constant(0, type_id) if primitive_name == "bool": return self.register_constant(False, type_id) vector_info = self.vector_type_info_from_type(type_id) if vector_info is not None: component_type, count = vector_info components = [ self.default_value_for_type(component_type) for _ in range(count) ] return self.register_vector_constant(type_id, components) matrix_info = self.matrix_type_info_from_type(type_id) if matrix_info is not None: column_type, count = matrix_info columns = [self.default_value_for_type(column_type) for _ in range(count)] return self.register_composite_constant(type_id, columns) array_info = self.array_type_info_from_type(type_id) if array_info is not None: element_type, size = array_info elements = [ self.default_value_for_type(element_type) for _ in range(size or 0) ] return self.register_composite_constant(type_id, elements) members = self.current_struct_members.get(type_id.type.base_type) if members is not None: values = [ self.default_value_for_type(member_type) for member_type, _ in members ] return self.register_composite_constant(type_id, values) return self.register_constant(0.0, self.register_primitive_type("float")) def process_array_literal( self, expr: ArrayLiteralNode, target_type: Optional[SpirvId] = None, constant: bool = False, ) -> Optional[SpirvId]: array_type = target_type element_type = None target_size = None if array_type is not None: array_type = self.ensure_registered_type(array_type) array_info = self.array_type_info_from_type(array_type) if array_info is not None: element_type, target_size = array_info values = [] for element in expr.elements: value = self.process_array_literal_element(element, element_type, constant) if value is None: return None values.append(value) if array_type is None: if values: element_type = self.value_types.get( values[0].id ) or self.find_registered_type_by_base(values[0].type.base_type) if element_type is None: element_type = self.register_primitive_type("float") target_size = len(values) array_type = self.register_array_type(element_type, target_size) if target_size is not None: values = values[:target_size] if element_type is not None: while len(values) < target_size: values.append(self.default_value_for_type(element_type)) if constant: if not all(self.is_constant_instruction(value) for value in values): return None return self.register_composite_constant(array_type, values) id_value = self.get_id() component_list = " ".join(f"%{value.id}" for value in values) self.emit( f"%{id_value} = OpCompositeConstruct %{array_type.id} {component_list}" ) spirv_id = SpirvId(id_value, array_type.type) self.value_types[id_value] = array_type return spirv_id def process_array_literal_element( self, element, target_type: Optional[SpirvId], constant: bool, ) -> Optional[SpirvId]: if isinstance(element, ArrayLiteralNode): return self.process_array_literal(element, target_type, constant) if constant: return self.process_constant_expression(element, target_type) return self.process_expression(element) def process_constant_expression( self, expr, target_type: Optional[SpirvId] = None, ) -> Optional[SpirvId]: if isinstance(expr, ArrayLiteralNode): return self.process_array_literal(expr, target_type, constant=True) if isinstance(expr, LiteralNode): return self.process_expression(expr) if isinstance(expr, FunctionCallNode): callee_expr = getattr(expr, "function", getattr(expr, "name", None)) callee_name = getattr(callee_expr, "name", callee_expr) vector_type = target_type if vector_type is None and isinstance(callee_name, str): if self.vector_component_type_and_count(callee_name) is not None: vector_type = self.map_crossgl_type(callee_name) if vector_type is not None: vector_info = self.vector_type_info_from_type(vector_type) if vector_info is not None: component_type, component_count = vector_info components = [] for arg in getattr(expr, "args", []): value = self.process_constant_expression(arg, component_type) if value is None: return None components.append(value) if len(components) == 1 and component_count > 1: components *= component_count components = components[:component_count] while len(components) < component_count: components.append(self.default_value_for_type(component_type)) return self.register_vector_constant(vector_type, components) value = self.process_expression(expr) if value is not None and self.is_constant_instruction(value): return value return None def ensure_assignable_pointer_for_name(self, name: str) -> Optional[SpirvId]: var_id = self.local_variables.get(name) if var_id is not None: if var_id.type.storage_class: return var_id value_type = self.value_types.get(var_id.id) if value_type is None: return None mutable_id = self.create_variable(value_type, "Function", name) self.store_to_variable(mutable_id, var_id) self.local_variables[name] = mutable_id return mutable_id return self.global_variables.get(name) def assignable_pointer_from_expression(self, expr) -> Optional[SpirvId]: if isinstance(expr, IdentifierNode): return self.ensure_assignable_pointer_for_name(expr.name) if isinstance(expr, VariableNode): return self.ensure_assignable_pointer_for_name(expr.name) if isinstance(expr, str): return self.ensure_assignable_pointer_for_name(expr) if isinstance(expr, MemberAccessNode): base_pointer = self.assignable_pointer_from_expression(expr.object) if base_pointer is None: return None return self.create_member_access_pointer(base_pointer, expr.member) if isinstance(expr, ArrayAccessNode): index = self.process_expression(expr.index) if index is None: return None array_variable = self.assignable_pointer_from_expression(expr.array) if array_variable is None: return None array_type = self.variable_value_types.get(array_variable.id) element_type = self.array_element_type_from_type(array_type) if element_type is None: element_type = self.determine_array_element_type(array_variable) if element_type is None: return None storage_class = array_variable.type.storage_class or "Function" ptr_type = self.register_pointer_type(element_type, storage_class) access = self.access_chain(array_variable, [index], ptr_type) self.variable_value_types[access.id] = element_type return access return None def is_resource_array_type(self, array_type: Optional[SpirvId]) -> bool: element_type = self.array_element_type_from_type(array_type) while element_type is not None: if element_type.id in self.resource_type_metadata: return True element_type = self.array_element_type_from_type(element_type) return False def create_array_element_access(self, array_expr, index: SpirvId): array_variable = self.variable_pointer_from_expression(array_expr) if array_variable is None or not array_variable.type.storage_class: addressable_array = self.assignable_pointer_from_expression(array_expr) if addressable_array is not None: array_variable = addressable_array if array_variable is not None: array_type = self.variable_value_types.get(array_variable.id) element_type = self.array_element_type_from_type(array_type) if element_type is None: element_type = self.determine_array_element_type(array_variable) if element_type is None: return None, None storage_class = array_variable.type.storage_class or "Function" ptr_type = self.register_pointer_type(element_type, storage_class) access = self.access_chain(array_variable, [index], ptr_type) self.variable_value_types[access.id] = element_type return access, element_type array = self.process_expression(array_expr) if array is None: return None, None element_type = self.determine_array_element_type(array) if element_type is None: return None, None array_type = self.value_types.get( array.id ) or self.find_registered_type_by_base(array.type.base_type) if array_type is not None and not array.type.storage_class: array_variable = self.create_variable(array_type, "Function") self.store_to_variable(array_variable, array) storage_class = array_variable.type.storage_class or "Function" ptr_type = self.register_pointer_type(element_type, storage_class) access = self.access_chain(array_variable, [index], ptr_type) self.variable_value_types[access.id] = element_type return access, element_type storage_class = array.type.storage_class or "Function" ptr_type = self.register_pointer_type(element_type, storage_class) access = self.access_chain(array, [index], ptr_type) self.variable_value_types[access.id] = element_type return access, element_type def process_assignment(self, node: AssignmentNode): """Process a CrossGL assignment statement.""" target = getattr( node, "name", getattr(node, "target", getattr(node, "left", None)) ) if isinstance(node.value, ArrayLiteralNode): target_pointer = self.assignable_pointer_from_expression(target) target_type = ( self.variable_value_types.get(target_pointer.id) if target_pointer is not None else None ) rhs_value = self.process_array_literal(node.value, target_type) else: rhs_value = self.process_expression(node.value) if rhs_value is None: return if isinstance(target, IdentifierNode): target = target.name if isinstance(target, str): var_id = self.ensure_assignable_pointer_for_name(target) if var_id is None: var_type = self.primitive_types["float"] if hasattr(rhs_value, "type"): var_type = ( self.find_registered_type_by_base(rhs_value.type.base_type) or var_type ) var_id = self.create_variable(var_type, "Function", target) self.local_variables[target] = var_id self.store_to_variable(var_id, rhs_value) elif isinstance(target, MemberAccessNode): base_pointer = self.assignable_pointer_from_expression(target.object) if base_pointer is None: return member_name = target.member access = self.create_member_access_pointer(base_pointer, member_name) if access is not None: self.store_to_variable(access, rhs_value) return # Default handling if member not found struct_type = self.struct_type_name_from_pointer(base_pointer) self.emit( f"; WARNING: Could not find member {member_name} in {struct_type}" ) elif isinstance(target, ArrayAccessNode): access = self.assignable_pointer_from_expression(target) element_type = ( self.variable_value_types.get(access.id) if access is not None else None ) if access is None or element_type is None: self.emit( f"; WARNING: Could not determine array element type for {target.array}" ) return self.store_to_variable(access, rhs_value) else: self.emit( f"; WARNING: Unsupported LHS type in assignment: {type(target).__name__}" ) def process_return(self, node: ReturnNode): """Process a CrossGL return statement.""" if hasattr(node, "value") and node.value: if isinstance(node.value, list) and node.value: return_value = self.process_expression(node.value[0]) if return_value: self.create_return_value(return_value) else: self.create_return() else: if isinstance(node.value, ArrayLiteralNode): return_value = self.process_array_literal( node.value, self.current_return_type ) else: return_value = self.process_expression(node.value) if return_value: self.create_return_value(return_value) else: self.create_return() else: self.create_return() def process_if(self, node: IfNode): """Process a CrossGL if statement.""" condition = self.process_expression(node.if_condition) if condition is None: condition = self.register_constant(True, self.primitive_types["bool"]) merge_label = SpirvId(self.get_id(), SpirvType("label")) then_label = SpirvId(self.get_id(), SpirvType("label")) else_label = SpirvId(self.get_id(), SpirvType("label")) self.create_selection_merge(merge_label) self.create_conditional_branch(condition, then_label, else_label) self.emit(f"%{then_label.id} = OpLabel") self.current_label = then_label.id self.process_statements(node.if_body) if not self.current_block_has_terminator(): self.create_branch(merge_label) self.emit(f"%{else_label.id} = OpLabel") self.current_label = else_label.id if node.else_body: self.process_statements(node.else_body) if not self.current_block_has_terminator(): self.create_branch(merge_label) self.emit(f"%{merge_label.id} = OpLabel") self.current_label = merge_label.id def process_for(self, node: ForNode): """Process a CrossGL for loop.""" if node.init: self.process_statement(node.init) header_label = SpirvId(self.get_id(), SpirvType("label")) body_label = SpirvId(self.get_id(), SpirvType("label")) continue_label = SpirvId(self.get_id(), SpirvType("label")) merge_label = SpirvId(self.get_id(), SpirvType("label")) self.create_branch(header_label) self.emit(f"%{header_label.id} = OpLabel") self.current_label = header_label.id condition = self.process_expression(node.condition) if condition is None: condition = self.register_constant(True, self.primitive_types["bool"]) self.create_loop_merge(merge_label, continue_label) self.create_conditional_branch(condition, body_label, merge_label) self.emit(f"%{body_label.id} = OpLabel") self.current_label = body_label.id self.loop_merge_labels.append(merge_label) self.loop_continue_labels.append(continue_label) try: if node.body: self.process_statements(node.body) if not self.current_block_has_terminator(): self.create_branch(continue_label) finally: self.loop_continue_labels.pop() self.loop_merge_labels.pop() self.emit(f"%{continue_label.id} = OpLabel") self.current_label = continue_label.id if node.update: self.process_statement(node.update) if not self.current_block_has_terminator(): self.create_branch(header_label) self.emit(f"%{merge_label.id} = OpLabel") self.current_label = merge_label.id def process_while(self, node: WhileNode): """Process a CrossGL while loop.""" header_label = SpirvId(self.get_id(), SpirvType("label")) body_label = SpirvId(self.get_id(), SpirvType("label")) continue_label = SpirvId(self.get_id(), SpirvType("label")) merge_label = SpirvId(self.get_id(), SpirvType("label")) self.create_branch(header_label) self.emit(f"%{header_label.id} = OpLabel") self.current_label = header_label.id condition = self.process_expression(node.condition) if condition is None: condition = self.register_constant(True, self.primitive_types["bool"]) self.create_loop_merge(merge_label, continue_label) self.create_conditional_branch(condition, body_label, merge_label) self.emit(f"%{body_label.id} = OpLabel") self.current_label = body_label.id self.loop_merge_labels.append(merge_label) self.loop_continue_labels.append(continue_label) try: if node.body: self.process_statements(node.body) if not self.current_block_has_terminator(): self.create_branch(continue_label) finally: self.loop_continue_labels.pop() self.loop_merge_labels.pop() self.emit(f"%{continue_label.id} = OpLabel") self.current_label = continue_label.id self.create_branch(header_label) self.emit(f"%{merge_label.id} = OpLabel") self.current_label = merge_label.id def process_break(self, node: BreakNode): """Process a CrossGL break statement.""" if not self.loop_merge_labels: self.emit("; WARNING: break used outside a loop") return self.create_branch(self.loop_merge_labels[-1]) def process_continue(self, node: ContinueNode): """Process a CrossGL continue statement.""" if not self.loop_continue_labels: self.emit("; WARNING: continue used outside a loop") return self.create_branch(self.loop_continue_labels[-1]) def process_increment_expression(self, node: UnaryOpNode) -> SpirvId: """Process prefix/postfix ++ and -- as load/update/store operations.""" variable_id = self.variable_pointer_from_expression(node.operand) if variable_id is None: self.emit("; WARNING: increment target is not assignable") int_type = self.register_primitive_type("int") return self.register_constant(0, int_type) value_type = self.variable_value_types.get(variable_id.id) if value_type is None: value_type = self.find_registered_type_by_base( variable_id.type.base_type.replace("ptr_", "", 1) ) if value_type is None: value_type = self.register_primitive_type("int") old_value = self.load_from_variable(variable_id, value_type) step_value = ( self.register_constant(1.0, value_type) if value_type.type.base_type == "float" else self.register_constant(1, value_type) ) operator = "+" if node.op == "++" else "-" new_value = self.binary_operation(operator, value_type, old_value, step_value) self.store_to_variable(variable_id, new_value) if getattr(node, "is_postfix", getattr(node, "postfix", False)): return old_value return new_value def process_expression(self, expr) -> Optional[SpirvId]: """Process a CrossGL expression.""" if expr is None: return None if isinstance(expr, bool): bool_type = self.register_primitive_type("bool") return self.register_constant(expr, bool_type) elif isinstance(expr, int): int_type = self.register_primitive_type("int") return self.register_constant(expr, int_type) elif isinstance(expr, float): float_type = self.register_primitive_type("float") return self.register_constant(expr, float_type) elif isinstance(expr, str): if expr in self.local_variables: var_id = self.local_variables[expr] return self.get_variable_value(var_id) elif expr in self.global_variables: return self.get_variable_value(self.global_variables[expr]) else: # Create a default float constant for missing variables in examples # This is to make the SPIR-V code valid even if we can't find the variable if expr.replace(".", "", 1).isdigit(): # Check if it's a numeric string float_type = self.register_primitive_type("float") try: value = float(expr) return self.register_constant(value, float_type) except ValueError: pass self.emit(f"; WARNING: Unknown variable {expr}") # Return a default value instead of None float_type = self.register_primitive_type("float") return self.register_constant(0.0, float_type) elif isinstance(expr, LiteralNode): literal_type = self.convert_type_node_to_string(expr.literal_type) primitive_type_name = self.normalize_primitive_name(literal_type) if primitive_type_name in {"float", "double"}: literal_type_id = self.register_primitive_type(primitive_type_name) return self.register_constant(float(expr.value), literal_type_id) if primitive_type_name in {"int", "uint"}: literal_type_id = self.register_primitive_type(primitive_type_name) return self.register_constant(int(expr.value), literal_type_id) if primitive_type_name == "bool": literal_type_id = self.register_primitive_type("bool") if isinstance(expr.value, str): value = expr.value.lower() == "true" else: value = bool(expr.value) return self.register_constant(value, literal_type_id) return self.process_expression(expr.value) elif isinstance(expr, IdentifierNode): return self.process_expression(expr.name) elif isinstance(expr, VariableNode): if expr.name in self.local_variables: var_id = self.local_variables[expr.name] return self.get_variable_value(var_id) elif expr.name in self.global_variables: return self.get_variable_value(self.global_variables[expr.name]) else: self.emit(f"; WARNING: Unknown variable {expr.name}") # Return a default value instead of None float_type = self.register_primitive_type("float") return self.register_constant(0.0, float_type) elif isinstance(expr, ArrayLiteralNode): return self.process_array_literal(expr) # Array access elif isinstance(expr, ArrayAccessNode): index = self.process_expression(expr.index) if index is None: self.emit(f"; WARNING: Failed to evaluate array access") float_type = self.register_primitive_type("float") return self.register_constant(0.0, float_type) access, element_type = self.create_array_element_access(expr.array, index) if access is None or element_type is None: self.emit( f"; WARNING: Could not determine array element type for {expr.array}" ) element_type = self.primitive_types["float"] return self.register_constant(0.0, element_type) return self.load_from_variable(access, element_type) elif isinstance(expr, BinaryOpNode): left = self.process_expression(expr.left) right = self.process_expression(expr.right) if left is None or right is None: # Return a default value instead of None float_type = self.register_primitive_type("float") return self.register_constant(0.0, float_type) # Determine result type result_type = left.type # Default to left operand's type return self.binary_operation( expr.op, self.map_crossgl_type(result_type.base_type), left, right ) elif isinstance(expr, UnaryOpNode): if expr.op in {"++", "--"}: return self.process_increment_expression(expr) operand = self.process_expression(expr.operand) if operand is None: # Return a default value instead of None float_type = self.register_primitive_type("float") return self.register_constant(0.0, float_type) return self.unary_operation(expr.op, operand.type, operand) elif isinstance(expr, TernaryOpNode): condition = self.process_expression(expr.condition) true_value = self.process_expression(expr.true_expr) false_value = self.process_expression(expr.false_expr) if condition is None: condition = self.register_constant( False, self.register_primitive_type("bool") ) if true_value is None or false_value is None: float_type = self.register_primitive_type("float") fallback = self.register_constant(0.0, float_type) true_value = true_value or fallback false_value = false_value or fallback result_type = self.map_crossgl_type(true_value.type.base_type) return self.select_operation( result_type, condition, true_value, false_value ) elif isinstance(expr, FunctionCallNode): callee_expr = getattr(expr, "function", getattr(expr, "name", None)) callee_name = None if hasattr(callee_expr, "name"): callee_name = callee_expr.name elif isinstance(callee_expr, str): callee_name = callee_expr # Evaluate arguments args = [] has_errors = False for arg_index, arg in enumerate(expr.args): arg_value = self.process_call_argument(callee_name, arg, arg_index) if arg_value is None: self.emit( f"; WARNING: Failed to evaluate argument for {callee_name or callee_expr}" ) has_errors = True # Create a default argument float_type = self.register_primitive_type("float") arg_value = self.register_constant(0.0, float_type) args.append(arg_value) if has_errors and callee_name == "vec2": # Special handling for vec2 constructor with errors float_type = self.register_primitive_type("float") vector_type = self.register_vector_type(float_type, 2) id_value = self.get_id() # Create default values if needed while len(args) < 2: args.append(self.register_constant(0.0, float_type)) arg_list = " ".join([f"%{arg.id}" for arg in args[:2]]) self.emit( f"%{id_value} = OpCompositeConstruct %{vector_type.id} {arg_list}" ) return SpirvId(id_value, vector_type.type) if callee_name is None: # Non-identifier callee (e.g., function table call) not supported in SPIR-V path self.emit("; WARNING: Unsupported callee expression in SPIR-V backend") float_type = self.register_primitive_type("float") return self.register_constant(0.0, float_type) if callee_name in self.resource_function_names(): return self.call_resource_function(callee_name, args) return self.call_function(callee_name, args) elif isinstance(expr, MemberAccessNode): member_name = expr.member base_pointer = self.variable_pointer_from_expression(expr.object) if base_pointer is not None: access = self.create_member_access_pointer(base_pointer, member_name) if access is not None: member_type = self.variable_value_types.get(access.id) return self.load_from_variable(access, member_type) base = self.process_expression(expr.object) if base is None: return None struct_type = base.type.base_type member_info = self.struct_member_info(struct_type, member_name) if member_info is not None: member_index, member_type = member_info return self.composite_extract(base, member_type, member_index) # Default handling if member not found self.emit( f"; WARNING: Could not find member {member_name} in {struct_type}" ) return None else: self.emit(f"; WARNING: Unknown expression type {type(expr).__name__}") return None def register_input( self, name: str, type_id: SpirvId, location: int, binding: int ) -> SpirvId: """Register an input variable with location decoration.""" ptr_type = self.register_pointer_type(type_id, "Input") id_value = self.get_id() self.emit(f"%{id_value} = OpVariable %{ptr_type.id} Input") self.decorations.append(f"OpDecorate %{id_value} Location {location}") if name: self.emit(f'OpName %{id_value} "{name}"') spirv_id = SpirvId(id_value, ptr_type.type, name) self.variable_value_types[id_value] = type_id self.inputs.append(spirv_id) return spirv_id def register_output( self, name: str, type_id: SpirvId, location: int, binding: int ) -> SpirvId: """Register an output variable with location decoration.""" ptr_type = self.register_pointer_type(type_id, "Output") id_value = self.get_id() self.emit(f"%{id_value} = OpVariable %{ptr_type.id} Output") self.decorations.append(f"OpDecorate %{id_value} Location {location}") if name: self.emit(f'OpName %{id_value} "{name}"') spirv_id = SpirvId(id_value, ptr_type.type, name) self.variable_value_types[id_value] = type_id self.outputs.append(spirv_id) return spirv_id def register_array_type( self, element_type: SpirvId, size: Optional[int] = None ) -> SpirvId: """Create and register an array type.""" key = (element_type.id, size) if key in self.array_types: return self.array_types[key] id_value = self.get_id() if size is not None: size_const = self.register_constant( size, self.register_primitive_type("int") ) self.emit(f"%{id_value} = OpTypeArray %{element_type.id} %{size_const.id}") else: self.emit(f"%{id_value} = OpTypeRuntimeArray %{element_type.id}") type_name = f"array_{element_type.type.base_type}_{size if size else 'rt'}" spirv_type = SpirvType(type_name) spirv_id = SpirvId(id_value, spirv_type, type_name) self.array_types[key] = spirv_id return spirv_id def determine_array_element_type(self, array_id: "SpirvId") -> Optional["SpirvId"]: """Determine the element type of an array based on its SpirvId. Args: array_id: The SpirvId of the array Returns: SpirvId of the element type, or None if it cannot be determined """ if ( not array_id or not hasattr(array_id, "type") or not hasattr(array_id.type, "base_type") ): return None array_type = array_id.type.base_type # Check if it's a known array type in our registry for (element_type_id, _), arr_type_id in self.array_types.items(): if arr_type_id.type.base_type == array_type: return self.find_registered_type_by_id(element_type_id) # If it's a pointer type, extract the base type if array_type.startswith("ptr_"): base_type = array_type.replace("ptr_", "", 1) for (element_type_id, _), arr_type_id in self.array_types.items(): if arr_type_id.type.base_type == base_type: return self.find_registered_type_by_id(element_type_id) # Look for array type pattern in the base type match = re.search(r"array_([^_]+)_", base_type) if match: element_type_name = match.group(1) # Look up the element type ID for type_dict in [ self.primitive_types, self.vector_types, self.matrix_types, ]: for type_id in type_dict.values(): if type_id.type.base_type == element_type_name: return type_id # Last resort: Try to parse from type name for type_dict in [ self.primitive_types, self.vector_types, self.matrix_types, ]: for type_id in type_dict.values(): # Check if type name is a substring of the array type if type_id.type.base_type in array_type: return type_id # Default to float if we can't determine the element type return self.primitive_types["float"] def get_function_qualifier(self, func) -> Optional[str]: """Return the shader-stage qualifier from old or new function AST shapes.""" if hasattr(func, "qualifiers") and func.qualifiers: return func.qualifiers[0] if func.qualifiers else None if hasattr(func, "qualifier"): return func.qualifier return None def stage_key(self, stage_type) -> str: """Normalize a stage enum or string to a registry key.""" if hasattr(stage_type, "value"): return stage_type.value return str(stage_type).split(".")[-1].lower() def spirv_execution_model(self, stage_name: Optional[str]) -> str: """Map a CrossGL stage name to a SPIR-V execution model.""" stage_map = { "vertex": "Vertex", "fragment": "Fragment", "compute": "GLCompute", "geometry": "Geometry", "tessellation_control": "TessellationControl", "tessellation_evaluation": "TessellationEvaluation", } return stage_map.get(stage_name or "fragment", "Fragment") def compute_local_size(self, stage) -> Tuple[int, int, int]: """Return compute workgroup dimensions from a stage execution config.""" config = getattr(stage, "execution_config", {}) or {} for key in ("local_size", "workgroup_size", "numthreads"): value = config.get(key) if isinstance(value, (list, tuple)) and len(value) >= 3: return int(value[0]), int(value[1]), int(value[2]) return ( int(config.get("local_size_x", 1)), int(config.get("local_size_y", 1)), int(config.get("local_size_z", 1)), ) def emit_entry_point( self, execution_model: str, function_id: SpirvId, name: str, stage=None ): """Emit SPIR-V entry-point and execution-mode declarations.""" interface_ids = " ".join( f"%{variable.id}" for variable in self.inputs + self.outputs ) interface_suffix = f" {interface_ids}" if interface_ids else "" self.emit( f'OpEntryPoint {execution_model} %{function_id.id} "{name}"' f"{interface_suffix}" ) if execution_model == "Fragment": self.emit(f"OpExecutionMode %{function_id.id} OriginUpperLeft") elif execution_model == "GLCompute": x, y, z = self.compute_local_size(stage) if self.requires_compute_derivatives: x = max(2, x + (x % 2)) y = max(2, y + (y % 2)) self.emit(f"OpExecutionMode %{function_id.id} LocalSize {x} {y} {z}") if self.requires_compute_derivatives: self.emit(f"OpExecutionMode %{function_id.id} DerivativeGroupQuadsKHR") def ordered_module_lines(self) -> List[str]: """Return SPIR-V assembly lines in logical module-layout order.""" header_lines = self.code_lines[:3] bound_line = f"; Bound: {self.next_id}" raw_lines = self.code_lines[4:] capabilities = ["OpCapability Shader"] + [ f"OpCapability {capability}" for capability in sorted(self.required_capabilities) ] extensions = [ f'OpExtension "{extension}"' for extension in sorted(self.required_extensions) ] imports = [] memory_model = [] entry_points = [] execution_modes = [] debug_names = [] annotations = [] declarations = [] global_variables = [] body = [] for line in raw_lines: if line.startswith("OpCapability "): if line not in capabilities: capabilities.append(line) elif line.startswith("OpExtension "): if line not in extensions: extensions.append(line) elif " = OpExtInstImport " in line: imports.append(line) elif line.startswith("OpMemoryModel "): memory_model.append(line) elif line.startswith("OpEntryPoint "): entry_points.append(line) elif line.startswith("OpExecutionMode"): execution_modes.append(line) elif line.startswith(("OpName ", "OpMemberName ", "OpString ", "OpLine ")): debug_names.append(line) elif line.startswith(("OpDecorate ", "OpMemberDecorate ")): annotations.append(line) elif re.match(r"%\d+ = (OpType|OpConstant|OpSpecConstant|OpUndef)", line): declarations.append(line) elif re.match(r"%\d+ = OpVariable %\d+ (?!Function\b)", line): global_variables.append(line) else: body.append(line) annotations.extend(self.decorations) def unique(lines: List[str]) -> List[str]: seen = set() result = [] for line in lines: if line in seen: continue seen.add(line) result.append(line) return result body = self.ordered_function_body_lines(body) return ( header_lines + [bound_line] + unique(capabilities) + unique(extensions) + imports + memory_model + entry_points + execution_modes + unique(debug_names) + unique(annotations) + declarations + global_variables + body ) def ordered_function_body_lines(self, lines: List[str]) -> List[str]: """Move function-scope variables into each function's first block.""" ordered = [] in_function = False first_block_insert_at = None pending_variables = [] def is_function_variable(line: str) -> bool: return re.match(r"%\d+ = OpVariable %\d+ Function\b", line) is not None def is_function_start(line: str) -> bool: return re.match(r"%\d+ = OpFunction\b", line) is not None def is_label(line: str) -> bool: return re.match(r"%\d+ = OpLabel\b", line) is not None for line in lines: if is_function_start(line): in_function = True first_block_insert_at = None pending_variables = [] ordered.append(line) continue if in_function and line == "OpFunctionEnd": if pending_variables: if first_block_insert_at is None: ordered.extend(pending_variables) else: ordered[first_block_insert_at:first_block_insert_at] = ( pending_variables ) ordered.append(line) in_function = False first_block_insert_at = None pending_variables = [] continue if in_function and is_function_variable(line): if first_block_insert_at is None: pending_variables.append(line) else: ordered.insert(first_block_insert_at, line) first_block_insert_at += 1 continue ordered.append(line) if in_function and first_block_insert_at is None and is_label(line): first_block_insert_at = len(ordered) if pending_variables: ordered[first_block_insert_at:first_block_insert_at] = ( pending_variables ) first_block_insert_at += len(pending_variables) pending_variables = [] return ordered
[docs] def generate(self, ast): """Generate SPIR-V code from a CrossGL AST.""" if not isinstance(ast, ShaderNode): return "; Error: Not a shader node" self.reset_generation_state() self.emit("; SPIR-V") self.emit("; Version: 1.0") self.emit("; Generator: CrossGL Vulkan SPIR-V Generator") self.emit("; Schema: 0") self.emit("OpCapability Shader") self.glsl_std450_id = self.get_id() self.emit(f'%{self.glsl_std450_id} = OpExtInstImport "GLSL.std.450"') self.emit("OpMemoryModel Logical GLSL450") self.register_primitive_type("void") self.register_primitive_type("bool") self.register_primitive_type("int") self.register_primitive_type("float") float_type = self.primitive_types["float"] for i in range(2, 5): self.register_vector_type(float_type, i) for struct in ast.structs: self.process_crossgl_struct(struct) self.function_resource_array_type_hints = ( self.collect_resource_array_parameter_type_hints(ast) ) self.function_execution_models = self.collect_function_execution_models(ast) for var in getattr(ast, "global_variables", []): self.process_global_variable_declaration(var) top_level_entries = [] for func in ast.functions: qualifier = self.get_function_qualifier(func) if func.name == "main" or qualifier in [ "vertex", "fragment", "compute", "geometry", "tessellation_control", "tessellation_evaluation", ]: top_level_entries.append((func, qualifier)) else: # Helper function self.process_function_node(func) entry_points = [] if getattr(ast, "stages", None): for stage in ast.stages.values(): for var in getattr(stage, "local_variables", []): if var.name not in self.global_variables: self.process_global_variable_declaration(var) processed_local_functions = set() for stage in ast.stages.values(): for func in getattr(stage, "local_functions", []): if id(func) not in processed_local_functions: self.process_function_node(func) processed_local_functions.add(id(func)) for stage_type, stage in ast.stages.items(): entry_function = stage.entry_point function_id = self.process_function_node(entry_function) stage_name = self.stage_key(stage_type) execution_model = self.spirv_execution_model(stage_name) entry_points.append( (execution_model, function_id, entry_function.name, stage) ) else: for func, qualifier in top_level_entries: function_id = self.process_function_node(func) execution_model = self.spirv_execution_model(qualifier) entry_points.append((execution_model, function_id, func.name, None)) if entry_points: self.main_fn_id = entry_points[0][1].id for execution_model, function_id, entry_name, stage in entry_points: self.emit_entry_point(execution_model, function_id, entry_name, stage) return "\n".join(self.ordered_module_lines())