"""CrossGL-to-Mojo code generator."""
from ..ast import (
ArrayNode,
ArrayAccessNode,
ArrayLiteralNode,
AssignmentNode,
BinaryOpNode,
CbufferNode,
ForNode,
FunctionCallNode,
FunctionNode,
IfNode,
MemberAccessNode,
ReturnNode,
ShaderNode,
StructNode,
TernaryOpNode,
UnaryOpNode,
VariableNode,
)
from .array_utils import parse_array_type, format_array_type, get_array_size_from_node
MOJO_VECTOR_TYPES = {
"vec2": ("DType.float32", 2, 2, None),
"vec3": ("DType.float32", 3, 4, "0.0"),
"vec4": ("DType.float32", 4, 4, None),
"vec2<f32>": ("DType.float32", 2, 2, None),
"vec3<f32>": ("DType.float32", 3, 4, "0.0"),
"vec4<f32>": ("DType.float32", 4, 4, None),
"vec2<f64>": ("DType.float64", 2, 2, None),
"vec3<f64>": ("DType.float64", 3, 4, "0.0"),
"vec4<f64>": ("DType.float64", 4, 4, None),
"vec2<i32>": ("DType.int32", 2, 2, None),
"vec3<i32>": ("DType.int32", 3, 4, "0"),
"vec4<i32>": ("DType.int32", 4, 4, None),
"vec2<u32>": ("DType.uint32", 2, 2, None),
"vec3<u32>": ("DType.uint32", 3, 4, "0"),
"vec4<u32>": ("DType.uint32", 4, 4, None),
"vec2<bool>": ("DType.bool", 2, 2, None),
"vec3<bool>": ("DType.bool", 3, 4, "False"),
"vec4<bool>": ("DType.bool", 4, 4, None),
"ivec2": ("DType.int32", 2, 2, None),
"ivec3": ("DType.int32", 3, 4, "0"),
"ivec4": ("DType.int32", 4, 4, None),
"uvec2": ("DType.uint32", 2, 2, None),
"uvec3": ("DType.uint32", 3, 4, "0"),
"uvec4": ("DType.uint32", 4, 4, None),
"dvec2": ("DType.float64", 2, 2, None),
"dvec3": ("DType.float64", 3, 4, "0.0"),
"dvec4": ("DType.float64", 4, 4, None),
"bvec2": ("DType.bool", 2, 2, None),
"bvec3": ("DType.bool", 3, 4, "False"),
"bvec4": ("DType.bool", 4, 4, None),
"bool2": ("DType.bool", 2, 2, None),
"bool3": ("DType.bool", 3, 4, "False"),
"bool4": ("DType.bool", 4, 4, None),
}
MOJO_MATRIX_TYPES = {
"mat2": ("DType.float32", 2, 2),
"mat3": ("DType.float32", 3, 3),
"mat4": ("DType.float32", 4, 4),
"mat2x2": ("DType.float32", 2, 2),
"mat2x3": ("DType.float32", 2, 3),
"mat2x4": ("DType.float32", 2, 4),
"mat3x2": ("DType.float32", 3, 2),
"mat3x3": ("DType.float32", 3, 3),
"mat3x4": ("DType.float32", 3, 4),
"mat4x2": ("DType.float32", 4, 2),
"mat4x3": ("DType.float32", 4, 3),
"mat4x4": ("DType.float32", 4, 4),
"dmat2": ("DType.float64", 2, 2),
"dmat3": ("DType.float64", 3, 3),
"dmat4": ("DType.float64", 4, 4),
"dmat2x2": ("DType.float64", 2, 2),
"dmat2x3": ("DType.float64", 2, 3),
"dmat2x4": ("DType.float64", 2, 4),
"dmat3x2": ("DType.float64", 3, 2),
"dmat3x3": ("DType.float64", 3, 3),
"dmat3x4": ("DType.float64", 3, 4),
"dmat4x2": ("DType.float64", 4, 2),
"dmat4x3": ("DType.float64", 4, 3),
"dmat4x4": ("DType.float64", 4, 4),
}
SWIZZLE_SETS = {
"xyzw": {"x": 0, "y": 1, "z": 2, "w": 3},
"rgba": {"r": 0, "g": 1, "b": 2, "a": 3},
}
MOJO_DTYPE_INFO = {
"DType.float32": ("float", "vec", "0.0"),
"DType.float64": ("double", "dvec", "0.0"),
"DType.int32": ("int", "ivec", "0"),
"DType.uint32": ("uint", "uvec", "0"),
"DType.bool": ("bool", "bvec", "False"),
}
MOJO_DTYPE_SUFFIX = {
"DType.float32": "f32",
"DType.float64": "f64",
"DType.int32": "i32",
"DType.uint32": "u32",
"DType.bool": "bool",
}
MOJO_SCALAR_DTYPES = {
"float": "DType.float32",
"double": "DType.float64",
"int": "DType.int32",
"uint": "DType.uint32",
"bool": "DType.bool",
}
MOJO_INTEGER_INDEX_TYPES = {"int", "uint", "short", "ushort", "long", "ulong"}
MOJO_VECTOR_ARITHMETIC_OPS = {
"+": "add",
"-": "sub",
"*": "mul",
"/": "div",
}
[docs]
class MojoCodeGen:
"""Emit Mojo-like shader source from the shared CrossGL AST."""
def __init__(self):
"""Initialize Mojo type maps and helper-generation state."""
self.vector_constructor_info = MOJO_VECTOR_TYPES
self.struct_types = {}
self.function_return_types = {}
self.variable_types = {}
self.required_helpers = set()
self.required_splat_helpers = set()
self.required_swizzle_helpers = set()
self.required_constructor_helpers = {}
self.required_matrix_types = set()
self.required_matrix_constructor_helpers = {}
self.current_return_type = None
self.current_shader = None
self.type_mapping = {
# Scalar Types
"void": "None",
"int": "Int32",
"short": "Int16",
"long": "Int64",
"uint": "UInt32",
"ushort": "UInt16",
"ulong": "UInt64",
"float": "Float32",
"double": "Float64",
"half": "Float16",
"bool": "Bool",
"string": "String",
"char": "String",
**{
name: f"SIMD[{dtype}, {storage_width}]"
for name, (dtype, _, storage_width, _) in MOJO_VECTOR_TYPES.items()
},
**{
name: self.matrix_type_name(dtype, columns, rows)
for name, (dtype, columns, rows) in MOJO_MATRIX_TYPES.items()
},
# Texture Types (Mojo equivalents)
"sampler2D": "Texture2D",
"samplerCube": "TextureCube",
"sampler": "Sampler",
}
self.semantic_map = {
# Vertex attributes
"gl_VertexID": "vertex_id",
"gl_InstanceID": "instance_id",
"gl_Position": "position",
"gl_PointSize": "point_size",
"gl_ClipDistance": "clip_distance",
# Fragment attributes
"gl_FragColor": "color(0)",
"gl_FragColor0": "color(0)",
"gl_FragColor1": "color(1)",
"gl_FragColor2": "color(2)",
"gl_FragColor3": "color(3)",
"gl_FragDepth": "depth(any)",
"gl_FragCoord": "position",
"gl_FrontFacing": "front_facing",
"gl_PointCoord": "point_coord",
# Standard vertex semantics
"POSITION": "position",
"NORMAL": "normal",
"TANGENT": "tangent",
"BINORMAL": "binormal",
"TEXCOORD": "texcoord",
"TEXCOORD0": "texcoord0",
"TEXCOORD1": "texcoord1",
"TEXCOORD2": "texcoord2",
"TEXCOORD3": "texcoord3",
"COLOR": "color",
"COLOR0": "color0",
"COLOR1": "color1",
}
# Function mapping for common shader functions
self.function_map = {
"texture": "sample",
"normalize": "normalize",
"dot": "dot_product",
"cross": "cross_product",
"length": "magnitude",
"reflect": "reflect",
"refract": "refract",
"sin": "sin",
"cos": "cos",
"tan": "tan",
"sqrt": "sqrt",
"pow": "power",
"abs": "abs",
"min": "min",
"max": "max",
"clamp": "clamp",
"mix": "lerp",
"smoothstep": "smoothstep",
"step": "step",
}
[docs]
def generate(self, ast):
"""Generate complete Mojo-like shader source for a CrossGL AST."""
self.struct_types = {}
self.function_return_types = {}
self.variable_types = {}
self.required_helpers = set()
self.required_splat_helpers = set()
self.required_swizzle_helpers = set()
self.required_constructor_helpers = {}
self.required_matrix_types = set()
self.required_matrix_constructor_helpers = {}
self.current_return_type = None
self.collect_function_return_types(ast)
header = "# Generated Mojo Shader Code\n"
header += "from math import *\n"
header += "from simd import *\n"
header += "from gpu import *\n\n"
code = ""
structs = getattr(ast, "structs", [])
for node in structs:
if isinstance(node, StructNode):
code += self.generate_struct(node)
global_vars = getattr(ast, "global_variables", [])
for node in global_vars:
if isinstance(node, ArrayNode):
code += self.generate_array_declaration(node)
else:
# Handle both old and new AST variable structures
if hasattr(node, "var_type"):
vtype = self.convert_type_node_to_string(node.var_type)
elif hasattr(node, "vtype"):
vtype = node.vtype
else:
vtype = "float"
self.register_variable_type(node.name, vtype)
if hasattr(node, "initial_value") and node.initial_value is not None:
if isinstance(
node.initial_value, ArrayLiteralNode
) and self.is_array_type_name(vtype):
init_expr = self.generate_array_literal_expression(
node.initial_value, vtype
)
else:
init_expr = self.generate_expression(node.initial_value)
code += f"var {node.name}: {self.map_type(vtype)} = {init_expr}\n"
elif self.is_array_type_name(vtype):
code += (
f"var {node.name} = "
f"{self.array_initial_value_for_type(vtype)}\n"
)
elif self.is_struct_type_name(vtype):
code += f"var {node.name} = {self.zero_value_for_type(vtype)}\n"
else:
code += f"var {node.name}: {self.map_type(vtype)}\n"
cbuffers = getattr(ast, "cbuffers", [])
if cbuffers:
code += "# Constant Buffers\n"
code += self.generate_cbuffers(ast)
functions = getattr(ast, "functions", [])
for func in functions:
# Handle both old and new AST function structures
if hasattr(func, "qualifiers") and func.qualifiers:
qualifier = func.qualifiers[0] if func.qualifiers else None
else:
qualifier = getattr(func, "qualifier", None)
if qualifier == "vertex":
code += "# Vertex Shader\n"
code += self.generate_function(func, shader_type="vertex")
elif qualifier == "fragment":
code += "# Fragment Shader\n"
code += self.generate_function(func, shader_type="fragment")
elif qualifier == "compute":
code += "# Compute Shader\n"
code += self.generate_function(func, shader_type="compute")
else:
code += self.generate_function(func)
# Handle shader stages (new AST structure)
if hasattr(ast, "stages") and ast.stages:
for stage_type, stage in ast.stages.items():
if hasattr(stage, "entry_point"):
stage_name = (
str(stage_type).split(".")[-1].lower()
) # Extract stage name from enum
code += f"# {stage_name.title()} Shader\n"
code += self.generate_function(
stage.entry_point, shader_type=stage_name
)
if hasattr(stage, "local_functions"):
for func in stage.local_functions:
code += self.generate_function(func)
return header + self.generate_required_helpers() + code
def collect_function_return_types(self, ast):
functions = list(getattr(ast, "functions", []))
stages = getattr(ast, "stages", {})
if stages:
for stage in stages.values():
entry_point = getattr(stage, "entry_point", None)
if entry_point is not None:
functions.append(entry_point)
functions.extend(getattr(stage, "local_functions", []))
for func in functions:
self.register_function_return_type(func)
def register_function_return_type(self, func):
if not hasattr(func, "name"):
return
if hasattr(func, "return_type"):
return_type = self.convert_type_node_to_string(func.return_type)
else:
return_type = "void"
self.function_return_types[func.name] = return_type
def convert_type_node_to_string(self, type_node) -> str:
"""Convert new AST TypeNode to string representation."""
if type_node.__class__.__name__ == "ArrayType":
element_type = self.convert_type_node_to_string(type_node.element_type)
size = self.format_array_size(type_node.size)
return (
f"{element_type}[{size}]" if size is not None else f"{element_type}[]"
)
if hasattr(type_node, "name"):
generic_args = getattr(type_node, "generic_args", [])
if generic_args:
args = ", ".join(
self.convert_type_node_to_string(arg) for arg in generic_args
)
return f"{type_node.name}<{args}>"
return type_node.name
elif hasattr(type_node, "element_type") and hasattr(type_node, "size"):
element_type = self.convert_type_node_to_string(type_node.element_type)
size = type_node.size
if element_type == "float":
return f"vec{size}"
elif element_type == "int":
return f"ivec{size}"
elif element_type == "uint":
return f"uvec{size}"
elif element_type == "double":
return f"dvec{size}"
elif element_type == "bool":
return f"bvec{size}"
else:
return f"{element_type}{size}"
elif hasattr(type_node, "element_type") and hasattr(type_node, "rows"):
element_type = self.convert_type_node_to_string(type_node.element_type)
prefix = "dmat" if element_type == "double" else "mat"
if type_node.rows == type_node.cols:
return f"{prefix}{type_node.rows}"
return f"{prefix}{type_node.rows}x{type_node.cols}"
else:
return str(type_node)
def format_array_size(self, size):
if size is None:
return None
if hasattr(size, "value"):
return size.value
return size
def extract_semantic_from_attributes(self, attributes):
"""Extract semantic information from new AST attributes."""
semantic_attrs = [
"position",
"color",
"texcoord",
"normal",
"tangent",
"binormal",
"POSITION",
"COLOR",
"TEXCOORD",
"NORMAL",
"TANGENT",
"BINORMAL",
"TEXCOORD0",
"TEXCOORD1",
"TEXCOORD2",
"TEXCOORD3",
]
for attr in attributes:
if hasattr(attr, "name") and attr.name in semantic_attrs:
return attr.name
return None
def generate_struct(self, node):
code = f"@value\nstruct {node.name}:\n"
self.struct_types[node.name] = {}
members = getattr(node, "members", [])
for member in members:
if isinstance(member, ArrayNode):
element_type = getattr(
member, "element_type", getattr(member, "vtype", "float")
)
size = get_array_size_from_node(member)
self.struct_types[node.name][member.name] = self.array_type_name(
element_type, size
)
code += (
f" var {member.name}: "
f"{self.array_storage_type(element_type, size)}\n"
)
else:
if hasattr(member, "member_type"):
member_type = self.convert_type_node_to_string(member.member_type)
elif hasattr(member, "vtype"):
member_type = member.vtype
else:
member_type = "float"
self.struct_types[node.name][member.name] = member_type
semantic = None
if hasattr(member, "semantic"):
semantic = member.semantic
elif hasattr(member, "attributes"):
semantic = self.extract_semantic_from_attributes(member.attributes)
semantic_comment = (
f" # {self.map_semantic(semantic)}" if semantic else ""
)
code += f" var {member.name}: {self.map_type(member_type)}{semantic_comment}\n"
code += "\n"
return code
def generate_cbuffers(self, ast):
code = ""
cbuffers = getattr(ast, "cbuffers", [])
for node in cbuffers:
if isinstance(node, StructNode):
code += f"@value\nstruct {node.name}:\n"
members = getattr(node, "members", [])
for member in members:
if isinstance(member, ArrayNode):
element_type = getattr(
member, "element_type", getattr(member, "vtype", "float")
)
size = get_array_size_from_node(member)
code += (
f" var {member.name}: "
f"{self.array_storage_type(element_type, size)}\n"
)
else:
# Handle both old and new AST member structures
if hasattr(member, "member_type"):
member_type = self.map_type(str(member.member_type))
else:
member_type = self.map_type(
getattr(member, "vtype", "float")
)
code += f" var {member.name}: {member_type}\n"
code += "\n"
elif hasattr(node, "name") and hasattr(node, "members"): # CbufferNode
code += f"@value\nstruct {node.name}:\n"
for member in node.members:
if isinstance(member, ArrayNode):
element_type = getattr(
member, "element_type", getattr(member, "vtype", "float")
)
size = get_array_size_from_node(member)
code += (
f" var {member.name}: "
f"{self.array_storage_type(element_type, size)}\n"
)
else:
# Handle both old and new AST member structures
if hasattr(member, "member_type"):
member_type = self.map_type(str(member.member_type))
else:
member_type = self.map_type(
getattr(member, "vtype", "float")
)
code += f" var {member.name}: {member_type}\n"
code += "\n"
return code
def generate_function(self, func, indent=0, shader_type=None):
"""Render one CrossGL function or shader entry point as Mojo code."""
code = ""
" " * indent
previous_variable_types = self.variable_types.copy()
previous_return_type = self.current_return_type
param_list = getattr(func, "parameters", getattr(func, "params", []))
param_names = {p.name for p in param_list if hasattr(p, "name")}
mutated_params = self.collect_mutated_parameters(
getattr(func, "body", []), param_names
)
params = []
for p in param_list:
if hasattr(p, "param_type"):
param_type = self.convert_type_node_to_string(p.param_type)
elif hasattr(p, "vtype"):
param_type = p.vtype
else:
param_type = "float"
semantic = None
if hasattr(p, "semantic"):
semantic = p.semantic
elif hasattr(p, "attributes"):
semantic = self.extract_semantic_from_attributes(p.attributes)
self.register_variable_type(p.name, param_type)
param_semantic = f" # {self.map_semantic(semantic)}" if semantic else ""
ownership = "owned " if p.name in mutated_params else ""
params.append(
f"{ownership}{p.name}: {self.map_type(param_type)}{param_semantic}"
)
params_str = ", ".join(params) if params else ""
if hasattr(func, "return_type"):
return_type = self.convert_type_node_to_string(func.return_type)
else:
return_type = "void"
self.function_return_types[func.name] = return_type
self.current_return_type = return_type
if shader_type == "vertex":
code += f"@vertex_shader\n"
elif shader_type == "fragment":
code += f"@fragment_shader\n"
elif shader_type == "compute":
code += f"@compute_shader\n"
code += f"fn {func.name}({params_str}) -> {self.map_type(return_type)}:\n"
body = getattr(func, "body", [])
if hasattr(body, "statements"):
for stmt in body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(body, list):
for stmt in body:
code += self.generate_statement(stmt, indent + 1)
else:
code += " pass\n"
code += "\n"
self.variable_types = previous_variable_types
self.current_return_type = previous_return_type
return code
def collect_mutated_parameters(self, body, param_names):
mutated = set()
for stmt in self.body_statements(body):
self.collect_mutated_parameters_from_node(stmt, param_names, mutated)
return mutated
def body_statements(self, body):
if hasattr(body, "statements"):
return body.statements
if isinstance(body, list):
return body
if body is None:
return []
return [body]
def collect_mutated_parameters_from_node(self, node, param_names, mutated):
if node is None:
return
if isinstance(node, AssignmentNode):
root_name = self.assignment_target_root(node.left)
if root_name in param_names:
mutated.add(root_name)
self.collect_mutated_parameters_from_node(node.right, param_names, mutated)
return
if isinstance(node, UnaryOpNode) and self.map_operator(node.op) in ["++", "--"]:
root_name = self.assignment_target_root(node.operand)
if root_name in param_names:
mutated.add(root_name)
return
for child in self.node_children(node):
self.collect_mutated_parameters_from_node(child, param_names, mutated)
def node_children(self, node):
children = []
for attr in (
"init",
"condition",
"update",
"body",
"then_branch",
"if_body",
"else_branch",
"else_body",
"value",
"expression",
"left",
"right",
"object",
"object_expr",
"array",
"array_expr",
"index",
"index_expr",
"operand",
"vector_expr",
):
if hasattr(node, attr):
children.append(getattr(node, attr))
for attr in ("statements", "args", "arguments"):
if hasattr(node, attr):
children.extend(getattr(node, attr))
if isinstance(node, ArrayLiteralNode):
children.extend(node.elements)
return children
def assignment_target_root(self, target):
if isinstance(target, str):
return target
if isinstance(target, VariableNode) and hasattr(target, "name"):
return target.name
if isinstance(target, ArrayAccessNode):
return self.assignment_target_root(target.array)
if isinstance(target, MemberAccessNode):
return self.assignment_target_root(target.object)
if hasattr(target, "__class__") and "Identifier" in str(target.__class__):
return getattr(target, "name", None)
if hasattr(target, "__class__") and "Swizzle" in str(target.__class__):
return self.assignment_target_root(getattr(target, "vector_expr", None))
return None
def generate_statement(self, stmt, indent=0):
"""Render a single CrossGL statement as Mojo code."""
indent_str = " " * indent
if isinstance(stmt, VariableNode):
if hasattr(stmt, "var_type"):
var_type = self.convert_type_node_to_string(stmt.var_type)
elif hasattr(stmt, "vtype") and stmt.vtype:
# Old AST structure - check if this is actually an array declaration disguised as a variable
vtype_str = str(stmt.vtype)
if (
"ArrayAccessNode" in vtype_str
and "array=" in vtype_str
and "index=" in vtype_str
):
# This is likely an array declaration
import re
array_match = re.search(r"array=(\w+).*?index=(\w+)", vtype_str)
if array_match:
array_match.group(1)
size = array_match.group(2)
base_type = "Float32" # Default, could be improved
return (
f"{indent_str}var {stmt.name} = "
f"InlineArray[{base_type}, {size}]"
"(unsafe_uninitialized=True)\n"
)
var_type = stmt.vtype
else:
var_type = "float"
self.register_variable_type(stmt.name, var_type)
if hasattr(stmt, "initial_value") and stmt.initial_value is not None:
if isinstance(
stmt.initial_value, ArrayLiteralNode
) and self.is_array_type_name(var_type):
init_expr = self.generate_array_literal_expression(
stmt.initial_value, var_type
)
else:
init_expr = self.generate_expression(stmt.initial_value)
return f"{indent_str}var {stmt.name}: {self.map_type(var_type)} = {init_expr}\n"
elif self.is_array_type_name(var_type):
return (
f"{indent_str}var {stmt.name} = "
f"{self.array_initial_value_for_type(var_type)}\n"
)
elif self.is_struct_type_name(var_type):
return (
f"{indent_str}var {stmt.name} = "
f"{self.zero_value_for_type(var_type)}\n"
)
else:
return f"{indent_str}var {stmt.name}: {self.map_type(var_type)}\n"
elif isinstance(stmt, ArrayNode):
return self.generate_array_declaration(stmt, indent)
elif isinstance(stmt, AssignmentNode):
return f"{indent_str}{self.generate_assignment(stmt)}\n"
elif isinstance(stmt, IfNode):
return self.generate_if(stmt, indent)
elif isinstance(stmt, ForNode):
return self.generate_for(stmt, indent)
elif isinstance(stmt, ReturnNode):
if isinstance(stmt.value, list):
# Multiple return values
values = ", ".join(self.generate_expression(val) for val in stmt.value)
return f"{indent_str}return {values}\n"
elif isinstance(stmt.value, ArrayLiteralNode) and self.is_array_type_name(
self.current_return_type
):
return_value = self.generate_array_literal_expression(
stmt.value, self.current_return_type
)
return f"{indent_str}return " f"{return_value}\n"
else:
return f"{indent_str}return {self.generate_expression(stmt.value)}\n"
elif isinstance(stmt, ArrayAccessNode):
# ArrayAccessNode should not appear as a statement by itself - it's likely a misclassified array declaration
# Try to handle it gracefully
return f"{indent_str}# Unhandled ArrayAccessNode: {stmt}\n"
else:
# Handle expressions that may be used as statements
expr_result = self.generate_expression(stmt)
if expr_result.strip():
return f"{indent_str}{expr_result}\n"
else:
return f"{indent_str}# Unhandled statement: {type(stmt).__name__}\n"
def generate_array_declaration(self, node, indent=0):
indent_str = " " * indent
size = get_array_size_from_node(node)
self.register_variable_type(
node.name, self.array_type_name(node.element_type, size)
)
return (
f"{indent_str}var {node.name} = "
f"{self.array_initial_value(node.element_type, size)}\n"
)
def generate_assignment(self, node):
left = self.generate_expression(node.left)
left_type = self.expression_result_type(node.left)
if isinstance(node.right, ArrayLiteralNode) and self.is_array_type_name(
left_type
):
right = self.generate_array_literal_expression(node.right, left_type)
else:
right = self.generate_expression(node.right)
op = self.map_operator(node.operator)
return f"{left} {op} {right}"
def generate_if(self, node, indent):
indent_str = " " * indent
condition = self.generate_expression(
node.condition if hasattr(node, "condition") else node.if_condition
)
code = f"{indent_str}if {condition}:\n"
if_body = getattr(node, "then_branch", getattr(node, "if_body", None))
if hasattr(if_body, "statements"):
for stmt in if_body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(if_body, list):
for stmt in if_body:
code += self.generate_statement(stmt, indent + 1)
else_branch = getattr(node, "else_branch", None)
if else_branch:
if hasattr(else_branch, "__class__") and "If" in str(else_branch.__class__):
# Generate elif by recursively generating the nested if with elif prefix
elif_condition = self.generate_expression(
else_branch.condition
if hasattr(else_branch, "condition")
else else_branch.if_condition
)
code += f"{indent_str}elif {elif_condition}:\n"
# Generate elif body
elif_body = getattr(
else_branch, "then_branch", getattr(else_branch, "if_body", None)
)
if hasattr(elif_body, "statements"):
for stmt in elif_body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(elif_body, list):
for stmt in elif_body:
code += self.generate_statement(stmt, indent + 1)
nested_else = getattr(else_branch, "else_branch", None)
if nested_else:
if hasattr(nested_else, "__class__") and "If" in str(
nested_else.__class__
):
# Another elif
remaining_code = self.generate_if(nested_else, indent)
# Remove the "if" prefix and replace with "elif"
remaining_lines = remaining_code.split("\n")
if remaining_lines[0].strip().startswith("if "):
remaining_lines[0] = remaining_lines[0].replace(
"if ", "elif ", 1
)
code += "\n".join(remaining_lines)
else:
# Final else clause
code += f"{indent_str}else:\n"
if hasattr(nested_else, "statements"):
for stmt in nested_else.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(nested_else, list):
for stmt in nested_else:
code += self.generate_statement(stmt, indent + 1)
else:
code += self.generate_statement(nested_else, indent + 1)
else:
code += f"{indent_str}else:\n"
if hasattr(else_branch, "statements"):
for stmt in else_branch.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(else_branch, list):
for stmt in else_branch:
code += self.generate_statement(stmt, indent + 1)
else:
code += self.generate_statement(else_branch, indent + 1)
return code
def generate_for(self, node, indent):
indent_str = " " * indent
init = self.generate_statement(node.init, 0).strip()
condition = self.generate_expression(node.condition)
update = self.generate_expression(node.update)
code = f"{indent_str}{init}\n"
code += f"{indent_str}while {condition}:\n"
if hasattr(node.body, "statements"):
for stmt in node.body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(node.body, list):
for stmt in node.body:
code += self.generate_statement(stmt, indent + 1)
else:
code += self.generate_statement(node.body, indent + 1)
# Add update at the end of the loop
code += f"{indent_str} {update}\n"
return code
def generate_expression(self, expr):
"""Render a CrossGL expression as Mojo expression syntax."""
if isinstance(expr, str):
return expr
elif isinstance(expr, (int, float, bool)):
return self.format_literal(expr)
elif isinstance(expr, VariableNode):
if hasattr(expr, "vtype") and expr.vtype and expr.name:
return f"{expr.name}"
elif hasattr(expr, "name"):
return expr.name
else:
return str(expr)
elif isinstance(expr, BinaryOpNode):
vector_binary = self.generate_vector_binary_op(expr)
if vector_binary is not None:
return vector_binary
left = self.generate_expression(expr.left)
right = self.generate_expression(expr.right)
op = self.map_operator(expr.op)
return f"({left} {op} {right})"
elif isinstance(expr, AssignmentNode):
return self.generate_assignment(expr)
elif isinstance(expr, ArrayLiteralNode):
return self.generate_array_literal_expression(expr)
elif isinstance(expr, UnaryOpNode):
operand = self.generate_expression(expr.operand)
op = self.map_operator(expr.op)
if op in ["++", "--"]:
assignment_op = "+=" if op == "++" else "-="
return f"{operand} {assignment_op} 1"
return f"({op}{operand})"
elif isinstance(expr, ArrayAccessNode):
# Handle array access properly
if hasattr(expr, "array") and hasattr(expr, "index"):
return self.generate_array_access_expression(expr)
else:
# Fallback for malformed ArrayAccessNode
return str(expr)
elif isinstance(expr, FunctionCallNode):
# Extract function name properly (might be IdentifierNode)
func_expr = getattr(expr, "function", None)
if func_expr is None:
func_expr = expr.name
func_name = None
if hasattr(func_expr, "name"):
# It's an IdentifierNode, extract the name
func_name = func_expr.name
callee = func_name
elif isinstance(func_expr, str):
func_name = func_expr
callee = func_expr
else:
callee = self.generate_expression(func_expr)
# Map function names to Mojo equivalents
func_name = self.function_map.get(func_name, func_name)
# Handle vector constructors
if func_name in self.vector_constructor_info:
return self.generate_vector_constructor(func_name, expr.args)
if func_name in MOJO_MATRIX_TYPES:
return self.generate_matrix_constructor(func_name, expr.args)
# Handle standard function calls
args = ", ".join(self.generate_expression(arg) for arg in expr.args)
return f"{callee}({args})"
elif isinstance(expr, MemberAccessNode):
obj = self.generate_expression(expr.object)
swizzle_indices = self.get_swizzle_indices(expr.member)
if swizzle_indices is not None:
obj_type = self.expression_result_type(expr.object)
return self.generate_swizzle(
expr.object, obj, obj_type, expr.member, swizzle_indices
)
return f"{obj}.{expr.member}"
elif isinstance(expr, TernaryOpNode):
condition = self.generate_expression(expr.condition)
true_expr = self.generate_expression(expr.true_expr)
false_expr = self.generate_expression(expr.false_expr)
return f"({true_expr} if {condition} else {false_expr})"
elif hasattr(expr, "__class__") and "Literal" in str(expr.__class__):
# Handle LiteralNode
if hasattr(expr, "value"):
literal_type = getattr(
getattr(expr, "literal_type", None), "name", None
)
return self.format_literal(expr.value, literal_type)
return str(expr)
elif hasattr(expr, "__class__") and "Identifier" in str(expr.__class__):
# Handle IdentifierNode
return getattr(expr, "name", str(expr))
elif hasattr(expr, "__class__") and "ExpressionStatement" in str(
expr.__class__
):
# Handle ExpressionStatementNode
if hasattr(expr, "expression"):
return self.generate_expression(expr.expression)
else:
return self.generate_expression(expr)
else:
# For unknown expression types, handle special cases
expr_str = str(expr)
# Check if this looks like an array declaration being misinterpreted
if (
"ArrayAccessNode" in expr_str
and "array=" in expr_str
and "index=" in expr_str
):
# Try to extract array name and size for array declarations
import re
array_match = re.search(r"array=(\w+).*?index=(\w+)", expr_str)
if array_match:
array_name = array_match.group(1)
array_match.group(2)
return f"{array_name}" # Just return the array name for now
return expr_str
def generate_vector_constructor(self, func_name, args):
helper_call = self.generate_constructor_helper_call(func_name, args)
if helper_call is not None:
return helper_call
dtype, source_width, storage_width, pad_literal = self.vector_constructor_info[
func_name
]
mojo_type = f"SIMD[{dtype}, {storage_width}]"
emitted_args = []
if len(args) == 1:
arg = args[0]
arg_components = self.vector_components_for_expression(arg, dtype)
if arg_components is not None:
emitted_args.extend(arg_components[:source_width])
elif source_width == 3:
arg_expr = self.generate_constructor_scalar_expression(arg, dtype)
if self.is_duplicate_sensitive_expression(arg):
helper_name = self.vec3_splat_helper_name(dtype)
self.required_splat_helpers.add(dtype)
return f"{helper_name}({arg_expr})"
emitted_args.extend([arg_expr] * source_width)
else:
emitted_args.append(
self.generate_constructor_scalar_expression(arg, dtype)
)
else:
for arg in args:
arg_components = self.vector_components_for_expression(arg, dtype)
if arg_components is not None:
emitted_args.extend(arg_components)
else:
emitted_args.append(
self.generate_constructor_scalar_expression(arg, dtype)
)
if len(emitted_args) > source_width:
emitted_args = emitted_args[:source_width]
if source_width == 3 and len(emitted_args) == 3:
emitted_args.append(pad_literal)
return f"{mojo_type}({', '.join(emitted_args)})"
def generate_matrix_constructor(self, func_name, args):
dtype, columns, rows = MOJO_MATRIX_TYPES[func_name]
matrix_key = (dtype, columns, rows)
self.required_matrix_types.add(matrix_key)
helper_call = self.generate_matrix_constructor_helper_call(
dtype, columns, rows, args
)
if helper_call is not None:
return helper_call
component_count = columns * rows
components = []
for arg in args:
arg_components = self.vector_components_for_expression(arg, dtype)
if arg_components is not None:
components.extend(arg_components)
else:
components.append(
self.generate_constructor_scalar_expression(arg, dtype)
)
if len(args) == 1 and len(components) == 1:
scalar = components[0]
components = [
scalar if column == row else self.matrix_zero_literal(dtype)
for column in range(columns)
for row in range(rows)
]
if len(components) > component_count:
components = components[:component_count]
elif len(components) < component_count:
components.extend(
self.matrix_zero_literal(dtype)
for _ in range(component_count - len(components))
)
matrix_type = self.matrix_type_name(dtype, columns, rows)
column_args = []
storage_rows = self.matrix_storage_rows(rows)
pad_literal = self.matrix_zero_literal(dtype)
for column in range(columns):
start = column * rows
column_components = components[start : start + rows]
if rows == 3:
column_components.append(pad_literal)
column_type = f"SIMD[{dtype}, {storage_rows}]"
column_args.append(f"{column_type}({', '.join(column_components)})")
return f"{matrix_type}({', '.join(column_args)})"
def generate_matrix_constructor_helper_call(self, dtype, columns, rows, args):
component_count = columns * rows
pieces = []
for arg in args:
piece = self.constructor_piece_for_expression(arg, dtype)
if piece is None:
return None
pieces.append(piece)
if len(args) == 1 and pieces and pieces[0]["kind"] == "scalar":
return None
pieces = self.select_constructor_pieces(pieces, component_count)
if pieces is None:
return None
has_duplicate_sensitive_vector = any(
piece["kind"] == "vector" and piece["duplicate_sensitive"]
for piece in pieces
)
if not has_duplicate_sensitive_vector:
return None
key = self.matrix_constructor_helper_key(dtype, columns, rows, pieces)
helper_name = self.matrix_constructor_helper_name(key)
self.required_matrix_constructor_helpers[key] = {
"key": key,
"dtype": dtype,
"columns": columns,
"rows": rows,
"pieces": pieces,
}
call_args = [self.generate_expression(piece["expr"]) for piece in pieces]
return f"{helper_name}({', '.join(call_args)})"
def matrix_constructor_helper_key(self, dtype, columns, rows, pieces):
signature = self.constructor_helper_key(dtype, columns * rows, columns, pieces)
return (dtype, columns, rows, signature[3])
def matrix_constructor_helper_name(self, key):
dtype, columns, rows, signature = key
vector_key = (dtype, columns * rows, columns, signature)
suffix = self.constructor_helper_name(vector_key).split("_", 4)[4]
return (
f"_crossgl_construct_matrix_{MOJO_DTYPE_SUFFIX[dtype]}_"
f"c{columns}_r{rows}_{suffix}"
)
def matrix_type_name(self, dtype, columns, rows):
dtype_suffix = MOJO_DTYPE_SUFFIX[dtype].upper()
return f"CrossGLMatrix{dtype_suffix}C{columns}R{rows}"
def matrix_storage_rows(self, rows):
return 4 if rows == 3 else rows
def matrix_zero_literal(self, dtype):
return MOJO_DTYPE_INFO[dtype][2]
def generate_matrix_type(self, key):
dtype, columns, rows = key
name = self.matrix_type_name(dtype, columns, rows)
storage_rows = self.matrix_storage_rows(rows)
column_type = f"SIMD[{dtype}, {storage_rows}]"
code = f"@value\nstruct {name}:\n"
for column in range(columns):
code += f" var c{column}: {column_type}\n"
code += "\n"
for index_type in ("Int", "Int32", "UInt32"):
code += f" fn __getitem__(self, index: {index_type}) -> {column_type}:\n"
for column in range(columns - 1):
code += f" if index == {column}:\n"
code += f" return self.c{column}\n"
code += f" return self.c{columns - 1}\n\n"
code += (
f" fn __setitem__(inout self, index: {index_type}, "
f"value: {column_type}):\n"
)
for column in range(columns - 1):
code += f" if index == {column}:\n"
code += f" self.c{column} = value\n"
code += " return\n"
code += f" self.c{columns - 1} = value\n\n"
return code + "\n"
def generate_matrix_constructor_helper(self, helper):
dtype = helper["dtype"]
columns = helper["columns"]
rows = helper["rows"]
matrix_type = self.matrix_type_name(dtype, columns, rows)
scalar_type, _, _ = MOJO_DTYPE_INFO[dtype]
mojo_scalar_type = self.map_type(scalar_type)
params = []
components = []
prelude = []
for index, piece in enumerate(helper["pieces"]):
if piece["kind"] == "vector":
param_name = f"v{index}"
vector_type = f"SIMD[{piece['dtype']}, {piece['storage_width']}]"
params.append(f"{param_name}: {vector_type}")
vector_expr = param_name
if piece["dtype"] != dtype:
vector_expr = f"{param_name}_cast"
prelude.append(
f" var {vector_expr} = {param_name}.cast[{dtype}]()\n"
)
components.extend(
f"{vector_expr}[{component_index}]"
for component_index in piece["indices"]
)
else:
param_name = f"s{index}"
piece_dtype = piece.get("dtype")
param_scalar_type = mojo_scalar_type
if piece_dtype is not None and piece_dtype != dtype:
scalar_type = MOJO_DTYPE_INFO[piece_dtype][0]
param_scalar_type = self.map_type(scalar_type)
params.append(f"{param_name}: {param_scalar_type}")
components.append(
self.cast_scalar_text(param_name, piece_dtype, dtype)
if piece_dtype is not None
else param_name
)
column_args = []
storage_rows = self.matrix_storage_rows(rows)
pad_literal = self.matrix_zero_literal(dtype)
for column in range(columns):
start = column * rows
column_components = components[start : start + rows]
if rows == 3:
column_components.append(pad_literal)
column_type = f"SIMD[{dtype}, {storage_rows}]"
column_args.append(f"{column_type}({', '.join(column_components)})")
helper_name = self.matrix_constructor_helper_name(helper["key"])
code = f"fn {helper_name}({', '.join(params)}) -> {matrix_type}:\n"
code += "".join(prelude)
code += f" return {matrix_type}({', '.join(column_args)})\n\n"
return code
def generate_vector_binary_op(self, expr):
op = self.map_operator(expr.op)
if op not in MOJO_VECTOR_ARITHMETIC_OPS:
return None
left_type = self.expression_result_type(expr.left)
right_type = self.expression_result_type(expr.right)
left_info = self.vector_type_info(left_type)
right_info = self.vector_type_info(right_type)
left_is_vec3 = left_info is not None and left_info[1] == 3
right_is_vec3 = right_info is not None and right_info[1] == 3
if not left_is_vec3 and not right_is_vec3:
return None
if left_is_vec3 and right_is_vec3:
if left_info[0] != right_info[0]:
return None
dtype = left_info[0]
helper_kind = "vv"
elif left_is_vec3:
if right_info is not None:
return None
dtype = left_info[0]
helper_kind = "vs"
else:
if left_info is not None:
return None
dtype = right_info[0]
helper_kind = "sv"
if dtype == "DType.bool" or dtype not in MOJO_DTYPE_SUFFIX:
return None
left = self.generate_expression(expr.left)
right = self.generate_expression(expr.right)
helper_name = self.vector_binary_helper_name(dtype, op, helper_kind)
self.required_helpers.add((dtype, op, helper_kind))
return f"{helper_name}({left}, {right})"
def generate_required_helpers(self):
if (
not self.required_helpers
and not self.required_splat_helpers
and not self.required_swizzle_helpers
and not self.required_constructor_helpers
and not self.required_matrix_types
and not self.required_matrix_constructor_helpers
):
return ""
code = ""
if self.required_matrix_types:
code += "# CrossGL matrix types\n"
for key in sorted(self.required_matrix_types):
code += self.generate_matrix_type(key)
code += "\n"
if (
self.required_helpers
or self.required_splat_helpers
or self.required_swizzle_helpers
or self.required_constructor_helpers
or self.required_matrix_constructor_helpers
):
code += "# CrossGL vector helpers\n"
for dtype, op, helper_kind in sorted(self.required_helpers):
code += self.generate_vector_binary_helper(dtype, op, helper_kind)
for dtype in sorted(self.required_splat_helpers):
code += self.generate_vec3_splat_helper(dtype)
for dtype, source_width, member in sorted(self.required_swizzle_helpers):
code += self.generate_swizzle_helper(dtype, source_width, member)
for key in sorted(self.required_constructor_helpers):
code += self.generate_constructor_helper(
self.required_constructor_helpers[key]
)
for key in sorted(self.required_matrix_constructor_helpers):
code += self.generate_matrix_constructor_helper(
self.required_matrix_constructor_helpers[key]
)
return code + "\n"
def generate_vector_binary_helper(self, dtype, op, helper_kind):
scalar_type, _, pad_literal = MOJO_DTYPE_INFO[dtype]
mojo_scalar_type = self.map_type(scalar_type)
vector_type = f"SIMD[{dtype}, 4]"
helper_name = self.vector_binary_helper_name(dtype, op, helper_kind)
if helper_kind == "vv":
params = f"a: {vector_type}, b: {vector_type}"
components = [f"a[{index}] {op} b[{index}]" for index in range(3)]
elif helper_kind == "vs":
params = f"v: {vector_type}, s: {mojo_scalar_type}"
components = [f"v[{index}] {op} s" for index in range(3)]
else:
params = f"s: {mojo_scalar_type}, v: {vector_type}"
components = [f"s {op} v[{index}]" for index in range(3)]
components.append(pad_literal)
args = ", ".join(components)
code = f"fn {helper_name}({params}) -> {vector_type}:\n"
code += f" return {vector_type}({args})\n\n"
return code
def vector_binary_helper_name(self, dtype, op, helper_kind):
op_name = MOJO_VECTOR_ARITHMETIC_OPS[op]
dtype_suffix = MOJO_DTYPE_SUFFIX[dtype]
return f"_crossgl_vec3_{op_name}_{dtype_suffix}_{helper_kind}"
def generate_vec3_splat_helper(self, dtype):
scalar_type, _, pad_literal = MOJO_DTYPE_INFO[dtype]
mojo_scalar_type = self.map_type(scalar_type)
vector_type = f"SIMD[{dtype}, 4]"
helper_name = self.vec3_splat_helper_name(dtype)
code = f"fn {helper_name}(s: {mojo_scalar_type}) -> {vector_type}:\n"
code += f" return {vector_type}(s, s, s, {pad_literal})\n\n"
return code
def vec3_splat_helper_name(self, dtype):
return f"_crossgl_vec3_splat_{MOJO_DTYPE_SUFFIX[dtype]}"
def generate_swizzle_helper(self, dtype, source_width, member):
_, _, pad_literal = MOJO_DTYPE_INFO[dtype]
swizzle_indices = self.get_swizzle_indices(member)
result_width = 2 if len(swizzle_indices) == 2 else 4
source_type = f"SIMD[{dtype}, {source_width}]"
result_type = f"SIMD[{dtype}, {result_width}]"
helper_name = self.swizzle_helper_name(dtype, source_width, member)
components = [f"v[{index}]" for index in swizzle_indices]
if len(swizzle_indices) == 3:
components.append(pad_literal)
code = f"fn {helper_name}(v: {source_type}) -> {result_type}:\n"
code += f" return {result_type}({', '.join(components)})\n\n"
return code
def swizzle_helper_name(self, dtype, source_width, member):
dtype_suffix = MOJO_DTYPE_SUFFIX[dtype]
return f"_crossgl_swizzle_{dtype_suffix}_{source_width}_{member}"
def generate_constructor_helper_call(self, func_name, args):
dtype, source_width, storage_width, pad_literal = self.vector_constructor_info[
func_name
]
pieces = []
for arg in args:
piece = self.constructor_piece_for_expression(arg, dtype)
if piece is None:
return None
pieces.append(piece)
pieces = self.select_constructor_pieces(pieces, source_width)
if pieces is None:
return None
has_duplicate_sensitive_vector = any(
piece["kind"] == "vector" and piece["duplicate_sensitive"]
for piece in pieces
)
if not has_duplicate_sensitive_vector:
return None
key = self.constructor_helper_key(dtype, source_width, storage_width, pieces)
helper_name = self.constructor_helper_name(key)
self.required_constructor_helpers[key] = {
"key": key,
"dtype": dtype,
"storage_width": storage_width,
"pad_literal": pad_literal,
"pieces": pieces,
}
call_args = [self.generate_expression(piece["expr"]) for piece in pieces]
return f"{helper_name}({', '.join(call_args)})"
def select_constructor_pieces(self, pieces, source_width):
selected = []
remaining = source_width
for piece in pieces:
if remaining == 0:
break
if piece["kind"] == "vector":
indices = piece["indices"][:remaining]
if indices:
selected.append({**piece, "indices": tuple(indices)})
remaining -= len(indices)
else:
selected.append(piece)
remaining -= 1
if remaining != 0:
return None
return selected
def constructor_piece_for_expression(self, expr, target_dtype):
if isinstance(expr, MemberAccessNode):
swizzle_indices = self.get_swizzle_indices(expr.member)
if swizzle_indices is not None:
source_type = self.expression_result_type(expr.object)
source_info = self.vector_type_info(source_type)
if source_info is None:
return None
return {
"kind": "vector",
"dtype": source_info[0],
"storage_width": source_info[2],
"indices": tuple(swizzle_indices),
"expr": expr.object,
"duplicate_sensitive": self.is_duplicate_sensitive_expression(
expr.object
),
}
expr_type = self.expression_result_type(expr)
info = self.vector_type_info(expr_type)
if info is not None:
_, source_width, storage_width, _ = info
return {
"kind": "vector",
"dtype": info[0],
"storage_width": storage_width,
"indices": tuple(range(source_width)),
"expr": expr,
"duplicate_sensitive": self.is_duplicate_sensitive_expression(expr),
}
return {
"kind": "scalar",
"expr": expr,
"dtype": self.expression_mojo_dtype(expr),
}
def constructor_helper_key(self, dtype, source_width, storage_width, pieces):
signature = []
for piece in pieces:
if piece["kind"] == "vector":
signature.append(
(
"v",
piece["dtype"],
piece["storage_width"],
piece["indices"],
)
)
else:
piece_dtype = piece.get("dtype")
if piece_dtype is not None and piece_dtype != dtype:
signature.append(("s", piece_dtype))
else:
signature.append(("s",))
return (dtype, source_width, storage_width, tuple(signature))
def constructor_helper_name(self, key):
dtype, _, storage_width, signature = key
parts = []
for piece in signature:
if piece[0] == "v":
_, piece_dtype, piece_storage_width, indices = piece
index_text = "".join(str(index) for index in indices)
parts.append(
f"v{MOJO_DTYPE_SUFFIX[piece_dtype]}{piece_storage_width}_{index_text}"
)
elif len(piece) > 1:
parts.append(f"s{MOJO_DTYPE_SUFFIX[piece[1]]}")
else:
parts.append("s")
suffix = "_".join(parts)
return f"_crossgl_construct_{MOJO_DTYPE_SUFFIX[dtype]}_{storage_width}_{suffix}"
def generate_constructor_helper(self, helper):
dtype = helper["dtype"]
scalar_type, _, _ = MOJO_DTYPE_INFO[dtype]
mojo_scalar_type = self.map_type(scalar_type)
result_type = f"SIMD[{dtype}, {helper['storage_width']}]"
params = []
components = []
prelude = []
for index, piece in enumerate(helper["pieces"]):
if piece["kind"] == "vector":
param_name = f"v{index}"
vector_type = f"SIMD[{piece['dtype']}, {piece['storage_width']}]"
params.append(f"{param_name}: {vector_type}")
vector_expr = param_name
if piece["dtype"] != dtype:
vector_expr = f"{param_name}_cast"
prelude.append(
f" var {vector_expr} = {param_name}.cast[{dtype}]()\n"
)
components.extend(
f"{vector_expr}[{component_index}]"
for component_index in piece["indices"]
)
else:
param_name = f"s{index}"
piece_dtype = piece.get("dtype")
param_scalar_type = mojo_scalar_type
if piece_dtype is not None and piece_dtype != dtype:
scalar_type = MOJO_DTYPE_INFO[piece_dtype][0]
param_scalar_type = self.map_type(scalar_type)
params.append(f"{param_name}: {param_scalar_type}")
components.append(
self.cast_scalar_text(param_name, piece_dtype, dtype)
if piece_dtype is not None
else param_name
)
if helper["pad_literal"] is not None and len(components) == 3:
components.append(helper["pad_literal"])
helper_name = self.constructor_helper_name(helper["key"])
code = f"fn {helper_name}({', '.join(params)}) -> {result_type}:\n"
code += "".join(prelude)
code += f" return {result_type}({', '.join(components)})\n\n"
return code
def register_variable_type(self, name, var_type):
if name and var_type:
self.variable_types[name] = self.type_name(var_type)
def type_name(self, type_value):
if hasattr(type_value, "name") or hasattr(type_value, "element_type"):
return self.convert_type_node_to_string(type_value)
return str(type_value)
def is_array_type_name(self, type_name):
return type_name is not None and "[" in str(type_name) and "]" in str(type_name)
def is_struct_type_name(self, type_name):
if type_name is None:
return False
return self.type_name(type_name) in self.struct_types
def array_type_name(self, element_type, size):
element_type_name = self.type_name(element_type)
if size is None:
return f"{element_type_name}[]"
return f"{element_type_name}[{size}]"
def array_storage_type(self, element_type, size):
element_type_name = self.map_type(element_type)
if size is None:
return f"List[{element_type_name}]"
return f"InlineArray[{element_type_name}, {size}]"
def array_initial_value(self, element_type, size):
array_type = self.array_storage_type(element_type, size)
if size is None:
return f"{array_type}()"
return f"{array_type}(unsafe_uninitialized=True)"
def array_initial_value_for_type(self, type_name):
element_type, size = parse_array_type(str(type_name))
return self.array_initial_value(element_type, size)
def array_element_type(self, type_name):
if not self.is_array_type_name(type_name):
return None
element_type, _ = parse_array_type(str(type_name))
return element_type
def generate_array_literal_expression(self, expr, target_type=None):
if target_type is not None and self.is_array_type_name(target_type):
element_type, size = parse_array_type(str(target_type))
else:
element_type = self.infer_array_literal_element_type(expr)
size = len(expr.elements)
array_type = self.array_storage_type(element_type, size)
elements = [
self.generate_array_literal_element(element, element_type)
for element in expr.elements
]
if size is not None:
size = int(size)
elements = elements[:size]
while len(elements) < size:
elements.append(self.zero_value_for_type(element_type))
return f"{array_type}({', '.join(elements)})"
def infer_array_literal_element_type(self, expr):
if not expr.elements:
return "float"
return self.expression_result_type(expr.elements[0]) or "float"
def generate_array_literal_element(self, element, element_type):
target_dtype = MOJO_SCALAR_DTYPES.get(self.type_name(element_type))
if target_dtype is not None:
return self.generate_constructor_scalar_expression(element, target_dtype)
return self.generate_expression(element)
def zero_value_for_type(self, type_name):
type_name = self.type_name(type_name)
if self.is_array_type_name(type_name):
element_type, size = parse_array_type(type_name)
return self.zero_array_value(element_type, size)
if type_name in self.struct_types:
return self.zero_struct_value(type_name)
vector_info = self.vector_type_info(type_name)
if vector_info is not None:
dtype, source_width, storage_width, pad_literal = vector_info
zero = MOJO_DTYPE_INFO[dtype][2]
components = [zero] * source_width
if pad_literal is not None and len(components) == 3:
components.append(pad_literal)
return f"SIMD[{dtype}, {storage_width}]({', '.join(components)})"
matrix_info = self.matrix_type_info(type_name)
if matrix_info is not None:
dtype, columns, rows = matrix_info
return self.zero_matrix_value(dtype, columns, rows)
dtype = MOJO_SCALAR_DTYPES.get(type_name)
if dtype is not None:
return MOJO_DTYPE_INFO[dtype][2]
return f"{self.map_type(type_name)}()"
def zero_array_value(self, element_type, size):
array_type = self.array_storage_type(element_type, size)
if size is None:
return f"{array_type}()"
try:
element_count = int(size)
except (TypeError, ValueError):
return f"{array_type}(unsafe_uninitialized=True)"
values = [self.zero_value_for_type(element_type) for _ in range(element_count)]
return f"{array_type}({', '.join(values)})"
def zero_struct_value(self, type_name):
fields = self.struct_types.get(type_name, {})
values = [
self.zero_value_for_type(field_type) for field_type in fields.values()
]
return f"{type_name}({', '.join(values)})"
def zero_matrix_value(self, dtype, columns, rows):
self.required_matrix_types.add((dtype, columns, rows))
matrix_type = self.matrix_type_name(dtype, columns, rows)
storage_rows = self.matrix_storage_rows(rows)
zero = self.matrix_zero_literal(dtype)
column_type = f"SIMD[{dtype}, {storage_rows}]"
column_values = []
for _ in range(columns):
components = [zero] * rows
if rows == 3:
components.append(zero)
column_values.append(f"{column_type}({', '.join(components)})")
return f"{matrix_type}({', '.join(column_values)})"
def generate_array_access_expression(self, expr):
array_type = self.expression_result_type(expr.array)
matrix_info = self.matrix_type_info(array_type)
vector_info = self.vector_type_info(array_type)
array_element_type = self.array_element_type(array_type)
array = self.generate_expression(expr.array)
index = self.generate_array_index_expression(
expr.index,
cast_integer_index=vector_info is not None
or array_element_type is not None,
)
if matrix_info is not None:
column_index = self.literal_int_value(expr.index)
if column_index is not None:
return f"{array}.c{column_index}"
return f"{array}[{index}]"
def generate_array_index_expression(self, expr, cast_integer_index=False):
index = self.generate_expression(expr)
if not cast_integer_index or self.literal_int_value(expr) is not None:
return index
index_type = self.expression_result_type(expr)
if index_type in MOJO_INTEGER_INDEX_TYPES:
return f"int({index})"
return index
def literal_int_value(self, expr):
if hasattr(expr, "value"):
try:
return int(expr.value)
except (TypeError, ValueError):
return None
if isinstance(expr, str):
try:
return int(expr)
except ValueError:
return None
return None
def expression_result_type(self, expr):
if isinstance(expr, str):
return self.variable_types.get(expr)
if isinstance(expr, VariableNode) and hasattr(expr, "name"):
return self.variable_types.get(expr.name)
if isinstance(expr, ArrayLiteralNode):
element_type = self.infer_array_literal_element_type(expr)
return self.array_type_name(element_type, len(expr.elements))
if isinstance(expr, ArrayAccessNode):
array_type = self.expression_result_type(expr.array)
array_element_type = self.array_element_type(array_type)
if array_element_type is not None:
return array_element_type
matrix_info = self.matrix_type_info(array_type)
if matrix_info is not None:
dtype, _, rows = matrix_info
return self.vector_type_name_for_dtype_width(dtype, rows)
vector_info = self.vector_type_info(array_type)
if vector_info is not None:
return MOJO_DTYPE_INFO[vector_info[0]][0]
return None
if isinstance(expr, BinaryOpNode):
left_type = self.expression_result_type(expr.left)
right_type = self.expression_result_type(expr.right)
left_info = self.vector_type_info(left_type)
right_info = self.vector_type_info(right_type)
if left_info is not None and right_info is not None:
return left_type if left_info == right_info else left_type
if left_info is not None:
return left_type
if right_info is not None:
return right_type
return left_type if left_type == right_type else left_type or right_type
if isinstance(expr, FunctionCallNode):
func_name = self.function_call_name(expr)
if func_name in self.vector_constructor_info:
return func_name
if func_name in MOJO_MATRIX_TYPES:
return func_name
return self.function_return_types.get(func_name)
if isinstance(expr, MemberAccessNode):
swizzle_indices = self.get_swizzle_indices(expr.member)
if swizzle_indices is not None:
obj_type = self.expression_result_type(expr.object)
return self.swizzle_result_type(obj_type, len(swizzle_indices))
obj_type = self.expression_result_type(expr.object)
if obj_type in self.struct_types:
return self.struct_types[obj_type].get(expr.member)
if hasattr(expr, "__class__") and "Identifier" in str(expr.__class__):
return self.variable_types.get(getattr(expr, "name", ""))
if hasattr(expr, "__class__") and "Literal" in str(expr.__class__):
literal_type = getattr(getattr(expr, "literal_type", None), "name", None)
if literal_type:
return literal_type
return None
def function_call_name(self, expr):
func_expr = getattr(expr, "function", None)
if func_expr is None:
func_expr = expr.name
if hasattr(func_expr, "name"):
return func_expr.name
if isinstance(func_expr, str):
return func_expr
return None
def vector_type_info(self, type_name):
if type_name in self.vector_constructor_info:
return self.vector_constructor_info[type_name]
return None
def matrix_type_info(self, type_name):
if type_name in MOJO_MATRIX_TYPES:
return MOJO_MATRIX_TYPES[type_name]
return None
def vector_type_name_for_dtype_width(self, dtype, width):
_, prefix, _ = MOJO_DTYPE_INFO[dtype]
return f"{prefix}{width}"
def swizzle_result_type(self, obj_type, component_count):
info = self.vector_type_info(obj_type)
dtype = info[0] if info else "DType.float32"
scalar_type, prefix, _ = MOJO_DTYPE_INFO.get(
dtype, MOJO_DTYPE_INFO["DType.float32"]
)
if component_count == 1:
return scalar_type
return f"{prefix}{component_count}"
def get_swizzle_indices(self, member):
if not member:
return None
for components in SWIZZLE_SETS.values():
if all(component in components for component in member):
return [components[component] for component in member]
return None
def expression_mojo_dtype(self, expr):
expr_type = self.expression_result_type(expr)
info = self.vector_type_info(expr_type)
if info is not None:
return info[0]
return MOJO_SCALAR_DTYPES.get(expr_type)
def is_literal_expression(self, expr):
return hasattr(expr, "__class__") and "Literal" in str(expr.__class__)
def cast_scalar_text(self, expr_text, source_dtype, target_dtype):
if target_dtype is None or source_dtype is None or source_dtype == target_dtype:
return expr_text
return f"({expr_text}).cast[{target_dtype}]()"
def cast_vector_component(self, component, source_dtype, target_dtype):
if target_dtype is None or source_dtype is None or source_dtype == target_dtype:
return component
return f"{component}.cast[{target_dtype}]()"
def generate_constructor_scalar_expression(self, expr, target_dtype):
expr_text = self.generate_expression(expr)
if self.is_literal_expression(expr):
return expr_text
return self.cast_scalar_text(
expr_text, self.expression_mojo_dtype(expr), target_dtype
)
def vector_components_for_expression(self, expr, target_dtype=None):
if isinstance(expr, MemberAccessNode):
obj = self.generate_expression(expr.object)
swizzle_indices = self.get_swizzle_indices(expr.member)
if swizzle_indices is not None:
source_info = self.vector_type_info(
self.expression_result_type(expr.object)
)
source_dtype = source_info[0] if source_info is not None else None
return [
self.cast_vector_component(
f"{obj}[{index}]", source_dtype, target_dtype
)
for index in swizzle_indices
]
expr_type = self.expression_result_type(expr)
info = self.vector_type_info(expr_type)
if info is None:
return None
source_dtype, source_width, _, _ = info
if source_width <= 1:
return None
expr_text = self.generate_expression(expr)
return [
self.cast_vector_component(
f"{expr_text}[{index}]", source_dtype, target_dtype
)
for index in range(source_width)
]
def generate_swizzle(self, source_expr, obj, obj_type, member, swizzle_indices):
if len(swizzle_indices) == 1:
return f"{obj}[{swizzle_indices[0]}]"
info = self.vector_type_info(obj_type)
dtype = info[0] if info else "DType.float32"
source_width = info[2] if info else 4
if info is not None and self.is_duplicate_sensitive_expression(source_expr):
helper_name = self.swizzle_helper_name(dtype, source_width, member)
self.required_swizzle_helpers.add((dtype, source_width, member))
return f"{helper_name}({obj})"
_, _, pad_literal = MOJO_DTYPE_INFO.get(dtype, MOJO_DTYPE_INFO["DType.float32"])
storage_width = 2 if len(swizzle_indices) == 2 else 4
components = [f"{obj}[{index}]" for index in swizzle_indices]
if len(swizzle_indices) == 3:
components.append(pad_literal)
return f"SIMD[{dtype}, {storage_width}]({', '.join(components)})"
def is_duplicate_sensitive_expression(self, expr):
if isinstance(expr, (FunctionCallNode, BinaryOpNode, TernaryOpNode)):
return True
if isinstance(expr, UnaryOpNode):
return self.is_duplicate_sensitive_expression(expr.operand)
if isinstance(expr, MemberAccessNode):
return self.is_duplicate_sensitive_expression(expr.object)
if isinstance(expr, ArrayAccessNode):
return self.is_duplicate_sensitive_expression(
expr.array
) or self.is_duplicate_sensitive_expression(expr.index)
return False
def format_literal(self, value, literal_type=None):
if isinstance(value, bool):
return "True" if value else "False"
if literal_type == "bool" and isinstance(value, str):
lower_value = value.lower()
if lower_value == "true":
return "True"
if lower_value == "false":
return "False"
if isinstance(value, str):
escaped = self.escape_literal(value)
return f'"{escaped}"'
return str(value)
def escape_literal(self, value):
text = str(value)
escaped = []
for index, char in enumerate(text):
if char == "\n":
escaped.append("\\n")
elif char == "\r":
escaped.append("\\r")
elif char == "\t":
escaped.append("\\t")
elif char == '"' and (index == 0 or text[index - 1] != "\\"):
escaped.append('\\"')
else:
escaped.append(char)
return "".join(escaped)
def map_type(self, vtype):
"""Map a CrossGL type name or type node to a Mojo type string."""
if vtype is None:
return "Float32"
if hasattr(vtype, "name") or hasattr(vtype, "element_type"):
vtype_str = self.convert_type_node_to_string(vtype)
else:
vtype_str = str(vtype)
if "[" in vtype_str and "]" in vtype_str:
base_type, size = parse_array_type(vtype_str)
base_mapped = self.map_type(base_type)
if size:
return f"InlineArray[{base_mapped}, {size}]"
else:
return f"List[{base_mapped}]"
if vtype_str in MOJO_MATRIX_TYPES:
dtype, columns, rows = MOJO_MATRIX_TYPES[vtype_str]
self.required_matrix_types.add((dtype, columns, rows))
return self.matrix_type_name(dtype, columns, rows)
return self.type_mapping.get(vtype_str, vtype_str)
def map_operator(self, op):
op_map = {
"PLUS": "+",
"MINUS": "-",
"MULTIPLY": "*",
"DIVIDE": "/",
"BITWISE_XOR": "^",
"BITWISE_OR": "|",
"BITWISE_AND": "&",
"LESS_THAN": "<",
"GREATER_THAN": ">",
"ASSIGN_ADD": "+=",
"ASSIGN_SUB": "-=",
"ASSIGN_MUL": "*=",
"ASSIGN_DIV": "/=",
"ASSIGN_MOD": "%=",
"ASSIGN_XOR": "^=",
"ASSIGN_OR": "|=",
"ASSIGN_AND": "&=",
"LESS_EQUAL": "<=",
"GREATER_EQUAL": ">=",
"EQUAL": "==",
"NOT_EQUAL": "!=",
"AND": "and",
"OR": "or",
"EQUALS": "=",
"ASSIGN_SHIFT_LEFT": "<<=",
"ASSIGN_SHIFT_RIGHT": ">>=",
"LOGICAL_AND": "and",
"LOGICAL_OR": "or",
"BITWISE_SHIFT_RIGHT": ">>",
"BITWISE_SHIFT_LEFT": "<<",
"MOD": "%",
"NOT": "not",
}
return op_map.get(op, op)
def map_semantic(self, semantic):
"""Map a CrossGL semantic to the Mojo backend attribute name."""
if semantic:
return self.semantic_map.get(semantic, semantic)
return ""