Source code for crosstl.translator.codegen.metal_codegen
"""CrossGL-to-Metal code generator."""
from ..ast import (
AssignmentNode,
ArrayNode,
ArrayAccessNode,
BinaryOpNode,
BreakNode,
ContinueNode,
ForInNode,
ForNode,
FunctionCallNode,
IfNode,
LiteralPatternNode,
LoopNode,
MatchNode,
MemberAccessNode,
MeshOpNode,
PreprocessorNode,
RayQueryOpNode,
RayTracingOpNode,
RangeNode,
ReturnNode,
StructNode,
SwitchNode,
TernaryOpNode,
UnaryOpNode,
VariableNode,
WaveOpNode,
WhileNode,
WildcardPatternNode,
)
from .array_utils import (
parse_array_type,
format_array_type,
format_c_style_array_declaration,
split_array_type_suffix,
get_array_size_from_node,
evaluate_literal_int_expression,
collect_literal_int_constants,
collect_struct_member_types,
)
from ..validation import (
collect_cbuffer_declaration_name_conflicts,
collect_cbuffer_member_global_conflicts,
collect_duplicate_cbuffer_member_names,
collect_duplicate_cbuffer_names,
collect_non_resource_global_resource_shadows,
expression_debug_name,
floating_coordinate_dimension,
integer_coordinate_dimension,
is_floating_scalar_type,
is_integer_scalar_type,
is_numeric_scalar_type,
IMAGE_RESOURCE_INTRINSIC_NAMES,
INTEGER_COORDINATE_INTRINSIC_NAMES,
OFFSET_DIMENSION_INTRINSIC_NAMES,
texture_bias_argument_index,
texture_compare_argument_index,
texture_gather_component_argument_index,
texture_gradient_argument_indices,
texture_intrinsic_allowed_argument_counts,
texture_intrinsic_max_argument_count,
texture_intrinsic_min_argument_count,
texture_lod_argument_index,
texture_mip_level_argument_index,
texture_offset_argument_indices,
texture_query_lod_coordinate_argument_index,
texture_sample_index_argument_index,
)
from .stage_utils import (
normalize_stage_name,
should_emit_qualified_function,
stage_matches,
)
from .resource_arrays import collect_resource_array_size_hints
class CharTypeMapper:
"""Normalize CrossGL char-like scalar and vector types for Metal output."""
def map_char_type(self, vtype):
"""Return the Metal-compatible integer type for a char-like type."""
char_type_mapping = {
"char": "int",
"signed char": "int",
"unsigned char": "uint",
"char2": "int2",
"char3": "int3",
"char4": "int4",
"uchar2": "uint2",
"uchar3": "uint3",
"uchar4": "uint4",
}
return char_type_mapping.get(vtype, vtype)
[docs]
class MetalCodeGen:
"""Emit Metal Shading Language from the shared CrossGL translator AST."""
def __init__(self):
"""Initialize Metal type maps and per-generation resource state."""
self.current_shader = None
self.vertex_item = None
self.fragment_item = None
self.gl_position = False
self.char_mapper = CharTypeMapper()
self.texture_variables = []
self.sampler_variables = []
self.cbuffer_variables = []
self.cbuffer_parameter_names = {}
self.cbuffer_member_references = {}
self.ambiguous_cbuffer_members = set()
self.cbuffers_by_name = {}
self.user_function_names = set()
self.function_cbuffer_dependencies = {}
self.function_global_resource_dependencies = {}
self.current_sampler_parameters = set()
self.texture_variable_types = {}
self.current_texture_parameters = {}
self.image_variable_formats = {}
self.current_image_format_parameters = {}
self.resource_array_size_hints = {}
self.function_resource_array_size_hints = {}
self.literal_int_constants = {}
self.current_function_name = None
self.current_function_return_type = None
self.current_expression_expected_type = None
self.local_variable_types = {}
self.struct_member_types = {}
self.type_mapping = {
# Scalar Types
"void": "void",
"short": "int",
"signed short": "int",
"unsigned short": "uint",
"int": "int",
"signed int": "int",
"unsigned int": "uint",
"long": "int64_t",
"signed long": "int64_t",
"unsigned long": "uint64_t",
"float": "float",
"half": "half",
"bool": "bool",
# Vector Types
"vec2": "float2",
"vec3": "float3",
"vec4": "float4",
"ivec2": "int2",
"ivec3": "int3",
"ivec4": "int4",
"short2": "int2",
"short3": "int3",
"short4": "int4",
"ushort2": "uint2",
"ushort3": "uint3",
"ushort4": "uint4",
"int2": "int2",
"int3": "int3",
"int4": "int4",
"uint2": "uint2",
"uint3": "uint3",
"uint4": "uint4",
"uvec2": "uint2",
"uvec3": "uint3",
"uvec4": "uint4",
"float2": "float2",
"float3": "float3",
"float4": "float4",
"half2": "half2",
"half3": "half3",
"half4": "half4",
"bvec2": "bool2",
"bvec3": "bool3",
"bvec4": "bool4",
"bool2": "bool2",
"bool3": "bool3",
"bool4": "bool4",
"sampler1D": "texture1d<float>",
"sampler2D": "texture2d<float>",
"sampler3D": "texture3d<float>",
"samplerCube": "texturecube<float>",
"sampler2DArray": "texture2d_array<float>",
"samplerCubeArray": "texturecube_array<float>",
"sampler2DMS": "texture2d_ms<float>",
"sampler2DMSArray": "texture2d_ms_array<float>",
"sampler2DShadow": "depth2d<float>",
"sampler2DArrayShadow": "depth2d_array<float>",
"samplerCubeShadow": "depthcube<float>",
"samplerCubeArrayShadow": "depthcube_array<float>",
"iimage2D": "texture2d<int, access::read_write>",
"iimage3D": "texture3d<int, access::read_write>",
"iimage2DArray": "texture2d_array<int, access::read_write>",
"uimage2D": "texture2d<uint, access::read_write>",
"uimage3D": "texture3d<uint, access::read_write>",
"uimage2DArray": "texture2d_array<uint, access::read_write>",
"image2D": "texture2d<float, access::read_write>",
"image3D": "texture3d<float, access::read_write>",
"imageCube": "texture2d_array<float, access::read_write>",
"image2DArray": "texture2d_array<float, access::read_write>",
# Matrix Types
"mat2": "float2x2",
"mat3": "float3x3",
"mat4": "float4x4",
"half2x2": "half2x2",
"half3x3": "half3x3",
"half4x4": "half4x4",
}
self.semantic_map = {
# Vertex inputs
"gl_VertexID": "vertex_id",
"gl_InstanceID": "instance_id",
"gl_IsFrontFace": "is_front_facing",
"gl_PrimitiveID": "primitive_id",
"POSITION": "attribute(0)",
"NORMAL": "attribute(1)",
"TANGENT": "attribute(2)",
"BINORMAL": "attribute(3)",
"TEXCOORD": "attribute(4)",
"TEXCOORD0": "attribute(5)",
"TEXCOORD1": "attribute(6)",
"TEXCOORD2": "attribute(7)",
"TEXCOORD3": "attribute(8)",
"TEXCOORD4": "attribute(9)",
"TEXCOORD5": "attribute(10)",
"TEXCOORD6": "attribute(11)",
"TEXCOORD7": "attribute(12)",
# Vertex outputs
"gl_Position": "position",
"gl_PointSize": "point_size",
"gl_ClipDistance": "clip_distance",
# Fragment inputs
"gl_FragColor": "[[color(0)]]",
"gl_FragColor0": "[[color(0)]]",
"gl_FragColor1": "[[color(1)]]",
"gl_FragColor2": "[[color(2)]]",
"gl_FragColor3": "[[color(3)]]",
"gl_FragColor4": "[[color(4)]]",
"gl_FragColor5": "[[color(5)]]",
"gl_FragColor6": "[[color(6)]]",
"gl_FragColor7": "[[color(7)]]",
"gl_FragDepth": "depth(any)",
# Additional Metal-specific attributes
"gl_FragCoord": "position",
"gl_FrontFacing": "is_front_facing",
"gl_PointCoord": "point_coord",
# Compute shader specific
"gl_GlobalInvocationID": "thread_position_in_grid",
"gl_LocalInvocationID": "thread_position_in_threadgroup",
"gl_WorkGroupID": "threadgroup_position_in_grid",
"gl_LocalInvocationIndex": "thread_index_in_threadgroup",
"gl_WorkGroupSize": "threads_per_threadgroup",
"gl_NumWorkGroups": "threadgroups_per_grid",
# Ray tracing / payload semantics
"payload": "payload",
"hit_attribute": "hit_attribute",
"callable_data": "callable_data",
"shader_record": "shader_record",
}
[docs]
def generate(self, ast):
"""Generate complete Metal Shading Language source for a CrossGL AST."""
return self.generate_program(ast)
def generate_stage(self, ast, shader_type):
"""Generate Metal source for a single requested shader stage."""
return self.generate_program(ast, target_stage=shader_type)
def generate_program(self, ast, target_stage=None):
"""Render an AST to Metal, optionally filtering stage entry points."""
target_stage = normalize_stage_name(target_stage)
self.texture_variables = []
self.sampler_variables = []
self.cbuffer_variables = getattr(ast, "cbuffers", []) or []
self.cbuffers_by_name = {
cbuffer.name: cbuffer
for cbuffer in self.cbuffer_variables
if getattr(cbuffer, "name", None)
}
all_functions = self.all_functions(ast)
self.user_function_names = {
func.name for func in all_functions if getattr(func, "name", None)
}
self.function_cbuffer_dependencies = self.collect_function_cbuffer_dependencies(
all_functions
)
self.cbuffer_parameter_names = self.collect_cbuffer_parameter_names(
self.cbuffer_variables
)
self.cbuffer_member_references = self.collect_cbuffer_member_references(
self.cbuffer_variables
)
if not self.cbuffer_variables:
self.ambiguous_cbuffer_members = set()
self.current_sampler_parameters = set()
self.texture_variable_types = {}
self.current_texture_parameters = {}
self.image_variable_formats = {}
self.current_image_format_parameters = {}
self.function_global_resource_dependencies = {}
self.required_image_atomic_compare_helpers = set()
self.literal_int_constants = collect_literal_int_constants(
getattr(ast, "constants", [])
)
(
self.resource_array_size_hints,
self.function_resource_array_size_hints,
) = self.collect_resource_array_size_hints(ast)
self.validate_global_resource_shadows(ast)
self.current_function_name = None
self.current_function_return_type = None
self.current_expression_expected_type = None
self.local_variable_types = {}
self.struct_member_types = collect_struct_member_types(
getattr(ast, "structs", []), self.type_name_string
)
code = "\n"
preprocessors = getattr(ast, "preprocessors", []) or []
pre_lines = []
for directive in preprocessors:
if isinstance(directive, PreprocessorNode):
line = f"#{directive.directive} {directive.content}".strip()
else:
line = str(directive).strip()
if line:
pre_lines.append(line)
if pre_lines:
code += "\n".join(pre_lines) + "\n"
if not any("metal_stdlib" in line for line in pre_lines):
code += "#include <metal_stdlib>\n"
code += "using namespace metal;\n"
code += "\n"
code += self.generate_constants(ast)
structs = getattr(ast, "structs", [])
for node in structs:
if isinstance(node, StructNode):
code += f"struct {node.name} {{\n"
members = getattr(node, "members", [])
for member in members:
if isinstance(member, ArrayNode):
# Handle array types in structs
element_type = getattr(
member, "element_type", getattr(member, "vtype", "float")
)
if member.size:
code += f" {self.map_type(element_type)} {member.name}[{member.size}];\n"
else:
# Dynamic arrays in Metal use array<type>
code += f" array<{self.map_type(element_type)}> {member.name};\n"
else:
semantic = None
if hasattr(member, "semantic"):
semantic = member.semantic
elif hasattr(member, "attributes"):
for attr in member.attributes:
if hasattr(attr, "name"):
semantic = attr.name
break
if hasattr(member, "member_type"):
if str(type(member.member_type)).find("ArrayType") != -1:
# Handle array types with C-style syntax for struct members
element_type = self.convert_type_node_to_string(
member.member_type.element_type
)
element_type = self.map_type(element_type)
if member.member_type.size is not None:
size_str = self.expression_to_string(
member.member_type.size
)
# For Metal, use C-style array syntax: type name[size]
semantic_attr = (
self.map_semantic(semantic) if semantic else ""
)
code += f" {element_type} {member.name}[{size_str}]{semantic_attr};\n"
else:
# Dynamic arrays - use array<type> syntax
semantic_attr = (
self.map_semantic(semantic) if semantic else ""
)
code += f" array<{element_type}> {member.name}{semantic_attr};\n"
continue # Skip the normal member_type handling
else:
member_type_str = self.convert_type_node_to_string(
member.member_type
)
member_type = self.map_type(member_type_str)
elif hasattr(member, "vtype"):
member_type = self.map_type(member.vtype)
else:
member_type = "float"
semantic_attr = self.map_semantic(semantic) if semantic else ""
code += f" {member_type} {member.name}{semantic_attr};\n"
code += "};\n"
global_vars = getattr(ast, "global_variables", [])
texture_register = 0
sampler_register = 0
for i, node in enumerate(global_vars):
# Handle both old and new AST variable structures
resource_count = 1
array_size = None
if hasattr(node, "var_type"):
if hasattr(node.var_type, "name") or hasattr(
node.var_type, "element_type"
):
# Check if it's an ArrayType and handle specially for global variables
if (
hasattr(node.var_type, "element_type")
and str(type(node.var_type)).find("ArrayType") != -1
): # ArrayType
base_type = self.convert_type_node_to_string(
node.var_type.element_type
)
array_size = (
self.expression_to_string(node.var_type.size)
if node.var_type.size
else self.resource_array_size_hints.get(node.name, "")
)
vtype = base_type
array_suffix = f"[{array_size}]" if array_size else "[]"
resource_count = self.resource_array_count(
node.var_type.size if node.var_type.size else array_size
)
else:
# Use the proper type conversion for TypeNode objects
vtype = self.convert_type_node_to_string(node.var_type)
array_suffix = ""
else:
vtype = str(node.var_type)
array_suffix = ""
elif hasattr(node, "vtype"):
vtype = node.vtype
array_suffix = ""
else:
vtype = "float"
array_suffix = ""
if vtype in [
"sampler1D",
"sampler2D",
"sampler3D",
"samplerCube",
"sampler2DArray",
"samplerCubeArray",
"sampler2DMS",
"sampler2DMSArray",
"sampler2DShadow",
"sampler2DArrayShadow",
"samplerCubeShadow",
"samplerCubeArrayShadow",
"iimage2D",
"iimage3D",
"iimage2DArray",
"uimage2D",
"uimage3D",
"uimage2DArray",
"image2D",
"image3D",
"imageCube",
"image2DArray",
]:
mapped_type = self.map_resource_type_with_format(vtype, node)
self.texture_variables.append(
(node, texture_register, mapped_type, array_size)
)
self.texture_variable_types[node.name] = mapped_type
explicit_format = self.explicit_image_format(node)
if explicit_format:
self.image_variable_formats[node.name] = explicit_format
texture_register += resource_count
elif vtype in ["sampler"]:
self.sampler_variables.append((node, sampler_register, array_size))
sampler_register += resource_count
else:
code += f"{self.map_type(vtype)} {node.name}{array_suffix};\n"
self.function_global_resource_dependencies = (
self.collect_function_global_resource_dependencies(all_functions)
)
cbuffers = getattr(ast, "cbuffers", [])
if cbuffers:
code += "// Constant Buffers\n"
code += self.generate_cbuffers(ast)
functions = getattr(ast, "functions", [])
functions_code = ""
for func in functions:
# Handle both old and new AST function structures
if hasattr(func, "qualifiers") and func.qualifiers:
qualifier = func.qualifiers[0] if func.qualifiers else None
else:
qualifier = getattr(func, "qualifier", None)
qualifier_name = normalize_stage_name(qualifier)
if not should_emit_qualified_function(target_stage, qualifier_name):
continue
if qualifier_name == "vertex":
functions_code += "// Vertex Shader\n"
functions_code += self.generate_function(func, shader_type="vertex")
elif qualifier_name == "fragment":
functions_code += "// Fragment Shader\n"
functions_code += self.generate_function(func, shader_type="fragment")
elif qualifier_name == "compute":
functions_code += "// Compute Shader\n"
functions_code += self.generate_function(func, shader_type="compute")
else:
functions_code += self.generate_function(func)
# Handle shader stages (new AST structure)
if hasattr(ast, "stages") and ast.stages:
for stage_type, stage in ast.stages.items():
if hasattr(stage, "entry_point"):
stage_name = normalize_stage_name(stage_type)
if not stage_matches(target_stage, stage_name):
continue
functions_code += f"// {stage_name.title()} Shader\n"
functions_code += self.generate_function(
stage.entry_point, shader_type=stage_name
)
if hasattr(stage, "local_functions"):
stage_name = normalize_stage_name(stage_type)
if not stage_matches(target_stage, stage_name):
continue
for func in stage.local_functions:
functions_code += self.generate_function(func)
code += self.generate_image_atomic_compare_helpers()
code += functions_code
return code
def generate_constants(self, ast):
code = ""
for node in getattr(ast, "constants", []) or []:
name = getattr(node, "name", None)
if not name:
continue
const_type = getattr(node, "const_type", getattr(node, "vtype", "float"))
value = getattr(node, "value", None)
value_code = self.generate_constant_expression(value)
code += f"constant {self.map_type(const_type)} {name} = {value_code};\n"
return f"{code}\n" if code else ""
def generate_constant_expression(self, expr):
value_code = self.generate_expression(expr)
if value_code == "True":
return "true"
if value_code == "False":
return "false"
return value_code
def generate_cbuffers(self, ast):
code = ""
cbuffers = getattr(ast, "cbuffers", [])
duplicate_names = collect_duplicate_cbuffer_names(cbuffers)
if duplicate_names:
names = ", ".join(sorted(duplicate_names))
raise ValueError(f"Duplicate cbuffer name(s) in Metal output: {names}")
declaration_conflicts = collect_cbuffer_declaration_name_conflicts(ast)
if declaration_conflicts:
names = ", ".join(sorted(declaration_conflicts))
raise ValueError(
"Cbuffer name(s) conflict with existing Metal declaration(s): "
f"{names}"
)
global_member_conflicts = collect_cbuffer_member_global_conflicts(ast)
if global_member_conflicts:
names = ", ".join(sorted(global_member_conflicts))
raise ValueError(
"Cbuffer member name(s) conflict with Metal global declaration(s): "
f"{names}"
)
for node in cbuffers:
if isinstance(node, StructNode):
code += f"struct {node.name} {{\n"
members = getattr(node, "members", [])
for member in members:
if isinstance(member, ArrayNode):
element_type = getattr(
member, "element_type", getattr(member, "vtype", "float")
)
if member.size:
code += f" {self.map_type(element_type)} {member.name}[{member.size}];\n"
else:
# Dynamic arrays in buffer blocks
code += f" array<{self.map_type(element_type)}> {member.name};\n"
else:
# Handle both old and new AST member structures
if hasattr(member, "member_type"):
member_type = self.map_type(member.member_type)
else:
member_type = self.map_type(
getattr(member, "vtype", "float")
)
declaration = format_c_style_array_declaration(
member_type, member.name
)
code += f" {declaration};\n"
code += "};\n"
elif hasattr(node, "name") and hasattr(
node, "members"
): # CbufferNode handling
code += f"struct {node.name} {{\n"
for member in node.members:
if isinstance(member, ArrayNode):
element_type = getattr(
member, "element_type", getattr(member, "vtype", "float")
)
if member.size:
code += f" {self.map_type(element_type)} {member.name}[{member.size}];\n"
else:
# Dynamic arrays in buffer blocks
code += f" array<{self.map_type(element_type)}> {member.name};\n"
else:
# Handle both old and new AST member structures
if hasattr(member, "member_type"):
member_type = self.map_type(member.member_type)
else:
member_type = self.map_type(
getattr(member, "vtype", "float")
)
declaration = format_c_style_array_declaration(
member_type, member.name
)
code += f" {declaration};\n"
code += "};\n"
return code
def generate_function(self, func, indent=0, shader_type=None):
"""Render a function or stage entry point with Metal attributes."""
code = ""
code += " " * indent
param_list = getattr(func, "parameters", getattr(func, "params", []))
params = []
reserved_parameter_names = {
getattr(parameter, "name", None)
for parameter in param_list
if getattr(parameter, "name", None)
}
sampler_parameters = set()
texture_parameters = {}
image_format_parameters = {}
previous_function_name = self.current_function_name
previous_function_return_type = self.current_function_return_type
previous_local_variable_types = self.local_variable_types
previous_cbuffer_parameter_names = self.cbuffer_parameter_names
previous_cbuffer_member_references = self.cbuffer_member_references
previous_ambiguous_cbuffer_members = self.ambiguous_cbuffer_members
self.current_function_name = getattr(func, "name", None)
self.local_variable_types = {}
for p in param_list:
if hasattr(p, "param_type"):
if hasattr(p.param_type, "name"):
raw_param_type = p.param_type.name
else:
raw_param_type = p.param_type
elif hasattr(p, "vtype"):
raw_param_type = p.vtype
else:
raw_param_type = "float"
self.local_variable_types[p.name] = self.type_name_string(raw_param_type)
if self.is_sampler_type(raw_param_type):
sampler_parameters.add(p.name)
elif self.is_resource_parameter_type(raw_param_type):
texture_parameters[p.name] = self.map_resource_type_with_format(
self.resource_base_type(raw_param_type), p
)
explicit_format = self.explicit_image_format(p)
if explicit_format:
image_format_parameters[p.name] = explicit_format
param_type = self.map_resource_type_with_format(raw_param_type, p)
semantic = self.semantic_from_node(p)
param_attr = self.parameter_attribute(raw_param_type, semantic, shader_type)
declaration = self.format_parameter_declaration(
raw_param_type, param_type, p.name, p
)
params.append(f"{declaration}{param_attr}")
if shader_type == "compute":
existing_param_names = {getattr(p, "name", None) for p in param_list}
for name, param_type, attribute in self.required_compute_builtin_parameters(
func
):
if name not in existing_param_names:
params.append(f"{param_type} {name} [[{attribute}]]")
reserved_parameter_names.add(name)
reserved_parameter_names.update(self.global_resource_parameter_names())
self.cbuffer_parameter_names = self.collect_cbuffer_parameter_names(
self.cbuffer_variables, reserved_names=reserved_parameter_names
)
self.cbuffer_member_references = self.collect_cbuffer_member_references(
self.cbuffer_variables
)
params_str = ", ".join(params)
if shader_type is None:
params_str = self.append_required_cbuffer_parameters(
params_str, self.current_function_name
)
params_str = self.append_required_global_resource_parameters(
params_str, self.current_function_name
)
if hasattr(func, "return_type"):
raw_return_type = self.type_name_string(func.return_type)
return_type = self.map_type(raw_return_type)
else:
raw_return_type = "void"
return_type = "void"
self.current_function_return_type = raw_return_type
if shader_type == "vertex":
params_str = self.append_global_resource_parameters(params_str)
code += f"vertex {return_type} vertex_{func.name}({params_str}) {{\n"
elif shader_type == "fragment":
params_str = self.append_global_resource_parameters(params_str)
code += f"fragment {return_type} fragment_{func.name}({params_str}) {{\n"
elif shader_type in ["compute", "ray_generation"]:
params_str = self.append_global_resource_parameters(params_str)
code += f"kernel {return_type} kernel_{func.name}({params_str}) {{\n"
elif shader_type in ["mesh", "object", "task", "amplification"]:
stage_keyword = "mesh" if shader_type == "mesh" else "object"
code += f"{stage_keyword} {return_type} {stage_keyword}_{func.name}({params_str}) {{\n"
elif shader_type in [
"ray_intersection",
"ray_any_hit",
"ray_closest_hit",
"ray_miss",
"ray_callable",
"intersection",
"anyhit",
"closesthit",
"miss",
"callable",
]:
rt_stage_map = {
"ray_intersection": "intersection",
"ray_any_hit": "anyhit",
"ray_closest_hit": "closesthit",
"ray_miss": "miss",
"ray_callable": "callable",
"intersection": "intersection",
"anyhit": "anyhit",
"closesthit": "closesthit",
"miss": "miss",
"callable": "callable",
}
stage_keyword = rt_stage_map.get(shader_type, shader_type)
code += f"{stage_keyword} {return_type} {stage_keyword}_{func.name}({params_str}) {{\n"
else:
# Handle semantic - get from attributes in new AST
semantic = None
if hasattr(func, "semantic"):
semantic = func.semantic
elif hasattr(func, "attributes"):
for attr in func.attributes:
if hasattr(attr, "name"):
semantic = attr.name
break
code += f"{return_type} {func.name}({params_str}) {self.map_semantic(semantic)} {{\n"
previous_sampler_parameters = self.current_sampler_parameters
previous_texture_parameters = self.current_texture_parameters
previous_image_format_parameters = self.current_image_format_parameters
self.current_sampler_parameters = sampler_parameters
self.current_texture_parameters = texture_parameters
self.current_image_format_parameters = image_format_parameters
body = getattr(func, "body", [])
if hasattr(body, "statements"):
for stmt in body.statements:
code += self.generate_statement(stmt, 1)
elif isinstance(body, list):
for stmt in body:
code += self.generate_statement(stmt, 1)
self.current_sampler_parameters = previous_sampler_parameters
self.current_texture_parameters = previous_texture_parameters
self.current_image_format_parameters = previous_image_format_parameters
self.current_function_name = previous_function_name
self.current_function_return_type = previous_function_return_type
self.local_variable_types = previous_local_variable_types
self.cbuffer_parameter_names = previous_cbuffer_parameter_names
self.cbuffer_member_references = previous_cbuffer_member_references
self.ambiguous_cbuffer_members = previous_ambiguous_cbuffer_members
code += "}\n\n"
return code
def required_compute_builtin_parameters(self, func):
used_names = self.used_compute_builtin_names(getattr(func, "body", []))
builtin_parameters = [
("gl_GlobalInvocationID", "uint3", "thread_position_in_grid"),
("gl_LocalInvocationID", "uint3", "thread_position_in_threadgroup"),
("gl_WorkGroupID", "uint3", "threadgroup_position_in_grid"),
("gl_LocalInvocationIndex", "uint", "thread_index_in_threadgroup"),
("gl_WorkGroupSize", "uint3", "threads_per_threadgroup"),
("gl_NumWorkGroups", "uint3", "threadgroups_per_grid"),
]
return [
parameter for parameter in builtin_parameters if parameter[0] in used_names
]
def used_compute_builtin_names(self, body):
builtin_names = {
"gl_GlobalInvocationID",
"gl_LocalInvocationID",
"gl_WorkGroupID",
"gl_LocalInvocationIndex",
"gl_WorkGroupSize",
"gl_NumWorkGroups",
}
used_names = set()
for node in self.iter_ast_nodes(body):
if hasattr(node, "__class__") and "Identifier" in str(node.__class__):
name = getattr(node, "name", "")
base_name = name.split(".", 1)[0]
if base_name in builtin_names:
used_names.add(base_name)
return used_names
def append_global_resource_parameters(self, params_str):
resource_params = []
if self.cbuffer_variables:
for i, cbuffer in enumerate(self.cbuffer_variables):
parameter_name = self.cbuffer_parameter_name(cbuffer)
resource_params.append(
f"constant {cbuffer.name}& {parameter_name} [[buffer({i})]]"
)
if self.texture_variables:
for (
texture_variable,
i,
texture_type,
array_size,
) in self.texture_variables:
declaration = self.format_resource_parameter(
texture_type, texture_variable.name, array_size
)
resource_params.append(f"{declaration} [[texture({i})]]")
if self.sampler_variables:
for sampler_variable, i, array_size in self.sampler_variables:
declaration = self.format_resource_parameter(
"sampler", sampler_variable.name, array_size
)
resource_params.append(f"{declaration} [[sampler({i})]]")
if not resource_params:
return params_str
if params_str:
return f"{params_str}, {', '.join(resource_params)}"
return ", ".join(resource_params)
def global_resource_parameter_names(self):
names = set()
for texture_variable, _, _, _ in self.texture_variables:
if getattr(texture_variable, "name", None):
names.add(texture_variable.name)
for sampler_variable, _, _ in self.sampler_variables:
if getattr(sampler_variable, "name", None):
names.add(sampler_variable.name)
return names
def append_required_cbuffer_parameters(self, params_str, func_name):
cbuffer_params = []
for cbuffer in self.required_function_cbuffers(func_name):
parameter_name = self.cbuffer_parameter_name(cbuffer)
cbuffer_params.append(f"constant {cbuffer.name}& {parameter_name}")
if not cbuffer_params:
return params_str
if params_str:
return f"{params_str}, {', '.join(cbuffer_params)}"
return ", ".join(cbuffer_params)
def append_required_global_resource_parameters(self, params_str, func_name):
resource_params = []
for (
texture_variable,
texture_type,
array_size,
) in self.required_function_textures(func_name):
texture_name = getattr(texture_variable, "name", None)
if texture_name:
resource_params.append(
self.format_resource_parameter(
texture_type, texture_name, array_size
)
)
for sampler_variable, array_size in self.required_function_samplers(func_name):
sampler_name = getattr(sampler_variable, "name", None)
if sampler_name:
resource_params.append(
self.format_resource_parameter("sampler", sampler_name, array_size)
)
if not resource_params:
return params_str
if params_str:
return f"{params_str}, {', '.join(resource_params)}"
return ", ".join(resource_params)
def cbuffer_parameter_name(self, cbuffer):
parameter_name = self.cbuffer_parameter_names.get(id(cbuffer))
if parameter_name:
return parameter_name
return self.default_cbuffer_parameter_name(cbuffer)
def default_cbuffer_parameter_name(self, cbuffer):
name = getattr(cbuffer, "name", "constants")
if not name:
return "constants"
return name[:1].lower() + name[1:]
def collect_cbuffer_parameter_names(self, cbuffers, reserved_names=None):
parameter_names = {}
used_names = set(reserved_names or [])
for cbuffer in cbuffers:
base_name = self.default_cbuffer_parameter_name(cbuffer)
parameter_name = base_name
suffix = 1
while parameter_name in used_names:
parameter_name = f"{base_name}{suffix}"
suffix += 1
used_names.add(parameter_name)
parameter_names[id(cbuffer)] = parameter_name
return parameter_names
def collect_cbuffer_member_references(self, cbuffers):
references = {}
ambiguous_members = collect_duplicate_cbuffer_member_names(cbuffers)
for cbuffer in cbuffers:
parameter_name = self.cbuffer_parameter_name(cbuffer)
for member in getattr(cbuffer, "members", []) or []:
member_name = getattr(member, "name", None)
if not member_name or member_name in ambiguous_members:
continue
references[member_name] = f"{parameter_name}.{member_name}"
self.ambiguous_cbuffer_members = ambiguous_members
return references
def generate_statement(self, stmt, indent=0):
"""Render a single CrossGL AST statement as Metal source."""
indent_str = " " * indent
if isinstance(stmt, VariableNode):
if hasattr(stmt, "var_type"):
var_type = self.convert_type_node_to_string(stmt.var_type)
elif hasattr(stmt, "vtype"):
var_type = stmt.vtype
else:
var_type = "float"
self.local_variable_types[stmt.name] = var_type
declaration = format_c_style_array_declaration(
self.map_type(var_type), stmt.name
)
declaration = f"{self.local_variable_qualifier(stmt)}{declaration}"
if hasattr(stmt, "initial_value") and stmt.initial_value is not None:
init_expr = self.generate_expression_with_expected(
stmt.initial_value, var_type
)
return f"{indent_str}{declaration} = {init_expr};\n"
else:
return f"{indent_str}{declaration};\n"
elif isinstance(stmt, ArrayNode):
# Improved array node handling
element_type = self.map_type(stmt.element_type)
size = get_array_size_from_node(stmt)
if size is None:
# Dynamic arrays in Metal need a size, use a large enough buffer
return f"{indent_str}device array<{element_type}, 1024> {stmt.name};\n"
else:
return f"{indent_str}array<{element_type}, {size}> {stmt.name};\n"
elif isinstance(stmt, AssignmentNode):
return f"{indent_str}{self.generate_assignment(stmt)};\n"
elif isinstance(stmt, BreakNode):
return f"{indent_str}break;\n"
elif isinstance(stmt, ContinueNode):
return f"{indent_str}continue;\n"
elif isinstance(stmt, IfNode):
return self.generate_if(stmt, indent)
elif isinstance(stmt, ForNode):
return self.generate_for(stmt, indent)
elif isinstance(stmt, ForInNode):
return self.generate_for_in(stmt, indent)
elif isinstance(stmt, WhileNode):
return self.generate_while(stmt, indent)
elif isinstance(stmt, LoopNode):
return self.generate_loop(stmt, indent)
elif isinstance(stmt, SwitchNode):
return self.generate_switch(stmt, indent)
elif isinstance(stmt, MatchNode):
return self.generate_match(stmt, indent)
elif isinstance(stmt, ReturnNode):
if getattr(stmt, "value", None) is None:
return f"{indent_str}return;\n"
if isinstance(stmt.value, list):
# Multiple return values
code = ""
for i, return_stmt in enumerate(stmt.value):
code += f"{self.generate_expression(return_stmt)}"
if i < len(stmt.value) - 1:
code += ", "
return f"{indent_str}return {code};\n"
else:
# Single return value
return (
f"{indent_str}return "
f"{self.generate_expression_with_expected(stmt.value, self.current_function_return_type)};\n"
)
elif hasattr(stmt, "__class__") and "ExpressionStatementNode" in str(
type(stmt)
):
# Handle ExpressionStatementNode
expr_code = self.generate_expression_statement(stmt)
return f"{indent_str}{expr_code};\n"
else:
return f"{indent_str}{self.generate_expression(stmt)};\n"
def local_variable_qualifier(self, node):
return "const " if "const" in getattr(node, "qualifiers", []) else ""
def type_name_string(self, vtype):
if vtype is None:
return None
if hasattr(vtype, "name") or hasattr(vtype, "element_type"):
return self.convert_type_node_to_string(vtype)
return str(vtype)
def generate_expression_with_expected(self, expr, expected_type):
previous_expected_type = self.current_expression_expected_type
self.current_expression_expected_type = self.type_name_string(expected_type)
try:
return self.generate_expression(expr)
finally:
self.current_expression_expected_type = previous_expected_type
def is_scalar_value_type(self, vtype):
vtype = self.type_name_string(vtype)
if not vtype:
return False
return self.map_type(vtype) in {
"float",
"half",
"double",
"int",
"uint",
"bool",
}
def is_vector_value_type(self, vtype):
vtype = self.type_name_string(vtype)
if not vtype:
return False
return self.map_type(vtype) in {
"float2",
"float3",
"float4",
"half2",
"half3",
"half4",
"double2",
"double3",
"double4",
"int2",
"int3",
"int4",
"uint2",
"uint3",
"uint4",
"bool2",
"bool3",
"bool4",
}
def vector_component_type(self, vtype):
mapped_type = self.map_type(vtype)
if mapped_type.startswith("float"):
return "float"
if mapped_type.startswith("half"):
return "half"
if mapped_type.startswith("double"):
return "double"
if mapped_type.startswith("uint"):
return "uint"
if mapped_type.startswith("int"):
return "int"
if mapped_type.startswith("bool"):
return "bool"
return None
def expression_result_type(self, expr):
if expr is None:
return None
if isinstance(expr, VariableNode):
return self.local_variable_types.get(getattr(expr, "name", None))
if isinstance(expr, (int, float)):
return "float" if isinstance(expr, float) else "int"
if isinstance(expr, BinaryOpNode):
left_type = self.expression_result_type(expr.left)
right_type = self.expression_result_type(expr.right)
if self.is_vector_value_type(left_type):
return left_type
if self.is_vector_value_type(right_type):
return right_type
if left_type == "float" or right_type == "float":
return "float"
return left_type or right_type
if isinstance(expr, UnaryOpNode):
return self.expression_result_type(expr.operand)
if isinstance(expr, AssignmentNode):
return self.expression_result_type(
getattr(expr, "target", getattr(expr, "left", None))
)
if isinstance(expr, ArrayAccessNode):
array_type = self.expression_result_type(expr.array)
if array_type and "[" in array_type and "]" in array_type:
base_type, _ = split_array_type_suffix(array_type)
return base_type
return array_type
if isinstance(expr, MemberAccessNode):
object_type = self.expression_result_type(expr.object)
member = str(expr.member)
if object_type and all(ch in "xyzwrgba" for ch in member):
component_type = self.vector_component_type(object_type)
if component_type and len(member) == 1:
return component_type
if component_type:
return f"{component_type}{len(member)}"
if object_type:
member_type = self.struct_member_types.get(
self.type_name_string(object_type), {}
).get(member)
if member_type:
return member_type
member_types = {
self.type_name_string(members[member])
for members in self.struct_member_types.values()
if member in members
}
if len(member_types) == 1:
return next(iter(member_types))
return None
if isinstance(expr, FunctionCallNode):
func_expr = getattr(expr, "function", None) or getattr(expr, "name", None)
func_name = getattr(func_expr, "name", func_expr)
if func_name in {
"float",
"half",
"double",
"int",
"uint",
"bool",
"vec2",
"vec3",
"vec4",
"ivec2",
"ivec3",
"ivec4",
"uvec2",
"uvec3",
"uvec4",
"bvec2",
"bvec3",
"bvec4",
"float2",
"float3",
"float4",
"int2",
"int3",
"int4",
"uint2",
"uint3",
"uint4",
"bool2",
"bool3",
"bool4",
}:
return str(func_name)
if hasattr(expr, "__class__") and "Literal" in str(expr.__class__):
value = getattr(expr, "value", None)
if isinstance(value, float):
return "float"
if isinstance(value, int):
return "int"
if isinstance(value, str):
return "float" if "." in value else "int"
if hasattr(expr, "__class__") and "Identifier" in str(expr.__class__):
return self.local_variable_types.get(getattr(expr, "name", None))
return None
def generate_expression_statement(self, stmt):
"""Generate code for expression statements."""
if hasattr(stmt, "expression"):
expr = self.generate_expression(stmt.expression)
return expr
else:
# Fallback for direct expression
return self.generate_expression(stmt)
def generate_assignment(self, node):
# Handle both old and new AST assignment structures
if hasattr(node, "target") and hasattr(node, "value"):
# New AST structure
lhs = self.generate_expression(node.target)
rhs = self.generate_expression_with_expected(
node.value, self.expression_result_type(node.target)
)
op = getattr(node, "operator", "=")
else:
# Old AST structure
lhs = self.generate_expression(node.left)
rhs = self.generate_expression_with_expected(
node.right, self.expression_result_type(node.left)
)
op = getattr(node, "operator", "=")
return f"{lhs} {op} {rhs}"
def generate_if(self, node, indent):
indent_str = " " * indent
condition = self.generate_expression(
node.condition if hasattr(node, "condition") else node.if_condition
)
code = f"{indent_str}if ({condition}) {{\n"
if_body = getattr(node, "then_branch", getattr(node, "if_body", None))
if hasattr(if_body, "statements"):
for stmt in if_body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(if_body, list):
for stmt in if_body:
code += self.generate_statement(stmt, indent + 1)
code += f"{indent_str}}}"
# Handle else branch - check if it's another if statement (else-if chain)
else_branch = getattr(node, "else_branch", None)
if else_branch:
# Check if else branch is another IfNode (else-if chain)
if hasattr(else_branch, "__class__") and "If" in str(else_branch.__class__):
# Generate else if by recursively generating the nested if with else if prefix
elif_condition = self.generate_expression(
else_branch.condition
if hasattr(else_branch, "condition")
else else_branch.if_condition
)
code += f" else if ({elif_condition}) {{\n"
# Generate elif body
elif_body = getattr(
else_branch, "then_branch", getattr(else_branch, "if_body", None)
)
if hasattr(elif_body, "statements"):
for stmt in elif_body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(elif_body, list):
for stmt in elif_body:
code += self.generate_statement(stmt, indent + 1)
code += f"{indent_str}}}"
# Recursively handle any remaining else-if chain
nested_else = getattr(else_branch, "else_branch", None)
if nested_else:
if hasattr(nested_else, "__class__") and "If" in str(
nested_else.__class__
):
# Another else if - recursively handle
remaining_code = self.generate_if(nested_else, indent)
# Remove the "if" prefix and replace with "else if"
remaining_lines = remaining_code.split("\n")
if remaining_lines[0].strip().startswith("if ("):
remaining_lines[0] = remaining_lines[0].replace(
"if (", " else if (", 1
)
code += "\n".join(
remaining_lines[1:]
) # Skip first line as we already handled it
else:
# Final else clause
code += " else {\n"
if hasattr(nested_else, "statements"):
for stmt in nested_else.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(nested_else, list):
for stmt in nested_else:
code += self.generate_statement(stmt, indent + 1)
else:
code += self.generate_statement(nested_else, indent + 1)
code += f"{indent_str}}}"
else:
# Regular else clause
code += " else {\n"
if hasattr(else_branch, "statements"):
# New AST BlockNode structure
for stmt in else_branch.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(else_branch, list):
# Old AST structure
for stmt in else_branch:
code += self.generate_statement(stmt, indent + 1)
else:
# Single statement
code += self.generate_statement(else_branch, indent + 1)
code += f"{indent_str}}}"
code += "\n"
return code
def generate_for(self, node, indent):
indent_str = " " * indent
init = self.generate_for_initializer(getattr(node, "init", None))
condition = (
self.generate_expression(node.condition)
if getattr(node, "condition", None)
else ""
)
update = (
self.generate_expression(node.update)
if getattr(node, "update", None)
else ""
)
code = f"{indent_str}for ({init}; {condition}; {update}) {{\n"
if hasattr(node.body, "statements"):
for stmt in node.body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(node.body, list):
for stmt in node.body:
code += self.generate_statement(stmt, indent + 1)
else:
code += self.generate_statement(node.body, indent + 1)
code += f"{indent_str}}}\n"
return code
def generate_for_in(self, node, indent):
indent_str = " " * indent
pattern = getattr(node, "pattern", "item")
iterable_node = getattr(node, "iterable", "")
if isinstance(iterable_node, RangeNode):
start = self.generate_expression(iterable_node.start)
end = self.generate_expression(iterable_node.end)
comparator = "<=" if iterable_node.inclusive else "<"
code = (
f"{indent_str}for (int {pattern} = {start}; "
f"{pattern} {comparator} {end}; ++{pattern}) {{\n"
)
else:
iterable = self.generate_expression(iterable_node)
code = (
f"{indent_str}for (int {pattern} = 0; {pattern} < {iterable}; "
f"++{pattern}) {{\n"
)
code += self.generate_statement_body(getattr(node, "body", []), indent + 1)
code += f"{indent_str}}}\n"
return code
def generate_while(self, node, indent):
indent_str = " " * indent
condition = self.generate_expression(getattr(node, "condition", ""))
code = f"{indent_str}while ({condition}) {{\n"
code += self.generate_statement_body(getattr(node, "body", []), indent + 1)
code += f"{indent_str}}}\n"
return code
def generate_loop(self, node, indent):
indent_str = " " * indent
code = f"{indent_str}while (true) {{\n"
code += self.generate_statement_body(getattr(node, "body", []), indent + 1)
code += f"{indent_str}}}\n"
return code
def generate_switch(self, node, indent):
indent_str = " " * indent
expression = self.generate_expression(getattr(node, "expression", ""))
code = f"{indent_str}switch ({expression}) {{\n"
for case in getattr(node, "cases", []) or []:
value = getattr(case, "value", None)
if value is None:
code += f"{indent_str} default:\n"
else:
code += f"{indent_str} case {self.generate_expression(value)}:\n"
code += self.generate_statement_body(
getattr(case, "statements", []), indent + 2
)
default_case = getattr(node, "default_case", None)
if default_case is not None:
code += f"{indent_str} default:\n"
code += self.generate_statement_body(default_case, indent + 2)
code += f"{indent_str}}}\n"
return code
def generate_match(self, node, indent):
indent_str = " " * indent
expression = self.generate_expression(getattr(node, "expression", ""))
code = f"{indent_str}switch ({expression}) {{\n"
for arm in getattr(node, "arms", []) or []:
pattern = getattr(arm, "pattern", None)
if not self.is_supported_switch_match_arm(arm):
raise ValueError(
"Unsupported match arm for Metal codegen; only unguarded "
"literal and wildcard patterns can be lowered to switch"
)
if isinstance(pattern, WildcardPatternNode):
code += f"{indent_str} default:\n"
else:
code += (
f"{indent_str} case "
f"{self.generate_expression(pattern.literal)}:\n"
)
body = getattr(arm, "body", [])
code += self.generate_statement_body(body, indent + 2)
if not self.statement_body_terminates(body):
code += f"{indent_str} break;\n"
code += f"{indent_str}}}\n"
return code
def is_supported_switch_match_arm(self, arm):
if getattr(arm, "guard", None) is not None:
return False
pattern = getattr(arm, "pattern", None)
return isinstance(pattern, (LiteralPatternNode, WildcardPatternNode))
def statement_body_terminates(self, body):
if hasattr(body, "statements"):
statements = body.statements
elif isinstance(body, list):
statements = body
elif body is None:
statements = []
else:
statements = [body]
return bool(statements) and isinstance(
statements[-1], (BreakNode, ContinueNode, ReturnNode)
)
def generate_statement_body(self, body, indent):
code = ""
if hasattr(body, "statements"):
for stmt in body.statements:
code += self.generate_statement(stmt, indent)
elif isinstance(body, list):
for stmt in body:
code += self.generate_statement(stmt, indent)
elif body is not None:
code += self.generate_statement(body, indent)
return code
def generate_for_initializer(self, init):
if init is None:
return ""
if isinstance(init, str):
return init
if isinstance(init, VariableNode) or (
hasattr(init, "__class__") and "ExpressionStatement" in str(init.__class__)
):
return self.generate_statement(init, 0).strip().rstrip(";")
return self.generate_expression(init).strip().rstrip(";")
def generate_expression(self, expr):
"""Render a CrossGL AST expression into Metal expression syntax."""
if expr is None:
return ""
elif isinstance(expr, str):
return expr
elif isinstance(expr, int) or isinstance(expr, float):
return str(expr)
elif isinstance(expr, VariableNode):
# Fix infinite recursion - directly return the name
if hasattr(expr, "name"):
return expr.name
else:
return str(expr)
elif isinstance(expr, BinaryOpNode):
left = self.generate_expression(expr.left)
right = self.generate_expression(expr.right)
return f"{left} {self.map_operator(expr.op)} {right}"
elif isinstance(expr, AssignmentNode):
left = self.generate_expression(expr.left)
right = self.generate_expression(expr.right)
return f"{left} {self.map_operator(expr.operator)} {right}"
elif isinstance(expr, UnaryOpNode):
operand = self.generate_expression(expr.operand)
return f"{self.map_operator(expr.op)}{operand}"
elif isinstance(expr, WaveOpNode):
args = ", ".join(self.generate_expression(arg) for arg in expr.arguments)
return f"{expr.operation}({args})"
elif isinstance(expr, RayTracingOpNode):
args = ", ".join(self.generate_expression(arg) for arg in expr.arguments)
return f"{expr.operation}({args})"
elif isinstance(expr, MeshOpNode):
args = ", ".join(self.generate_expression(arg) for arg in expr.arguments)
return f"{expr.operation}({args})"
elif isinstance(expr, RayQueryOpNode):
query = self.generate_expression(expr.query_expr)
args = ", ".join(self.generate_expression(arg) for arg in expr.arguments)
return f"{query}.{expr.operation}({args})"
elif isinstance(expr, ArrayAccessNode):
# Handle array access
array = self.generate_expression(expr.array)
index = self.generate_expression(expr.index)
return f"{array}[{index}]"
elif isinstance(expr, FunctionCallNode):
# Resolve callee expression (can be Identifier/Member/Array access)
func_expr = getattr(expr, "function", None)
if func_expr is None:
func_expr = expr.name
func_name = None
if hasattr(func_expr, "name") and isinstance(func_expr.name, str):
func_name = func_expr.name
callee = func_name
elif isinstance(func_expr, str):
func_name = func_expr
callee = func_expr
else:
callee = self.generate_expression(func_expr)
texture_call = self.generate_texture_call(func_name, expr.args)
if texture_call is not None:
return texture_call
# Special handling for common GLSL functions
elif func_name == "normalize":
args = ", ".join(self.generate_expression(arg) for arg in expr.args)
return f"normalize({args})"
elif func_name in ["mix", "clamp", "smoothstep", "step", "dot", "cross"]:
# These function names are the same in GLSL and Metal
args = ", ".join(self.generate_expression(arg) for arg in expr.args)
return f"{func_name}({args})"
# Vector constructors
elif func_name in [
"vec2",
"vec3",
"vec4",
"ivec2",
"ivec3",
"ivec4",
"uvec2",
"uvec3",
"uvec4",
"bvec2",
"bvec3",
"bvec4",
]:
# Map to Metal's float2, float3, float4
metal_type = self.map_type(func_name)
args = ", ".join(self.generate_expression(arg) for arg in expr.args)
return f"{metal_type}({args})"
else:
# Standard function call
args = [self.generate_expression(arg) for arg in expr.args]
if func_name in self.user_function_names:
args.extend(
self.cbuffer_parameter_name(cbuffer)
for cbuffer in self.required_function_cbuffers(func_name)
)
args.extend(
self.required_function_resource_argument_names(func_name)
)
args = ", ".join(args)
return f"{callee}({args})"
elif isinstance(expr, MemberAccessNode):
obj = self.generate_expression(expr.object)
return f"{obj}.{expr.member}"
elif isinstance(expr, TernaryOpNode):
return f"{self.generate_expression(expr.condition)} ? {self.generate_expression(expr.true_expr)} : {self.generate_expression(expr.false_expr)}"
elif hasattr(expr, "__class__") and "Literal" in str(expr.__class__):
# Handle LiteralNode
if hasattr(expr, "value"):
value = expr.value
literal_type = getattr(
getattr(expr, "literal_type", None), "name", None
)
if (
literal_type == "uint"
and isinstance(value, int)
and not isinstance(value, bool)
):
return f"{value}u"
if isinstance(value, str) and not (
value.startswith('"') and value.endswith('"')
):
return f'"{value}"' # Add quotes for string literals
return str(value)
return str(expr)
elif hasattr(expr, "__class__") and "Identifier" in str(expr.__class__):
# Handle IdentifierNode
name = getattr(expr, "name", str(expr))
if (
name not in self.local_variable_types
and name in self.ambiguous_cbuffer_members
):
raise ValueError(
f"Ambiguous cbuffer member reference '{name}' appears in multiple cbuffers"
)
if (
name not in self.local_variable_types
and name in self.cbuffer_member_references
):
return self.cbuffer_member_references[name]
return name
else:
return str(expr)
def default_sampler_expression(self):
return "sampler(mag_filter::linear, min_filter::linear)"
def sampler_variable_names(self):
return {
sampler_variable.name for sampler_variable, _, _ in self.sampler_variables
} | self.current_sampler_parameters
def is_sampler_type(self, vtype):
return self.resource_base_type(vtype) == "sampler"
def is_resource_parameter_type(self, vtype):
return self.resource_base_type(vtype) in {
"sampler",
"sampler1D",
"sampler2D",
"sampler3D",
"samplerCube",
"sampler2DArray",
"samplerCubeArray",
"sampler2DMS",
"sampler2DMSArray",
"sampler2DShadow",
"sampler2DArrayShadow",
"samplerCubeShadow",
"samplerCubeArrayShadow",
"iimage2D",
"iimage3D",
"iimage2DArray",
"uimage2D",
"uimage3D",
"uimage2DArray",
"image2D",
"image3D",
"imageCube",
"image2DArray",
}
def is_texture_or_image_resource_type(self, vtype):
return self.is_resource_parameter_type(vtype) and not self.is_sampler_type(
vtype
)
def is_integer_coordinate_type(self, vtype):
type_name = self.type_name_string(vtype)
base_type = self.resource_base_type(type_name)
mapped_type = self.map_type(base_type)
return base_type in {
"int",
"uint",
"ivec2",
"ivec3",
"ivec4",
"uvec2",
"uvec3",
"uvec4",
} or mapped_type in {
"int",
"uint",
"int2",
"int3",
"int4",
"uint2",
"uint3",
"uint4",
}
def resource_coordinate_dimension(self, texture_type):
texture_type = self.resource_base_type(texture_type)
if not texture_type or "cube" in texture_type:
return None
if texture_type.startswith("texture1d_array<"):
return 2
if texture_type.startswith("texture1d<"):
return 1
if texture_type.startswith("texture2d_ms_array<"):
return 3
if texture_type.startswith("texture2d_ms<"):
return 2
if texture_type.startswith("texture2d_array<"):
return 3
if texture_type.startswith("texture2d<"):
return 2
if texture_type.startswith("depth2d_array<"):
return 3
if texture_type.startswith("depth2d<"):
return 2
if texture_type.startswith("texture3d<"):
return 3
return None
def resource_offset_dimension(self, func_name, texture_type):
texture_type = self.resource_base_type(texture_type)
if not texture_type or "cube" in texture_type:
return None
if func_name == "texelFetchOffset":
if self.is_multisample_texture_resource(texture_type):
return None
if texture_type.startswith("texture1d_array<"):
return 1
if texture_type.startswith("texture1d<"):
return 1
if texture_type.startswith("texture2d_array<"):
return 2
if texture_type.startswith("texture2d<"):
return 2
if texture_type.startswith("texture3d<"):
return 3
return None
if func_name in {"textureGatherOffset", "textureGatherOffsets"}:
return 2 if self.texture_gather_supports_offset(texture_type) else None
if func_name == "textureGatherCompareOffset":
return (
2
if self.texture_gather_compare_offset_supported(texture_type)
else None
)
if func_name in {
"textureCompareOffset",
"textureCompareLodOffset",
"textureCompareGradOffset",
"textureCompareProjOffset",
"textureCompareProjLodOffset",
"textureCompareProjGradOffset",
}:
return 2 if self.texture_compare_offset_supported(texture_type) else None
if (
func_name in OFFSET_DIMENSION_INTRINSIC_NAMES
and not self.texture_sample_supports_offset(texture_type)
):
return None
if texture_type.startswith("texture2d_array<"):
return 2
if texture_type.startswith("texture2d<"):
return 2
return None
def resource_gradient_dimension(self, func_name, texture_type):
texture_type = self.resource_base_type(texture_type)
if not texture_type or "access::" in texture_type:
return None
if "ms" in texture_type:
return None
if texture_type.startswith(("texture2d_array<", "depth2d_array<")):
return 2
if texture_type.startswith(("texture2d<", "depth2d<")):
return 2
if texture_type.startswith("texture3d<"):
return 3
if texture_type.startswith(("texturecube_array<", "depthcube_array<")):
return 3
if texture_type.startswith(("texturecube<", "depthcube<")):
return 3
return None
def resource_query_lod_coordinate_dimension(self, texture_type):
texture_type = self.resource_base_type(texture_type)
if not texture_type or "access::" in texture_type or "_ms" in texture_type:
return None
if texture_type.startswith("texture1d_array<"):
return 2
if texture_type.startswith("texture1d<"):
return 1
if texture_type.startswith(("texture2d_array<", "depth2d_array<")):
return 3
if texture_type.startswith(("texture2d<", "depth2d<")):
return 2
if texture_type.startswith("texture3d<"):
return 3
if texture_type.startswith(("texturecube_array<", "depthcube_array<")):
return 4
if texture_type.startswith(("texturecube<", "depthcube<")):
return 3
return None
def parameter_attribute(self, raw_param_type, semantic, shader_type):
if semantic:
return self.map_semantic(semantic)
if self.is_resource_parameter_type(raw_param_type):
return ""
if shader_type in {"vertex", "fragment"}:
return " [[stage_in]]"
return ""
def format_parameter_declaration(
self, raw_param_type, mapped_type, name, node=None
):
array_type = self.resource_array_parameter(raw_param_type, node)
if array_type is not None:
resource_type, array_size = array_type
return self.format_resource_parameter(resource_type, name, array_size)
return format_c_style_array_declaration(mapped_type, name)
def format_resource_parameter(self, resource_type, name, array_size):
if array_size is not None:
array_size = array_size or "1"
return f"array<{resource_type}, {array_size}> {name}"
return f"{resource_type} {name}"
def resource_array_parameter(self, vtype, node=None):
if hasattr(vtype, "element_type") and str(type(vtype)).find("ArrayType") != -1:
base_type = self.convert_type_node_to_string(vtype.element_type)
if not self.is_resource_parameter_type(base_type):
return None
array_size = (
self.safe_expression_to_string(vtype.size)
if vtype.size is not None
else self.function_resource_array_size_hints.get(
self.current_function_name, {}
).get(node.name, "")
)
return self.map_resource_type_with_format(base_type, node), array_size
if hasattr(vtype, "name") or hasattr(vtype, "element_type"):
return None
type_string = str(vtype)
if "[" not in type_string or "]" not in type_string:
return None
base_type, array_size = parse_array_type(type_string)
if not self.is_resource_parameter_type(base_type):
return None
return self.map_resource_type_with_format(base_type, node), (
self.function_resource_array_size_hints.get(
self.current_function_name, {}
).get(node.name, "")
if array_size is None
else array_size
)
def collect_resource_array_size_hints(self, ast):
return collect_resource_array_size_hints(
global_arrays=self.collect_unsized_resource_globals(ast),
function_arrays=self.collect_unsized_resource_parameters(ast),
fixed_global_array_sizes=self.collect_fixed_resource_global_sizes(ast),
fixed_function_array_sizes=self.collect_fixed_resource_parameter_sizes(ast),
functions=self.all_functions(ast),
walk_nodes=self.iter_ast_nodes,
expression_name=self.expression_name,
literal_int_value=self.literal_int_value,
visible_literal_int_constants=self.visible_literal_int_constants,
function_call_name=self.function_call_name,
initial_size=1,
format_size=str,
)
def collect_unsized_resource_globals(self, ast):
globals_by_name = {}
for node in getattr(ast, "global_variables", []) or []:
name = getattr(node, "name", getattr(node, "variable_name", None))
vtype = getattr(node, "var_type", getattr(node, "vtype", None))
if name and self.is_unsized_resource_array_type(vtype):
globals_by_name[name] = vtype
return globals_by_name
def collect_fixed_resource_global_sizes(self, ast):
global_arrays = {}
for node in getattr(ast, "global_variables", []) or []:
name = getattr(node, "name", getattr(node, "variable_name", None))
vtype = getattr(node, "var_type", getattr(node, "vtype", None))
size = self.fixed_resource_array_size(vtype)
if name and size is not None:
global_arrays[name] = size
return global_arrays
def collect_unsized_resource_parameters(self, ast):
function_arrays = {}
for func in self.all_functions(ast):
func_name = getattr(func, "name", None)
if not func_name:
continue
for param in getattr(func, "parameters", getattr(func, "params", [])):
vtype = getattr(param, "param_type", getattr(param, "vtype", None))
if self.is_unsized_resource_array_type(vtype):
function_arrays.setdefault(func_name, {})[param.name] = vtype
return function_arrays
def collect_fixed_resource_parameter_sizes(self, ast):
function_arrays = {}
for func in self.all_functions(ast):
func_name = getattr(func, "name", None)
if not func_name:
continue
for param in getattr(func, "parameters", getattr(func, "params", [])):
size = self.fixed_resource_array_size(
getattr(param, "param_type", getattr(param, "vtype", None))
)
if size is not None:
function_arrays.setdefault(func_name, {})[param.name] = size
return function_arrays
def fixed_resource_array_size(self, vtype):
if hasattr(vtype, "element_type") and str(type(vtype)).find("ArrayType") != -1:
if vtype.size is None:
return None
base_type = self.convert_type_node_to_string(vtype.element_type)
if not self.is_resource_parameter_type(base_type):
return None
size = self.literal_int_value(vtype.size, self.literal_int_constants)
return size if size is not None and size > 0 else None
if hasattr(vtype, "name") or hasattr(vtype, "element_type"):
return None
type_string = str(vtype)
if "[" not in type_string or "]" not in type_string:
return None
base_type, size = parse_array_type(type_string)
if size is None or not self.is_resource_parameter_type(base_type):
return None
return max(size, 1)
def is_unsized_resource_array_type(self, vtype):
if hasattr(vtype, "element_type") and str(type(vtype)).find("ArrayType") != -1:
if vtype.size is not None:
return False
base_type = self.convert_type_node_to_string(vtype.element_type)
return self.is_resource_parameter_type(base_type)
if hasattr(vtype, "name") or hasattr(vtype, "element_type"):
return False
type_string = str(vtype)
if "[" not in type_string or "]" not in type_string:
return False
base_type, size = parse_array_type(type_string)
return size is None and self.is_resource_parameter_type(base_type)
def all_functions(self, ast):
functions = list(getattr(ast, "functions", []) or [])
for stage in getattr(ast, "stages", {}).values():
entry_point = getattr(stage, "entry_point", None)
if entry_point is not None:
functions.append(entry_point)
functions.extend(getattr(stage, "local_functions", []) or [])
return functions
def collect_global_resource_names(self, root):
resource_names = set()
for node in getattr(root, "global_variables", []) or []:
var_type = getattr(node, "var_type", getattr(node, "vtype", "float"))
var_name = getattr(node, "name", getattr(node, "variable_name", None))
if var_name and self.is_resource_parameter_type(var_type):
resource_names.add(var_name)
return resource_names
def validate_global_resource_shadows(self, ast):
conflicts = collect_non_resource_global_resource_shadows(
ast,
self.collect_global_resource_names(ast),
self.is_resource_parameter_type,
)
if conflicts:
names = ", ".join(sorted(conflicts))
raise ValueError(
"Non-resource local declaration(s) shadow Metal global resource(s): "
f"{names}"
)
def collect_function_cbuffer_dependencies(self, functions):
direct_dependencies = {}
function_calls = {}
for func in functions:
func_name = getattr(func, "name", None)
if not func_name:
continue
direct_dependencies[func_name] = self.direct_cbuffer_dependencies(func)
function_calls[func_name] = self.called_user_function_names(func)
dependencies = {name: set(deps) for name, deps in direct_dependencies.items()}
changed = True
while changed:
changed = False
for func_name, calls in function_calls.items():
before = set(dependencies.get(func_name, set()))
for called_name in calls:
dependencies.setdefault(func_name, set()).update(
dependencies.get(called_name, set())
)
if dependencies.get(func_name, set()) != before:
changed = True
return dependencies
def direct_cbuffer_dependencies(self, func):
local_names = {
getattr(param, "name", None)
for param in getattr(func, "parameters", getattr(func, "params", []))
if getattr(param, "name", None)
}
for node in self.iter_ast_nodes(getattr(func, "body", [])):
if isinstance(node, VariableNode) and getattr(node, "name", None):
local_names.add(node.name)
member_to_cbuffer = {}
for cbuffer in self.cbuffer_variables:
cbuffer_name = getattr(cbuffer, "name", None)
if not cbuffer_name:
continue
for member in getattr(cbuffer, "members", []) or []:
member_name = getattr(member, "name", None)
if member_name:
member_to_cbuffer[member_name] = cbuffer_name
dependencies = set()
for node in self.iter_ast_nodes(getattr(func, "body", [])):
if not (hasattr(node, "__class__") and "Identifier" in str(node.__class__)):
continue
name = getattr(node, "name", None)
if not name or name in local_names:
continue
cbuffer_name = member_to_cbuffer.get(name)
if cbuffer_name:
dependencies.add(cbuffer_name)
return dependencies
def called_user_function_names(self, func):
called_names = set()
for node in self.iter_ast_nodes(getattr(func, "body", [])):
if not isinstance(node, FunctionCallNode):
continue
func_name = self.function_call_name(node)
if func_name in self.user_function_names and func_name != getattr(
func, "name", None
):
called_names.add(func_name)
return called_names
def required_function_cbuffers(self, func_name):
dependencies = self.function_cbuffer_dependencies.get(func_name, set())
return [
cbuffer
for cbuffer in self.cbuffer_variables
if getattr(cbuffer, "name", None) in dependencies
]
def collect_function_global_resource_dependencies(self, functions):
direct_dependencies = {}
function_calls = {}
for func in functions:
func_name = getattr(func, "name", None)
if not func_name:
continue
direct_dependencies[func_name] = self.direct_global_resource_dependencies(
func
)
function_calls[func_name] = self.called_user_function_names(func)
dependencies = {name: set(deps) for name, deps in direct_dependencies.items()}
changed = True
while changed:
changed = False
for func_name, calls in function_calls.items():
before = set(dependencies.get(func_name, set()))
for called_name in calls:
dependencies.setdefault(func_name, set()).update(
dependencies.get(called_name, set())
)
if dependencies.get(func_name, set()) != before:
changed = True
return dependencies
def direct_global_resource_dependencies(self, func):
local_names = {
getattr(param, "name", None)
for param in getattr(func, "parameters", getattr(func, "params", []))
if getattr(param, "name", None)
}
for node in self.iter_ast_nodes(getattr(func, "body", [])):
if isinstance(node, VariableNode) and getattr(node, "name", None):
local_names.add(node.name)
texture_names = self.global_texture_names()
sampler_names = self.global_sampler_names()
dependencies = set()
for node in self.iter_ast_nodes(getattr(func, "body", [])):
if hasattr(node, "__class__") and "Identifier" in str(node.__class__):
name = getattr(node, "name", None)
if (
name
and name not in local_names
and (name in texture_names or name in sampler_names)
):
dependencies.add(name)
if isinstance(node, FunctionCallNode):
self.add_texture_call_resource_dependencies(
node, local_names, texture_names, sampler_names, dependencies
)
return dependencies
def add_texture_call_resource_dependencies(
self, call, local_names, texture_names, sampler_names, dependencies
):
func_name = self.function_call_name(call)
if not func_name or not str(func_name).startswith(("texture", "image")):
return
args = getattr(call, "arguments", getattr(call, "args", []))
if not args:
return
texture_name = self.expression_name(args[0])
if texture_name in texture_names and texture_name not in local_names:
dependencies.add(texture_name)
if len(args) >= 3:
sampler_name = self.expression_name(args[1])
if sampler_name in sampler_names and sampler_name not in local_names:
dependencies.add(sampler_name)
return
if not self.texture_sampling_uses_implicit_sampler(func_name):
return
implicit_sampler_name = f"{texture_name}Sampler" if texture_name else None
if (
implicit_sampler_name in sampler_names
and implicit_sampler_name not in local_names
):
dependencies.add(implicit_sampler_name)
def texture_sampling_uses_implicit_sampler(self, func_name):
return func_name in {
"texture",
"textureLod",
"textureGrad",
"textureOffset",
"textureLodOffset",
"textureGradOffset",
"textureProj",
"textureProjLod",
"textureProjGrad",
"textureProjOffset",
"textureProjLodOffset",
"textureProjGradOffset",
}
def global_texture_names(self):
return {
texture_variable.name
for texture_variable, _, _, _ in self.texture_variables
if getattr(texture_variable, "name", None)
}
def global_sampler_names(self):
return {
sampler_variable.name
for sampler_variable, _, _ in self.sampler_variables
if getattr(sampler_variable, "name", None)
}
def required_function_textures(self, func_name):
dependencies = self.function_global_resource_dependencies.get(func_name, set())
return [
(texture_variable, texture_type, array_size)
for texture_variable, _, texture_type, array_size in self.texture_variables
if getattr(texture_variable, "name", None) in dependencies
]
def required_function_samplers(self, func_name):
dependencies = self.function_global_resource_dependencies.get(func_name, set())
return [
(sampler_variable, array_size)
for sampler_variable, _, array_size in self.sampler_variables
if getattr(sampler_variable, "name", None) in dependencies
]
def required_function_resource_argument_names(self, func_name):
return [
texture_variable.name
for texture_variable, _, _ in self.required_function_textures(func_name)
] + [
sampler_variable.name
for sampler_variable, _ in self.required_function_samplers(func_name)
]
def iter_ast_nodes(self, node):
if node is None or isinstance(node, (str, int, float, bool)):
return
if isinstance(node, (list, tuple, set)):
for item in node:
yield from self.iter_ast_nodes(item)
return
if not hasattr(node, "__dict__"):
return
yield node
for key, value in vars(node).items():
if key in {"parent", "annotations"}:
continue
yield from self.iter_ast_nodes(value)
def literal_int_value(self, expr, constants=None):
return evaluate_literal_int_expression(expr, constants)
def visible_literal_int_constants(self, func):
visible_constants = dict(self.literal_int_constants)
for param in getattr(func, "parameters", []) or []:
visible_constants.pop(getattr(param, "name", None), None)
for node in self.iter_ast_nodes(getattr(func, "body", [])):
if isinstance(node, VariableNode):
name = getattr(node, "name", None)
if not name:
continue
visible_constants.pop(name, None)
if "const" not in getattr(node, "qualifiers", []):
continue
value = self.literal_int_value(
getattr(node, "initial_value", None), visible_constants
)
if value is not None:
visible_constants[name] = value
return visible_constants
def function_call_name(self, call):
func_expr = getattr(call, "function", None)
if func_expr is None:
func_expr = getattr(call, "name", None)
if isinstance(func_expr, str):
return func_expr
if hasattr(func_expr, "name") and isinstance(func_expr.name, str):
return func_expr.name
return None
def supported_image_formats(self):
return {
"r8",
"r8_snorm",
"r8i",
"r8ui",
"r16",
"r16_snorm",
"r16f",
"r16i",
"r16ui",
"r32f",
"r32i",
"r32ui",
"rg8",
"rg8_snorm",
"rg8i",
"rg8ui",
"rg16",
"rg16_snorm",
"rg16f",
"rg16i",
"rg16ui",
"rg32f",
"rg32i",
"rg32ui",
"rgba8",
"rgba8_snorm",
"rgba8i",
"rgba8ui",
"rgba16",
"rgba16_snorm",
"rgba16f",
"rgba16i",
"rgba16ui",
"rgba32f",
"rgba32i",
"rgba32ui",
}
def scalar_image_format_components(self):
return {
"r8": "float",
"r8_snorm": "float",
"r16": "float",
"r16_snorm": "float",
"r16f": "float",
"r32f": "float",
"r8i": "int",
"r16i": "int",
"r32i": "int",
"r8ui": "uint",
"r16ui": "uint",
"r32ui": "uint",
}
def vector_image_format_components(self):
return {
"rg8": "float",
"rg8_snorm": "float",
"rg16": "float",
"rg16_snorm": "float",
"rg16f": "float",
"rg8i": "int",
"rg16i": "int",
"rg8ui": "uint",
"rg16ui": "uint",
"rg32f": "float",
"rg32i": "int",
"rg32ui": "uint",
"rgba8": "float",
"rgba8_snorm": "float",
"rgba16": "float",
"rgba16_snorm": "float",
"rgba16f": "float",
"rgba32f": "float",
"rgba8i": "int",
"rgba16i": "int",
"rgba32i": "int",
"rgba8ui": "uint",
"rgba16ui": "uint",
"rgba32ui": "uint",
}
def attribute_value_to_string(self, value):
if value is None:
return None
if isinstance(value, str):
return value
if hasattr(value, "name"):
return str(value.name)
if hasattr(value, "value"):
return str(value.value).strip('"')
return str(value)
def explicit_image_format(self, node):
if not hasattr(node, "attributes"):
return None
supported_formats = self.supported_image_formats()
for attr in node.attributes:
attr_name = getattr(attr, "name", None)
if not attr_name:
continue
attr_name = str(attr_name).lower()
if attr_name in supported_formats:
return attr_name
if attr_name == "format":
arguments = getattr(attr, "arguments", []) or []
if not arguments:
continue
format_name = self.attribute_value_to_string(arguments[0])
if format_name is None:
continue
format_name = str(format_name).lower()
if format_name in supported_formats:
return format_name
return None
def is_image_format_attribute(self, attr):
attr_name = getattr(attr, "name", None)
if not attr_name:
return False
attr_name = str(attr_name).lower()
return attr_name == "format" or attr_name in self.supported_image_formats()
def semantic_from_node(self, node):
if hasattr(node, "semantic"):
return node.semantic
if not hasattr(node, "attributes"):
return None
for attr in node.attributes:
if self.is_image_format_attribute(attr):
continue
if hasattr(attr, "name"):
return attr.name
return None
def map_resource_type_with_format(self, vtype, node=None):
if vtype is None:
return self.map_type(vtype)
if hasattr(vtype, "name") or hasattr(vtype, "element_type"):
vtype_str = self.convert_type_node_to_string(vtype)
else:
vtype_str = str(vtype)
if "[" in vtype_str and "]" in vtype_str:
base_type, array_suffix = split_array_type_suffix(vtype_str)
base_mapped = self.map_image_base_type_with_format(base_type, node)
return f"{base_mapped}{array_suffix}"
return self.map_image_base_type_with_format(vtype_str, node)
def map_image_base_type_with_format(self, vtype, node=None):
base_type = self.resource_base_type(vtype)
explicit_format = self.explicit_image_format(node) if node is not None else None
component_type = self.scalar_image_format_components().get(
explicit_format
) or self.vector_image_format_components().get(explicit_format)
texture_types = {
"image2D": "texture2d",
"iimage2D": "texture2d",
"uimage2D": "texture2d",
"image3D": "texture3d",
"iimage3D": "texture3d",
"uimage3D": "texture3d",
"image2DArray": "texture2d_array",
"iimage2DArray": "texture2d_array",
"uimage2DArray": "texture2d_array",
"imageCube": "texture2d_array",
}
texture_type = texture_types.get(base_type)
if component_type and texture_type:
return f"{texture_type}<{component_type}, access::read_write>"
return self.map_type(vtype)
def resource_base_type(self, vtype):
if vtype is None:
return ""
if hasattr(vtype, "element_type") and str(type(vtype)).find("ArrayType") != -1:
return self.resource_base_type(vtype.element_type)
if hasattr(vtype, "name") or hasattr(vtype, "element_type"):
vtype = self.convert_type_node_to_string(vtype)
vtype = str(vtype)
if "[" in vtype and "]" in vtype:
base_type, _ = parse_array_type(vtype)
return base_type
return vtype
def resource_array_count(self, size):
if size is None:
return 1
resolved_size = self.literal_int_value(size, self.literal_int_constants)
if resolved_size is not None:
return max(resolved_size, 1)
size_str = str(size)
return max(int(size_str), 1) if size_str.isdigit() else 1
def expression_name(self, expr):
if isinstance(expr, str):
return expr
if hasattr(expr, "name") and isinstance(expr.name, str):
return expr.name
if isinstance(expr, ArrayAccessNode):
return self.expression_name(expr.array)
return None
def texture_sampler_expression(self, texture_name):
sampler_arg = ""
for sampler_variable, _, _ in self.sampler_variables:
if sampler_variable.name == texture_name + "Sampler":
sampler_arg = sampler_variable.name
break
return sampler_arg or self.default_sampler_expression()
def is_explicit_sampler_argument(self, args):
if len(args) < 3:
return False
return self.texture_call_uses_explicit_sampler(args)
def texture_call_uses_explicit_sampler(self, args):
if len(args) < 2:
return False
sampler_name = self.expression_name(args[1]) or self.generate_expression(
args[1]
)
if sampler_name in self.sampler_variable_names():
return True
arg_type = self.expression_result_type(args[1])
return arg_type is not None and self.is_sampler_type(arg_type)
def texture_call_parts(self, args):
explicit_sampler = self.is_explicit_sampler_argument(args)
coord_index = 2 if explicit_sampler else 1
if len(args) <= coord_index:
return None
texture_name = self.generate_expression(args[0])
texture_base_name = self.expression_name(args[0]) or texture_name
sampler_arg = (
self.generate_expression(args[1])
if explicit_sampler
else self.texture_sampler_expression(texture_base_name)
)
coord = self.generate_expression(args[coord_index])
extra_args = args[coord_index + 1 :]
return texture_name, sampler_arg, coord, extra_args
def texture_resource_type(self, texture_arg):
texture_name = self.expression_name(texture_arg)
if not texture_name:
return None
return self.current_texture_parameters.get(
texture_name, self.texture_variable_types.get(texture_name)
)
def texture_argument_resource_type(self, texture_arg):
texture_type = self.texture_resource_type(texture_arg)
if texture_type is not None:
return texture_type
arg_type = self.expression_result_type(texture_arg)
if arg_type is None or not self.is_texture_or_image_resource_type(arg_type):
return None
return self.map_resource_type_with_format(self.resource_base_type(arg_type))
def validate_texture_resource_argument(self, func_name, args):
if not args or func_name not in self.texture_resource_operation_names():
return
if self.texture_resource_type(args[0]) is not None:
return
arg_type = self.expression_result_type(args[0])
if arg_type is not None and self.is_texture_or_image_resource_type(arg_type):
return
texture_name = self.expression_name(args[0]) or str(args[0])
raise ValueError(
f"Metal texture operation '{func_name}' requires a declared "
f"texture or image resource argument: {texture_name}"
)
def validate_image_resource_argument(self, func_name, args):
if not args or func_name not in IMAGE_RESOURCE_INTRINSIC_NAMES:
return
texture_type = self.texture_argument_resource_type(args[0])
if self.is_storage_image_resource(texture_type):
return
texture_name = self.expression_name(args[0]) or str(args[0])
raise ValueError(
f"Metal image operation '{func_name}' requires a storage "
f"image resource argument: {texture_name}"
)
def validate_integer_coordinate_argument(self, func_name, args):
if func_name not in INTEGER_COORDINATE_INTRINSIC_NAMES or len(args) < 2:
return
coord_type = self.expression_result_type(args[1])
if coord_type is None or self.is_integer_coordinate_type(coord_type):
return
raise ValueError(
f"Metal resource operation '{func_name}' requires an integer "
f"coordinate argument: {expression_debug_name(args[1])} has type "
f"{self.type_name_string(coord_type)}"
)
def validate_coordinate_dimension_argument(self, func_name, args):
if func_name not in INTEGER_COORDINATE_INTRINSIC_NAMES or len(args) < 2:
return
texture_type = self.texture_argument_resource_type(args[0])
expected_dimension = self.resource_coordinate_dimension(texture_type)
if expected_dimension is None:
return
coord_type = self.expression_result_type(args[1])
coord_dimension = integer_coordinate_dimension(
self.type_name_string(coord_type)
)
if coord_dimension is None or coord_dimension == expected_dimension:
return
raise ValueError(
f"Metal resource operation '{func_name}' requires a "
f"{expected_dimension}D integer coordinate for "
f"{self.resource_base_type(texture_type)}: "
f"{expression_debug_name(args[1])} has type "
f"{self.type_name_string(coord_type)}"
)
def validate_offset_dimension_argument(self, func_name, args):
offset_indices = texture_offset_argument_indices(
func_name,
self.texture_call_uses_explicit_sampler(args),
len(args),
)
if not offset_indices:
return
texture_type = self.texture_argument_resource_type(args[0])
expected_dimension = self.resource_offset_dimension(func_name, texture_type)
if expected_dimension is None:
return
for offset_index in offset_indices:
offset_type = self.expression_result_type(args[offset_index])
if offset_type is None:
continue
if not self.is_integer_coordinate_type(offset_type):
raise ValueError(
f"Metal resource operation '{func_name}' requires an integer "
f"offset argument: {expression_debug_name(args[offset_index])} "
f"has type {self.type_name_string(offset_type)}"
)
offset_dimension = integer_coordinate_dimension(
self.type_name_string(offset_type)
)
if offset_dimension is None or offset_dimension == expected_dimension:
continue
raise ValueError(
f"Metal resource operation '{func_name}' requires a "
f"{expected_dimension}D integer offset for "
f"{self.resource_base_type(texture_type)}: "
f"{expression_debug_name(args[offset_index])} has type "
f"{self.type_name_string(offset_type)}"
)
def gradient_argument_dimension(self, vtype):
type_name = self.resource_base_type(self.type_name_string(vtype))
mapped_type = self.map_type(type_name)
return floating_coordinate_dimension(
mapped_type
) or floating_coordinate_dimension(type_name)
def query_lod_coordinate_dimension(self, vtype):
type_name = self.resource_base_type(self.type_name_string(vtype))
mapped_type = self.map_type(type_name)
return floating_coordinate_dimension(
mapped_type
) or floating_coordinate_dimension(type_name)
def validate_query_lod_coordinate_argument(self, func_name, args):
coord_index = texture_query_lod_coordinate_argument_index(
func_name,
self.texture_call_uses_explicit_sampler(args),
len(args),
)
if coord_index is None:
return
texture_type = self.texture_argument_resource_type(args[0])
expected_dimension = self.resource_query_lod_coordinate_dimension(texture_type)
if expected_dimension is None:
return
coord_type = self.expression_result_type(args[coord_index])
if coord_type is None:
return
coord_dimension = self.query_lod_coordinate_dimension(coord_type)
if coord_dimension is None:
raise ValueError(
f"Metal texture query operation '{func_name}' requires a floating "
f"coordinate argument: {expression_debug_name(args[coord_index])} "
f"has type {self.type_name_string(coord_type)}"
)
if coord_dimension == expected_dimension:
return
raise ValueError(
f"Metal texture query operation '{func_name}' requires a "
f"{expected_dimension}D floating coordinate for "
f"{self.resource_base_type(texture_type)}: "
f"{expression_debug_name(args[coord_index])} has type "
f"{self.type_name_string(coord_type)}"
)
def validate_gradient_dimension_arguments(self, func_name, args):
gradient_indices = texture_gradient_argument_indices(
func_name,
self.texture_call_uses_explicit_sampler(args),
len(args),
)
if not gradient_indices:
return
texture_type = self.texture_argument_resource_type(args[0])
expected_dimension = self.resource_gradient_dimension(func_name, texture_type)
if expected_dimension is None:
return
for gradient_index in gradient_indices:
gradient_type = self.expression_result_type(args[gradient_index])
if gradient_type is None:
continue
gradient_dimension = self.gradient_argument_dimension(gradient_type)
if gradient_dimension is None:
raise ValueError(
f"Metal resource operation '{func_name}' requires a floating "
f"gradient argument: {expression_debug_name(args[gradient_index])} "
f"has type {self.type_name_string(gradient_type)}"
)
if gradient_dimension == expected_dimension:
continue
raise ValueError(
f"Metal resource operation '{func_name}' requires a "
f"{expected_dimension}D floating gradient for "
f"{self.resource_base_type(texture_type)}: "
f"{expression_debug_name(args[gradient_index])} has type "
f"{self.type_name_string(gradient_type)}"
)
def is_scalar_floating_type(self, vtype):
type_name = self.type_name_string(vtype)
if not type_name or "[" in str(type_name):
return False
mapped_type = self.map_type(type_name)
return is_floating_scalar_type(mapped_type) or is_floating_scalar_type(
type_name
)
def is_scalar_numeric_type(self, vtype):
type_name = self.type_name_string(vtype)
if not type_name or "[" in str(type_name):
return False
mapped_type = self.map_type(type_name)
return is_numeric_scalar_type(mapped_type) or is_numeric_scalar_type(type_name)
def is_scalar_integer_type(self, vtype):
type_name = self.type_name_string(vtype)
if not type_name or "[" in str(type_name):
return False
mapped_type = self.map_type(type_name)
return is_integer_scalar_type(mapped_type) or is_integer_scalar_type(type_name)
def texture_argument_diagnostic_type(self, arg):
texture_type = self.texture_resource_type(arg)
if texture_type is not None:
return texture_type
arg_name = self.expression_name(arg)
sampler_names = {
sampler_variable.name for sampler_variable, _, _ in self.sampler_variables
}
if arg_name in sampler_names or arg_name in self.current_sampler_parameters:
return "sampler"
return self.expression_result_type(arg)
def validate_compare_argument(self, func_name, args):
compare_index = texture_compare_argument_index(
func_name,
self.texture_call_uses_explicit_sampler(args),
len(args),
)
if compare_index is None:
return
compare_type = self.expression_result_type(args[compare_index])
if compare_type is None or self.is_scalar_floating_type(compare_type):
return
raise ValueError(
f"Metal texture compare operation '{func_name}' requires a scalar "
f"floating compare argument: {expression_debug_name(args[compare_index])} "
f"has type {self.type_name_string(compare_type)}"
)
def validate_lod_argument(self, func_name, args):
lod_index = texture_lod_argument_index(
func_name,
self.texture_call_uses_explicit_sampler(args),
len(args),
)
if lod_index is None:
return
lod_type = self.texture_argument_diagnostic_type(args[lod_index])
if lod_type is None or self.is_scalar_numeric_type(lod_type):
return
raise ValueError(
f"Metal texture LOD operation '{func_name}' requires a scalar "
f"numeric lod argument: {expression_debug_name(args[lod_index])} "
f"has type {self.type_name_string(lod_type)}"
)
def validate_bias_argument(self, func_name, args):
bias_index = texture_bias_argument_index(
func_name,
self.texture_call_uses_explicit_sampler(args),
len(args),
)
if bias_index is None:
return
bias_type = self.texture_argument_diagnostic_type(args[bias_index])
if bias_type is None or self.is_scalar_numeric_type(bias_type):
return
raise ValueError(
f"Metal texture bias operation '{func_name}' requires a scalar "
f"numeric bias argument: {expression_debug_name(args[bias_index])} "
f"has type {self.type_name_string(bias_type)}"
)
def validate_mip_level_argument(self, func_name, args):
level_index = texture_mip_level_argument_index(func_name, len(args))
if level_index is None:
return
level_type = self.texture_argument_diagnostic_type(args[level_index])
if level_type is None or self.is_scalar_integer_type(level_type):
return
raise ValueError(
f"Metal resource operation '{func_name}' requires a scalar integer "
f"mip/sample level argument: {expression_debug_name(args[level_index])} "
f"has type {self.type_name_string(level_type)}"
)
def validate_sample_index_argument(self, func_name, args):
sample_index = texture_sample_index_argument_index(func_name, len(args))
if sample_index is None:
return
texture_type = self.texture_argument_resource_type(args[0])
if not self.is_multisample_texture_resource(texture_type):
return
sample_type = self.texture_argument_diagnostic_type(args[sample_index])
if sample_type is None or self.is_scalar_integer_type(sample_type):
return
raise ValueError(
f"Metal multisample texel fetch operation '{func_name}' requires a "
f"scalar integer sample index argument: "
f"{expression_debug_name(args[sample_index])} has type "
f"{self.type_name_string(sample_type)}"
)
def validate_gather_component_argument(self, func_name, args):
component_index = texture_gather_component_argument_index(
func_name,
self.texture_call_uses_explicit_sampler(args),
len(args),
)
if component_index is None:
return
component_type = self.texture_argument_diagnostic_type(args[component_index])
if component_type is None or self.is_scalar_integer_type(component_type):
return
raise ValueError(
f"Metal texture gather operation '{func_name}' requires a scalar "
f"integer component argument: "
f"{expression_debug_name(args[component_index])} has type "
f"{self.type_name_string(component_type)}"
)
def validate_texture_call_arity(self, func_name, args):
if func_name not in self.texture_resource_operation_names():
return
has_explicit_sampler = self.texture_call_uses_explicit_sampler(args)
min_count = texture_intrinsic_min_argument_count(
func_name,
has_explicit_sampler,
)
if min_count is not None and len(args) < min_count:
raise ValueError(
f"Metal texture operation '{func_name}' requires at least "
f"{min_count} argument(s), got {len(args)}"
)
allowed_counts = texture_intrinsic_allowed_argument_counts(
func_name,
has_explicit_sampler,
)
if allowed_counts is not None and len(args) not in allowed_counts:
counts = ", ".join(str(count) for count in allowed_counts)
raise ValueError(
f"Metal texture operation '{func_name}' accepts "
f"{counts} argument(s), got {len(args)}"
)
max_count = texture_intrinsic_max_argument_count(
func_name,
has_explicit_sampler,
)
if max_count is None or len(args) <= max_count:
return
raise ValueError(
f"Metal texture operation '{func_name}' accepts at most "
f"{max_count} argument(s), got {len(args)}"
)
def texture_resource_operation_names(self):
return {
"texture",
"textureLod",
"textureGrad",
"textureOffset",
"textureLodOffset",
"textureGradOffset",
"textureProj",
"textureProjOffset",
"textureProjLod",
"textureProjLodOffset",
"textureProjGrad",
"textureProjGradOffset",
"textureCompare",
"textureCompareOffset",
"textureCompareLod",
"textureCompareLodOffset",
"textureCompareGrad",
"textureCompareGradOffset",
"textureCompareProj",
"textureCompareProjOffset",
"textureCompareProjLod",
"textureCompareProjLodOffset",
"textureCompareProjGrad",
"textureCompareProjGradOffset",
"textureGather",
"textureGatherOffset",
"textureGatherOffsets",
"textureGatherCompare",
"textureGatherCompareOffset",
"textureQueryLod",
"textureQueryLevels",
"textureSize",
"textureSamples",
"texelFetch",
"texelFetchOffset",
"imageLoad",
"imageStore",
"imageSize",
"imageSamples",
"imageAtomicAdd",
"imageAtomicMin",
"imageAtomicMax",
"imageAtomicAnd",
"imageAtomicOr",
"imageAtomicXor",
"imageAtomicExchange",
"imageAtomicCompSwap",
}
def image_resource_format(self, texture_arg):
texture_name = self.expression_name(texture_arg)
if not texture_name:
return None
return self.current_image_format_parameters.get(
texture_name, self.image_variable_formats.get(texture_name)
)
def is_array_texture_resource(self, texture_type):
return texture_type in {
"texture2d_array<float>",
"depth2d_array<float>",
"texturecube_array<float>",
"depthcube_array<float>",
}
def is_multisample_texture_resource(self, texture_type):
return texture_type in {
"texture2d_ms<float>",
"texture2d_ms_array<float>",
}
def is_storage_image_resource(self, texture_type):
texture_type = self.resource_base_type(texture_type)
return (
texture_type.startswith("texture2d<")
or texture_type.startswith("texture3d<")
or texture_type.startswith("texture2d_array<")
) and "access::read_write" in texture_type
def vector_component(self, expression, component):
if all(char.isalnum() or char in "_.[]" for char in expression):
return f"{expression}.{component}"
return f"({expression}).{component}"
def array_texture_coordinate_parts(self, coord):
coord_xy = self.vector_component(coord, "xy")
layer = f"uint({self.vector_component(coord, 'z')})"
return coord_xy, layer
def cube_array_texture_coordinate_parts(self, coord):
coord_xyz = self.vector_component(coord, "xyz")
layer = f"uint({self.vector_component(coord, 'w')})"
return coord_xyz, layer
def texture_coordinate_parts(self, texture_type, coord):
if texture_type in {"texturecube_array<float>", "depthcube_array<float>"}:
return self.cube_array_texture_coordinate_parts(coord)
return self.array_texture_coordinate_parts(coord)
def texture_gradient_options(self, texture_type, ddx, ddy):
if texture_type in {
"texturecube<float>",
"depthcube<float>",
"texturecube_array<float>",
"depthcube_array<float>",
}:
return f"gradientcube({ddx}, {ddy})"
if texture_type == "texture3d<float>":
return f"gradient3d({ddx}, {ddy})"
return f"gradient2d({ddx}, {ddy})"
def texture_gather_supports_offset(self, texture_type):
return texture_type in {"texture2d<float>", "texture2d_array<float>"}
def texture_gather_supported(self, texture_type):
return texture_type in {
"texture2d<float>",
"texture2d_array<float>",
"texturecube<float>",
"texturecube_array<float>",
}
def texture_sample_supports_offset(self, texture_type):
texture_type = self.resource_base_type(texture_type)
return texture_type in {"texture2d<float>", "texture2d_array<float>"}
def unsupported_texture_sample_offset_call(self, func_name, reason):
return (
f"/* unsupported Metal texture offset: {func_name} {reason} */ float4(0.0)"
)
def texture_sample_offset_coord_args(self, texture_type, coord):
if self.is_array_texture_resource(texture_type):
return self.texture_coordinate_parts(texture_type, coord)
return (coord,)
def generate_texture_sample_offset_call(
self, func_name, texture_name, sampler_arg, coord, extra_args, texture_type
):
if not self.texture_sample_supports_offset(texture_type):
return self.unsupported_texture_sample_offset_call(
func_name, "offsets require 2D or 2D-array textures"
)
coord_args = self.texture_sample_offset_coord_args(texture_type, coord)
if func_name == "textureOffset":
if len(extra_args) not in {1, 2}:
return self.unsupported_texture_sample_offset_call(
func_name, "requires offset and optional bias arguments"
)
offset = self.generate_expression(extra_args[0])
args = [sampler_arg] + list(coord_args)
if len(extra_args) == 2:
bias = self.generate_expression(extra_args[1])
args.append(f"bias({bias})")
args.append(offset)
return f"{texture_name}.sample({', '.join(args)})"
if func_name == "textureLodOffset":
if len(extra_args) != 2:
return self.unsupported_texture_sample_offset_call(
func_name, "requires lod and offset arguments"
)
lod = self.generate_expression(extra_args[0])
offset = self.generate_expression(extra_args[1])
args = [sampler_arg] + list(coord_args) + [f"level({lod})", offset]
return f"{texture_name}.sample({', '.join(args)})"
if func_name == "textureGradOffset":
if len(extra_args) != 3:
return self.unsupported_texture_sample_offset_call(
func_name,
"requires gradient x, gradient y, and offset arguments",
)
ddx = self.generate_expression(extra_args[0])
ddy = self.generate_expression(extra_args[1])
offset = self.generate_expression(extra_args[2])
gradient_options = self.texture_gradient_options(texture_type, ddx, ddy)
args = [sampler_arg] + list(coord_args) + [gradient_options, offset]
return f"{texture_name}.sample({', '.join(args)})"
return self.unsupported_texture_sample_offset_call(
func_name, "is not a supported texture offset operation"
)
def unsupported_texture_projected_call(self, func_name, reason):
return f"/* unsupported Metal projected texture: {func_name} {reason} */ float4(0.0)"
def projected_texture_coord(self, texture_arg, coord_arg, coord):
texture_type = self.resource_base_type(self.texture_resource_type(texture_arg))
coord_type = self.resource_base_type(self.expression_result_type(coord_arg))
specs = {
"texture1d<float>": {
"vec2": ("x", "y"),
"float2": ("x", "y"),
"vec4": ("x", "w"),
"float4": ("x", "w"),
},
"texture2d<float>": {
"vec3": ("xy", "z"),
"float3": ("xy", "z"),
"vec4": ("xy", "w"),
"float4": ("xy", "w"),
},
"texture2d_array<float>": {
"vec4": ("xy", "w"),
"float4": ("xy", "w"),
},
"texture3d<float>": {
"vec4": ("xyz", "w"),
"float4": ("xyz", "w"),
},
}
texture_specs = specs.get(texture_type)
if texture_specs is None:
return None
coord_spec = texture_specs.get(coord_type)
if coord_spec is None:
return None
numerator, divisor = coord_spec
projected_coord = (
f"{self.vector_component(coord, numerator)} / "
f"{self.vector_component(coord, divisor)}"
)
if texture_type == "texture2d_array<float>":
return f"{projected_coord}, " f"uint({self.vector_component(coord, 'z')})"
return projected_coord
def projected_texture_offset_supported(self, texture_type):
texture_type = self.resource_base_type(texture_type)
return texture_type in {"texture2d<float>", "texture2d_array<float>"}
def generate_texture_projected_call(
self,
func_name,
texture_name,
sampler_arg,
coord,
extra_args,
texture_type,
args,
):
coord_index = 2 if self.is_explicit_sampler_argument(args) else 1
projected_coord = self.projected_texture_coord(
args[0], args[coord_index], coord
)
if projected_coord is None:
return self.unsupported_texture_projected_call(
func_name, "requires 1D, 2D, or 3D projection coordinates"
)
if func_name == "textureProj":
if not extra_args:
return f"{texture_name}.sample({sampler_arg}, {projected_coord})"
if len(extra_args) == 1:
bias = self.generate_expression(extra_args[0])
return (
f"{texture_name}.sample("
f"{sampler_arg}, {projected_coord}, bias({bias}))"
)
return self.unsupported_texture_projected_call(
func_name, "accepts at most one bias argument"
)
if func_name == "textureProjOffset":
if not self.projected_texture_offset_supported(texture_type):
return self.unsupported_texture_projected_call(
func_name, "offsets require 2D textures"
)
if len(extra_args) == 1:
offset = self.generate_expression(extra_args[0])
return (
f"{texture_name}.sample("
f"{sampler_arg}, {projected_coord}, {offset})"
)
if len(extra_args) == 2:
offset = self.generate_expression(extra_args[0])
bias = self.generate_expression(extra_args[1])
return (
f"{texture_name}.sample("
f"{sampler_arg}, {projected_coord}, bias({bias}), {offset})"
)
return self.unsupported_texture_projected_call(
func_name, "requires offset and optional bias arguments"
)
if func_name == "textureProjLod":
if len(extra_args) != 1:
return self.unsupported_texture_projected_call(
func_name, "requires one lod argument"
)
lod = self.generate_expression(extra_args[0])
return (
f"{texture_name}.sample("
f"{sampler_arg}, {projected_coord}, level({lod}))"
)
if func_name == "textureProjLodOffset":
if not self.projected_texture_offset_supported(texture_type):
return self.unsupported_texture_projected_call(
func_name, "offsets require 2D textures"
)
if len(extra_args) != 2:
return self.unsupported_texture_projected_call(
func_name, "requires lod and offset arguments"
)
lod = self.generate_expression(extra_args[0])
offset = self.generate_expression(extra_args[1])
return (
f"{texture_name}.sample("
f"{sampler_arg}, {projected_coord}, level({lod}), {offset})"
)
if func_name == "textureProjGrad":
if len(extra_args) != 2:
return self.unsupported_texture_projected_call(
func_name, "requires gradient x and gradient y arguments"
)
ddx = self.generate_expression(extra_args[0])
ddy = self.generate_expression(extra_args[1])
gradient_options = self.texture_gradient_options(texture_type, ddx, ddy)
return (
f"{texture_name}.sample("
f"{sampler_arg}, {projected_coord}, {gradient_options})"
)
if not self.projected_texture_offset_supported(texture_type):
return self.unsupported_texture_projected_call(
func_name, "offsets require 2D textures"
)
if len(extra_args) != 3:
return self.unsupported_texture_projected_call(
func_name, "requires gradient x, gradient y, and offset arguments"
)
ddx = self.generate_expression(extra_args[0])
ddy = self.generate_expression(extra_args[1])
offset = self.generate_expression(extra_args[2])
gradient_options = self.texture_gradient_options(texture_type, ddx, ddy)
return (
f"{texture_name}.sample("
f"{sampler_arg}, {projected_coord}, {gradient_options}, {offset})"
)
def is_array_expression(self, node):
type_name = self.expression_result_type(node)
return isinstance(type_name, str) and "[" in type_name and "]" in type_name
def texture_gather_offsets_args(self, extra_args):
if len(extra_args) in {1, 2} and self.is_array_expression(extra_args[0]):
offsets_name = self.generate_expression(extra_args[0])
offset_args = [f"{offsets_name}[{index}]" for index in range(4)]
component_arg = extra_args[1] if len(extra_args) == 2 else None
return offset_args, component_arg
if len(extra_args) in {4, 5}:
component_arg = extra_args[4] if len(extra_args) == 5 else None
return extra_args[:4], component_arg
return None, None
def texture_gather_component_option(self, component_arg):
if component_arg is None:
return None
components = {
0: "component::x",
1: "component::y",
2: "component::z",
3: "component::w",
}
return components.get(self.literal_int_value(component_arg))
def texture_gather_coord_args(self, texture_type, coord):
if self.is_array_texture_resource(texture_type):
coord_part, layer = self.texture_coordinate_parts(texture_type, coord)
return [coord_part, layer]
return [coord]
def texture_gather_call_expression(
self,
texture_name,
sampler_arg,
coord_args,
offset_arg=None,
component=None,
default_offset_for_component=False,
):
args = [sampler_arg] + coord_args
if offset_arg is not None:
args.append(offset_arg)
elif component is not None and default_offset_for_component:
args.append("int2(0)")
if component is not None:
args.append(component)
return f"{texture_name}.gather({', '.join(args)})"
def texture_gather_offsets_expression(
self, texture_name, sampler_arg, coord_args, offset_args, component
):
component_suffixes = ("x", "y", "z", "w")
component_values = []
for index, offset_arg in enumerate(offset_args):
gather = self.texture_gather_call_expression(
texture_name,
sampler_arg,
coord_args,
self.generate_expression(offset_arg),
component,
default_offset_for_component=True,
)
component_values.append(f"{gather}.{component_suffixes[index]}")
return f"float4({', '.join(component_values)})"
def texture_gather_dynamic_component_expression(
self, build_expression, component_expr
):
component_options = (
"component::x",
"component::y",
"component::z",
"component::w",
)
component_calls = [
build_expression(component) for component in component_options
]
return (
f"({component_expr} == 0 ? {component_calls[0]} : "
f"{component_expr} == 1 ? {component_calls[1]} : "
f"{component_expr} == 2 ? {component_calls[2]} : {component_calls[3]})"
)
def unsupported_texture_gather_call(self, func_name, reason):
return (
f"/* unsupported Metal texture gather: {func_name} {reason} */ float4(0.0)"
)
def unsupported_multisample_texture_call(self, func_name, texture_type):
return (
f"/* unsupported Metal multisample texture call: "
f"{func_name} on {texture_type} */ float4(0.0)"
)
def unsupported_multisample_texture_query_lod_call(self, texture_type):
return (
"/* unsupported Metal multisample texture query: "
f"textureQueryLod on {texture_type} */ float2(0.0)"
)
def unsupported_texture_query_levels_call(self, texture_type):
texture_type = self.resource_base_type(texture_type)
return (
"/* unsupported Metal texture query: "
f"textureQueryLevels on {texture_type} */ 0"
)
def unsupported_texture_query_lod_call(self, texture_type):
texture_type = self.resource_base_type(texture_type)
return (
"/* unsupported Metal texture query: "
f"textureQueryLod on {texture_type} */ float2(0.0)"
)
def storage_image_texture_operation_expression(self, func_name, texture_type):
if not self.is_storage_image_resource(texture_type):
return None
texture_type = self.resource_base_type(texture_type)
if func_name in {
"textureCompare",
"textureCompareOffset",
"textureCompareLod",
"textureCompareLodOffset",
"textureCompareGrad",
"textureCompareGradOffset",
"textureCompareProj",
"textureCompareProjOffset",
"textureCompareProjLod",
"textureCompareProjLodOffset",
"textureCompareProjGrad",
"textureCompareProjGradOffset",
}:
return (
"/* unsupported Metal storage image texture comparison: "
f"{func_name} on {texture_type} */ 0.0"
)
if func_name in {
"texture",
"textureLod",
"textureGrad",
"textureOffset",
"textureLodOffset",
"textureGradOffset",
"textureProj",
"textureProjOffset",
"textureProjLod",
"textureProjLodOffset",
"textureProjGrad",
"textureProjGradOffset",
"textureGather",
"textureGatherOffset",
"textureGatherOffsets",
"textureGatherCompare",
"textureGatherCompareOffset",
"texelFetch",
"texelFetchOffset",
}:
return (
"/* unsupported Metal storage image texture operation: "
f"{func_name} on {texture_type} */ float4(0.0)"
)
return None
def is_cube_texture_resource(self, texture_type):
return texture_type in {
"texturecube<float>",
"texturecube_array<float>",
"depthcube<float>",
"depthcube_array<float>",
}
def unsupported_cube_texel_fetch_call(self, func_name, texture_type):
return (
f"/* unsupported Metal texel fetch: {func_name} on "
f"{texture_type} */ float4(0.0)"
)
def generate_texture_gather_call(
self, func_name, texture_name, sampler_arg, coord, extra_args, texture_type
):
if self.is_multisample_texture_resource(texture_type):
return self.unsupported_multisample_texture_call(func_name, texture_type)
if func_name == "textureGather" and not self.texture_gather_supported(
texture_type
):
return self.unsupported_texture_gather_call(
func_name, "requires 2D, 2D-array, cube, or cube-array textures"
)
coord_args = self.texture_gather_coord_args(texture_type, coord)
supports_offset = self.texture_gather_supports_offset(texture_type)
offset_args = []
component_arg = None
if func_name == "textureGather":
if len(extra_args) > 1:
return self.unsupported_texture_gather_call(
func_name, "accepts at most one component argument"
)
if extra_args:
component_arg = extra_args[0]
elif func_name == "textureGatherOffset":
if len(extra_args) not in {1, 2}:
return self.unsupported_texture_gather_call(
func_name, "requires offset and optional component arguments"
)
if not supports_offset:
return self.unsupported_texture_gather_call(
func_name, "offsets require 2D or 2D-array textures"
)
offset_args = [extra_args[0]]
if len(extra_args) == 2:
component_arg = extra_args[1]
else:
if not supports_offset:
return self.unsupported_texture_gather_call(
func_name, "offsets require 2D or 2D-array textures"
)
offset_args, component_arg = self.texture_gather_offsets_args(extra_args)
if offset_args is None:
return self.unsupported_texture_gather_call(
func_name,
"requires a typed offsets array or four offset arguments",
)
component = self.texture_gather_component_option(component_arg)
if component is not None or component_arg is None:
if func_name == "textureGatherOffsets":
return self.texture_gather_offsets_expression(
texture_name, sampler_arg, coord_args, offset_args, component
)
offset_arg = (
self.generate_expression(offset_args[0]) if offset_args else None
)
return self.texture_gather_call_expression(
texture_name,
sampler_arg,
coord_args,
offset_arg,
component,
default_offset_for_component=supports_offset,
)
if self.literal_int_value(component_arg) is not None:
return self.unsupported_texture_gather_call(
func_name, "component literal must be 0, 1, 2, or 3"
)
component_expr = self.generate_expression(component_arg)
if func_name == "textureGatherOffsets":
return self.texture_gather_dynamic_component_expression(
lambda option: self.texture_gather_offsets_expression(
texture_name, sampler_arg, coord_args, offset_args, option
),
component_expr,
)
offset_arg = self.generate_expression(offset_args[0]) if offset_args else None
return self.texture_gather_dynamic_component_expression(
lambda option: self.texture_gather_call_expression(
texture_name,
sampler_arg,
coord_args,
offset_arg,
option,
default_offset_for_component=supports_offset,
),
component_expr,
)
def texture_compare_offset_supported(self, texture_type):
return texture_type in {"depth2d<float>", "depth2d_array<float>"}
def unsupported_texture_compare_call(self, func_name, reason):
return f"/* unsupported Metal texture compare: {func_name} {reason} */ 0.0"
def texture_compare_projected_coord_args(self, texture_type, coord_arg, coord):
texture_type = self.resource_base_type(texture_type)
coord_type = self.resource_base_type(self.expression_result_type(coord_arg))
if texture_type == "depth2d<float>":
if coord_type in {"vec3", "float3"}:
divisor = self.vector_component(coord, "z")
elif coord_type in {"vec4", "float4"}:
divisor = self.vector_component(coord, "w")
else:
return None
return [f"{self.vector_component(coord, 'xy')} / {divisor}"]
if texture_type != "depth2d_array<float>" or coord_type not in {
"vec4",
"float4",
}:
return None
projected_coord = (
f"{self.vector_component(coord, 'xy')} / "
f"{self.vector_component(coord, 'w')}"
)
layer = f"uint({self.vector_component(coord, 'z')})"
return [projected_coord, layer]
def generate_texture_compare_call(
self,
func_name,
texture_name,
sampler_arg,
coord,
extra_args,
texture_type,
args=None,
):
if not extra_args:
return self.unsupported_texture_compare_call(
func_name, "requires a compare argument"
)
compare = self.generate_expression(extra_args[0])
if func_name in {
"textureCompareProj",
"textureCompareProjOffset",
"textureCompareProjLod",
"textureCompareProjLodOffset",
"textureCompareProjGrad",
"textureCompareProjGradOffset",
}:
coord_index = 2 if self.is_explicit_sampler_argument(args or []) else 1
coord_arg = (args or [None, None])[coord_index]
coord_args = self.texture_compare_projected_coord_args(
texture_type, coord_arg, coord
)
if coord_args is None:
return self.unsupported_texture_compare_call(
func_name,
"requires depth2d vec3/vec4 or depth2d_array vec4 projection coordinates",
)
projected_args = [sampler_arg] + coord_args + [compare]
if func_name == "textureCompareProj":
if len(extra_args) != 1:
return self.unsupported_texture_compare_call(
func_name, "accepts no extra arguments"
)
return f"{texture_name}.sample_compare({', '.join(projected_args)})"
if func_name == "textureCompareProjOffset":
if len(extra_args) != 2:
return self.unsupported_texture_compare_call(
func_name, "requires compare and offset arguments"
)
offset = self.generate_expression(extra_args[1])
args = projected_args + [offset]
return f"{texture_name}.sample_compare({', '.join(args)})"
if func_name == "textureCompareProjLod":
if len(extra_args) != 2:
return self.unsupported_texture_compare_call(
func_name, "requires compare and lod arguments"
)
lod = self.generate_expression(extra_args[1])
args = projected_args + [f"level({lod})"]
return f"{texture_name}.sample_compare({', '.join(args)})"
if func_name == "textureCompareProjLodOffset":
if len(extra_args) != 3:
return self.unsupported_texture_compare_call(
func_name, "requires compare, lod, and offset arguments"
)
lod = self.generate_expression(extra_args[1])
offset = self.generate_expression(extra_args[2])
args = projected_args + [f"level({lod})", offset]
return f"{texture_name}.sample_compare({', '.join(args)})"
if func_name == "textureCompareProjGrad":
if len(extra_args) != 3:
return self.unsupported_texture_compare_call(
func_name,
"requires compare, gradient x, and gradient y arguments",
)
ddx = self.generate_expression(extra_args[1])
ddy = self.generate_expression(extra_args[2])
gradient_options = self.texture_gradient_options(texture_type, ddx, ddy)
args = projected_args + [gradient_options]
return f"{texture_name}.sample_compare({', '.join(args)})"
if len(extra_args) != 4:
return self.unsupported_texture_compare_call(
func_name,
"requires compare, gradient x, gradient y, and offset arguments",
)
ddx = self.generate_expression(extra_args[1])
ddy = self.generate_expression(extra_args[2])
gradient_options = self.texture_gradient_options(texture_type, ddx, ddy)
offset = self.generate_expression(extra_args[3])
args = projected_args + [gradient_options, offset]
return f"{texture_name}.sample_compare({', '.join(args)})"
coord_args = (
self.texture_coordinate_parts(texture_type, coord)
if self.is_array_texture_resource(texture_type)
else (coord,)
)
if func_name == "textureCompare":
if len(extra_args) != 1:
return self.unsupported_texture_compare_call(
func_name, "accepts no extra arguments"
)
args = [sampler_arg] + list(coord_args) + [compare]
return f"{texture_name}.sample_compare({', '.join(args)})"
if func_name == "textureCompareOffset":
if len(extra_args) != 2:
return self.unsupported_texture_compare_call(
func_name, "requires compare and offset arguments"
)
if not self.texture_compare_offset_supported(texture_type):
return self.unsupported_texture_compare_call(
func_name, "offsets require 2D or 2D-array depth textures"
)
offset = self.generate_expression(extra_args[1])
args = [sampler_arg] + list(coord_args) + [compare, offset]
return f"{texture_name}.sample_compare({', '.join(args)})"
if func_name == "textureCompareLod":
if len(extra_args) != 2:
return self.unsupported_texture_compare_call(
func_name, "requires compare and lod arguments"
)
lod = self.generate_expression(extra_args[1])
args = [sampler_arg] + list(coord_args) + [compare, f"level({lod})"]
return f"{texture_name}.sample_compare({', '.join(args)})"
if func_name == "textureCompareLodOffset":
if len(extra_args) != 3:
return self.unsupported_texture_compare_call(
func_name, "requires compare, lod, and offset arguments"
)
if not self.texture_compare_offset_supported(texture_type):
return self.unsupported_texture_compare_call(
func_name, "offsets require 2D or 2D-array depth textures"
)
lod = self.generate_expression(extra_args[1])
offset = self.generate_expression(extra_args[2])
args = (
[sampler_arg]
+ list(coord_args)
+ [
compare,
f"level({lod})",
offset,
]
)
return f"{texture_name}.sample_compare({', '.join(args)})"
if func_name == "textureCompareGrad":
if len(extra_args) != 3:
return self.unsupported_texture_compare_call(
func_name,
"requires compare, gradient x, and gradient y arguments",
)
ddx = self.generate_expression(extra_args[1])
ddy = self.generate_expression(extra_args[2])
gradient_options = self.texture_gradient_options(texture_type, ddx, ddy)
args = [sampler_arg] + list(coord_args) + [compare, gradient_options]
return f"{texture_name}.sample_compare({', '.join(args)})"
if func_name == "textureCompareGradOffset":
if len(extra_args) != 4:
return self.unsupported_texture_compare_call(
func_name,
"requires compare, gradient x, gradient y, and offset arguments",
)
if not self.texture_compare_offset_supported(texture_type):
return self.unsupported_texture_compare_call(
func_name, "offsets require 2D or 2D-array depth textures"
)
ddx = self.generate_expression(extra_args[1])
ddy = self.generate_expression(extra_args[2])
gradient_options = self.texture_gradient_options(texture_type, ddx, ddy)
offset = self.generate_expression(extra_args[3])
args = (
[sampler_arg]
+ list(coord_args)
+ [
compare,
gradient_options,
offset,
]
)
return f"{texture_name}.sample_compare({', '.join(args)})"
return self.unsupported_texture_compare_call(
func_name, "is not a supported shadow compare operation"
)
def texture_gather_compare_offset_supported(self, texture_type):
return texture_type in {"depth2d<float>", "depth2d_array<float>"}
def unsupported_texture_gather_compare_call(self, func_name, reason):
return (
f"/* unsupported Metal texture gather compare: "
f"{func_name} {reason} */ float4(0.0)"
)
def generate_texture_gather_compare_call(
self, func_name, texture_name, sampler_arg, coord, extra_args, texture_type
):
if not extra_args:
return self.unsupported_texture_gather_compare_call(
func_name, "requires a compare argument"
)
compare = self.generate_expression(extra_args[0])
coord_args = self.texture_gather_coord_args(texture_type, coord)
if func_name == "textureGatherCompare":
if len(extra_args) != 1:
return self.unsupported_texture_gather_compare_call(
func_name, "accepts no extra arguments"
)
args = [sampler_arg] + coord_args + [compare]
return f"{texture_name}.gather_compare({', '.join(args)})"
if len(extra_args) != 2:
return self.unsupported_texture_gather_compare_call(
func_name, "requires compare and offset arguments"
)
if not self.texture_gather_compare_offset_supported(texture_type):
return self.unsupported_texture_gather_compare_call(
func_name, "offsets require 2D or 2D-array depth textures"
)
offset = self.generate_expression(extra_args[1])
args = [sampler_arg] + coord_args + [compare, offset]
return f"{texture_name}.gather_compare({', '.join(args)})"
def texture_query_size_expression(self, texture_arg, lod_arg=None):
texture_name = self.generate_expression(texture_arg)
texture_type = self.texture_resource_type(texture_arg)
lod = self.generate_expression(lod_arg) if lod_arg is not None else "0"
lod_arg_string = f"uint({lod})"
if self.is_storage_image_resource(texture_type):
texture_type = self.resource_base_type(texture_type)
if texture_type.startswith("texture2d_array<"):
return (
f"int3({texture_name}.get_width(), "
f"{texture_name}.get_height(), "
f"{texture_name}.get_array_size())"
)
if texture_type.startswith("texture3d<"):
return (
f"int3({texture_name}.get_width(), "
f"{texture_name}.get_height(), "
f"{texture_name}.get_depth())"
)
return f"int2({texture_name}.get_width(), " f"{texture_name}.get_height())"
if texture_type in {"texture1d<float>"}:
return f"int({texture_name}.get_width({lod_arg_string}))"
if texture_type in {
"texture2d<float>",
"depth2d<float>",
"texturecube<float>",
"depthcube<float>",
}:
return (
f"int2({texture_name}.get_width({lod_arg_string}), "
f"{texture_name}.get_height({lod_arg_string}))"
)
if texture_type in {
"texture2d_array<float>",
"depth2d_array<float>",
"texturecube_array<float>",
"depthcube_array<float>",
}:
return (
f"int3({texture_name}.get_width({lod_arg_string}), "
f"{texture_name}.get_height({lod_arg_string}), "
f"{texture_name}.get_array_size())"
)
if texture_type in {"texture3d<float>"}:
return (
f"int3({texture_name}.get_width({lod_arg_string}), "
f"{texture_name}.get_height({lod_arg_string}), "
f"{texture_name}.get_depth({lod_arg_string}))"
)
if texture_type == "texture2d_ms<float>":
return f"int2({texture_name}.get_width(), {texture_name}.get_height())"
if texture_type == "texture2d_ms_array<float>":
return (
f"int3({texture_name}.get_width(), {texture_name}.get_height(), "
f"{texture_name}.get_array_size())"
)
return None
def texture_query_levels_expression(self, texture_arg):
texture_name = self.generate_expression(texture_arg)
texture_type = self.texture_resource_type(texture_arg)
if self.is_storage_image_resource(texture_type):
return self.unsupported_texture_query_levels_call(texture_type)
if self.is_multisample_texture_resource(texture_type):
return "1"
return f"int({texture_name}.get_num_mip_levels())"
def texture_samples_expression(self, texture_arg):
texture_name = self.generate_expression(texture_arg)
texture_type = self.texture_resource_type(texture_arg)
if not self.is_multisample_texture_resource(texture_type):
return "/* unsupported Metal texture samples query: requires multisample texture */ 0"
return f"int({texture_name}.get_num_samples())"
def image_coordinate_expression(self, image_type, coord):
if image_type in {
"texture2d_array<float, access::read_write>",
"texture2d_array<int, access::read_write>",
"texture2d_array<uint, access::read_write>",
}:
coord_xy = f"uint2({self.vector_component(coord, 'xy')})"
layer = f"uint({self.vector_component(coord, 'z')})"
return coord_xy, layer
if image_type in {
"texture3d<float, access::read_write>",
"texture3d<int, access::read_write>",
"texture3d<uint, access::read_write>",
}:
return f"uint3({coord})", None
return f"uint2({coord})", None
def is_integer_image_type(self, image_type):
return image_type in {
"texture2d<int, access::read_write>",
"texture3d<int, access::read_write>",
"texture2d_array<int, access::read_write>",
"texture2d<uint, access::read_write>",
"texture3d<uint, access::read_write>",
"texture2d_array<uint, access::read_write>",
}
def is_scalar_image_format(self, image_format):
return image_format in {
"r8",
"r8_snorm",
"r16",
"r16_snorm",
"r16f",
"r32f",
"r8i",
"r16i",
"r32i",
"r8ui",
"r16ui",
"r32ui",
}
def is_two_component_image_format(self, image_format):
return image_format in {
"rg8",
"rg8_snorm",
"rg16",
"rg16_snorm",
"rg16f",
"rg8i",
"rg16i",
"rg8ui",
"rg16ui",
"rg32f",
"rg32i",
"rg32ui",
}
def is_scalar_integer_image_resource(self, image_type, image_format):
if image_format is not None:
return self.is_scalar_image_format(image_format)
return self.is_integer_image_type(image_type)
def is_float_image_resource(self, image_type):
return image_type in {
"texture2d<float, access::read_write>",
"texture3d<float, access::read_write>",
"texture2d_array<float, access::read_write>",
}
def image_load_component_suffix(self, image_type, image_format):
if self.is_scalar_integer_image_resource(image_type, image_format):
return ".x"
if self.is_float_image_resource(image_type) and self.is_scalar_value_type(
self.current_expression_expected_type
):
return ".x"
if self.is_two_component_image_format(image_format):
if self.is_scalar_value_type(self.current_expression_expected_type):
return ".x"
return ".xy"
return ""
def image_format_store_constructor(self, image_format):
return {
"r8": "float4",
"r8_snorm": "float4",
"r16": "float4",
"r16_snorm": "float4",
"r16f": "float4",
"r32f": "float4",
"r8i": "int4",
"r16i": "int4",
"r32i": "int4",
"r8ui": "uint4",
"r16ui": "uint4",
"r32ui": "uint4",
}.get(image_format)
def integer_image_store_constructor(self, image_type):
if image_type in {
"texture2d<int, access::read_write>",
"texture3d<int, access::read_write>",
"texture2d_array<int, access::read_write>",
}:
return "int4"
if image_type in {
"texture2d<uint, access::read_write>",
"texture3d<uint, access::read_write>",
"texture2d_array<uint, access::read_write>",
}:
return "uint4"
return None
def two_component_image_store_expression(
self, image_format, value, value_type=None
):
constructors = {
"rg8": ("float4", "0.0"),
"rg8_snorm": ("float4", "0.0"),
"rg16": ("float4", "0.0"),
"rg16_snorm": ("float4", "0.0"),
"rg16f": ("float4", "0.0"),
"rg8i": ("int4", "0"),
"rg16i": ("int4", "0"),
"rg8ui": ("uint4", "0u"),
"rg16ui": ("uint4", "0u"),
"rg32f": ("float4", "0.0"),
"rg32i": ("int4", "0"),
"rg32ui": ("uint4", "0u"),
}
constructor = constructors.get(image_format)
if constructor is None:
return None
type_name, zero_value = constructor
if self.is_scalar_value_type(value_type):
return f"{type_name}({value}, {zero_value}, {zero_value}, {zero_value})"
return f"{type_name}({value}, {zero_value}, {zero_value})"
def image_store_value_expression(
self, image_type, image_format, value, value_type=None
):
two_component_value = self.two_component_image_store_expression(
image_format, value, value_type
)
if two_component_value is not None:
return two_component_value
constructor = None
if self.is_scalar_integer_image_resource(image_type, image_format):
constructor = self.integer_image_store_constructor(image_type)
if constructor is None:
constructor = self.image_format_store_constructor(image_format)
elif self.is_float_image_resource(image_type) and self.is_scalar_value_type(
value_type
):
constructor = "float4"
if constructor:
return f"{constructor}({value})"
return value
def image_atomic_method(self, func_name):
return {
"imageAtomicAdd": "atomic_fetch_add",
"imageAtomicMin": "atomic_fetch_min",
"imageAtomicMax": "atomic_fetch_max",
"imageAtomicAnd": "atomic_fetch_and",
"imageAtomicOr": "atomic_fetch_or",
"imageAtomicXor": "atomic_fetch_xor",
"imageAtomicExchange": "atomic_exchange",
}.get(func_name)
def image_atomic_compare_helper_name(self, texture_type):
suffixes = {
"texture2d<int, access::read_write>": "iimage2D",
"texture2d<uint, access::read_write>": "uimage2D",
"texture3d<int, access::read_write>": "iimage3D",
"texture3d<uint, access::read_write>": "uimage3D",
"texture2d_array<int, access::read_write>": "iimage2DArray",
"texture2d_array<uint, access::read_write>": "uimage2DArray",
}
suffix = suffixes.get(texture_type)
if not suffix:
return None
return f"imageAtomicCompSwap_{suffix}"
def image_atomic_compare_return_type(self, texture_type):
if texture_type in {
"texture2d<int, access::read_write>",
"texture3d<int, access::read_write>",
"texture2d_array<int, access::read_write>",
}:
return "int"
if texture_type in {
"texture2d<uint, access::read_write>",
"texture3d<uint, access::read_write>",
"texture2d_array<uint, access::read_write>",
}:
return "uint"
return None
def image_atomic_compare_vector_type(self, texture_type):
if texture_type in {
"texture2d<int, access::read_write>",
"texture3d<int, access::read_write>",
"texture2d_array<int, access::read_write>",
}:
return "int4"
if texture_type in {
"texture2d<uint, access::read_write>",
"texture3d<uint, access::read_write>",
"texture2d_array<uint, access::read_write>",
}:
return "uint4"
return None
def image_atomic_compare_coord_type(self, texture_type):
if texture_type in {
"texture2d<int, access::read_write>",
"texture2d<uint, access::read_write>",
}:
return "int2"
if texture_type in {
"texture3d<int, access::read_write>",
"texture3d<uint, access::read_write>",
"texture2d_array<int, access::read_write>",
"texture2d_array<uint, access::read_write>",
}:
return "int3"
return None
def image_atomic_compare_exchange_expression(self, texture_type):
if texture_type in {
"texture2d<int, access::read_write>",
"texture2d<uint, access::read_write>",
}:
return "image.atomic_compare_exchange_weak(uint2(coord), &original, value)"
if texture_type in {
"texture3d<int, access::read_write>",
"texture3d<uint, access::read_write>",
}:
return "image.atomic_compare_exchange_weak(uint3(coord), &original, value)"
if texture_type in {
"texture2d_array<int, access::read_write>",
"texture2d_array<uint, access::read_write>",
}:
return "image.atomic_compare_exchange_weak(uint2(coord.xy), uint(coord.z), &original, value)"
return None
def generate_image_atomic_compare_helpers(self):
if not self.required_image_atomic_compare_helpers:
return ""
helpers = []
for texture_type in sorted(self.required_image_atomic_compare_helpers):
helper_name = self.image_atomic_compare_helper_name(texture_type)
return_type = self.image_atomic_compare_return_type(texture_type)
vector_type = self.image_atomic_compare_vector_type(texture_type)
coord_type = self.image_atomic_compare_coord_type(texture_type)
exchange_expr = self.image_atomic_compare_exchange_expression(texture_type)
if (
not helper_name
or not return_type
or not vector_type
or not coord_type
or not exchange_expr
):
continue
helpers.append(
f"{return_type} {helper_name}({texture_type} image, {coord_type} coord, {return_type} compareValue, {return_type} value) {{\n"
f" {vector_type} original;\n"
" do {\n"
" original.x = compareValue;\n"
f" }} while (!{exchange_expr} && original.x == compareValue);\n"
" return original.x;\n"
"}\n\n"
)
return "".join(helpers)
def generate_image_call(self, func_name, args):
if func_name == "imageAtomicCompSwap" and len(args) >= 4:
image_name = self.generate_expression(args[0])
coord = self.generate_expression(args[1])
compare = self.generate_expression(args[2])
value = self.generate_expression(args[3])
image_type = self.texture_resource_type(args[0])
helper_name = self.image_atomic_compare_helper_name(image_type)
if not helper_name:
return None
self.required_image_atomic_compare_helpers.add(image_type)
return f"{helper_name}({image_name}, {coord}, {compare}, {value})"
atomic_method = self.image_atomic_method(func_name)
if atomic_method and len(args) >= 3:
image_name = self.generate_expression(args[0])
coord = self.generate_expression(args[1])
value = self.generate_expression(args[2])
image_type = self.texture_resource_type(args[0])
texel_coord, layer = self.image_coordinate_expression(image_type, coord)
if layer is not None:
return (
f"{image_name}.{atomic_method}({texel_coord}, {layer}, {value}).x"
)
return f"{image_name}.{atomic_method}({texel_coord}, {value}).x"
if func_name == "imageLoad" and len(args) >= 2:
image_name = self.generate_expression(args[0])
coord = self.generate_expression(args[1])
image_type = self.texture_resource_type(args[0])
texel_coord, layer = self.image_coordinate_expression(image_type, coord)
if layer is not None:
load_expr = f"{image_name}.read({texel_coord}, {layer})"
else:
load_expr = f"{image_name}.read({texel_coord})"
image_format = self.image_resource_format(args[0])
return f"{load_expr}{self.image_load_component_suffix(image_type, image_format)}"
if func_name == "imageStore" and len(args) >= 3:
image_name = self.generate_expression(args[0])
coord = self.generate_expression(args[1])
value = self.generate_expression(args[2])
image_type = self.texture_resource_type(args[0])
image_format = self.image_resource_format(args[0])
value = self.image_store_value_expression(
image_type, image_format, value, self.expression_result_type(args[2])
)
texel_coord, layer = self.image_coordinate_expression(image_type, coord)
if layer is not None:
return f"{image_name}.write({value}, {texel_coord}, {layer})"
return f"{image_name}.write({value}, {texel_coord})"
return None
def generate_texture_call(self, func_name, args):
if not func_name:
return None
self.validate_texture_call_arity(func_name, args)
self.validate_image_resource_argument(func_name, args)
self.validate_texture_resource_argument(func_name, args)
self.validate_integer_coordinate_argument(func_name, args)
self.validate_coordinate_dimension_argument(func_name, args)
self.validate_query_lod_coordinate_argument(func_name, args)
self.validate_compare_argument(func_name, args)
self.validate_lod_argument(func_name, args)
self.validate_bias_argument(func_name, args)
self.validate_sample_index_argument(func_name, args)
self.validate_mip_level_argument(func_name, args)
self.validate_gradient_dimension_arguments(func_name, args)
self.validate_offset_dimension_argument(func_name, args)
self.validate_gather_component_argument(func_name, args)
image_call = self.generate_image_call(func_name, args)
if image_call is not None:
return image_call
if func_name in {"textureSize", "imageSize"} and args:
lod_arg = args[1] if len(args) > 1 else None
return self.texture_query_size_expression(args[0], lod_arg)
if func_name == "textureQueryLevels" and args:
return self.texture_query_levels_expression(args[0])
if func_name in {"textureSamples", "imageSamples"} and args:
return self.texture_samples_expression(args[0])
if len(args) < 2:
return None
parts = self.texture_call_parts(args)
if parts is None:
return None
texture_name, sampler_arg, coord, extra_args = parts
texture_type = self.texture_resource_type(args[0])
storage_image_operation = self.storage_image_texture_operation_expression(
func_name, texture_type
)
if storage_image_operation is not None:
return storage_image_operation
is_array_texture = self.is_array_texture_resource(texture_type)
if is_array_texture:
coord_xy, layer = self.texture_coordinate_parts(texture_type, coord)
if func_name in {
"texture",
"textureLod",
"textureGrad",
} and self.is_multisample_texture_resource(texture_type):
return self.unsupported_multisample_texture_call(func_name, texture_type)
if func_name == "texture":
if extra_args:
bias = self.generate_expression(extra_args[0])
if is_array_texture:
return (
f"{texture_name}.sample("
f"{sampler_arg}, {coord_xy}, {layer}, bias({bias}))"
)
return f"{texture_name}.sample({sampler_arg}, {coord}, bias({bias}))"
if is_array_texture:
return f"{texture_name}.sample({sampler_arg}, {coord_xy}, {layer})"
return f"{texture_name}.sample({sampler_arg}, {coord})"
if func_name == "textureLod" and extra_args:
lod = self.generate_expression(extra_args[0])
if is_array_texture:
return f"{texture_name}.sample({sampler_arg}, {coord_xy}, {layer}, level({lod}))"
return f"{texture_name}.sample({sampler_arg}, {coord}, level({lod}))"
if func_name == "textureGrad" and len(extra_args) >= 2:
ddx = self.generate_expression(extra_args[0])
ddy = self.generate_expression(extra_args[1])
gradient_options = self.texture_gradient_options(texture_type, ddx, ddy)
if is_array_texture:
return f"{texture_name}.sample({sampler_arg}, {coord_xy}, {layer}, {gradient_options})"
return f"{texture_name}.sample({sampler_arg}, {coord}, {gradient_options})"
if func_name in {
"textureOffset",
"textureLodOffset",
"textureGradOffset",
}:
return self.generate_texture_sample_offset_call(
func_name,
texture_name,
sampler_arg,
coord,
extra_args,
texture_type,
)
if func_name in {
"textureProj",
"textureProjOffset",
"textureProjLod",
"textureProjLodOffset",
"textureProjGrad",
"textureProjGradOffset",
}:
return self.generate_texture_projected_call(
func_name,
texture_name,
sampler_arg,
coord,
extra_args,
texture_type,
args,
)
if func_name in {
"textureGather",
"textureGatherOffset",
"textureGatherOffsets",
}:
return self.generate_texture_gather_call(
func_name, texture_name, sampler_arg, coord, extra_args, texture_type
)
if func_name in {
"textureCompare",
"textureCompareOffset",
"textureCompareLod",
"textureCompareLodOffset",
"textureCompareGrad",
"textureCompareGradOffset",
"textureCompareProj",
"textureCompareProjOffset",
"textureCompareProjLod",
"textureCompareProjLodOffset",
"textureCompareProjGrad",
"textureCompareProjGradOffset",
}:
return self.generate_texture_compare_call(
func_name,
texture_name,
sampler_arg,
coord,
extra_args,
texture_type,
args,
)
if func_name in {"textureGatherCompare", "textureGatherCompareOffset"}:
return self.generate_texture_gather_compare_call(
func_name, texture_name, sampler_arg, coord, extra_args, texture_type
)
if func_name == "textureQueryLod":
if self.is_multisample_texture_resource(texture_type):
return self.unsupported_multisample_texture_query_lod_call(texture_type)
if self.is_storage_image_resource(texture_type):
return self.unsupported_texture_query_lod_call(texture_type)
lod_coord = coord_xy if is_array_texture else coord
return (
f"float2({texture_name}.calculate_unclamped_lod({sampler_arg}, {lod_coord}), "
f"{texture_name}.calculate_clamped_lod({sampler_arg}, {lod_coord}))"
)
if func_name == "texelFetch" and len(args) >= 3:
lod = self.generate_expression(args[2])
if self.is_cube_texture_resource(texture_type):
return self.unsupported_cube_texel_fetch_call(func_name, texture_type)
if self.is_multisample_texture_resource(texture_type):
if texture_type == "texture2d_ms_array<float>":
texel_xy, layer = self.array_texture_coordinate_parts(coord)
return f"{texture_name}.read({texel_xy}, {layer}, uint({lod}))"
return f"{texture_name}.read({coord}, uint({lod}))"
if is_array_texture:
texel_xy, layer = self.array_texture_coordinate_parts(coord)
return f"{texture_name}.read({texel_xy}, {layer}, {lod})"
return f"{texture_name}.read({coord}, {lod})"
if func_name == "texelFetchOffset" and len(args) >= 4:
lod = self.generate_expression(args[2])
offset = self.generate_expression(args[3])
if self.is_cube_texture_resource(texture_type):
return self.unsupported_cube_texel_fetch_call(func_name, texture_type)
if self.is_multisample_texture_resource(texture_type):
return "/* unsupported Metal texel fetch offset: multisample textures do not support offsets */ float4(0.0)"
if is_array_texture:
texel_xy, layer = self.array_texture_coordinate_parts(coord)
return f"{texture_name}.read(({texel_xy} + {offset}), {layer}, {lod})"
return f"{texture_name}.read(({coord} + {offset}), {lod})"
return None
def convert_type_node_to_string(self, type_node) -> str:
"""Convert new AST TypeNode to string representation."""
if hasattr(type_node, "name"):
return type_node.name
elif hasattr(type_node, "rows") and hasattr(type_node, "cols"):
element_type = self.convert_type_node_to_string(type_node.element_type)
if type_node.rows == type_node.cols:
return f"float{type_node.rows}x{type_node.rows}"
else:
return f"float{type_node.cols}x{type_node.rows}"
elif hasattr(type_node, "element_type") and hasattr(type_node, "size"):
if str(type(type_node)).find("ArrayType") != -1:
element_type = self.convert_type_node_to_string(type_node.element_type)
if type_node.size is not None:
if isinstance(type_node.size, int):
return f"{element_type}[{type_node.size}]"
else:
size_str = self.safe_expression_to_string(type_node.size)
return f"{element_type}[{size_str}]"
else:
return f"{element_type}[]"
else:
element_type = self.convert_type_node_to_string(type_node.element_type)
size = type_node.size
if element_type == "float":
return f"float{size}"
elif element_type == "int":
return f"int{size}"
elif element_type == "uint":
return f"uint{size}"
elif element_type == "bool":
return f"bool{size}"
else:
return f"{element_type}{size}"
else:
return str(type_node)
def safe_expression_to_string(self, expr):
"""Convert an expression node to a string representation safely (avoid infinite recursion)."""
return self.safe_expression_to_string_with_precedence(expr)
def safe_expression_to_string_with_precedence(self, expr, parent_precedence=0):
if hasattr(expr, "value"):
return str(expr.value)
elif getattr(expr, "name", None) is not None:
return str(expr.name)
elif isinstance(expr, int) or isinstance(expr, float):
return str(expr)
elif isinstance(expr, str):
return expr
elif isinstance(expr, BinaryOpNode):
operator = self.map_operator(expr.op)
precedence = self.expression_precedence(operator)
left = self.safe_expression_to_string_with_precedence(expr.left, precedence)
right = self.safe_expression_to_string_with_precedence(
expr.right, precedence + 1
)
expression = f"{left} {operator} {right}"
if precedence < parent_precedence:
return f"({expression})"
return expression
elif isinstance(expr, UnaryOpNode):
operand = self.safe_expression_to_string_with_precedence(
expr.operand, self.expression_precedence("unary")
)
return f"{self.map_operator(expr.op)}{operand}"
else:
# Fallback - avoid calling generate_expression to prevent infinite recursion
return str(expr)
def expression_precedence(self, operator):
return {
"||": 1,
"&&": 2,
"|": 3,
"^": 4,
"&": 5,
"==": 6,
"!=": 6,
"<": 7,
">": 7,
"<=": 7,
">=": 7,
"<<": 8,
">>": 8,
"+": 9,
"-": 9,
"*": 10,
"/": 10,
"%": 10,
"unary": 11,
}.get(operator, 0)
def expression_to_string(self, expr):
"""Convert an expression node to a string representation."""
return self.safe_expression_to_string(expr)
def map_type(self, vtype):
"""Map types to Metal equivalents, handling both strings and TypeNode objects."""
if vtype is None:
return "float"
if hasattr(vtype, "name") or hasattr(vtype, "element_type"):
vtype_str = self.convert_type_node_to_string(vtype)
else:
vtype_str = str(vtype)
if "[" in vtype_str and "]" in vtype_str:
base_type, array_suffix = split_array_type_suffix(vtype_str)
base_mapped = self.type_mapping.get(base_type, base_type)
return f"{base_mapped}{array_suffix}"
return self.type_mapping.get(vtype_str, vtype_str)
def map_operator(self, op):
op_map = {
"PLUS": "+",
"MINUS": "-",
"MULTIPLY": "*",
"DIVIDE": "/",
"BITWISE_XOR": "^",
"BITWISE_OR": "|",
"BITWISE_AND": "&",
"LESS_THAN": "<",
"GREATER_THAN": ">",
"ASSIGN_ADD": "+=",
"ASSIGN_SUB": "-=",
"ASSIGN_OR": "|=",
"ASSIGN_MUL": "*=",
"ASSIGN_DIV": "/=",
"ASSIGN_MOD": "%=",
"ASSIGN_XOR": "^=",
"LESS_EQUAL": "<=",
"GREATER_EQUAL": ">=",
"EQUAL": "==",
"NOT_EQUAL": "!=",
"AND": "&&",
"OR": "||",
"EQUALS": "=",
"ASSIGN_SHIFT_LEFT": "<<=",
"ASSIGN_SHIFT_RIGHT": ">>=",
"ASSIGN_AND": "&=",
"LOGICAL_AND": "&&",
"BITWISE_SHIFT_RIGHT": ">>",
"BITWISE_SHIFT_LEFT": "<<",
}
return op_map.get(op, op)
def map_semantic(self, semantic):
"""Map a CrossGL semantic to Metal attribute syntax."""
if semantic is not None:
mapped_semantic = self.semantic_map.get(semantic, semantic)
# If the mapped semantic already has brackets, use it as-is
if mapped_semantic.startswith("[[") and mapped_semantic.endswith("]]"):
return f" {mapped_semantic}"
else:
# Add brackets for Metal attribute syntax
return f" [[{mapped_semantic}]]"
else:
return ""