Source code for crosstl.translator.codegen.slang_codegen

"""CrossGL-to-Slang code generator."""

from ..ast import (
    ArrayNode,
    ArrayAccessNode,
    ArrayLiteralNode,
    AssignmentNode,
    BinaryOpNode,
    BreakNode,
    CaseNode,
    CbufferNode,
    ContinueNode,
    ExpressionStatementNode,
    ForNode,
    FunctionCallNode,
    IdentifierNode,
    FunctionNode,
    IfNode,
    LiteralNode,
    MemberAccessNode,
    ReturnNode,
    ShaderNode,
    StructNode,
    SwitchNode,
    TernaryOpNode,
    UnaryOpNode,
    VariableNode,
    WhileNode,
)
from .array_utils import (
    format_c_style_array_declaration,
    get_array_size_from_node,
    split_array_type_suffix,
)


[docs] class SlangCodeGen: """Emit Slang shader source from the shared CrossGL AST.""" def __init__(self): """Initialize Slang generation state and helper caches.""" self.indent_level = 0 self.indent_str = " " self.variable_types = {} self.image_resource_types = {} self.helper_functions = {} self.current_function_return_type = None self.current_expression_expected_type = None self._generating = False def indent(self): """Return whitespace for the current indentation level.""" return self.indent_str * self.indent_level
[docs] def generate(self, ast): """Generate Slang source for a CrossGL AST or AST fragment.""" outermost = not self._generating if outermost: self._generating = True self.variable_types = {} self.image_resource_types = {} self.helper_functions = {} self.current_function_return_type = None self.current_expression_expected_type = None if isinstance(ast, list): result = "" for node in ast: result += self.generate(node) + "\n" return self.finish_generation(result, outermost) elif isinstance(ast, ShaderNode): return self.finish_generation(self.generate_shader(ast), outermost) elif isinstance(ast, StructNode): return self.finish_generation(self.generate_struct(ast), outermost) else: # Handle new AST structure result = "" structs = getattr(ast, "structs", []) for struct in structs: result += self.generate_struct(struct) + "\n\n" global_vars = getattr(ast, "global_variables", []) for node in global_vars: result += self.generate_global_variable(node) cbuffers = getattr(ast, "cbuffers", []) for node in cbuffers: if isinstance(node, StructNode): result += ( "cbuffer " + self.generate_struct_definition(node) + "\n\n" ) elif hasattr(node, "name") and hasattr(node, "members"): result += f"cbuffer {node.name} {{\n" for member in node.members: if hasattr(member, "member_type"): member_type = str(member.member_type) else: member_type = getattr(member, "vtype", "float") result += ( f" {self.convert_type(member_type)} {member.name};\n" ) result += "};\n\n" functions = getattr(ast, "functions", []) for function in functions: # Handle both old and new AST function structures if hasattr(function, "qualifiers") and function.qualifiers: qualifier = function.qualifiers[0] if function.qualifiers else None else: qualifier = getattr(function, "qualifier", None) if qualifier == "vertex": result += "// Vertex Shader\n" result += self.generate_function(function) + "\n\n" elif qualifier == "fragment": result += "// Fragment Shader\n" result += self.generate_function(function) + "\n\n" else: result += self.generate_function(function) + "\n\n" # Handle shader stages (new AST structure) if hasattr(ast, "stages") and ast.stages: for stage_type, stage in ast.stages.items(): result += self.generate_stage(stage_type, stage) return self.finish_generation(result, outermost)
def finish_generation(self, result, outermost): if not outermost: return result helpers = self.emit_helper_functions() self._generating = False if helpers: return helpers + result return result def emit_helper_functions(self): if not self.helper_functions: return "" return "\n\n".join(self.helper_functions.values()) + "\n\n" def generate_shader(self, node): """Render a full CrossGL shader AST as a Slang translation unit.""" result = "" structs = getattr(node, "structs", []) for struct in structs: result += self.generate_struct(struct) + "\n\n" global_vars = getattr(node, "global_variables", []) for global_var in global_vars: result += self.generate_global_variable(global_var) functions = getattr(node, "functions", []) for function in functions: stage_name = self.get_function_stage(function) if stage_name: result += f"// {stage_name.title()} Shader\n" result += self.generate_function(function, shader_type=stage_name) result += "\n\n" else: result += self.generate_function(function) + "\n\n" stages = getattr(node, "stages", {}) for stage_type, stage in stages.items(): result += self.generate_stage(stage_type, stage) return result def get_stage_name(self, stage_type): if hasattr(stage_type, "value"): return stage_type.value return str(stage_type).split(".")[-1].lower() def get_function_stage(self, function): if hasattr(function, "qualifiers") and function.qualifiers: qualifier = function.qualifiers[0] else: qualifier = getattr(function, "qualifier", None) if qualifier in {"vertex", "fragment", "compute"}: return qualifier return None def generate_stage(self, stage_type, stage): """Render one staged entry point and its local functions.""" stage_name = self.get_stage_name(stage_type) result = f"// {stage_name.title()} Shader\n" local_variables = getattr(stage, "local_variables", []) for local_var in local_variables: result += self.generate_global_variable(local_var) for func in getattr(stage, "local_functions", []): result += self.generate_function(func) + "\n\n" entry_point = getattr(stage, "entry_point", None) if entry_point is not None: result += self.generate_function(entry_point, shader_type=stage_name) result += "\n\n" return result def convert_type_node_to_string(self, type_node) -> str: if hasattr(type_node, "name"): generic_args = getattr(type_node, "generic_args", []) if generic_args: args = ", ".join( self.convert_type_node_to_string(arg) for arg in generic_args ) return f"{type_node.name}<{args}>" return type_node.name if hasattr(type_node, "rows") and hasattr(type_node, "cols"): element_type = self.convert_type_node_to_string(type_node.element_type) if element_type == "float": if type_node.rows == type_node.cols: return f"mat{type_node.rows}" return f"mat{type_node.rows}x{type_node.cols}" return f"{element_type}{type_node.rows}x{type_node.cols}" if hasattr(type_node, "element_type") and hasattr(type_node, "size"): element_type = self.convert_type_node_to_string(type_node.element_type) if type_node.__class__.__name__ == "ArrayType": if type_node.size is None: return f"{element_type}[]" size = self.format_array_size_expression(type_node.size) return f"{element_type}[{size}]" if element_type == "float": return f"vec{type_node.size}" if element_type == "int": return f"ivec{type_node.size}" if element_type == "uint": return f"uvec{type_node.size}" if element_type == "bool": return f"bvec{type_node.size}" return f"{element_type}{type_node.size}" return str(type_node) def format_array_size_expression(self, expr): if isinstance(expr, int): return str(expr) if isinstance(expr, BinaryOpNode): left = self.format_array_size_expression(expr.left) right = self.format_array_size_expression(expr.right) return f"({left} {expr.op} {right})" if isinstance(expr, UnaryOpNode): return f"{expr.op}{self.format_array_size_expression(expr.operand)}" return self.generate_expression(expr) def format_declaration(self, type_name, name, node=None): mapped_type = self.map_resource_type_with_format(type_name, node) return format_c_style_array_declaration(mapped_type, name) def get_variable_type(self, node): if hasattr(node, "var_type"): return self.convert_type_node_to_string(node.var_type) if hasattr(node, "vtype"): return node.vtype return "float" def register_variable_type(self, name, type_name, node=None): if not name or type_name is None: return if not isinstance(type_name, str): type_name = self.convert_type_node_to_string(type_name) self.variable_types[name] = type_name if self.is_storage_image_type(type_name): self.image_resource_types[name] = self.map_resource_type_with_format( type_name, node ) def generate_global_variable(self, node): if isinstance(node, ArrayNode): self.register_variable_type(node.name, node.element_type) element_type = self.convert_type(node.element_type) size = get_array_size_from_node(node) if size is None: return f"{element_type} {node.name}[];\n" return f"{element_type} {node.name}[{size}];\n" vtype = self.get_variable_type(node) self.register_variable_type(node.name, vtype, node) declaration = self.format_declaration(vtype, node.name, node) initial_value = getattr(node, "initial_value", getattr(node, "value", None)) if initial_value is not None: initial_expr = self.generate_expression_with_expected(initial_value, vtype) return f"{declaration} = {initial_expr};\n" return f"{declaration};\n" def generate_struct(self, node): result = f"struct {node.name}\n{{\n" self.indent_level += 1 members = getattr(node, "members", []) for member in members: if hasattr(member, "member_type"): member_type = self.convert_type( self.convert_type_node_to_string(member.member_type) ) elif hasattr(member, "vtype"): member_type = self.convert_type(member.vtype) else: member_type = "float" semantic = None if hasattr(member, "semantic"): semantic = member.semantic elif hasattr(member, "attributes"): for attr in member.attributes: if hasattr(attr, "name") and attr.name in [ "position", "color", "texcoord", "normal", ]: semantic = attr.name break semantic_str = f" : {semantic}" if semantic else "" declaration = self.format_declaration(member_type, member.name) result += f"{self.indent()}{declaration}{semantic_str};\n" self.indent_level -= 1 result += "};" return result def generate_struct_definition(self, node): result = f"{node.name}\n{{\n" members = getattr(node, "members", []) for member in members: if hasattr(member, "member_type"): member_type = self.convert_type_node_to_string(member.member_type) else: member_type = getattr(member, "vtype", "float") result += f" {self.format_declaration(member_type, member.name)};\n" result += "};" return result def generate_function(self, node, shader_type=None): """Render one CrossGL function or shader entry point as Slang code.""" saved_variable_types = self.variable_types.copy() saved_image_resource_types = self.image_resource_types.copy() saved_function_return_type = self.current_function_return_type if hasattr(node, "return_type"): ret_type_name = self.convert_type_node_to_string(node.return_type) ret_type = self.convert_type(ret_type_name) else: ret_type_name = "void" ret_type = "void" self.current_function_return_type = ret_type_name semantic = None if hasattr(node, "semantic"): semantic = node.semantic elif hasattr(node, "attributes"): for attr in node.attributes: if hasattr(attr, "name"): semantic = attr.name break semantic_str = f" : {semantic}" if semantic else "" param_list = getattr(node, "parameters", getattr(node, "params", [])) params_str = "" if param_list: if param_list and hasattr(param_list[0], "name"): params = [] for param in param_list: if hasattr(param, "param_type"): param_type_name = self.convert_type_node_to_string( param.param_type ) self.register_variable_type(param.name, param_type_name, param) param_type = self.map_resource_type_with_format( param_type_name, param ) elif hasattr(param, "vtype"): self.register_variable_type(param.name, param.vtype, param) param_type = self.map_resource_type_with_format( param.vtype, param ) else: param_type = "float" params.append( format_c_style_array_declaration(param_type, param.name) ) params_str = ", ".join(params) else: for param_type, param_name in param_list: self.register_variable_type(param_name, param_type) params_str = ", ".join( [ f"{self.map_resource_type_with_format(param_type)} {param_name}" for param_type, param_name in param_list ] ) result = "" if shader_type: result += f'[shader("{shader_type}")]\n' result += f"{ret_type} {node.name}({params_str}){semantic_str}\n{{\n" self.indent_level += 1 body = getattr(node, "body", []) if hasattr(body, "statements"): for stmt in body.statements: result += self.emit_statement(stmt) + "\n" elif isinstance(body, list): for stmt in body: result += self.emit_statement(stmt) + "\n" self.indent_level -= 1 result += "}" self.variable_types = saved_variable_types self.image_resource_types = saved_image_resource_types self.current_function_return_type = saved_function_return_type return result def emit_statement(self, node): statement = self.generate_statement(node) lines = statement.splitlines() return "\n".join( self.indent() + line if line and not line[0].isspace() else line for line in lines ) def generate_statement(self, node): """Render a single CrossGL statement as Slang code.""" if isinstance(node, ReturnNode): if node.value is None: return "return;" return ( "return " f"{self.generate_expression_with_expected(node.value, self.current_function_return_type)};" ) elif isinstance(node, AssignmentNode): return self.generate_assignment(node) + ";" elif isinstance(node, ExpressionStatementNode): return self.generate_expression(node.expression) + ";" elif isinstance(node, VariableNode): var_type = self.get_variable_type(node) self.register_variable_type(node.name, var_type, node) declaration = self.format_declaration(var_type, node.name, node) initial_value = getattr(node, "initial_value", getattr(node, "value", None)) if initial_value is not None: initial_expr = self.generate_expression_with_expected( initial_value, var_type ) return f"{declaration} = {initial_expr};" return f"{declaration};" elif isinstance(node, IfNode): return self.generate_if(node) elif isinstance(node, ForNode): return self.generate_for(node) elif isinstance(node, WhileNode): return self.generate_while(node) elif isinstance(node, SwitchNode): return self.generate_switch(node) elif isinstance(node, BreakNode): return "break;" elif isinstance(node, ContinueNode): return "continue;" else: return self.generate_expression(node) + ";" def generate_assignment(self, node): left = self.generate_expression(node.left) right = self.generate_expression_with_expected( node.right, self.expression_result_type(node.left) ) return f"{left} {node.operator} {right}" def generate_expression_with_expected(self, expr, expected_type): previous_expected_type = self.current_expression_expected_type self.current_expression_expected_type = self.type_name_string(expected_type) try: return self.generate_expression(expr) finally: self.current_expression_expected_type = previous_expected_type def type_name_string(self, type_name): if type_name is None: return None if not isinstance(type_name, str): return self.convert_type_node_to_string(type_name) return type_name def is_scalar_value_type(self, type_name): type_name = self.type_name_string(type_name) if not type_name: return False return self.convert_type(type_name) in { "float", "double", "int", "uint", "bool", } def is_vector_value_type(self, type_name): type_name = self.type_name_string(type_name) if not type_name: return False return self.convert_type(type_name) in { "float2", "float3", "float4", "double2", "double3", "double4", "int2", "int3", "int4", "uint2", "uint3", "uint4", "bool2", "bool3", "bool4", } def vector_component_type(self, type_name): mapped_type = self.convert_type(type_name) if mapped_type.startswith("double"): return "double" if mapped_type.startswith("float"): return "float" if mapped_type.startswith("uint"): return "uint" if mapped_type.startswith("int"): return "int" if mapped_type.startswith("bool"): return "bool" return None def expression_result_type(self, expr): if expr is None: return None if isinstance(expr, VariableNode): return self.variable_types.get(getattr(expr, "name", None)) if isinstance(expr, IdentifierNode): return self.variable_types.get(getattr(expr, "name", None)) if isinstance(expr, LiteralNode): literal_type = getattr(getattr(expr, "literal_type", None), "name", None) if literal_type: return literal_type if isinstance(expr.value, float): return "float" if isinstance(expr.value, int) and not isinstance(expr.value, bool): return "int" if isinstance(expr.value, bool): return "bool" if isinstance(expr, BinaryOpNode): left_type = self.expression_result_type(expr.left) right_type = self.expression_result_type(expr.right) if self.is_vector_value_type(left_type): return left_type if self.is_vector_value_type(right_type): return right_type if left_type == "float" or right_type == "float": return "float" return left_type or right_type if isinstance(expr, UnaryOpNode): return self.expression_result_type(expr.operand) if isinstance(expr, AssignmentNode): return self.expression_result_type(getattr(expr, "left", None)) if isinstance(expr, ArrayAccessNode): array_type = self.type_name_string(self.expression_result_type(expr.array)) if array_type and "[" in array_type and "]" in array_type: base_type, _ = split_array_type_suffix(array_type) return base_type return array_type if isinstance(expr, MemberAccessNode): object_type = self.expression_result_type(expr.object) member = str(expr.member) if object_type and all(ch in "xyzwrgba" for ch in member): component_type = self.vector_component_type(object_type) if component_type and len(member) == 1: return component_type if component_type: return f"{component_type}{len(member)}" return None if isinstance(expr, FunctionCallNode): func_expr = getattr(expr, "function", None) or getattr(expr, "name", None) func_name = getattr(func_expr, "name", func_expr) if func_name == "imageLoad" and getattr(expr, "args", None): return self.image_resource_element_type( self.image_resource_type(expr.args[0]) ) if isinstance(func_name, str) and func_name in { "float", "double", "int", "uint", "bool", "vec2", "vec3", "vec4", "ivec2", "ivec3", "ivec4", "uvec2", "uvec3", "uvec4", "bvec2", "bvec3", "bvec4", "float2", "float3", "float4", "int2", "int3", "int4", "uint2", "uint3", "uint4", "bool2", "bool3", "bool4", }: return str(func_name) return None def generate_literal(self, node): value = node.value literal_type = getattr(getattr(node, "literal_type", None), "name", None) if isinstance(value, bool): return "true" if value else "false" if literal_type == "bool" and isinstance(value, str): lower_value = value.lower() if lower_value in {"true", "false"}: return lower_value if literal_type == "char": escaped = self.escape_literal(value, quote="'") return f"'{escaped}'" if ( literal_type == "uint" and isinstance(value, int) and not isinstance(value, bool) ): return f"{value}u" if isinstance(value, str): escaped = self.escape_literal(value, quote='"') return f'"{escaped}"' return str(value) def escape_literal(self, value, quote): text = str(value) escaped = [] for index, char in enumerate(text): if char == "\n": escaped.append("\\n") elif char == "\r": escaped.append("\\r") elif char == "\t": escaped.append("\\t") elif char == quote and (index == 0 or text[index - 1] != "\\"): escaped.append("\\" + char) else: escaped.append(char) return "".join(escaped) def generate_expression(self, node): """Render a CrossGL expression as Slang expression syntax.""" if isinstance(node, VariableNode): return node.name elif isinstance(node, IdentifierNode): return node.name elif isinstance(node, LiteralNode): return self.generate_literal(node) elif isinstance(node, ExpressionStatementNode): return self.generate_expression(node.expression) elif isinstance(node, AssignmentNode): return self.generate_assignment(node) elif isinstance(node, ArrayAccessNode): array = self.generate_expression( getattr(node, "array", getattr(node, "array_expr", None)) ) index = self.format_array_access_index( getattr(node, "index", getattr(node, "index_expr", None)) ) return f"{array}[{index}]" elif isinstance(node, ArrayLiteralNode): elements = ", ".join( self.generate_expression(element) for element in node.elements ) return f"{{{elements}}}" elif isinstance(node, MemberAccessNode): obj = self.generate_expression(node.object) return f"{obj}.{node.member}" elif isinstance(node, BinaryOpNode): left = self.generate_expression(node.left) right = self.generate_expression(node.right) return f"{left} {node.op} {right}" elif isinstance(node, FunctionCallNode): func_expr = getattr(node, "function", None) if func_expr is None: func_expr = node.name if hasattr(func_expr, "name"): callee = func_expr.name elif isinstance(func_expr, str): callee = func_expr else: callee = self.generate_expression(func_expr) resource_call = self.generate_resource_call(callee, node.args) if resource_call is not None: return resource_call args = ", ".join([self.generate_expression(arg) for arg in node.args]) callee = self.convert_type(callee) return f"{callee}({args})" elif isinstance(node, UnaryOpNode): operand = self.generate_expression(node.operand) if getattr(node, "is_postfix", False): return f"{operand}{node.op}" return f"{node.op}{operand}" elif isinstance(node, TernaryOpNode): condition = self.generate_expression(node.condition) true_expr = self.generate_expression(node.true_expr) false_expr = self.generate_expression(node.false_expr) return f"({condition} ? {true_expr} : {false_expr})" elif isinstance(node, str): return node else: return str(node) def format_array_access_index(self, index): if isinstance(index, BinaryOpNode): return self.format_array_size_expression(index) return self.generate_expression(index) def generate_if(self, node): condition = self.generate_expression( getattr(node, "condition", getattr(node, "if_condition", None)) ) result = f"if ({condition})\n{{\n" self.indent_level += 1 for stmt in self.get_statements(getattr(node, "if_body", [])): result += self.emit_statement(stmt) + "\n" self.indent_level -= 1 result += self.indent() + "}" else_body = getattr(node, "else_body", None) if else_body: result += "\nelse\n{\n" self.indent_level += 1 for stmt in self.get_statements(else_body): result += self.emit_statement(stmt) + "\n" self.indent_level -= 1 result += self.indent() + "}" return result def generate_for(self, node): init = self.generate_statement(node.init).rstrip(";") condition = self.generate_expression(node.condition) update = self.generate_statement(node.update).rstrip(";") result = f"for ({init}; {condition}; {update})\n{{\n" self.indent_level += 1 for stmt in self.get_statements(node.body): result += self.emit_statement(stmt) + "\n" self.indent_level -= 1 result += self.indent() + "}" return result def generate_while(self, node): condition = self.generate_expression(node.condition) result = f"while ({condition})\n{{\n" self.indent_level += 1 for stmt in self.get_statements(node.body): result += self.emit_statement(stmt) + "\n" self.indent_level -= 1 result += self.indent() + "}" return result def generate_switch(self, node): expression = self.generate_expression(node.expression) result = f"switch ({expression})\n{{\n" self.indent_level += 1 for case in getattr(node, "cases", []): if not isinstance(case, CaseNode): continue if case.value is None: result += self.indent() + "default:\n" else: case_value = self.generate_expression(case.value) result += self.indent() + f"case {case_value}:\n" self.indent_level += 1 for stmt in self.get_statements(case.statements): result += self.emit_statement(stmt) + "\n" self.indent_level -= 1 self.indent_level -= 1 result += self.indent() + "}" return result def get_statements(self, body): if body is None: return [] if hasattr(body, "statements"): return body.statements if isinstance(body, list): return body return [body] def convert_type(self, type_name): """Map a CrossGL type name or type node to a Slang type string.""" # Map CrossGL types to Slang types type_map = { "vec2<f32>": "float2", "vec3<f32>": "float3", "vec4<f32>": "float4", "vec2<f64>": "double2", "vec3<f64>": "double3", "vec4<f64>": "double4", "vec2<i32>": "int2", "vec3<i32>": "int3", "vec4<i32>": "int4", "vec2<u32>": "uint2", "vec3<u32>": "uint3", "vec4<u32>": "uint4", "vec2<bool>": "bool2", "vec3<bool>": "bool3", "vec4<bool>": "bool4", "vec2": "float2", "vec3": "float3", "vec4": "float4", "ivec2": "int2", "ivec3": "int3", "ivec4": "int4", "uvec2": "uint2", "uvec3": "uint3", "uvec4": "uint4", "dvec2": "double2", "dvec3": "double3", "dvec4": "double4", "bvec2": "bool2", "bvec3": "bool3", "bvec4": "bool4", "mat2": "float2x2", "mat3": "float3x3", "mat4": "float4x4", "mat2x2": "float2x2", "mat2x3": "float2x3", "mat2x4": "float2x4", "mat3x2": "float3x2", "mat3x3": "float3x3", "mat3x4": "float3x4", "mat4x2": "float4x2", "mat4x3": "float4x3", "mat4x4": "float4x4", "dmat2": "double2x2", "dmat3": "double3x3", "dmat4": "double4x4", "dmat2x2": "double2x2", "dmat2x3": "double2x3", "dmat2x4": "double2x4", "dmat3x2": "double3x2", "dmat3x3": "double3x3", "dmat3x4": "double3x4", "dmat4x2": "double4x2", "dmat4x3": "double4x3", "dmat4x4": "double4x4", "float": "float", "int": "int", "uint": "uint", "bool": "bool", "void": "void", "sampler": "SamplerState", "sampler1D": "Sampler1D<float4>", "sampler2D": "Sampler2D<float4>", "sampler3D": "Sampler3D<float4>", "samplerCube": "SamplerCube<float4>", "sampler2DArray": "Sampler2DArray<float4>", "samplerCubeArray": "SamplerCubeArray<float4>", "sampler2DMS": "Sampler2DMS<float4>", "sampler2DMSArray": "Sampler2DMSArray<float4>", "sampler2DShadow": "Sampler2DShadow", "sampler2DArrayShadow": "Sampler2DArrayShadow", "samplerCubeShadow": "SamplerCubeShadow", "samplerCubeArrayShadow": "SamplerCubeArrayShadow", "iimage2D": "RWTexture2D<int>", "iimage3D": "RWTexture3D<int>", "iimage2DArray": "RWTexture2DArray<int>", "iimage2DMS": "RWTexture2DMS<int>", "iimage2DMSArray": "RWTexture2DMSArray<int>", "uimage2D": "RWTexture2D<uint>", "uimage3D": "RWTexture3D<uint>", "uimage2DArray": "RWTexture2DArray<uint>", "uimage2DMS": "RWTexture2DMS<uint>", "uimage2DMSArray": "RWTexture2DMSArray<uint>", "image2D": "RWTexture2D<float4>", "image3D": "RWTexture3D<float4>", "image2DArray": "RWTexture2DArray<float4>", "image2DMS": "RWTexture2DMS<float4>", "image2DMSArray": "RWTexture2DMSArray<float4>", } return type_map.get(type_name, type_name) def supported_image_formats(self): return { "r8", "r8_snorm", "r8i", "r8ui", "r16", "r16_snorm", "r16f", "r16i", "r16ui", "r32f", "r32i", "r32ui", "rg8", "rg8_snorm", "rg8i", "rg8ui", "rg16", "rg16_snorm", "rg16f", "rg16i", "rg16ui", "rg32f", "rg32i", "rg32ui", "rgba8", "rgba8_snorm", "rgba8i", "rgba8ui", "rgba16", "rgba16_snorm", "rgba16f", "rgba16i", "rgba16ui", "rgba32f", "rgba32i", "rgba32ui", } def scalar_image_format_components(self): return { "r8": "float", "r8_snorm": "float", "r16": "float", "r16_snorm": "float", "r16f": "float", "r32f": "float", "r8i": "int", "r16i": "int", "r32i": "int", "r8ui": "uint", "r16ui": "uint", "r32ui": "uint", } def vector_image_format_components(self): return { "rg8": "float2", "rg8_snorm": "float2", "rg16": "float2", "rg16_snorm": "float2", "rg16f": "float2", "rg8i": "int2", "rg16i": "int2", "rg8ui": "uint2", "rg16ui": "uint2", "rg32f": "float2", "rg32i": "int2", "rg32ui": "uint2", "rgba8": "float4", "rgba8_snorm": "float4", "rgba16": "float4", "rgba16_snorm": "float4", "rgba16f": "float4", "rgba32f": "float4", "rgba8i": "int4", "rgba16i": "int4", "rgba32i": "int4", "rgba8ui": "uint4", "rgba16ui": "uint4", "rgba32ui": "uint4", } def attribute_value_to_string(self, value): if value is None: return None if isinstance(value, str): return value if hasattr(value, "name"): return str(value.name) if hasattr(value, "value"): return str(value.value).strip('"') return str(value) def explicit_image_format(self, node): if not hasattr(node, "attributes"): return None supported_formats = self.supported_image_formats() for attr in node.attributes: attr_name = getattr(attr, "name", None) if not attr_name: continue attr_name = str(attr_name).lower() if attr_name in supported_formats: return attr_name if attr_name == "format": arguments = getattr(attr, "arguments", []) or [] if not arguments: continue format_name = self.attribute_value_to_string(arguments[0]) if format_name is None: continue format_name = str(format_name).lower() if format_name in supported_formats: return format_name return None def map_resource_type_with_format(self, type_name, node=None): type_name = self.type_name_string(type_name) if type_name is None: return self.convert_type(type_name) if "[" in type_name and "]" in type_name: base_type, array_suffix = split_array_type_suffix(type_name) mapped_base = self.map_image_base_type_with_format(base_type, node) return f"{mapped_base}{array_suffix}" return self.map_image_base_type_with_format(type_name, node) def map_image_base_type_with_format(self, type_name, node=None): base_type = self.resource_base_type(type_name) explicit_format = self.explicit_image_format(node) if node is not None else None component_type = self.scalar_image_format_components().get( explicit_format ) or self.vector_image_format_components().get(explicit_format) texture_types = { "image2D": "RWTexture2D", "iimage2D": "RWTexture2D", "uimage2D": "RWTexture2D", "image3D": "RWTexture3D", "iimage3D": "RWTexture3D", "uimage3D": "RWTexture3D", "image2DArray": "RWTexture2DArray", "iimage2DArray": "RWTexture2DArray", "uimage2DArray": "RWTexture2DArray", "image2DMS": "RWTexture2DMS", "iimage2DMS": "RWTexture2DMS", "uimage2DMS": "RWTexture2DMS", "image2DMSArray": "RWTexture2DMSArray", "iimage2DMSArray": "RWTexture2DMSArray", "uimage2DMSArray": "RWTexture2DMSArray", } texture_type = texture_types.get(base_type) if component_type and texture_type: return f"{texture_type}<{component_type}>" return self.convert_type(type_name) def is_storage_image_type(self, type_name): base_type = self.resource_base_type(type_name) return isinstance(base_type, str) and base_type in { "image2D", "iimage2D", "uimage2D", "image3D", "iimage3D", "uimage3D", "image2DArray", "iimage2DArray", "uimage2DArray", "image2DMS", "iimage2DMS", "uimage2DMS", "image2DMSArray", "iimage2DMSArray", "uimage2DMSArray", } def image_resource_type(self, image_arg): image_name = self.get_expression_name(image_arg) if not image_name: return None return self.image_resource_types.get(image_name) def image_resource_element_type(self, image_type): image_type = self.resource_base_type(image_type) if not image_type or "<" not in image_type or not image_type.endswith(">"): return None return image_type[image_type.find("<") + 1 : -1] def vector_size(self, type_name): if not isinstance(type_name, str) or not type_name[-1:].isdigit(): return None size = int(type_name[-1]) return size if size in {2, 3, 4} else None def vector_zero_value(self, type_name): if isinstance(type_name, str) and type_name.startswith("uint"): return "0u" if isinstance(type_name, str) and type_name.startswith("int"): return "0" return "0.0" def image_load_expression(self, args): image_name = self.generate_expression(args[0]) coord = self.generate_expression(args[1]) if len(args) >= 3: sample = self.generate_expression(args[2]) load_expr = f"{image_name}[{coord}, {sample}]" else: load_expr = f"{image_name}[{coord}]" image_type = self.image_resource_type(args[0]) element_type = self.image_resource_element_type(image_type) if self.vector_size(element_type) and self.is_scalar_value_type( self.current_expression_expected_type ): return f"{load_expr}.x" return load_expr def image_store_value_expression(self, image_arg, value_arg): value = self.generate_expression(value_arg) image_type = self.image_resource_type(image_arg) element_type = self.image_resource_element_type(image_type) if not self.vector_size(element_type): return value if not self.is_scalar_value_type(self.expression_result_type(value_arg)): return value if self.vector_size(element_type) == 2: return f"{element_type}({value}, {self.vector_zero_value(element_type)})" return f"{element_type}({value})" def image_store_expression(self, args): image_name = self.generate_expression(args[0]) coord = self.generate_expression(args[1]) if len(args) >= 4: sample = self.generate_expression(args[2]) value = self.image_store_value_expression(args[0], args[3]) return f"{image_name}[{coord}, {sample}] = {value}" value = self.image_store_value_expression(args[0], args[2]) return f"{image_name}[{coord}] = {value}" def image_atomic_intrinsic(self, operation): return { "imageAtomicAdd": "InterlockedAdd", "imageAtomicMin": "InterlockedMin", "imageAtomicMax": "InterlockedMax", "imageAtomicAnd": "InterlockedAnd", "imageAtomicOr": "InterlockedOr", "imageAtomicXor": "InterlockedXor", "imageAtomicExchange": "InterlockedExchange", "imageAtomicCompSwap": "InterlockedCompareExchange", }.get(operation) def image_atomic_helper_suffix(self, image_type): return { "RWTexture2D<int>": "iimage2D", "RWTexture2D<uint>": "uimage2D", "RWTexture3D<int>": "iimage3D", "RWTexture3D<uint>": "uimage3D", "RWTexture2DArray<int>": "iimage2DArray", "RWTexture2DArray<uint>": "uimage2DArray", }.get(image_type) def image_atomic_return_type(self, image_type): element_type = self.image_resource_element_type(image_type) if element_type in {"int", "uint"}: return element_type return None def image_atomic_coord_type(self, image_type): if image_type in {"RWTexture2D<int>", "RWTexture2D<uint>"}: return "int2" if image_type in { "RWTexture3D<int>", "RWTexture3D<uint>", "RWTexture2DArray<int>", "RWTexture2DArray<uint>", }: return "int3" return None def image_atomic_helper_name(self, operation, image_type): suffix = self.image_atomic_helper_suffix(image_type) if not suffix: return None return f"cgl_{operation}_{suffix}" def image_atomic_zero_value(self, image_type=None): element_type = self.image_resource_element_type(image_type) if isinstance(element_type, str) and element_type.startswith("uint"): return "0u" expected_type = self.convert_type(self.current_expression_expected_type) if expected_type == "uint": return "0u" return "0" def unsupported_image_atomic_call(self, operation, reason, image_type=None): return ( f"/* unsupported Slang image atomic: {operation} {reason} */ " f"{self.image_atomic_zero_value(image_type)}" ) def image_atomic_required_args_reason(self, operation): if operation == "imageAtomicCompSwap": return "requires image, coordinate, compare, and value arguments" return "requires image, coordinate, and value arguments" def image_atomic_expression(self, operation, args): if not self.image_atomic_intrinsic(operation): return None required_args = 4 if operation == "imageAtomicCompSwap" else 3 if len(args) < required_args: return self.unsupported_image_atomic_call( operation, self.image_atomic_required_args_reason(operation) ) image_type = self.resource_base_type(self.image_resource_type(args[0])) helper_name = self.image_atomic_helper_name(operation, image_type) if not helper_name: return self.unsupported_image_atomic_call( operation, "requires scalar int or uint image2D/image3D/image2DArray resource", image_type, ) self.register_helper_function( helper_name, self.build_image_atomic_helper(helper_name, operation, image_type), ) image_name = self.generate_expression(args[0]) coord = self.generate_expression(args[1]) if operation == "imageAtomicCompSwap": compare = self.generate_expression(args[2]) value = self.generate_expression(args[3]) return f"{helper_name}({image_name}, {coord}, {compare}, {value})" value = self.generate_expression(args[2]) return f"{helper_name}({image_name}, {coord}, {value})" def build_image_atomic_helper(self, helper_name, operation, image_type): return_type = self.image_atomic_return_type(image_type) coord_type = self.image_atomic_coord_type(image_type) intrinsic = self.image_atomic_intrinsic(operation) if not return_type or not coord_type or not intrinsic: return "" if operation == "imageAtomicCompSwap": return ( f"{return_type} {helper_name}({image_type} image, " f"{coord_type} coord, {return_type} compareValue, " f"{return_type} value)\n" "{\n" f" {return_type} original;\n" " InterlockedCompareExchange(image[coord], compareValue, value, original);\n" " return original;\n" "}" ) return ( f"{return_type} {helper_name}({image_type} image, " f"{coord_type} coord, {return_type} value)\n" "{\n" f" {return_type} original;\n" f" {intrinsic}(image[coord], value, original);\n" " return original;\n" "}" ) def resource_query_slang_type(self, resource_arg, resource_type): if self.is_storage_image_type(resource_type): image_type = self.resource_base_type(self.image_resource_type(resource_arg)) if image_type: return image_type return self.convert_type(resource_type) def resource_query_helper_name(self, func_name, resource_type, resource_slang_type): base_name = f"cgl_{func_name}_{resource_type}" if resource_slang_type == self.convert_type(resource_type): return base_name return f"{base_name}_{self.resource_helper_type_suffix(resource_slang_type)}" def resource_helper_type_suffix(self, resource_slang_type): return "".join( char if char.isalnum() else "_" for char in str(resource_slang_type).strip("_") ).strip("_") def generate_resource_call(self, func_name, args): if func_name == "imageLoad" and len(args) >= 2: return self.image_load_expression(args) if func_name == "imageStore" and len(args) >= 3: return self.image_store_expression(args) if func_name in { "imageAtomicAdd", "imageAtomicMin", "imageAtomicMax", "imageAtomicAnd", "imageAtomicOr", "imageAtomicXor", "imageAtomicExchange", "imageAtomicCompSwap", }: return self.image_atomic_expression(func_name, args) if func_name in {"texture", "textureLod", "textureGrad"}: sample_args = self.sampled_texture_args(args) if sample_args is None: return None texture_name, coord, extra_args = sample_args if func_name == "texture": if extra_args: bias = self.generate_expression(extra_args[0]) return f"{texture_name}.SampleBias({coord}, {bias})" return f"{texture_name}.Sample({coord})" if func_name == "textureLod" and extra_args: lod = self.generate_expression(extra_args[0]) return f"{texture_name}.SampleLevel({coord}, {lod})" if func_name == "textureGrad" and len(extra_args) >= 2: ddx = self.generate_expression(extra_args[0]) ddy = self.generate_expression(extra_args[1]) return f"{texture_name}.SampleGrad({coord}, {ddx}, {ddy})" return None if func_name in {"textureOffset", "textureLodOffset", "textureGradOffset"}: return self.generate_texture_offset(func_name, args) if func_name in { "textureProj", "textureProjOffset", "textureProjLod", "textureProjLodOffset", "textureProjGrad", "textureProjGradOffset", }: return self.generate_texture_projected(func_name, args) if func_name in { "textureGather", "textureGatherOffset", "textureGatherOffsets", }: return self.generate_texture_gather(func_name, args) if func_name in { "textureCompare", "textureCompareLod", "textureCompareGrad", "textureCompareOffset", }: return self.generate_texture_compare(func_name, args) if func_name in {"textureGatherCompare", "textureGatherCompareOffset"}: return self.generate_texture_gather_compare(func_name, args) if func_name == "texelFetch": fetch_args = self.sampled_texture_args(args) if fetch_args is None: return None texture_name, coord, extra_args = fetch_args if not extra_args: return None lod_or_sample = self.generate_expression(extra_args[0]) texture_type = self.get_expression_type(args[0]) if self.is_multisample_sampler_type(texture_type): return f"{texture_name}[{coord}, {lod_or_sample}]" coord_constructor = self.texel_fetch_coord_constructor(texture_type) return f"{texture_name}.Load({coord_constructor}({coord}, {lod_or_sample}))" if func_name in {"textureSize", "imageSize"}: return self.generate_dimension_query(func_name, args) if func_name in {"textureSamples", "imageSamples"}: return self.generate_sample_count_query(func_name, args) if func_name == "textureQueryLevels": return self.generate_texture_query_levels(args) if func_name == "textureQueryLod": return self.generate_texture_query_lod(args) return None def generate_dimension_query(self, func_name, args): if not args: return None resource_name = self.generate_expression(args[0]) resource_type = self.resource_base_type(self.get_expression_type(args[0])) spec = self.dimension_query_spec(resource_type) if spec is None: return None resource_slang_type = self.resource_query_slang_type(args[0], resource_type) helper_name = self.resource_query_helper_name( func_name, resource_type, resource_slang_type ) self.register_helper_function( helper_name, self.build_dimension_query_helper( helper_name, resource_type, spec, resource_slang_type ), ) if spec["mip"]: lod = self.generate_expression(args[1]) if len(args) > 1 else "0" return f"{helper_name}({resource_name}, {lod})" return f"{helper_name}({resource_name})" def generate_sample_count_query(self, func_name, args): if not args: return None resource_name = self.generate_expression(args[0]) resource_type = self.resource_base_type(self.get_expression_type(args[0])) spec = self.dimension_query_spec(resource_type) if spec is None or not spec["samples"]: return None resource_slang_type = self.resource_query_slang_type(args[0], resource_type) helper_name = self.resource_query_helper_name( func_name, resource_type, resource_slang_type ) self.register_helper_function( helper_name, self.build_sample_count_query_helper( helper_name, resource_type, spec, resource_slang_type ), ) return f"{helper_name}({resource_name})" def sampled_texture_args(self, args): coord_index = self.sampled_texture_coord_index(args) if len(args) <= coord_index: return None texture_name = self.generate_expression(args[0]) coord = self.generate_expression(args[coord_index]) return texture_name, coord, args[coord_index + 1 :] def sampled_texture_coord_index(self, args): return 2 if self.is_explicit_sampler_argument(args) else 1 def generate_texture_offset(self, func_name, args): sample_args = self.sampled_texture_args(args) if sample_args is None: return self.unsupported_texture_offset_call( func_name, "requires texture and coordinate arguments" ) texture_name, coord, extra_args = sample_args if func_name == "textureOffset": if len(extra_args) != 1: return self.unsupported_texture_offset_call( func_name, "requires one offset argument" ) offset = self.generate_expression(extra_args[0]) return f"{texture_name}.Sample({coord}, {offset})" if func_name == "textureLodOffset": if len(extra_args) != 2: return self.unsupported_texture_offset_call( func_name, "requires lod and offset arguments" ) lod = self.generate_expression(extra_args[0]) offset = self.generate_expression(extra_args[1]) return f"{texture_name}.SampleLevel({coord}, {lod}, {offset})" if len(extra_args) != 3: return self.unsupported_texture_offset_call( func_name, "requires gradient x, gradient y, and offset arguments" ) ddx = self.generate_expression(extra_args[0]) ddy = self.generate_expression(extra_args[1]) offset = self.generate_expression(extra_args[2]) return f"{texture_name}.SampleGrad({coord}, {ddx}, {ddy}, {offset})" def unsupported_texture_offset_call(self, func_name, reason): return ( f"/* unsupported Slang texture offset: {func_name} {reason} */ float4(0.0)" ) def generate_texture_projected(self, func_name, args): sample_args = self.sampled_texture_args(args) if sample_args is None: return self.unsupported_texture_projected_call( func_name, "requires texture and projected coordinate arguments" ) texture_name, coord, extra_args = sample_args coord_node = args[self.sampled_texture_coord_index(args)] projected_coord = self.projected_texture_coord(args[0], coord_node, coord) if projected_coord is None: return self.unsupported_texture_projected_call( func_name, "requires sampler1D/2D/3D projection coordinates" ) if func_name == "textureProj": if not extra_args: return f"{texture_name}.Sample({projected_coord})" if len(extra_args) == 1: bias = self.generate_expression(extra_args[0]) return f"{texture_name}.SampleBias({projected_coord}, {bias})" return self.unsupported_texture_projected_call( func_name, "accepts at most one bias argument" ) if func_name == "textureProjOffset": if len(extra_args) == 1: offset = self.generate_expression(extra_args[0]) return f"{texture_name}.Sample({projected_coord}, {offset})" if len(extra_args) == 2: offset = self.generate_expression(extra_args[0]) bias = self.generate_expression(extra_args[1]) return f"{texture_name}.SampleBias({projected_coord}, {bias}, {offset})" return self.unsupported_texture_projected_call( func_name, "requires offset and optional bias arguments" ) if func_name == "textureProjLod": if len(extra_args) != 1: return self.unsupported_texture_projected_call( func_name, "requires one lod argument" ) lod = self.generate_expression(extra_args[0]) return f"{texture_name}.SampleLevel({projected_coord}, {lod})" if func_name == "textureProjLodOffset": if len(extra_args) != 2: return self.unsupported_texture_projected_call( func_name, "requires lod and offset arguments" ) lod = self.generate_expression(extra_args[0]) offset = self.generate_expression(extra_args[1]) return f"{texture_name}.SampleLevel({projected_coord}, {lod}, {offset})" if func_name == "textureProjGrad": if len(extra_args) != 2: return self.unsupported_texture_projected_call( func_name, "requires gradient x and gradient y arguments" ) ddx = self.generate_expression(extra_args[0]) ddy = self.generate_expression(extra_args[1]) return f"{texture_name}.SampleGrad({projected_coord}, {ddx}, {ddy})" if len(extra_args) != 3: return self.unsupported_texture_projected_call( func_name, "requires gradient x, gradient y, and offset arguments" ) ddx = self.generate_expression(extra_args[0]) ddy = self.generate_expression(extra_args[1]) offset = self.generate_expression(extra_args[2]) return f"{texture_name}.SampleGrad({projected_coord}, {ddx}, {ddy}, {offset})" def projected_texture_coord(self, texture_node, coord_node, coord): resource_type = self.resource_base_type(self.get_expression_type(texture_node)) coord_type = self.resource_base_type(self.get_expression_type(coord_node)) specs = { "sampler1D": { "vec2": ("x", "y"), "float2": ("x", "y"), "vec4": ("x", "w"), "float4": ("x", "w"), }, "sampler2D": { "vec3": ("xy", "z"), "float3": ("xy", "z"), "vec4": ("xy", "w"), "float4": ("xy", "w"), }, "sampler3D": { "vec4": ("xyz", "w"), "float4": ("xyz", "w"), }, } resource_specs = specs.get(resource_type) if resource_specs is None: return None coord_spec = resource_specs.get(coord_type) if coord_spec is None: return None numerator, divisor = coord_spec return f"{coord}.{numerator} / {coord}.{divisor}" def unsupported_texture_projected_call(self, func_name, reason): return ( f"/* unsupported Slang projected texture: " f"{func_name} {reason} */ float4(0.0)" ) def generate_texture_gather(self, func_name, args): gather_args = self.sampled_texture_args(args) if gather_args is None: return self.unsupported_texture_gather_call( func_name, "requires texture and coordinate arguments" ) texture_name, coord, extra_args = gather_args offset_args = [] component_arg = None if func_name == "textureGather": if len(extra_args) > 1: return self.unsupported_texture_gather_call( func_name, "accepts at most one component argument" ) if extra_args: component_arg = extra_args[0] elif func_name == "textureGatherOffset": if len(extra_args) not in {1, 2}: return self.unsupported_texture_gather_call( func_name, "requires offset and optional component arguments" ) offset_args = [extra_args[0]] if len(extra_args) == 2: component_arg = extra_args[1] else: offset_args, component_arg = self.texture_gather_offsets_args(extra_args) if offset_args is None: return self.unsupported_texture_gather_call( func_name, "requires a typed offsets array or four offset arguments", ) method_args = [coord] + [ self.generate_expression(offset_arg) for offset_arg in offset_args ] method = self.texture_gather_method(component_arg) if method is not None: return f"{texture_name}.{method}({', '.join(method_args)})" if isinstance(component_arg, LiteralNode): return self.unsupported_texture_gather_call( func_name, "component literal must be 0, 1, 2, or 3" ) component = self.generate_expression(component_arg) return self.texture_gather_component_expression( texture_name, method_args, component ) def texture_gather_offsets_args(self, extra_args): if len(extra_args) in {1, 2} and self.is_array_expression(extra_args[0]): offsets_name = self.generate_expression(extra_args[0]) offset_args = [f"{offsets_name}[{index}]" for index in range(4)] component_arg = extra_args[1] if len(extra_args) == 2 else None return offset_args, component_arg if len(extra_args) in {4, 5}: component_arg = extra_args[4] if len(extra_args) == 5 else None return extra_args[:4], component_arg return None, None def texture_gather_method(self, component_arg): if component_arg is None: return "Gather" methods = { 0: "GatherRed", 1: "GatherGreen", 2: "GatherBlue", 3: "GatherAlpha", } return methods.get(self.literal_int_value(component_arg)) def texture_gather_component_expression(self, texture_name, method_args, component): arg_list = ", ".join(method_args) component_calls = [ f"{texture_name}.{method}({arg_list})" for method in ( "GatherRed", "GatherGreen", "GatherBlue", "GatherAlpha", ) ] return ( f"({component} == 0 ? {component_calls[0]} : " f"{component} == 1 ? {component_calls[1]} : " f"{component} == 2 ? {component_calls[2]} : " f"{component_calls[3]})" ) def unsupported_texture_gather_call(self, func_name, reason): return ( f"/* unsupported Slang texture gather: {func_name} {reason} */ float4(0.0)" ) def generate_texture_compare(self, func_name, args): compare_args = self.texture_compare_args(func_name, args) if compare_args is None: return self.unsupported_texture_compare_call( func_name, "requires texture, coordinate, and compare arguments" ) texture_name, coord, compare, extra_args = compare_args if not self.is_shadow_compare_resource(args[0]): return self.unsupported_texture_compare_call( func_name, "requires a shadow sampler resource" ) if func_name == "textureCompare": if extra_args: return self.unsupported_texture_compare_call( func_name, "accepts no extra arguments" ) return f"{texture_name}.SampleCmp({coord}, {compare})" if func_name == "textureCompareOffset": if len(extra_args) != 1: return self.unsupported_texture_compare_call( func_name, "requires one offset argument" ) offset = self.generate_expression(extra_args[0]) return f"{texture_name}.SampleCmp({coord}, {compare}, {offset})" if func_name == "textureCompareLod": if len(extra_args) != 1: return self.unsupported_texture_compare_call( func_name, "requires one lod argument" ) lod = self.generate_expression(extra_args[0]) return f"{texture_name}.SampleCmpLevel({coord}, {compare}, {lod})" if len(extra_args) != 2: return self.unsupported_texture_compare_call( func_name, "requires gradient x and gradient y arguments" ) ddx = self.generate_expression(extra_args[0]) ddy = self.generate_expression(extra_args[1]) return f"{texture_name}.SampleCmpGrad({coord}, {compare}, {ddx}, {ddy})" def generate_texture_gather_compare(self, func_name, args): compare_args = self.texture_compare_args(func_name, args) if compare_args is None: return self.unsupported_texture_gather_compare_call( func_name, "requires texture, coordinate, and compare arguments" ) texture_name, coord, compare, extra_args = compare_args if not self.is_shadow_compare_resource(args[0]): return self.unsupported_texture_gather_compare_call( func_name, "requires a shadow sampler resource" ) if func_name == "textureGatherCompare": if extra_args: return self.unsupported_texture_gather_compare_call( func_name, "accepts no extra arguments" ) return f"{texture_name}.GatherCmp({coord}, {compare})" if len(extra_args) != 1: return self.unsupported_texture_gather_compare_call( func_name, "requires one offset argument" ) offset = self.generate_expression(extra_args[0]) return f"{texture_name}.GatherCmp({coord}, {compare}, {offset})" def texture_compare_args(self, func_name, args): coord_index = 2 if self.is_explicit_sampler_argument(args) else 1 if len(args) <= coord_index + 1: return None texture_name = self.generate_expression(args[0]) coord = self.generate_expression(args[coord_index]) compare = self.generate_expression(args[coord_index + 1]) return texture_name, coord, compare, args[coord_index + 2 :] def is_shadow_compare_resource(self, node): resource_type = self.resource_base_type(self.get_expression_type(node)) return resource_type is None or resource_type in { "sampler2DShadow", "sampler2DArrayShadow", "samplerCubeShadow", "samplerCubeArrayShadow", } def unsupported_texture_compare_call(self, func_name, reason): return f"/* unsupported Slang shadow compare: {func_name} {reason} */ 0.0" def unsupported_texture_gather_compare_call(self, func_name, reason): return ( f"/* unsupported Slang shadow gather compare: " f"{func_name} {reason} */ float4(0.0)" ) def literal_int_value(self, node): if not isinstance(node, LiteralNode): return None value = node.value if isinstance(value, bool): return None if isinstance(value, int): return value if isinstance(value, str): try: return int(value, 0) except ValueError: return None return None def is_array_expression(self, node): type_name = self.get_expression_type(node) return isinstance(type_name, str) and "[" in type_name and "]" in type_name def generate_texture_query_levels(self, args): if not args: return None resource_name = self.generate_expression(args[0]) resource_type = self.resource_base_type(self.get_expression_type(args[0])) spec = self.dimension_query_spec(resource_type) if spec is None or not spec["mip"]: return None helper_name = f"cgl_textureQueryLevels_{resource_type}" self.register_helper_function( helper_name, self.build_texture_query_levels_helper(helper_name, resource_type, spec), ) return f"{helper_name}({resource_name})" def generate_texture_query_lod(self, args): query_args = self.texture_query_lod_args(args) if query_args is None: return None texture_name, coord = query_args unclamped = f"{texture_name}.CalculateLevelOfDetailUnclamped({coord})" clamped = f"{texture_name}.CalculateLevelOfDetail({coord})" return f"float2({unclamped}, {clamped})" def texture_query_lod_args(self, args): coord_index = 2 if self.is_explicit_sampler_argument(args) else 1 if len(args) <= coord_index: return None resource_type = self.resource_base_type(self.get_expression_type(args[0])) if not self.is_lod_query_sampler_type(resource_type): return None texture_name = self.generate_expression(args[0]) coord = self.generate_expression(args[coord_index]) return texture_name, coord def register_helper_function(self, name, source): if name not in self.helper_functions: self.helper_functions[name] = source def build_dimension_query_helper( self, helper_name, resource_type, spec, resource_slang_type=None ): resource_slang_type = resource_slang_type or self.convert_type(resource_type) return_type = self.query_return_type(spec["dimensions"]) params = f"{resource_slang_type} tex" if spec["mip"]: params += ", uint mipLevel" declarations = self.query_local_declarations(spec) get_dimensions_args = self.get_dimensions_args(spec) dimensions = ", ".join(spec["dimensions"]) if len(spec["dimensions"]) == 1: return_value = spec["dimensions"][0] else: return_value = f"{return_type}({dimensions})" return ( f"{return_type} {helper_name}({params})\n" "{\n" f"{declarations}" f" tex.GetDimensions({get_dimensions_args});\n" f" return {return_value};\n" "}" ) def build_sample_count_query_helper( self, helper_name, resource_type, spec, resource_slang_type=None ): resource_slang_type = resource_slang_type or self.convert_type(resource_type) declarations = self.query_local_declarations(spec) get_dimensions_args = self.get_dimensions_args(spec) return ( f"int {helper_name}({resource_slang_type} tex)\n" "{\n" f"{declarations}" f" tex.GetDimensions({get_dimensions_args});\n" " return samples;\n" "}" ) def build_texture_query_levels_helper(self, helper_name, resource_type, spec): resource_slang_type = self.convert_type(resource_type) declarations = self.query_local_declarations(spec) get_dimensions_args = self.texture_query_levels_args(spec) return ( f"int {helper_name}({resource_slang_type} tex)\n" "{\n" f"{declarations}" f" tex.GetDimensions({get_dimensions_args});\n" " return levels;\n" "}" ) def query_return_type(self, dimensions): if len(dimensions) == 1: return "int" return f"int{len(dimensions)}" def query_local_declarations(self, spec): names = list(spec["dimensions"]) if spec["samples"]: names.append("samples") if spec["mip"]: names.append("levels") return "".join(f" int {name};\n" for name in names) def get_dimensions_args(self, spec): args = [] if spec["mip"]: args.append("mipLevel") args.extend(spec["dimensions"]) if spec["samples"]: args.append("samples") if spec["mip"]: args.append("levels") return ", ".join(args) def texture_query_levels_args(self, spec): args = list(spec["dimensions"]) args.append("levels") return ", ".join(args) def dimension_query_spec(self, type_name): specs = { "sampler1D": (("width",), True, False), "sampler1DArray": (("width", "elements"), True, False), "sampler2D": (("width", "height"), True, False), "sampler2DShadow": (("width", "height"), True, False), "sampler2DArray": (("width", "height", "elements"), True, False), "sampler2DArrayShadow": ( ("width", "height", "elements"), True, False, ), "sampler3D": (("width", "height", "depth"), True, False), "samplerCube": (("width", "height"), True, False), "samplerCubeShadow": (("width", "height"), True, False), "samplerCubeArray": (("width", "height", "elements"), True, False), "samplerCubeArrayShadow": ( ("width", "height", "elements"), True, False, ), "sampler2DMS": (("width", "height"), False, True), "sampler2DMSArray": (("width", "height", "elements"), False, True), "image2D": (("width", "height"), False, False), "iimage2D": (("width", "height"), False, False), "uimage2D": (("width", "height"), False, False), "image2DArray": (("width", "height", "elements"), False, False), "iimage2DArray": (("width", "height", "elements"), False, False), "uimage2DArray": (("width", "height", "elements"), False, False), "image3D": (("width", "height", "depth"), False, False), "iimage3D": (("width", "height", "depth"), False, False), "uimage3D": (("width", "height", "depth"), False, False), "image2DMS": (("width", "height"), False, True), "iimage2DMS": (("width", "height"), False, True), "uimage2DMS": (("width", "height"), False, True), "image2DMSArray": (("width", "height", "elements"), False, True), "iimage2DMSArray": (("width", "height", "elements"), False, True), "uimage2DMSArray": (("width", "height", "elements"), False, True), } spec = specs.get(type_name) if spec is None: return None dimensions, mip, samples = spec return { "dimensions": dimensions, "mip": mip, "samples": samples, } def is_explicit_sampler_argument(self, args): if len(args) < 3: return False return self.is_sampler_state_type(self.get_expression_type(args[1])) def is_sampler_state_type(self, type_name): return self.resource_base_type(type_name) in { "sampler", "SamplerState", "SamplerComparisonState", } def is_lod_query_sampler_type(self, type_name): resource_type = self.resource_base_type(type_name) return ( isinstance(resource_type, str) and resource_type.startswith("sampler") and resource_type != "sampler" and "MS" not in resource_type and "Shadow" not in resource_type ) def get_expression_type(self, node): name = self.get_expression_name(node) if name is None: return None return self.variable_types.get(name) def get_expression_name(self, node): if isinstance(node, IdentifierNode): return node.name if isinstance(node, VariableNode): return node.name if isinstance(node, str): return node if isinstance(node, ArrayAccessNode): return self.get_expression_name( getattr(node, "array", getattr(node, "array_expr", None)) ) return None def resource_base_type(self, type_name): if not isinstance(type_name, str): return None return type_name.split("[", 1)[0] def is_multisample_sampler_type(self, type_name): return self.resource_base_type(type_name) in { "sampler2DMS", "sampler2DMSArray", } def texel_fetch_coord_constructor(self, type_name): base_type = self.resource_base_type(type_name) if base_type in {"sampler1D", "sampler1DArray"}: return "int2" if base_type == "sampler1D" else "int3" if base_type in {"sampler3D", "sampler2DArray"}: return "int4" return "int3"