Source code for crosstl.backend.Rust.RustCrossGLCodeGen

"""Reverse code generator that emits CrossGL from Rust AST nodes."""

from .RustAst import *
from .RustParser import *
from .RustLexer import *


[docs] class RustToCrossGLConverter: """Serialize Rust backend AST nodes back into CrossGL source.""" def __init__(self): """Initialize Rust-to-CrossGL type and semantic mappings.""" self.type_map = { # Rust primitive types to CrossGL "()": "void", "f32": "float", "f64": "double", "i32": "int", "i64": "int64_t", "u32": "uint", "u64": "uint64_t", "i8": "int8_t", "u8": "uint8_t", "i16": "int16_t", "u16": "uint16_t", "bool": "bool", "usize": "uint", "isize": "int", # Rust vector types to CrossGL "Vec2<f32>": "vec2", "Vec3<f32>": "vec3", "Vec4<f32>": "vec4", "Vec2<i32>": "ivec2", "Vec3<i32>": "ivec3", "Vec4<i32>": "ivec4", "Vec2<u32>": "uvec2", "Vec3<u32>": "uvec3", "Vec4<u32>": "uvec4", "Vec2<bool>": "bvec2", "Vec3<bool>": "bvec3", "Vec4<bool>": "bvec4", # Simplified vector names "Vec2": "vec2", "Vec3": "vec3", "Vec4": "vec4", # Rust matrix types to CrossGL "Mat2<f32>": "mat2", "Mat3<f32>": "mat3", "Mat4<f32>": "mat4", "Mat2": "mat2", "Mat3": "mat3", "Mat4": "mat4", # GPU-specific types "Texture2D": "sampler2D", "TextureCube": "samplerCube", "Sampler": "sampler", # Reference types (stripped in CrossGL) "&f32": "float", "&i32": "int", "&u32": "uint", "&bool": "bool", "&mut f32": "float", "&mut i32": "int", "&mut u32": "uint", "&mut bool": "bool", } self.semantic_map = { # Rust shader attributes to CrossGL semantics "vertex_position": "Position", "vertex_normal": "Normal", "vertex_tangent": "Tangent", "vertex_binormal": "Binormal", "vertex_texcoord": "TexCoord", "vertex_texcoord0": "TexCoord0", "vertex_texcoord1": "TexCoord1", "vertex_texcoord2": "TexCoord2", "vertex_texcoord3": "TexCoord3", "vertex_color": "Color", "instance_id": "InstanceID", "vertex_id": "VertexID", # Fragment shader semantics "fragment_position": "gl_Position", "fragment_color": "gl_FragColor", "fragment_depth": "gl_FragDepth", "front_face": "gl_IsFrontFace", "primitive_id": "gl_PrimitiveID", "point_coord": "gl_PointCoord", # Compute shader semantics "local_invocation_id": "gl_LocalInvocationID", "global_invocation_id": "gl_GlobalInvocationID", "workgroup_id": "gl_WorkGroupID", "num_workgroups": "gl_NumWorkGroups", } self.function_map = { # Rust math functions to CrossGL "abs": "abs", "min": "min", "max": "max", "clamp": "clamp", "floor": "floor", "ceil": "ceil", "round": "round", "sqrt": "sqrt", "pow": "pow", "exp": "exp", "exp2": "exp2", "log": "log", "log2": "log2", "sin": "sin", "cos": "cos", "tan": "tan", "asin": "asin", "acos": "acos", "atan": "atan", "atan2": "atan2", "sinh": "sinh", "cosh": "cosh", "tanh": "tanh", "degrees": "degrees", "radians": "radians", # Vector operations "dot": "dot", "cross": "cross", "length": "length", "normalize": "normalize", "distance": "distance", "reflect": "reflect", "refract": "refract", "faceforward": "faceforward", # Matrix operations "transpose": "transpose", "determinant": "determinant", "inverse": "inverse", # Texture sampling "sample": "texture", "sample_level": "textureLod", "sample_grad": "textureGrad", "sample_offset": "textureOffset", } self.attribute_map = { # Rust shader attributes to CrossGL qualifiers "vertex_shader": "vertex", "fragment_shader": "fragment", "compute_shader": "compute", "binding": "binding", "location": "location", "group": "group", "builtin": "builtin", "stage": "stage", "workgroup_size": "workgroup_size", } self.indentation = 0 self.code = [] self.type_aliases = {}
[docs] def get_indent(self): """Return whitespace for the current indentation level.""" return " " * self.indentation
[docs] def visit(self, node): """Dispatch a Rust AST node to a visitor method when available.""" if isinstance(node, StructNode): return self.visit_StructNode(node) elif isinstance(node, FunctionNode): return self.visit_FunctionNode(node) elif isinstance(node, BinaryOpNode): return self.visit_BinaryOpNode(node) elif isinstance(node, UnaryOpNode): return self.visit_UnaryOpNode(node) elif isinstance(node, ImplNode): return self.visit_ImplNode(node) elif isinstance(node, UseNode): return self.visit_UseNode(node) elif isinstance(node, ConstNode): return self.visit_ConstNode(node) elif isinstance(node, StaticNode): return self.visit_StaticNode(node) # For other node types, use existing methods if hasattr(self, f"generate_{type(node).__name__}"): method = getattr(self, f"generate_{type(node).__name__}") return method(node) return self.generate_expression(node)
[docs] def generate(self, ast): """Generate complete CrossGL source from a parsed Rust AST.""" self.type_aliases = {} code = "shader main {\n" for use_stmt in ast.use_statements: code += f" // use {use_stmt.path}\n" for alias in getattr(ast, "type_aliases", []): self.register_type_alias(alias) for alias in getattr(ast, "type_aliases", []): alias_code = self.generate_type_alias(alias) if alias_code: code += alias_code for global_var in ast.global_variables: if isinstance(global_var, ConstNode): declarator = self.format_typed_declarator( global_var.vtype, global_var.name ) value = self.generate_expression(global_var.value) code += f" const {declarator} = {value};\n" elif isinstance(global_var, StaticNode): mutability = "mut " if global_var.is_mutable else "" declarator = self.format_typed_declarator( global_var.vtype, global_var.name ) value = self.generate_expression(global_var.value) code += f" static {mutability}{declarator} = {value};\n" for struct in ast.structs: if isinstance(struct, StructNode): code += f" struct {struct.name} {{\n" for member in struct.members: semantic = self.get_semantic_from_attributes(member.attributes) declarator = self.format_typed_declarator(member.vtype, member.name) code += f" {declarator}{semantic};\n" code += " }\n\n" for func in ast.functions: shader_type = self.get_shader_type_from_attributes(func.attributes) if shader_type: code += f" // {shader_type.title()} Shader\n" code += f" {shader_type} {{\n" code += self.generate_function_body(func.body, indent=2) code += " }\n\n" else: code += self.generate_function(func, indent=1) for impl_block in ast.impl_blocks: code += f" // Implementation for {impl_block.struct_name}\n" for func in impl_block.functions: code += self.generate_function( func, indent=1, struct_name=impl_block.struct_name ) code += "}\n" return code
def register_type_alias(self, alias): if getattr(alias, "name", None): self.type_aliases[alias.name] = alias def generate_type_alias(self, alias): if getattr(alias, "generics", None) or not getattr(alias, "alias_type", None): return "" declarator = self.format_typed_declarator(alias.alias_type, alias.name) return f" typedef {declarator};\n"
[docs] def generate_function(self, func, indent=1, struct_name=None): """Render one Rust function node as a CrossGL function.""" code = "" indent_str = " " * indent params = [] for param in func.params: params.append(self.format_typed_declarator(param.vtype, param.name)) params_str = ", ".join(params) return_type = self.map_type(func.return_type) if struct_name: func_name = f"{struct_name}_{func.name}" else: func_name = func.name code += f"{indent_str}{return_type} {func_name}({params_str}) {{\n" code += self.generate_function_body(func.body, indent=indent + 1) code += f"{indent_str}}}\n\n" return code
def generate_function_body(self, body, indent=1): code = "" indent_str = " " * indent for stmt in body: if isinstance(stmt, LetNode): code += self.generate_let_statement(stmt, indent) elif isinstance(stmt, AssignmentNode): code += f"{indent_str}{self.generate_assignment(stmt)};\n" elif isinstance(stmt, ReturnNode): if stmt.value: code += ( f"{indent_str}return {self.generate_expression(stmt.value)};\n" ) else: code += f"{indent_str}return;\n" elif isinstance(stmt, IfNode): code += self.generate_if_statement(stmt, indent) elif isinstance(stmt, ForNode): code += self.generate_for_loop(stmt, indent) elif isinstance(stmt, WhileNode): code += self.generate_while_loop(stmt, indent) elif isinstance(stmt, LoopNode): code += self.generate_loop(stmt, indent) elif isinstance(stmt, MatchNode): code += self.generate_match_statement(stmt, indent) elif isinstance(stmt, BreakNode): code += f"{indent_str}break;\n" elif isinstance(stmt, ContinueNode): code += f"{indent_str}continue;\n" elif isinstance(stmt, FunctionCallNode): code += f"{indent_str}{self.generate_expression(stmt)};\n" elif isinstance(stmt, BinaryOpNode): code += f"{indent_str}{self.generate_expression(stmt)};\n" elif isinstance(stmt, str): code += f"{indent_str}{stmt};\n" else: expr = self.generate_expression(stmt) if expr: code += f"{indent_str}{expr};\n" return code def generate_let_statement(self, stmt, indent): indent_str = " " * indent type_str = "" if stmt.vtype: type_str = self.format_typed_declarator(stmt.vtype, stmt.name) elif stmt.value: type_str = "" if stmt.value: value_str = self.generate_expression(stmt.value) if stmt.vtype: return f"{indent_str}{type_str} = {value_str};\n" return f"{indent_str}{stmt.name} = {value_str};\n" else: if stmt.vtype: return f"{indent_str}{type_str};\n" return f"{indent_str}{stmt.name};\n" def generate_assignment(self, node): left = self.generate_expression(node.left) right = self.generate_expression(node.right) return f"{left} {node.operator} {right}" def generate_if_statement(self, node, indent): indent_str = " " * indent condition = self.generate_expression(node.condition) code = f"{indent_str}if ({condition}) {{\n" code += self.generate_function_body(node.if_body, indent + 1) code += f"{indent_str}}}" if node.else_body: if ( isinstance(node.else_body, list) and len(node.else_body) == 1 and isinstance(node.else_body[0], IfNode) ): code += " else " code += self.generate_if_statement(node.else_body[0], 0).lstrip() else: code += " else {\n" if isinstance(node.else_body, list): code += self.generate_function_body(node.else_body, indent + 1) else: code += self.generate_function_body([node.else_body], indent + 1) code += f"{indent_str}}}" code += "\n" return code def generate_for_loop(self, node, indent): indent_str = " " * indent pattern = node.pattern iterable = self.generate_expression(node.iterable) # Convert Rust for-in loop to C-style for loop code = f"{indent_str}for (int {pattern} = 0; {pattern} < {iterable}; {pattern}++) {{\n" code += self.generate_function_body(node.body, indent + 1) code += f"{indent_str}}}\n" return code def generate_while_loop(self, node, indent): indent_str = " " * indent condition = self.generate_expression(node.condition) code = f"{indent_str}while ({condition}) {{\n" code += self.generate_function_body(node.body, indent + 1) code += f"{indent_str}}}\n" return code def generate_loop(self, node, indent): indent_str = " " * indent # Convert Rust infinite loop to while(true) code = f"{indent_str}while (true) {{\n" code += self.generate_function_body(node.body, indent + 1) code += f"{indent_str}}}\n" return code def generate_match_statement(self, node, indent): indent_str = " " * indent expression = self.generate_expression(node.expression) # Convert Rust match to switch statement code = f"{indent_str}switch ({expression}) {{\n" for arm in node.arms: pattern = self.generate_expression(arm.pattern) code += f"{indent_str} case {pattern}:\n" code += self.generate_function_body(arm.body, indent + 2) code += f"{indent_str} break;\n" code += f"{indent_str}}}\n" return code
[docs] def generate_expression(self, expr): """Render a Rust backend expression node as CrossGL syntax.""" if isinstance(expr, str): return expr elif isinstance(expr, VariableNode): return expr.name elif isinstance(expr, BinaryOpNode): left = self.generate_expression(expr.left) right = self.generate_expression(expr.right) return f"({left} {expr.op} {right})" elif isinstance(expr, UnaryOpNode): operand = self.generate_expression(expr.operand) return f"({expr.op}{operand})" elif isinstance(expr, FunctionCallNode): if isinstance(expr.name, str): constructor = self.format_path_constructor_call(expr.name, expr.args) if constructor is not None: return constructor func_name = self.map_function(expr.name) else: func_name = self.generate_expression(expr.name) args = ", ".join(self.generate_expression(arg) for arg in expr.args) return f"{func_name}({args})" elif isinstance(expr, MemberAccessNode): obj = self.generate_expression(expr.object) return f"{obj}.{expr.member}" elif isinstance(expr, ArrayAccessNode): array = self.generate_expression(expr.array) index = self.generate_expression(expr.index) return f"{array}[{index}]" elif isinstance(expr, VectorConstructorNode): args = ", ".join(self.generate_expression(arg) for arg in expr.args) type_name = self.map_type(expr.type_name) return f"{type_name}({args})" elif isinstance(expr, TernaryOpNode): condition = self.generate_expression(expr.condition) true_expr = self.generate_expression(expr.true_expr) false_expr = self.generate_expression(expr.false_expr) return f"({condition} ? {true_expr} : {false_expr})" elif isinstance(expr, CastNode): expression = self.generate_expression(expr.expression) target_type = self.map_type(expr.target_type) return f"({target_type}){expression}" elif isinstance(expr, ReferenceNode): # References are handled differently in CrossGL return self.generate_expression(expr.expression) elif isinstance(expr, DereferenceNode): # Dereferences are typically not needed in CrossGL return self.generate_expression(expr.expression) elif isinstance(expr, TupleNode): # Tuples might not be directly supported, convert to struct or multiple variables elements = ", ".join( self.generate_expression(elem) for elem in expr.elements ) return f"({elements})" elif isinstance(expr, ArrayNode): if expr.size is not None and len(expr.elements) == 1: repeated = self.expand_repeated_array_literal( expr.elements[0], expr.size ) if repeated is not None: return repeated elements = ", ".join( self.generate_expression(elem) for elem in expr.elements ) return f"{{{elements}}}" elif isinstance(expr, BlockNode): # Block expressions might need special handling if expr.expression: return self.generate_expression(expr.expression) else: return "" elif isinstance(expr, (int, float, bool)): return str(expr).lower() if isinstance(expr, bool) else str(expr) else: return str(expr)
def format_path_constructor_call(self, function_name, args): if not function_name.endswith("::new"): return None type_name = function_name[: -len("::new")] mapped_type = self.map_type(type_name) if mapped_type == type_name: return None args_str = ", ".join(self.generate_expression(arg) for arg in args) return f"{mapped_type}({args_str})"
[docs] def map_type(self, rust_type): """Map a Rust type name to the closest CrossGL type name.""" if not rust_type: return "void" referenced_type = self.strip_reference_type(rust_type) if referenced_type != rust_type: return self.map_type(referenced_type) array_parts = self.split_array_type(rust_type) if array_parts: base_type, array_suffix = array_parts return f"{self.map_type(base_type)}{array_suffix}" resolved_alias = self.resolve_type_alias(rust_type) if resolved_alias is not None: return resolved_alias if "<" in rust_type and ">" in rust_type: base_type = rust_type.split("<")[0] if base_type in ["Vec2", "Vec3", "Vec4"]: return self.type_map.get( rust_type, self.type_map.get(base_type, rust_type) ) return self.type_map.get(rust_type, rust_type)
def resolve_type_alias(self, rust_type): for alias_name in self.type_alias_lookup_names(rust_type): alias = self.type_aliases.get(alias_name) if alias is None: continue if getattr(alias, "generics", None): return None return alias.name generic = self.parse_generic_type(rust_type) if generic is None: return None base_name, args = generic alias = None for alias_name in self.type_alias_lookup_names(base_name): alias = self.type_aliases.get(alias_name) if alias is not None: break if alias is None: return None generics = getattr(alias, "generics", []) or [] if not generics or len(generics) != len(args): return None substituted = self.substitute_type_parameters(alias.alias_type, generics, args) if substituted == rust_type: return None return self.map_type(substituted) def type_alias_lookup_names(self, type_name): names = [type_name] if "::" in type_name: names.append(type_name.rsplit("::", 1)[-1]) return names def parse_generic_type(self, rust_type): if not rust_type or "<" not in rust_type or not rust_type.endswith(">"): return None base_name, remainder = rust_type.split("<", 1) base_name = base_name.strip() if not base_name: return None args_text = remainder[:-1] args = self.split_generic_arguments(args_text) if not args: return None return base_name, args def split_generic_arguments(self, args_text): args = [] current = [] depth = 0 for char in args_text: if char == "," and depth == 0: arg = "".join(current).strip() if arg: args.append(arg) current = [] continue if char in "<[(": depth += 1 elif char in ">])": depth = max(0, depth - 1) current.append(char) arg = "".join(current).strip() if arg: args.append(arg) return args def substitute_type_parameters(self, alias_type, generics, args): substitutions = { self.generic_parameter_name(generic): arg for generic, arg in zip(generics, args) } result = alias_type for name, replacement in substitutions.items(): result = self.replace_type_identifier(result, name, replacement) return result def generic_parameter_name(self, generic): return generic.split(":", 1)[0].strip() def replace_type_identifier(self, text, name, replacement): result = [] index = 0 name_len = len(name) while index < len(text): if ( text.startswith(name, index) and self.is_type_identifier_boundary(text, index - 1) and self.is_type_identifier_boundary(text, index + name_len) ): result.append(replacement) index += name_len else: result.append(text[index]) index += 1 return "".join(result) def is_type_identifier_boundary(self, text, index): if index < 0 or index >= len(text): return True return not (text[index].isalnum() or text[index] == "_") def strip_reference_type(self, type_name): if type_name.startswith("&mut "): return type_name[5:].strip() if type_name.startswith("&"): return type_name[1:].strip() return type_name def split_array_type(self, type_name): if not type_name or "[" not in type_name: return None base_type = type_name.split("[", 1)[0] array_suffix = type_name[len(base_type) :] if not base_type or not array_suffix: return None return base_type, array_suffix def format_typed_declarator(self, type_name, name): mapped_type = self.map_type(type_name) array_parts = self.split_array_type(mapped_type) if array_parts: base_type, array_suffix = array_parts return f"{base_type} {name}{array_suffix}" return f"{mapped_type} {name}" def expand_repeated_array_literal(self, element, size): try: count = int(self.generate_expression(size)) except (TypeError, ValueError): return None if count < 0: return None value = self.generate_expression(element) return "{" + ", ".join(value for _ in range(count)) + "}" def map_function(self, rust_func): return self.function_map.get(rust_func, rust_func) def get_shader_type_from_attributes(self, attributes): if not attributes: return None for attr in attributes: if isinstance(attr, AttributeNode): mapped = self.attribute_map.get(attr.name) if mapped in ["vertex", "fragment", "compute"]: return mapped return None def get_semantic_from_attributes(self, attributes): if not attributes: return "" for attr in attributes: if isinstance(attr, AttributeNode): if attr.name in self.semantic_map: return f" @ {self.semantic_map[attr.name]}" elif attr.name == "location" and attr.args: return f" @ location({attr.args[0]})" elif attr.name == "binding" and attr.args: return f" @ binding({attr.args[0]})" return "" def visit_StructNode(self, node): code = f"struct {node.name} {{\n" self.indentation += 1 for member in node.members: semantic = self.get_semantic_from_attributes(member.attributes) type_str = self.map_type(member.vtype) code += f"{self.get_indent()}{type_str} {member.name}{semantic};\n" self.indentation -= 1 code += f"{self.get_indent()}}}\n" return code def visit_FunctionNode(self, node): shader_type = self.get_shader_type_from_attributes(node.attributes) if shader_type: return f"{shader_type} {{\n{self.generate_function_body(node.body, 1)}}}\n" else: return self.generate_function(node, 0) def visit_BinaryOpNode(self, node): left = self.generate_expression(node.left) right = self.generate_expression(node.right) return f"({left} {node.op} {right})" def visit_UnaryOpNode(self, node): operand = self.generate_expression(node.operand) return f"({node.op}{operand})" def visit_ImplNode(self, node): code = f"// Implementation for {node.struct_name}\n" for func in node.functions: code += self.generate_function(func, 0, node.struct_name) return code def visit_UseNode(self, node): return f"// use {node.path}\n" def visit_ConstNode(self, node): type_str = self.map_type(node.vtype) value = self.generate_expression(node.value) return f"const {type_str} {node.name} = {value};\n" def visit_StaticNode(self, node): mutability = "mut " if node.is_mutable else "" type_str = self.map_type(node.vtype) value = self.generate_expression(node.value) return f"static {mutability}{type_str} {node.name} = {value};\n"