"""Reverse code generator that emits CrossGL from Rust AST nodes."""
from .RustAst import *
from .RustParser import *
from .RustLexer import *
[docs]
class RustToCrossGLConverter:
"""Serialize Rust backend AST nodes back into CrossGL source."""
def __init__(self):
"""Initialize Rust-to-CrossGL type and semantic mappings."""
self.type_map = {
# Rust primitive types to CrossGL
"()": "void",
"f32": "float",
"f64": "double",
"i32": "int",
"i64": "int64_t",
"u32": "uint",
"u64": "uint64_t",
"i8": "int8_t",
"u8": "uint8_t",
"i16": "int16_t",
"u16": "uint16_t",
"bool": "bool",
"usize": "uint",
"isize": "int",
# Rust vector types to CrossGL
"Vec2<f32>": "vec2",
"Vec3<f32>": "vec3",
"Vec4<f32>": "vec4",
"Vec2<i32>": "ivec2",
"Vec3<i32>": "ivec3",
"Vec4<i32>": "ivec4",
"Vec2<u32>": "uvec2",
"Vec3<u32>": "uvec3",
"Vec4<u32>": "uvec4",
"Vec2<bool>": "bvec2",
"Vec3<bool>": "bvec3",
"Vec4<bool>": "bvec4",
# Simplified vector names
"Vec2": "vec2",
"Vec3": "vec3",
"Vec4": "vec4",
# Rust matrix types to CrossGL
"Mat2<f32>": "mat2",
"Mat3<f32>": "mat3",
"Mat4<f32>": "mat4",
"Mat2": "mat2",
"Mat3": "mat3",
"Mat4": "mat4",
# GPU-specific types
"Texture2D": "sampler2D",
"TextureCube": "samplerCube",
"Sampler": "sampler",
# Reference types (stripped in CrossGL)
"&f32": "float",
"&i32": "int",
"&u32": "uint",
"&bool": "bool",
"&mut f32": "float",
"&mut i32": "int",
"&mut u32": "uint",
"&mut bool": "bool",
}
self.semantic_map = {
# Rust shader attributes to CrossGL semantics
"vertex_position": "Position",
"vertex_normal": "Normal",
"vertex_tangent": "Tangent",
"vertex_binormal": "Binormal",
"vertex_texcoord": "TexCoord",
"vertex_texcoord0": "TexCoord0",
"vertex_texcoord1": "TexCoord1",
"vertex_texcoord2": "TexCoord2",
"vertex_texcoord3": "TexCoord3",
"vertex_color": "Color",
"instance_id": "InstanceID",
"vertex_id": "VertexID",
# Fragment shader semantics
"fragment_position": "gl_Position",
"fragment_color": "gl_FragColor",
"fragment_depth": "gl_FragDepth",
"front_face": "gl_IsFrontFace",
"primitive_id": "gl_PrimitiveID",
"point_coord": "gl_PointCoord",
# Compute shader semantics
"local_invocation_id": "gl_LocalInvocationID",
"global_invocation_id": "gl_GlobalInvocationID",
"workgroup_id": "gl_WorkGroupID",
"num_workgroups": "gl_NumWorkGroups",
}
self.function_map = {
# Rust math functions to CrossGL
"abs": "abs",
"min": "min",
"max": "max",
"clamp": "clamp",
"floor": "floor",
"ceil": "ceil",
"round": "round",
"sqrt": "sqrt",
"pow": "pow",
"exp": "exp",
"exp2": "exp2",
"log": "log",
"log2": "log2",
"sin": "sin",
"cos": "cos",
"tan": "tan",
"asin": "asin",
"acos": "acos",
"atan": "atan",
"atan2": "atan2",
"sinh": "sinh",
"cosh": "cosh",
"tanh": "tanh",
"degrees": "degrees",
"radians": "radians",
# Vector operations
"dot": "dot",
"cross": "cross",
"length": "length",
"normalize": "normalize",
"distance": "distance",
"reflect": "reflect",
"refract": "refract",
"faceforward": "faceforward",
# Matrix operations
"transpose": "transpose",
"determinant": "determinant",
"inverse": "inverse",
# Texture sampling
"sample": "texture",
"sample_level": "textureLod",
"sample_grad": "textureGrad",
"sample_offset": "textureOffset",
}
self.attribute_map = {
# Rust shader attributes to CrossGL qualifiers
"vertex_shader": "vertex",
"fragment_shader": "fragment",
"compute_shader": "compute",
"binding": "binding",
"location": "location",
"group": "group",
"builtin": "builtin",
"stage": "stage",
"workgroup_size": "workgroup_size",
}
self.indentation = 0
self.code = []
self.type_aliases = {}
[docs]
def get_indent(self):
"""Return whitespace for the current indentation level."""
return " " * self.indentation
[docs]
def visit(self, node):
"""Dispatch a Rust AST node to a visitor method when available."""
if isinstance(node, StructNode):
return self.visit_StructNode(node)
elif isinstance(node, FunctionNode):
return self.visit_FunctionNode(node)
elif isinstance(node, BinaryOpNode):
return self.visit_BinaryOpNode(node)
elif isinstance(node, UnaryOpNode):
return self.visit_UnaryOpNode(node)
elif isinstance(node, ImplNode):
return self.visit_ImplNode(node)
elif isinstance(node, UseNode):
return self.visit_UseNode(node)
elif isinstance(node, ConstNode):
return self.visit_ConstNode(node)
elif isinstance(node, StaticNode):
return self.visit_StaticNode(node)
# For other node types, use existing methods
if hasattr(self, f"generate_{type(node).__name__}"):
method = getattr(self, f"generate_{type(node).__name__}")
return method(node)
return self.generate_expression(node)
[docs]
def generate(self, ast):
"""Generate complete CrossGL source from a parsed Rust AST."""
self.type_aliases = {}
code = "shader main {\n"
for use_stmt in ast.use_statements:
code += f" // use {use_stmt.path}\n"
for alias in getattr(ast, "type_aliases", []):
self.register_type_alias(alias)
for alias in getattr(ast, "type_aliases", []):
alias_code = self.generate_type_alias(alias)
if alias_code:
code += alias_code
for global_var in ast.global_variables:
if isinstance(global_var, ConstNode):
declarator = self.format_typed_declarator(
global_var.vtype, global_var.name
)
value = self.generate_expression(global_var.value)
code += f" const {declarator} = {value};\n"
elif isinstance(global_var, StaticNode):
mutability = "mut " if global_var.is_mutable else ""
declarator = self.format_typed_declarator(
global_var.vtype, global_var.name
)
value = self.generate_expression(global_var.value)
code += f" static {mutability}{declarator} = {value};\n"
for struct in ast.structs:
if isinstance(struct, StructNode):
code += f" struct {struct.name} {{\n"
for member in struct.members:
semantic = self.get_semantic_from_attributes(member.attributes)
declarator = self.format_typed_declarator(member.vtype, member.name)
code += f" {declarator}{semantic};\n"
code += " }\n\n"
for func in ast.functions:
shader_type = self.get_shader_type_from_attributes(func.attributes)
if shader_type:
code += f" // {shader_type.title()} Shader\n"
code += f" {shader_type} {{\n"
code += self.generate_function_body(func.body, indent=2)
code += " }\n\n"
else:
code += self.generate_function(func, indent=1)
for impl_block in ast.impl_blocks:
code += f" // Implementation for {impl_block.struct_name}\n"
for func in impl_block.functions:
code += self.generate_function(
func, indent=1, struct_name=impl_block.struct_name
)
code += "}\n"
return code
def register_type_alias(self, alias):
if getattr(alias, "name", None):
self.type_aliases[alias.name] = alias
def generate_type_alias(self, alias):
if getattr(alias, "generics", None) or not getattr(alias, "alias_type", None):
return ""
declarator = self.format_typed_declarator(alias.alias_type, alias.name)
return f" typedef {declarator};\n"
[docs]
def generate_function(self, func, indent=1, struct_name=None):
"""Render one Rust function node as a CrossGL function."""
code = ""
indent_str = " " * indent
params = []
for param in func.params:
params.append(self.format_typed_declarator(param.vtype, param.name))
params_str = ", ".join(params)
return_type = self.map_type(func.return_type)
if struct_name:
func_name = f"{struct_name}_{func.name}"
else:
func_name = func.name
code += f"{indent_str}{return_type} {func_name}({params_str}) {{\n"
code += self.generate_function_body(func.body, indent=indent + 1)
code += f"{indent_str}}}\n\n"
return code
def generate_function_body(self, body, indent=1):
code = ""
indent_str = " " * indent
for stmt in body:
if isinstance(stmt, LetNode):
code += self.generate_let_statement(stmt, indent)
elif isinstance(stmt, AssignmentNode):
code += f"{indent_str}{self.generate_assignment(stmt)};\n"
elif isinstance(stmt, ReturnNode):
if stmt.value:
code += (
f"{indent_str}return {self.generate_expression(stmt.value)};\n"
)
else:
code += f"{indent_str}return;\n"
elif isinstance(stmt, IfNode):
code += self.generate_if_statement(stmt, indent)
elif isinstance(stmt, ForNode):
code += self.generate_for_loop(stmt, indent)
elif isinstance(stmt, WhileNode):
code += self.generate_while_loop(stmt, indent)
elif isinstance(stmt, LoopNode):
code += self.generate_loop(stmt, indent)
elif isinstance(stmt, MatchNode):
code += self.generate_match_statement(stmt, indent)
elif isinstance(stmt, BreakNode):
code += f"{indent_str}break;\n"
elif isinstance(stmt, ContinueNode):
code += f"{indent_str}continue;\n"
elif isinstance(stmt, FunctionCallNode):
code += f"{indent_str}{self.generate_expression(stmt)};\n"
elif isinstance(stmt, BinaryOpNode):
code += f"{indent_str}{self.generate_expression(stmt)};\n"
elif isinstance(stmt, str):
code += f"{indent_str}{stmt};\n"
else:
expr = self.generate_expression(stmt)
if expr:
code += f"{indent_str}{expr};\n"
return code
def generate_let_statement(self, stmt, indent):
indent_str = " " * indent
type_str = ""
if stmt.vtype:
type_str = self.format_typed_declarator(stmt.vtype, stmt.name)
elif stmt.value:
type_str = ""
if stmt.value:
value_str = self.generate_expression(stmt.value)
if stmt.vtype:
return f"{indent_str}{type_str} = {value_str};\n"
return f"{indent_str}{stmt.name} = {value_str};\n"
else:
if stmt.vtype:
return f"{indent_str}{type_str};\n"
return f"{indent_str}{stmt.name};\n"
def generate_assignment(self, node):
left = self.generate_expression(node.left)
right = self.generate_expression(node.right)
return f"{left} {node.operator} {right}"
def generate_if_statement(self, node, indent):
indent_str = " " * indent
condition = self.generate_expression(node.condition)
code = f"{indent_str}if ({condition}) {{\n"
code += self.generate_function_body(node.if_body, indent + 1)
code += f"{indent_str}}}"
if node.else_body:
if (
isinstance(node.else_body, list)
and len(node.else_body) == 1
and isinstance(node.else_body[0], IfNode)
):
code += " else "
code += self.generate_if_statement(node.else_body[0], 0).lstrip()
else:
code += " else {\n"
if isinstance(node.else_body, list):
code += self.generate_function_body(node.else_body, indent + 1)
else:
code += self.generate_function_body([node.else_body], indent + 1)
code += f"{indent_str}}}"
code += "\n"
return code
def generate_for_loop(self, node, indent):
indent_str = " " * indent
pattern = node.pattern
iterable = self.generate_expression(node.iterable)
# Convert Rust for-in loop to C-style for loop
code = f"{indent_str}for (int {pattern} = 0; {pattern} < {iterable}; {pattern}++) {{\n"
code += self.generate_function_body(node.body, indent + 1)
code += f"{indent_str}}}\n"
return code
def generate_while_loop(self, node, indent):
indent_str = " " * indent
condition = self.generate_expression(node.condition)
code = f"{indent_str}while ({condition}) {{\n"
code += self.generate_function_body(node.body, indent + 1)
code += f"{indent_str}}}\n"
return code
def generate_loop(self, node, indent):
indent_str = " " * indent
# Convert Rust infinite loop to while(true)
code = f"{indent_str}while (true) {{\n"
code += self.generate_function_body(node.body, indent + 1)
code += f"{indent_str}}}\n"
return code
def generate_match_statement(self, node, indent):
indent_str = " " * indent
expression = self.generate_expression(node.expression)
# Convert Rust match to switch statement
code = f"{indent_str}switch ({expression}) {{\n"
for arm in node.arms:
pattern = self.generate_expression(arm.pattern)
code += f"{indent_str} case {pattern}:\n"
code += self.generate_function_body(arm.body, indent + 2)
code += f"{indent_str} break;\n"
code += f"{indent_str}}}\n"
return code
[docs]
def generate_expression(self, expr):
"""Render a Rust backend expression node as CrossGL syntax."""
if isinstance(expr, str):
return expr
elif isinstance(expr, VariableNode):
return expr.name
elif isinstance(expr, BinaryOpNode):
left = self.generate_expression(expr.left)
right = self.generate_expression(expr.right)
return f"({left} {expr.op} {right})"
elif isinstance(expr, UnaryOpNode):
operand = self.generate_expression(expr.operand)
return f"({expr.op}{operand})"
elif isinstance(expr, FunctionCallNode):
if isinstance(expr.name, str):
constructor = self.format_path_constructor_call(expr.name, expr.args)
if constructor is not None:
return constructor
func_name = self.map_function(expr.name)
else:
func_name = self.generate_expression(expr.name)
args = ", ".join(self.generate_expression(arg) for arg in expr.args)
return f"{func_name}({args})"
elif isinstance(expr, MemberAccessNode):
obj = self.generate_expression(expr.object)
return f"{obj}.{expr.member}"
elif isinstance(expr, ArrayAccessNode):
array = self.generate_expression(expr.array)
index = self.generate_expression(expr.index)
return f"{array}[{index}]"
elif isinstance(expr, VectorConstructorNode):
args = ", ".join(self.generate_expression(arg) for arg in expr.args)
type_name = self.map_type(expr.type_name)
return f"{type_name}({args})"
elif isinstance(expr, TernaryOpNode):
condition = self.generate_expression(expr.condition)
true_expr = self.generate_expression(expr.true_expr)
false_expr = self.generate_expression(expr.false_expr)
return f"({condition} ? {true_expr} : {false_expr})"
elif isinstance(expr, CastNode):
expression = self.generate_expression(expr.expression)
target_type = self.map_type(expr.target_type)
return f"({target_type}){expression}"
elif isinstance(expr, ReferenceNode):
# References are handled differently in CrossGL
return self.generate_expression(expr.expression)
elif isinstance(expr, DereferenceNode):
# Dereferences are typically not needed in CrossGL
return self.generate_expression(expr.expression)
elif isinstance(expr, TupleNode):
# Tuples might not be directly supported, convert to struct or multiple variables
elements = ", ".join(
self.generate_expression(elem) for elem in expr.elements
)
return f"({elements})"
elif isinstance(expr, ArrayNode):
if expr.size is not None and len(expr.elements) == 1:
repeated = self.expand_repeated_array_literal(
expr.elements[0], expr.size
)
if repeated is not None:
return repeated
elements = ", ".join(
self.generate_expression(elem) for elem in expr.elements
)
return f"{{{elements}}}"
elif isinstance(expr, BlockNode):
# Block expressions might need special handling
if expr.expression:
return self.generate_expression(expr.expression)
else:
return ""
elif isinstance(expr, (int, float, bool)):
return str(expr).lower() if isinstance(expr, bool) else str(expr)
else:
return str(expr)
def format_path_constructor_call(self, function_name, args):
if not function_name.endswith("::new"):
return None
type_name = function_name[: -len("::new")]
mapped_type = self.map_type(type_name)
if mapped_type == type_name:
return None
args_str = ", ".join(self.generate_expression(arg) for arg in args)
return f"{mapped_type}({args_str})"
[docs]
def map_type(self, rust_type):
"""Map a Rust type name to the closest CrossGL type name."""
if not rust_type:
return "void"
referenced_type = self.strip_reference_type(rust_type)
if referenced_type != rust_type:
return self.map_type(referenced_type)
array_parts = self.split_array_type(rust_type)
if array_parts:
base_type, array_suffix = array_parts
return f"{self.map_type(base_type)}{array_suffix}"
resolved_alias = self.resolve_type_alias(rust_type)
if resolved_alias is not None:
return resolved_alias
if "<" in rust_type and ">" in rust_type:
base_type = rust_type.split("<")[0]
if base_type in ["Vec2", "Vec3", "Vec4"]:
return self.type_map.get(
rust_type, self.type_map.get(base_type, rust_type)
)
return self.type_map.get(rust_type, rust_type)
def resolve_type_alias(self, rust_type):
for alias_name in self.type_alias_lookup_names(rust_type):
alias = self.type_aliases.get(alias_name)
if alias is None:
continue
if getattr(alias, "generics", None):
return None
return alias.name
generic = self.parse_generic_type(rust_type)
if generic is None:
return None
base_name, args = generic
alias = None
for alias_name in self.type_alias_lookup_names(base_name):
alias = self.type_aliases.get(alias_name)
if alias is not None:
break
if alias is None:
return None
generics = getattr(alias, "generics", []) or []
if not generics or len(generics) != len(args):
return None
substituted = self.substitute_type_parameters(alias.alias_type, generics, args)
if substituted == rust_type:
return None
return self.map_type(substituted)
def type_alias_lookup_names(self, type_name):
names = [type_name]
if "::" in type_name:
names.append(type_name.rsplit("::", 1)[-1])
return names
def parse_generic_type(self, rust_type):
if not rust_type or "<" not in rust_type or not rust_type.endswith(">"):
return None
base_name, remainder = rust_type.split("<", 1)
base_name = base_name.strip()
if not base_name:
return None
args_text = remainder[:-1]
args = self.split_generic_arguments(args_text)
if not args:
return None
return base_name, args
def split_generic_arguments(self, args_text):
args = []
current = []
depth = 0
for char in args_text:
if char == "," and depth == 0:
arg = "".join(current).strip()
if arg:
args.append(arg)
current = []
continue
if char in "<[(":
depth += 1
elif char in ">])":
depth = max(0, depth - 1)
current.append(char)
arg = "".join(current).strip()
if arg:
args.append(arg)
return args
def substitute_type_parameters(self, alias_type, generics, args):
substitutions = {
self.generic_parameter_name(generic): arg
for generic, arg in zip(generics, args)
}
result = alias_type
for name, replacement in substitutions.items():
result = self.replace_type_identifier(result, name, replacement)
return result
def generic_parameter_name(self, generic):
return generic.split(":", 1)[0].strip()
def replace_type_identifier(self, text, name, replacement):
result = []
index = 0
name_len = len(name)
while index < len(text):
if (
text.startswith(name, index)
and self.is_type_identifier_boundary(text, index - 1)
and self.is_type_identifier_boundary(text, index + name_len)
):
result.append(replacement)
index += name_len
else:
result.append(text[index])
index += 1
return "".join(result)
def is_type_identifier_boundary(self, text, index):
if index < 0 or index >= len(text):
return True
return not (text[index].isalnum() or text[index] == "_")
def strip_reference_type(self, type_name):
if type_name.startswith("&mut "):
return type_name[5:].strip()
if type_name.startswith("&"):
return type_name[1:].strip()
return type_name
def split_array_type(self, type_name):
if not type_name or "[" not in type_name:
return None
base_type = type_name.split("[", 1)[0]
array_suffix = type_name[len(base_type) :]
if not base_type or not array_suffix:
return None
return base_type, array_suffix
def format_typed_declarator(self, type_name, name):
mapped_type = self.map_type(type_name)
array_parts = self.split_array_type(mapped_type)
if array_parts:
base_type, array_suffix = array_parts
return f"{base_type} {name}{array_suffix}"
return f"{mapped_type} {name}"
def expand_repeated_array_literal(self, element, size):
try:
count = int(self.generate_expression(size))
except (TypeError, ValueError):
return None
if count < 0:
return None
value = self.generate_expression(element)
return "{" + ", ".join(value for _ in range(count)) + "}"
def map_function(self, rust_func):
return self.function_map.get(rust_func, rust_func)
def get_shader_type_from_attributes(self, attributes):
if not attributes:
return None
for attr in attributes:
if isinstance(attr, AttributeNode):
mapped = self.attribute_map.get(attr.name)
if mapped in ["vertex", "fragment", "compute"]:
return mapped
return None
def get_semantic_from_attributes(self, attributes):
if not attributes:
return ""
for attr in attributes:
if isinstance(attr, AttributeNode):
if attr.name in self.semantic_map:
return f" @ {self.semantic_map[attr.name]}"
elif attr.name == "location" and attr.args:
return f" @ location({attr.args[0]})"
elif attr.name == "binding" and attr.args:
return f" @ binding({attr.args[0]})"
return ""
def visit_StructNode(self, node):
code = f"struct {node.name} {{\n"
self.indentation += 1
for member in node.members:
semantic = self.get_semantic_from_attributes(member.attributes)
type_str = self.map_type(member.vtype)
code += f"{self.get_indent()}{type_str} {member.name}{semantic};\n"
self.indentation -= 1
code += f"{self.get_indent()}}}\n"
return code
def visit_FunctionNode(self, node):
shader_type = self.get_shader_type_from_attributes(node.attributes)
if shader_type:
return f"{shader_type} {{\n{self.generate_function_body(node.body, 1)}}}\n"
else:
return self.generate_function(node, 0)
def visit_BinaryOpNode(self, node):
left = self.generate_expression(node.left)
right = self.generate_expression(node.right)
return f"({left} {node.op} {right})"
def visit_UnaryOpNode(self, node):
operand = self.generate_expression(node.operand)
return f"({node.op}{operand})"
def visit_ImplNode(self, node):
code = f"// Implementation for {node.struct_name}\n"
for func in node.functions:
code += self.generate_function(func, 0, node.struct_name)
return code
def visit_UseNode(self, node):
return f"// use {node.path}\n"
def visit_ConstNode(self, node):
type_str = self.map_type(node.vtype)
value = self.generate_expression(node.value)
return f"const {type_str} {node.name} = {value};\n"
def visit_StaticNode(self, node):
mutability = "mut " if node.is_mutable else ""
type_str = self.map_type(node.vtype)
value = self.generate_expression(node.value)
return f"static {mutability}{type_str} {node.name} = {value};\n"