"""Reverse code generator that emits CrossGL from Mojo AST nodes."""
from .MojoAst import *
from .MojoParser import *
from .MojoLexer import *
[docs]
class MojoToCrossGLConverter:
"""Serialize Mojo backend AST nodes back into CrossGL source."""
def __init__(self):
"""Initialize Mojo-to-CrossGL type, semantic, and function mappings."""
self.type_map = {
# Scalar Types
"void": "void",
"Int": "int",
"Int8": "int8_t",
"Int16": "int16_t",
"Int32": "int",
"Int64": "int64_t",
"UInt": "uint",
"UInt8": "uint8_t",
"UInt16": "uint16_t",
"UInt32": "uint",
"UInt64": "uint64_t",
"Float": "float",
"Float16": "half",
"Float32": "float",
"Float64": "double",
"Bool": "bool",
# Vector Types (Mojo SIMD types)
"SIMD[DType.float32, 2]": "vec2",
"SIMD[DType.float32, 3]": "vec3",
"SIMD[DType.float32, 4]": "vec4",
"SIMD[DType.int32, 2]": "ivec2",
"SIMD[DType.int32, 3]": "ivec3",
"SIMD[DType.int32, 4]": "ivec4",
"SIMD[DType.uint32, 2]": "uvec2",
"SIMD[DType.uint32, 3]": "uvec3",
"SIMD[DType.uint32, 4]": "uvec4",
"SIMD[DType.bool, 2]": "bvec2",
"SIMD[DType.bool, 3]": "bvec3",
"SIMD[DType.bool, 4]": "bvec4",
# Matrix Types
"Matrix[DType.float32, 2, 2]": "mat2",
"Matrix[DType.float32, 3, 3]": "mat3",
"Matrix[DType.float32, 4, 4]": "mat4",
"Matrix[DType.float16, 2, 2]": "half2x2",
"Matrix[DType.float16, 3, 3]": "half3x3",
"Matrix[DType.float16, 4, 4]": "half4x4",
# Texture Types (hypothetical Mojo texture types)
"Texture2D": "sampler2D",
"TextureCube": "samplerCube",
"Texture3D": "sampler3D",
# Common simplified aliases
"float2": "vec2",
"float3": "vec3",
"float4": "vec4",
"int2": "ivec2",
"int3": "ivec3",
"int4": "ivec4",
"uint2": "uvec2",
"uint3": "uvec3",
"uint4": "uvec4",
"bool2": "bvec2",
"bool3": "bvec3",
"bool4": "bvec4",
"mat2": "mat2",
"mat3": "mat3",
"mat4": "mat4",
}
self.semantic_map = {
# Vertex attributes
"position": "Position",
"normal": "Normal",
"tangent": "Tangent",
"binormal": "Binormal",
"texcoord": "TexCoord",
"texcoord0": "TexCoord0",
"texcoord1": "TexCoord1",
"texcoord2": "TexCoord2",
"texcoord3": "TexCoord3",
"color": "Color",
"color0": "Color0",
"color1": "Color1",
"color2": "Color2",
"color3": "Color3",
# Built-in variables
"vertex_id": "gl_VertexID",
"instance_id": "gl_InstanceID",
"gl_position": "gl_Position",
"gl_fragcolor": "gl_FragColor",
"gl_fragdepth": "gl_FragDepth",
"gl_frontfacing": "gl_IsFrontFace",
"gl_pointcoord": "gl_PointCoord",
"gl_primitiveId": "gl_PrimitiveID",
}
self.function_map = {
# Math functions
"math.sqrt": "sqrt",
"math.pow": "pow",
"math.sin": "sin",
"math.cos": "cos",
"math.tan": "tan",
"math.abs": "abs",
"math.floor": "floor",
"math.ceil": "ceil",
"math.min": "min",
"math.max": "max",
"math.clamp": "clamp",
# SIMD functions
"simd.dot": "dot",
"simd.cross": "cross",
"simd.normalize": "normalize",
"simd.length": "length",
"simd.reflect": "reflect",
"simd.refract": "refract",
}
[docs]
def generate(self, ast):
"""Generate complete CrossGL source from a parsed Mojo AST."""
code = "shader main {\n"
if hasattr(ast, "functions") and ast.functions:
imports = [f for f in ast.functions if isinstance(f, ImportNode)]
if imports:
code += " // Imports\n"
for imp in imports:
code += f" // import {imp.module_name}"
if imp.alias:
code += f" as {imp.alias}"
code += "\n"
code += "\n"
if hasattr(ast, "functions"):
structs = [f for f in ast.functions if isinstance(f, StructNode)]
for struct_node in structs:
code += f" struct {struct_node.name} {{\n"
for member in struct_node.members:
semantic = (
self.map_semantic(member.attributes)
if hasattr(member, "attributes")
else ""
)
code += f" {self.map_type(member.vtype)} {member.name}{semantic};\n"
code += " }\n\n"
if hasattr(ast, "functions"):
cbuffers = [f for f in ast.functions if isinstance(f, ConstantBufferNode)]
if cbuffers:
code += " // Constant Buffers\n"
for cbuffer in cbuffers:
code += f" cbuffer {cbuffer.name} {{\n"
for member in cbuffer.members:
code += (
f" {self.map_type(member.vtype)} {member.name};\n"
)
code += " }\n\n"
if hasattr(ast, "functions"):
functions = [f for f in ast.functions if isinstance(f, FunctionNode)]
for func in functions:
if (
func.qualifier == "vertex"
or self.has_vertex_attribute(func)
or "vertex" in func.name
):
code += " // Vertex Shader\n"
code += " vertex {\n"
code += self.generate_function(func, indent=2)
code += " }\n\n"
elif (
func.qualifier == "fragment"
or self.has_fragment_attribute(func)
or "fragment" in func.name
):
code += " // Fragment Shader\n"
code += " fragment {\n"
code += self.generate_function(func, indent=2)
code += " }\n\n"
elif (
func.qualifier == "compute"
or self.has_compute_attribute(func)
or "compute" in func.name
):
code += " // Compute Shader\n"
code += " compute {\n"
code += self.generate_function(func, indent=2)
code += " }\n\n"
else:
code += self.generate_function(func, indent=1)
code += "}\n"
return code
def has_vertex_attribute(self, func):
if not hasattr(func, "attributes") or not func.attributes:
return False
for attr in func.attributes:
if hasattr(attr, "name") and attr.name in ["vertex", "vertex_main"]:
return True
return False
def has_fragment_attribute(self, func):
if not hasattr(func, "attributes") or not func.attributes:
return False
for attr in func.attributes:
if hasattr(attr, "name") and attr.name in ["fragment", "fragment_main"]:
return True
return False
def has_compute_attribute(self, func):
if not hasattr(func, "attributes") or not func.attributes:
return False
for attr in func.attributes:
if hasattr(attr, "name") and attr.name in ["compute", "compute_main"]:
return True
return False
[docs]
def generate_function(self, func, indent=1):
"""Render one Mojo function node as a CrossGL function."""
code = ""
indent_str = " " * indent
params = []
if hasattr(func, "params") and func.params:
for p in func.params:
param_str = f"{self.map_type(p.vtype)} {p.name}"
if hasattr(p, "attributes") and p.attributes:
semantic = self.map_semantic(p.attributes)
if semantic:
param_str += f" {semantic}"
params.append(param_str)
params_str = ", ".join(params) if params else ""
return_type = self.map_type(func.return_type) if func.return_type else "void"
func_attributes = ""
if hasattr(func, "attributes") and func.attributes:
func_attributes = self.map_semantic(func.attributes)
code += (
f"{indent_str}{return_type} {func.name}({params_str}){func_attributes} {{\n"
)
if hasattr(func, "body") and func.body:
code += self.generate_function_body(func.body, indent + 1)
code += f"{indent_str}}}\n\n"
return code
def generate_function_body(self, body, indent=0):
code = ""
indent_str = " " * indent
for stmt in body:
code += indent_str
if isinstance(stmt, VariableDeclarationNode):
code += self.generate_variable_declaration(stmt) + ";\n"
elif isinstance(stmt, VariableNode):
if hasattr(stmt, "vtype") and stmt.vtype:
code += f"{self.map_type(stmt.vtype)} {stmt.name};\n"
else:
code += f"{stmt.name};\n"
elif isinstance(stmt, AssignmentNode):
code += self.generate_assignment(stmt) + ";\n"
elif isinstance(stmt, ReturnNode):
code += f"return {self.generate_expression(stmt.value)};\n"
elif isinstance(stmt, BinaryOpNode):
code += f"{self.generate_expression(stmt)};\n"
elif isinstance(stmt, ForNode):
code += self.generate_for_loop(stmt, indent)
elif isinstance(stmt, WhileNode):
code += self.generate_while_loop(stmt, indent)
elif isinstance(stmt, IfNode):
code += self.generate_if_statement(stmt, indent)
elif isinstance(stmt, SwitchNode):
code += self.generate_switch_statement(stmt, indent)
elif isinstance(stmt, FunctionCallNode):
code += f"{self.generate_expression(stmt)};\n"
elif isinstance(stmt, str):
code += f"{stmt};\n"
else:
# For any unhandled statement type
code += f"// Unhandled statement type: {type(stmt).__name__}\n"
return code
def generate_variable_declaration(self, node):
var_type = "var" if node.var_type == "var" else "let"
if hasattr(node, "initial_value") and node.initial_value:
return f"{var_type} {node.name} = {self.generate_expression(node.initial_value)}"
else:
return f"{var_type} {node.name}"
def generate_assignment(self, node):
left = self.generate_expression(node.left)
right = self.generate_expression(node.right)
op = node.operator if hasattr(node, "operator") else "="
return f"{left} {op} {right}"
def generate_for_loop(self, node, indent):
indent_str = " " * indent
init = self.generate_expression(node.init) if node.init else ""
condition = (
self.generate_expression(node.condition) if node.condition else "true"
)
update = self.generate_expression(node.update) if node.update else ""
code = f"for ({init}; {condition}; {update}) {{\n"
if hasattr(node, "body") and node.body:
code += self.generate_function_body(node.body, indent + 1)
code += indent_str + "}\n"
return code
def generate_while_loop(self, node, indent):
indent_str = " " * indent
condition = (
self.generate_expression(node.condition) if node.condition else "true"
)
code = f"while ({condition}) {{\n"
if hasattr(node, "body") and node.body:
code += self.generate_function_body(node.body, indent + 1)
code += indent_str + "}\n"
return code
def generate_if_statement(self, node, indent):
indent_str = " " * indent
condition = self.generate_expression(node.condition)
code = f"if ({condition}) {{\n"
if hasattr(node, "if_body") and node.if_body:
code += self.generate_function_body(node.if_body, indent + 1)
code += indent_str + "}"
if hasattr(node, "else_body") and node.else_body:
if isinstance(node.else_body, list):
if len(node.else_body) == 1 and isinstance(node.else_body[0], IfNode):
code += " else "
code += self.generate_if_statement(node.else_body[0], indent)
else:
code += " else {\n"
code += self.generate_function_body(node.else_body, indent + 1)
code += indent_str + "}"
elif isinstance(node.else_body, IfNode):
code += " else "
code += self.generate_if_statement(node.else_body, indent)
else:
code += " else {\n"
code += self.generate_function_body([node.else_body], indent + 1)
code += indent_str + "}"
code += "\n"
return code
def generate_switch_statement(self, node, indent):
indent_str = " " * indent
expression = self.generate_expression(node.expression)
code = f"switch ({expression}) {{\n"
if hasattr(node, "cases") and node.cases:
for case in node.cases:
if hasattr(case, "condition") and case.condition is not None:
case_value = self.generate_expression(case.condition)
code += indent_str + f" case {case_value}:\n"
else:
code += indent_str + " default:\n"
if hasattr(case, "body") and case.body:
code += self.generate_function_body(case.body, indent + 2)
code += indent_str + " break;\n"
code += indent_str + "}\n"
return code
[docs]
def generate_expression(self, expr):
"""Render a Mojo backend expression node as CrossGL syntax."""
if expr is None:
return ""
elif isinstance(expr, str):
return expr
elif isinstance(expr, (int, float, bool)):
return str(expr)
elif isinstance(expr, VariableNode):
if hasattr(expr, "vtype") and expr.vtype:
return f"{self.map_type(expr.vtype)} {expr.name}"
else:
return expr.name
elif isinstance(expr, VariableDeclarationNode):
return self.generate_variable_declaration(expr)
elif isinstance(expr, AssignmentNode):
return self.generate_assignment(expr)
elif isinstance(expr, BinaryOpNode):
left = self.generate_expression(expr.left)
right = self.generate_expression(expr.right)
op = expr.op if hasattr(expr, "op") else "+"
return f"({left} {op} {right})"
elif isinstance(expr, UnaryOpNode):
operand = self.generate_expression(expr.operand)
op = expr.op if hasattr(expr, "op") else "+"
return f"({op}{operand})"
elif isinstance(expr, FunctionCallNode):
args = []
if hasattr(expr, "args") and expr.args:
args = [self.generate_expression(arg) for arg in expr.args]
args_str = ", ".join(args)
func_name = self.map_function(expr.name)
return f"{func_name}({args_str})"
elif isinstance(expr, MemberAccessNode):
obj = self.generate_expression(expr.object)
return f"{obj}.{expr.member}"
elif isinstance(expr, TernaryOpNode):
condition = self.generate_expression(expr.condition)
true_expr = self.generate_expression(expr.true_expr)
false_expr = self.generate_expression(expr.false_expr)
return f"({condition} ? {true_expr} : {false_expr})"
elif isinstance(expr, VectorConstructorNode):
args = []
if hasattr(expr, "args") and expr.args:
args = [self.generate_expression(arg) for arg in expr.args]
args_str = ", ".join(args)
type_name = self.map_type(expr.type_name)
return f"{type_name}({args_str})"
elif isinstance(expr, ArrayAccessNode):
array = self.generate_expression(expr.array)
index = self.generate_expression(expr.index)
return f"{array}[{index}]"
else:
# For any unhandled expression type
return f"/* Unhandled expression: {type(expr).__name__} */"
[docs]
def map_type(self, mojo_type):
"""Map a Mojo type name to the closest CrossGL type name."""
if mojo_type is None:
return "void"
return self.type_map.get(mojo_type, mojo_type)
[docs]
def map_semantic(self, attributes):
"""Map Mojo decorators or attributes to CrossGL semantic annotations."""
if not attributes:
return ""
for attr in attributes:
if hasattr(attr, "name"):
semantic = self.semantic_map.get(attr.name, attr.name)
return f"@ {semantic}"
return ""
def map_function(self, func_name):
return self.function_map.get(func_name, func_name)