Source code for crosstl.translator.codegen.vector_arithmetic

"""Utilities for lowering vector arithmetic expressions during code generation."""

from ..ast import (
    ArrayAccessNode,
    BinaryOpNode,
    FunctionCallNode,
    IdentifierNode,
    LiteralNode,
    MemberAccessNode,
    TernaryOpNode,
    UnaryOpNode,
    VariableNode,
)


[docs] class VectorArithmeticMixin: """Helpers for inferring and lowering vector arithmetic expressions."""
[docs] def collect_function_return_types(self, ast_node): """Collect function return types from global and stage-local functions.""" function_return_types = {} for func in getattr(ast_node, "functions", []): function_return_types[func.name] = getattr(func, "return_type", None) for stage in getattr(ast_node, "stages", {}).values(): entry_point = getattr(stage, "entry_point", None) if entry_point is not None: function_return_types[entry_point.name] = getattr( entry_point, "return_type", None ) for func in getattr(stage, "local_functions", []): function_return_types[func.name] = getattr(func, "return_type", None) return function_return_types
[docs] def resource_call_result_type(self, func_name, raw_args): """Infer the result type of resource-related intrinsic calls.""" if not isinstance(func_name, str): return None resource_type = None if raw_args: resource_type = self.resource_base_type( self.get_expression_type(raw_args[0]) ) if func_name in {"textureSize", "imageSize"} and resource_type: spec = self.dimension_query_spec(resource_type) if spec is None: return None dimensions = spec[0] return self.query_return_type(dimensions) if func_name in {"textureSamples", "imageSamples", "textureQueryLevels"}: return "int" if func_name == "textureQueryLod": return "float2" if func_name in { "textureGather", "textureGatherOffset", "textureGatherOffsets", "textureGatherCompare", "textureGatherCompareOffset", }: return "float4" if func_name in { "textureCompare", "textureCompareLod", "textureCompareGrad", "textureCompareOffset", "textureCompareLodOffset", "textureCompareGradOffset", }: return "float" if func_name in {"texture", "textureLod", "textureGrad", "texelFetch"}: if self.is_shadow_resource_type(resource_type): return "float" return "float4" if func_name == "imageLoad" and resource_type: return self.image_value_type(resource_type) return None
[docs] def expression_result_type(self, node): """Infer the best-effort result type for an expression node.""" if node is None: return None if isinstance(node, (IdentifierNode, VariableNode, ArrayAccessNode)): return self.get_expression_type(node) if isinstance(node, LiteralNode): literal_type = getattr(getattr(node, "literal_type", None), "name", None) if literal_type: return literal_type if isinstance(node.value, bool): return "bool" if isinstance(node.value, float): return "float" if isinstance(node.value, int): return "int" return None if isinstance(node, FunctionCallNode): func_expr = getattr(node, "function", getattr(node, "name", None)) func_name = getattr(func_expr, "name", func_expr) if isinstance(func_name, str) and self.vector_type_info(func_name): return func_name resource_result_type = self.resource_call_result_type( func_name, getattr(node, "arguments", getattr(node, "args", [])) ) if resource_result_type is not None: return resource_result_type if isinstance(func_name, str): return self.function_return_types.get(func_name) return None if isinstance(node, BinaryOpNode): left_type = self.expression_result_type(node.left) right_type = self.expression_result_type(node.right) if self.vector_type_info(left_type): return left_type if self.vector_type_info(right_type): return right_type return left_type or right_type if isinstance(node, UnaryOpNode): return self.expression_result_type(node.operand) if isinstance(node, TernaryOpNode): return self.expression_result_type( node.true_expr ) or self.expression_result_type(node.false_expr) if isinstance(node, MemberAccessNode): object_expr = getattr(node, "object_expr", getattr(node, "object", None)) object_type = self.expression_result_type(object_expr) object_type_name = ( self.convert_type_node_to_string(object_type) if object_type is not None and not isinstance(object_type, str) else object_type ) member = getattr(node, "member", "") struct_members = self.struct_member_types.get(object_type_name, {}) if member in struct_members: return struct_members[member] vector_info = self.vector_type_info(object_type) if not vector_info: return None if len(member) == 1: return vector_info["component_type"] if all(component in "xyzwrgba" for component in member): return self.vector_type_for_components( vector_info["component_type"], len(member) ) return None
[docs] def vector_type_info(self, type_name): """Return constructor and component metadata for a vector type.""" if type_name is None: return None if not isinstance(type_name, str): type_name = self.convert_type_node_to_string(type_name) mapped_type = self.map_vector_arithmetic_type(type_name) vector_details = { "float2": ("make_float2", "float", ("x", "y")), "float3": ("make_float3", "float", ("x", "y", "z")), "float4": ("make_float4", "float", ("x", "y", "z", "w")), "double2": ("make_double2", "double", ("x", "y")), "double3": ("make_double3", "double", ("x", "y", "z")), "double4": ("make_double4", "double", ("x", "y", "z", "w")), "int2": ("make_int2", "int", ("x", "y")), "int3": ("make_int3", "int", ("x", "y", "z")), "int4": ("make_int4", "int", ("x", "y", "z", "w")), "uint2": ("make_uint2", "uint", ("x", "y")), "uint3": ("make_uint3", "uint", ("x", "y", "z")), "uint4": ("make_uint4", "uint", ("x", "y", "z", "w")), "uchar2": ("make_uchar2", "bool", ("x", "y")), "uchar3": ("make_uchar3", "bool", ("x", "y", "z")), "uchar4": ("make_uchar4", "bool", ("x", "y", "z", "w")), } details = vector_details.get(mapped_type) if details is None: return None constructor, component_type, components = details return { "type": mapped_type, "constructor": constructor, "component_type": component_type, "components": components, }
[docs] def vector_type_for_components(self, component_type, component_count): """Return a vector type name for a component type/count pair.""" if component_count < 2 or component_count > 4: return component_type prefixes = { "float": "vec", "double": "dvec", "int": "ivec", "uint": "uvec", "bool": "bvec", } prefix = prefixes.get(component_type) if prefix is None: return None return f"{prefix}{component_count}"
[docs] def lower_vector_binary_operation( self, left_node, left_expr, right_node, right_expr, operator, ): """Lower vector binary arithmetic into a helper call when required.""" if operator not in {"+", "-", "*", "/"}: return None left_type = self.expression_result_type(left_node) right_type = self.expression_result_type(right_node) left_info = self.vector_type_info(left_type) right_info = self.vector_type_info(right_type) if not left_info and not right_info: return None if left_info and right_info: if len(left_info["components"]) != len(right_info["components"]): return None helper_name = self.require_vector_binary_helper( left_info, operator, "vector" ) if helper_name is None: return None return f"{helper_name}({left_expr}, {right_expr})" if left_info and right_type is not None: helper_name = self.require_vector_binary_helper( left_info, operator, "scalar_right" ) if helper_name is None: return None return f"{helper_name}({left_expr}, {right_expr})" if right_info and left_type is not None: helper_name = self.require_vector_binary_helper( right_info, operator, "scalar_left" ) if helper_name is None: return None return f"{helper_name}({left_expr}, {right_expr})" return None
[docs] def require_vector_binary_helper(self, vector_info, operator, operand_shape): """Register and return a helper function for vector binary arithmetic.""" if vector_info["component_type"] == "bool": return None operator_names = { "+": "add", "-": "sub", "*": "mul", "/": "div", } operation_name = operator_names[operator] helper_name = f"cgl_{vector_info['type']}_{operation_name}" if operand_shape == "scalar_right": helper_name += "_scalar" elif operand_shape == "scalar_left": helper_name = f"cgl_scalar_{operation_name}_{vector_info['type']}" if helper_name in self.helper_functions: return helper_name vector_type = vector_info["type"] scalar_type = self.vector_scalar_parameter_type(vector_info) components = vector_info["components"] constructor = vector_info["constructor"] if operand_shape == "vector": params = f"{vector_type} lhs, {vector_type} rhs" args = [ f"(lhs.{component} {operator} rhs.{component})" for component in components ] elif operand_shape == "scalar_right": params = f"{vector_type} lhs, {scalar_type} rhs" args = [f"(lhs.{component} {operator} rhs)" for component in components] else: params = f"{scalar_type} lhs, {vector_type} rhs" args = [f"(lhs {operator} rhs.{component})" for component in components] helper = ( f"__device__ inline {vector_type} {helper_name}({params})\n" "{\n" f" return {constructor}({', '.join(args)});\n" "}" ) self.helper_functions[helper_name] = helper return helper_name
[docs] def vector_scalar_parameter_type(self, vector_info): """Return the scalar parameter type used by generated vector helpers.""" if vector_info["component_type"] == "uint": return "unsigned int" return vector_info["component_type"]