"""CrossGL-to-Rust code generator."""
from ..ast import (
ArrayNode,
ArrayAccessNode,
ArrayLiteralNode,
AssignmentNode,
BinaryOpNode,
CbufferNode,
ForNode,
FunctionCallNode,
FunctionNode,
IdentifierNode,
IfNode,
LiteralNode,
MemberAccessNode,
ReturnNode,
ShaderNode,
StructNode,
TernaryOpNode,
UnaryOpNode,
VariableNode,
)
from .array_utils import parse_array_type, format_array_type, get_array_size_from_node
[docs]
class RustCodeGen:
"""Emit Rust-like GPU shader source from the shared CrossGL AST."""
def __init__(self):
"""Initialize Rust type maps and expression-generation state."""
self.current_shader = None
self.type_mapping = {
# Scalar Types
"void": "()",
"int": "i32",
"short": "i16",
"long": "i64",
"uint": "u32",
"ushort": "u16",
"ulong": "u64",
"float": "f32",
"double": "f64",
"half": "f16",
"bool": "bool",
"string": "&'static str",
"char": "char",
# Vector Types (using GPU-style vector types)
"vec2<f32>": "Vec2<f32>",
"vec3<f32>": "Vec3<f32>",
"vec4<f32>": "Vec4<f32>",
"vec2<f64>": "Vec2<f64>",
"vec3<f64>": "Vec3<f64>",
"vec4<f64>": "Vec4<f64>",
"vec2<i32>": "Vec2<i32>",
"vec3<i32>": "Vec3<i32>",
"vec4<i32>": "Vec4<i32>",
"vec2<u32>": "Vec2<u32>",
"vec3<u32>": "Vec3<u32>",
"vec4<u32>": "Vec4<u32>",
"vec2<bool>": "Vec2<bool>",
"vec3<bool>": "Vec3<bool>",
"vec4<bool>": "Vec4<bool>",
"vec2": "Vec2<f32>",
"vec3": "Vec3<f32>",
"vec4": "Vec4<f32>",
"ivec2": "Vec2<i32>",
"ivec3": "Vec3<i32>",
"ivec4": "Vec4<i32>",
"uvec2": "Vec2<u32>",
"uvec3": "Vec3<u32>",
"uvec4": "Vec4<u32>",
"dvec2": "Vec2<f64>",
"dvec3": "Vec3<f64>",
"dvec4": "Vec4<f64>",
"bvec2": "Vec2<bool>",
"bvec3": "Vec3<bool>",
"bvec4": "Vec4<bool>",
"bool2": "Vec2<bool>",
"bool3": "Vec3<bool>",
"bool4": "Vec4<bool>",
# Matrix Types
"mat2": "Mat2<f32>",
"mat3": "Mat3<f32>",
"mat4": "Mat4<f32>",
"mat2x2": "Mat2<f32>",
"mat2x3": "Mat2x3<f32>",
"mat2x4": "Mat2x4<f32>",
"mat3x2": "Mat3x2<f32>",
"mat3x3": "Mat3<f32>",
"mat3x4": "Mat3x4<f32>",
"mat4x2": "Mat4x2<f32>",
"mat4x3": "Mat4x3<f32>",
"mat4x4": "Mat4<f32>",
"dmat2": "Mat2<f64>",
"dmat3": "Mat3<f64>",
"dmat4": "Mat4<f64>",
"dmat2x2": "Mat2<f64>",
"dmat2x3": "Mat2x3<f64>",
"dmat2x4": "Mat2x4<f64>",
"dmat3x2": "Mat3x2<f64>",
"dmat3x3": "Mat3<f64>",
"dmat3x4": "Mat3x4<f64>",
"dmat4x2": "Mat4x2<f64>",
"dmat4x3": "Mat4x3<f64>",
"dmat4x4": "Mat4<f64>",
# Texture Types
"sampler2D": "Texture2D<f32>",
"samplerCube": "TextureCube<f32>",
"sampler": "Sampler",
}
self.semantic_map = {
# Vertex attributes
"gl_VertexID": "vertex_id",
"gl_InstanceID": "instance_id",
"gl_Position": "position",
"gl_PointSize": "point_size",
"gl_ClipDistance": "clip_distance",
# Fragment attributes
"gl_FragColor": "target(0)",
"gl_FragColor0": "target(0)",
"gl_FragColor1": "target(1)",
"gl_FragColor2": "target(2)",
"gl_FragColor3": "target(3)",
"gl_FragDepth": "depth(any)",
"gl_FragCoord": "position",
"gl_FrontFacing": "front_facing",
"gl_PointCoord": "point_coord",
# Standard vertex semantics
"POSITION": "position",
"NORMAL": "normal",
"TANGENT": "tangent",
"BINORMAL": "binormal",
"TEXCOORD": "texcoord",
"TEXCOORD0": "texcoord(0)",
"TEXCOORD1": "texcoord(1)",
"TEXCOORD2": "texcoord(2)",
"TEXCOORD3": "texcoord(3)",
"COLOR": "color",
"COLOR0": "color(0)",
"COLOR1": "color(1)",
}
# Function mapping for common shader functions
self.function_map = {
"texture": "sample",
"normalize": "normalize",
"dot": "dot",
"cross": "cross",
"length": "length",
"reflect": "reflect",
"refract": "refract",
"sin": "sin",
"cos": "cos",
"tan": "tan",
"sqrt": "sqrt",
"pow": "pow",
"abs": "abs",
"min": "min",
"max": "max",
"clamp": "clamp",
"mix": "lerp",
"smoothstep": "smoothstep",
"step": "step",
"floor": "floor",
"ceil": "ceil",
"fract": "fract",
"mod": "modulo",
}
self.variable_types = {}
self.current_return_type = None
[docs]
def generate(self, ast):
"""Generate complete Rust-like shader source for a CrossGL AST."""
self.variable_types = {}
self.current_return_type = None
code = "// Generated Rust GPU Shader Code\n"
code += "use gpu::*;\n"
code += "use math::*;\n\n"
structs = getattr(ast, "structs", [])
for node in structs:
if isinstance(node, StructNode):
code += self.generate_struct(node)
global_vars = getattr(ast, "global_variables", [])
for node in global_vars:
if isinstance(node, ArrayNode):
code += self.generate_array_declaration(node)
else:
# Handle both old and new AST variable structures
if hasattr(node, "var_type"):
var_type = self.convert_type_node_to_string(node.var_type)
elif hasattr(node, "vtype"):
var_type = node.vtype
else:
var_type = "float"
self.register_variable_type(node.name, var_type)
initial_value = getattr(
node, "initial_value", getattr(node, "value", None)
)
if initial_value is not None:
init_expr = self.generate_expression_with_type(
initial_value, var_type
)
else:
init_expr = "Default::default()"
code += (
f"static {node.name}: {self.map_type(var_type)} = "
f"{init_expr};\n"
)
cbuffers = self.get_cbuffer_nodes(ast)
if cbuffers:
code += "// Constant Buffers\n"
code += self.generate_cbuffers(ast)
functions = getattr(ast, "functions", [])
for func in functions:
# Handle both old and new AST function structures
if hasattr(func, "qualifiers") and func.qualifiers:
qualifier = func.qualifiers[0] if func.qualifiers else None
else:
qualifier = getattr(func, "qualifier", None)
if qualifier == "vertex":
code += "// Vertex Shader\n"
code += self.generate_function(func, shader_type="vertex")
elif qualifier == "fragment":
code += "// Fragment Shader\n"
code += self.generate_function(func, shader_type="fragment")
elif qualifier == "compute":
code += "// Compute Shader\n"
code += self.generate_function(func, shader_type="compute")
else:
code += self.generate_function(func)
# Handle shader stages (new AST structure)
if hasattr(ast, "stages") and ast.stages:
for stage_type, stage in ast.stages.items():
if hasattr(stage, "entry_point"):
stage_name = str(stage_type).split(".")[-1].lower()
code += f"// {stage_name.title()} Shader\n"
code += self.generate_function(
stage.entry_point, shader_type=stage_name
)
if hasattr(stage, "local_functions"):
for func in stage.local_functions:
code += self.generate_function(func)
return code
def generate_struct(self, node):
code = f"#[repr(C)]\n#[derive(Debug, Clone, Copy)]\n"
code += f"pub struct {node.name} {{\n"
members = getattr(node, "members", [])
for member in members:
if isinstance(member, ArrayNode):
element_type = getattr(
member, "element_type", getattr(member, "vtype", "float")
)
if member.size:
code += f" pub {member.name}: [{self.map_type_to_rust(element_type)}; {member.size}],\n"
else:
code += f" pub {member.name}: Vec<{self.map_type_to_rust(element_type)}>,\n"
else:
if hasattr(member, "member_type"):
member_type = self.convert_type_node_to_string(member.member_type)
elif hasattr(member, "vtype"):
member_type = member.vtype
else:
member_type = "float"
semantic = None
if hasattr(member, "semantic"):
semantic = member.semantic
elif hasattr(member, "attributes"):
semantic = self.extract_semantic_from_attributes(member.attributes)
semantic_comment = (
f" // {self.map_semantic(semantic)}" if semantic else ""
)
code += f" pub {member.name}: {self.map_type(member_type)},{semantic_comment}\n"
code += "}\n\n"
return code
def convert_type_node_to_string(self, type_node) -> str:
"""Convert new AST TypeNode to string representation."""
if type_node.__class__.__name__ == "ArrayType":
element_type = self.convert_type_node_to_string(type_node.element_type)
size = self.format_array_size(type_node.size)
return (
f"{element_type}[{size}]" if size is not None else f"{element_type}[]"
)
if hasattr(type_node, "name"):
generic_args = getattr(type_node, "generic_args", [])
if generic_args:
args = ", ".join(
self.convert_type_node_to_string(arg) for arg in generic_args
)
return f"{type_node.name}<{args}>"
return type_node.name
elif hasattr(type_node, "element_type") and hasattr(type_node, "size"):
element_type = self.convert_type_node_to_string(type_node.element_type)
size = type_node.size
if element_type == "float":
return f"vec{size}"
elif element_type == "int":
return f"ivec{size}"
elif element_type == "uint":
return f"uvec{size}"
elif element_type == "double":
return f"dvec{size}"
elif element_type == "bool":
return f"bvec{size}"
else:
return f"{element_type}{size}"
elif hasattr(type_node, "element_type") and hasattr(type_node, "rows"):
element_type = self.convert_type_node_to_string(type_node.element_type)
prefix = "dmat" if element_type == "double" else "mat"
if type_node.rows == type_node.cols:
return f"{prefix}{type_node.rows}"
return f"{prefix}{type_node.rows}x{type_node.cols}"
else:
return str(type_node)
def format_array_size(self, size):
if size is None:
return None
if hasattr(size, "value"):
return size.value
return size
def extract_semantic_from_attributes(self, attributes):
"""Extract semantic information from new AST attributes."""
semantic_attrs = [
"position",
"color",
"texcoord",
"normal",
"tangent",
"binormal",
"POSITION",
"COLOR",
"TEXCOORD",
"NORMAL",
"TANGENT",
"BINORMAL",
"TEXCOORD0",
"TEXCOORD1",
"TEXCOORD2",
"TEXCOORD3",
]
for attr in attributes:
if hasattr(attr, "name") and attr.name in semantic_attrs:
return attr.name
return None
def get_member_type(self, member):
if hasattr(member, "member_type"):
return self.convert_type_node_to_string(member.member_type)
if hasattr(member, "vtype"):
return member.vtype
return "float"
def get_cbuffer_nodes(self, ast):
nodes = []
seen = set()
for attr in ("cbuffers", "constants"):
for node in getattr(ast, attr, None) or []:
node_id = id(node)
if node_id not in seen:
nodes.append(node)
seen.add(node_id)
return nodes
def map_type_to_rust(self, type_str):
"""Enhanced type mapping for Rust."""
# Handle vector types first
if type_str.startswith("float") and len(type_str) > 5:
size = type_str[5:]
if size.isdigit():
return f"Vec{size}<f32>"
elif type_str.startswith("int") and len(type_str) > 3:
size = type_str[3:]
if size.isdigit():
return f"Vec{size}<i32>"
# Standard type mapping
type_map = {
"void": "()",
"bool": "bool",
"int": "i32",
"uint": "u32",
"float": "f32",
"double": "f64",
"vec2": "Vec2<f32>",
"vec3": "Vec3<f32>",
"vec4": "Vec4<f32>",
"ivec2": "Vec2<i32>",
"ivec3": "Vec3<i32>",
"ivec4": "Vec4<i32>",
"uvec2": "Vec2<u32>",
"uvec3": "Vec3<u32>",
"uvec4": "Vec4<u32>",
"mat2": "Mat2<f32>",
"mat3": "Mat3<f32>",
"mat4": "Mat4<f32>",
"float2": "Vec2<f32>",
"float3": "Vec3<f32>",
"float4": "Vec4<f32>",
}
return type_map.get(type_str, type_str)
def generate_cbuffers(self, ast):
code = ""
cbuffers = self.get_cbuffer_nodes(ast)
for node in cbuffers:
if isinstance(node, StructNode):
code += f"#[repr(C)]\n#[derive(Debug, Clone, Copy)]\n"
code += f"pub struct {node.name} {{\n"
for member in node.members:
if isinstance(member, ArrayNode):
if member.size:
code += f" pub {member.name}: [{self.map_type(member.element_type)}; {member.size}],\n"
else:
code += f" pub {member.name}: Vec<{self.map_type(member.element_type)}>,\n"
else:
code += f" pub {member.name}: {self.map_type(self.get_member_type(member))},\n"
code += "}\n\n"
code += self.generate_cbuffer_member_statics(node.members)
elif hasattr(node, "name") and hasattr(node, "members"): # CbufferNode
code += f"#[repr(C)]\n#[derive(Debug, Clone, Copy)]\n"
code += f"pub struct {node.name} {{\n"
for member in node.members:
if isinstance(member, ArrayNode):
if member.size:
code += f" pub {member.name}: [{self.map_type(member.element_type)}; {member.size}],\n"
else:
code += f" pub {member.name}: Vec<{self.map_type(member.element_type)}>,\n"
else:
code += f" pub {member.name}: {self.map_type(self.get_member_type(member))},\n"
code += "}\n\n"
code += self.generate_cbuffer_member_statics(node.members)
return code
def generate_cbuffer_member_statics(self, members):
code = ""
for member in members:
if isinstance(member, ArrayNode):
if member.size:
member_type = (
f"[{self.map_type(member.element_type)}; {member.size}]"
)
else:
member_type = f"Vec<{self.map_type(member.element_type)}>"
else:
member_type = self.map_type(self.get_member_type(member))
code += f"static {member.name}: {member_type} = Default::default();\n"
return code
def generate_function(self, func, indent=0, shader_type=None):
"""Render one CrossGL function or shader entry point as Rust code."""
code = ""
code += " " * indent
saved_variable_types = self.variable_types.copy()
saved_return_type = self.current_return_type
param_list = getattr(func, "parameters", getattr(func, "params", []))
params = []
for p in param_list:
if hasattr(p, "param_type"):
param_type = self.convert_type_node_to_string(p.param_type)
elif hasattr(p, "vtype"):
param_type = p.vtype
else:
param_type = "float"
self.register_variable_type(p.name, param_type)
params.append(f"{p.name}: {self.map_type(param_type)}")
params_str = ", ".join(params) if params else ""
if hasattr(func, "return_type"):
return_type = self.convert_type_node_to_string(func.return_type)
else:
return_type = "void"
self.current_return_type = return_type
if shader_type == "vertex":
code += f"#[vertex_shader]\n"
elif shader_type == "fragment":
code += f"#[fragment_shader]\n"
elif shader_type == "compute":
code += f"#[compute_shader]\n"
code += f"pub fn {func.name}({params_str}) -> {self.map_type(return_type)} {{\n"
body = getattr(func, "body", [])
if hasattr(body, "statements"):
for stmt in body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(body, list):
for stmt in body:
code += self.generate_statement(stmt, indent + 1)
code += " " * indent + "}\n\n"
self.variable_types = saved_variable_types
self.current_return_type = saved_return_type
return code
def generate_param_attributes(self, param):
"""Generate Rust GPU parameter attributes based on semantic"""
if not param.semantic:
return ""
semantic = param.semantic.lower()
if "position" in semantic:
return "#[location(0)] "
elif "normal" in semantic:
return "#[location(1)] "
elif "texcoord" in semantic:
if "texcoord0" in semantic:
return "#[location(2)] "
elif "texcoord1" in semantic:
return "#[location(3)] "
else:
return "#[location(2)] "
elif "color" in semantic:
return "#[location(4)] "
elif "gl_position" in semantic:
return "#[builtin(position)] "
elif "gl_fragcolor" in semantic:
return "#[location(0)] "
return ""
def generate_statement(self, stmt, indent=0):
"""Render a single CrossGL statement as Rust code."""
indent_str = " " * indent
if isinstance(stmt, VariableNode):
if hasattr(stmt, "var_type"):
vtype = stmt.var_type
elif hasattr(stmt, "vtype"):
vtype = stmt.vtype
else:
vtype = "f32"
self.register_variable_type(stmt.name, vtype)
if hasattr(stmt, "initial_value") and stmt.initial_value is not None:
init_expr = self.generate_expression_with_type(
stmt.initial_value, vtype
)
return f"{indent_str}let mut {stmt.name}: {self.map_type(vtype)} = {init_expr};\n"
else:
return f"{indent_str}let mut {stmt.name}: {self.map_type(vtype)};\n"
elif isinstance(stmt, ArrayNode):
return self.generate_array_declaration(stmt, indent)
elif isinstance(stmt, AssignmentNode):
return f"{indent_str}{self.generate_assignment(stmt)};\n"
elif isinstance(stmt, IfNode):
return self.generate_if(stmt, indent)
elif isinstance(stmt, ForNode):
return self.generate_for(stmt, indent)
elif isinstance(stmt, ReturnNode):
if hasattr(stmt, "value") and stmt.value is not None:
# Handle both single values and lists
if isinstance(stmt.value, list):
# Multiple return values (tuple)
values = ", ".join(
self.generate_expression(val) for val in stmt.value
)
return f"{indent_str}return ({values});\n"
else:
# Single return value
if isinstance(stmt.value, ArrayLiteralNode):
return_expr = self.generate_expression_with_type(
stmt.value, self.current_return_type
)
return f"{indent_str}return {return_expr};\n"
return (
f"{indent_str}return {self.generate_expression(stmt.value)};\n"
)
else:
# Void return
return f"{indent_str}return;\n"
elif hasattr(stmt, "__class__") and "ExpressionStatement" in str(
stmt.__class__
):
# Handle ExpressionStatementNode
if hasattr(stmt, "expression"):
return f"{indent_str}{self.generate_expression(stmt.expression)};\n"
else:
return f"{indent_str}{self.generate_expression(stmt)};\n"
elif isinstance(stmt, ArrayAccessNode):
# ArrayAccessNode as statement - likely misclassified
return f"{indent_str}// Unhandled ArrayAccessNode: {stmt}\n"
else:
# Try to generate as expression
expr_result = self.generate_expression(stmt)
if expr_result and expr_result.strip():
return f"{indent_str}{expr_result};\n"
else:
return f"{indent_str}// Unhandled statement: {type(stmt).__name__}\n"
def generate_array_declaration(self, node, indent=0):
indent_str = " " * indent
element_type = self.map_type(node.element_type)
size = get_array_size_from_node(node)
if size is None:
return f"{indent_str}let {node.name}: Vec<{element_type}> = Vec::new();\n"
else:
return f"{indent_str}let {node.name}: [{element_type}; {size}] = [Default::default(); {size}];\n"
def generate_expression_with_type(self, expr, target_type):
if isinstance(expr, ArrayLiteralNode):
return self.generate_array_literal_expression(expr, target_type)
return self.generate_expression(expr)
def is_array_type_name(self, type_name):
return type_name is not None and "[" in str(type_name) and "]" in str(type_name)
def generate_array_literal_expression(self, expr, target_type=None):
elements = [self.generate_expression(element) for element in expr.elements]
if self.is_array_type_name(target_type):
_, size = parse_array_type(str(target_type))
if size is None:
return f"vec![{', '.join(elements)}]"
elements = elements[:size]
while len(elements) < size:
elements.append("Default::default()")
return f"[{', '.join(elements)}]"
def generate_assignment(self, node):
# Handle both old and new AST assignment structures
if hasattr(node, "target") and hasattr(node, "value"):
# New AST structure
lhs = self.generate_expression(node.target)
lhs_type = self.expression_result_type(node.target)
rhs = self.generate_expression_with_type(node.value, lhs_type)
op = getattr(node, "operator", "=")
else:
# Old AST structure
lhs = self.generate_expression(node.left)
lhs_type = self.expression_result_type(node.left)
rhs = self.generate_expression_with_type(node.right, lhs_type)
op = getattr(node, "operator", "=")
return f"{lhs} {op} {rhs}"
def generate_if(self, node, indent):
indent_str = " " * indent
condition = self.generate_expression(
node.condition if hasattr(node, "condition") else node.if_condition
)
code = f"{indent_str}if {condition} {{\n"
if_body = getattr(node, "then_branch", getattr(node, "if_body", None))
if hasattr(if_body, "statements"):
for stmt in if_body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(if_body, list):
for stmt in if_body:
code += self.generate_statement(stmt, indent + 1)
code += f"{indent_str}}}"
else_branch = getattr(node, "else_branch", None)
if else_branch:
if hasattr(else_branch, "__class__") and "If" in str(else_branch.__class__):
# Generate else if by recursively generating the nested if with else if prefix
elif_condition = self.generate_expression(
else_branch.condition
if hasattr(else_branch, "condition")
else else_branch.if_condition
)
code += f" else if {elif_condition} {{\n"
# Generate elif body
elif_body = getattr(
else_branch, "then_branch", getattr(else_branch, "if_body", None)
)
if hasattr(elif_body, "statements"):
for stmt in elif_body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(elif_body, list):
for stmt in elif_body:
code += self.generate_statement(stmt, indent + 1)
code += f"{indent_str}}}"
nested_else = getattr(else_branch, "else_branch", None)
if nested_else:
if hasattr(nested_else, "__class__") and "If" in str(
nested_else.__class__
):
# Another else if - recursively handle
remaining_code = self.generate_if(nested_else, indent)
# Remove the "if" prefix and replace with "else if"
remaining_lines = remaining_code.split("\n")
if remaining_lines[0].strip().startswith("if "):
remaining_lines[0] = remaining_lines[0].replace(
"if ", " else if ", 1
)
code += "\n".join(
remaining_lines[1:]
) # Skip first line as we already handled it
else:
# Final else clause
code += f" else {{\n"
if hasattr(nested_else, "statements"):
for stmt in nested_else.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(nested_else, list):
for stmt in nested_else:
code += self.generate_statement(stmt, indent + 1)
else:
code += self.generate_statement(nested_else, indent + 1)
code += f"{indent_str}}}"
else:
code += f" else {{\n"
if hasattr(else_branch, "statements"):
for stmt in else_branch.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(else_branch, list):
for stmt in else_branch:
code += self.generate_statement(stmt, indent + 1)
else:
code += self.generate_statement(else_branch, indent + 1)
code += f"{indent_str}}}"
code += "\n"
return code
def generate_for(self, node, indent):
indent_str = " " * indent
init = self.generate_statement(node.init, 0).strip()
if init.endswith(";"):
init = init[:-1]
condition = self.generate_expression(node.condition)
update = self.generate_expression(node.update)
code = f"{indent_str}{init};\n"
code += f"{indent_str}while {condition} {{\n"
if hasattr(node.body, "statements"):
for stmt in node.body.statements:
code += self.generate_statement(stmt, indent + 1)
elif isinstance(node.body, list):
for stmt in node.body:
code += self.generate_statement(stmt, indent + 1)
else:
code += self.generate_statement(node.body, indent + 1)
# Add update at the end of the loop
code += f"{indent_str} {update};\n"
code += f"{indent_str}}}\n"
return code
def generate_expression(self, expr):
"""Render a CrossGL expression as Rust expression syntax."""
if expr is None:
return ""
elif isinstance(expr, str):
return expr
elif isinstance(expr, (int, float, bool)):
if isinstance(expr, bool):
return "true" if expr else "false"
return str(expr)
elif hasattr(expr, "__class__") and "Literal" in str(expr.__class__):
if hasattr(expr, "value"):
literal_type = getattr(
getattr(expr, "literal_type", None), "name", None
)
return self.format_literal(expr.value, literal_type)
return str(expr)
elif hasattr(expr, "__class__") and "Identifier" in str(expr.__class__):
return getattr(expr, "name", str(expr))
elif isinstance(expr, VariableNode):
if hasattr(expr, "name"):
return expr.name
else:
return str(expr)
elif hasattr(expr, "__class__") and "BinaryOp" in str(expr.__class__):
left = self.generate_expression(getattr(expr, "left", ""))
right = self.generate_expression(getattr(expr, "right", ""))
op = getattr(expr, "operator", getattr(expr, "op", "+"))
return f"({left} {self.map_operator(op)} {right})"
elif isinstance(expr, AssignmentNode):
return self.generate_assignment(expr)
elif isinstance(expr, ArrayLiteralNode):
return self.generate_array_literal_expression(expr)
elif hasattr(expr, "__class__") and "UnaryOp" in str(expr.__class__):
operand = self.generate_expression(getattr(expr, "operand", ""))
op = getattr(expr, "operator", getattr(expr, "op", "+"))
op = self.map_operator(op)
if op in ["++", "--"]:
assignment_op = "+=" if op == "++" else "-="
return f"{operand} {assignment_op} 1"
return f"({op}{operand})"
elif hasattr(expr, "__class__") and "ArrayAccess" in str(expr.__class__):
array_expr = getattr(expr, "array_expr", getattr(expr, "array", ""))
index_expr = getattr(expr, "index_expr", getattr(expr, "index", ""))
array = self.generate_expression(array_expr)
index = self.generate_expression(index_expr)
return f"{array}[{index}]"
elif hasattr(expr, "__class__") and "FunctionCall" in str(expr.__class__):
func_expr = getattr(expr, "function", getattr(expr, "name", "unknown"))
func_name = None
if hasattr(func_expr, "name"):
func_name = func_expr.name
callee = func_name
elif isinstance(func_expr, str):
func_name = func_expr
callee = func_expr
else:
callee = self.generate_expression(func_expr)
args = getattr(expr, "arguments", getattr(expr, "args", []))
func_name = self.function_map.get(func_name, func_name)
scalar_cast = self.generate_scalar_constructor_call(func_name, args)
if scalar_cast is not None:
return scalar_cast
vector_info = self.vector_type_info(func_name)
if vector_info:
rust_type = self.map_type(func_name)
generated_args = [self.generate_expression(arg) for arg in args]
if len(generated_args) == 1:
arg_type = self.expression_result_type(args[0])
if arg_type is not None and not self.vector_type_info(arg_type):
generated_args *= vector_info["size"]
args_str = ", ".join(generated_args)
return f"{rust_type}::new({args_str})"
if func_name in [
"mat2",
"mat3",
"mat4",
"mat2x2",
"mat2x3",
"mat2x4",
"mat3x2",
"mat3x3",
"mat3x4",
"mat4x2",
"mat4x3",
"mat4x4",
"dmat2",
"dmat3",
"dmat4",
"dmat2x2",
"dmat2x3",
"dmat2x4",
"dmat3x2",
"dmat3x3",
"dmat3x4",
"dmat4x2",
"dmat4x3",
"dmat4x4",
]:
rust_type = self.map_type(func_name)
args_str = ", ".join(self.generate_expression(arg) for arg in args)
return f"{rust_type}::new({args_str})"
args_str = ", ".join(self.generate_expression(arg) for arg in args)
return f"{callee}({args_str})"
elif hasattr(expr, "__class__") and "MemberAccess" in str(expr.__class__):
obj_expr = getattr(expr, "object_expr", getattr(expr, "object", ""))
member = getattr(expr, "member", "")
obj = self.generate_expression(obj_expr)
return f"{obj}.{member}"
elif hasattr(expr, "__class__") and "TernaryOp" in str(expr.__class__):
condition = self.generate_expression(getattr(expr, "condition", ""))
true_expr = self.generate_expression(getattr(expr, "true_expr", ""))
false_expr = self.generate_expression(getattr(expr, "false_expr", ""))
return f"(if {condition} {{ {true_expr} }} else {{ {false_expr} }})"
else:
return str(expr)
def generate_scalar_constructor_call(self, func_name, args):
rust_type = self.scalar_constructor_type(func_name)
if rust_type is None or len(args) != 1:
return None
arg = args[0]
arg_expr = self.generate_expression(arg)
if rust_type == "bool":
arg_type = self.expression_result_type(arg)
if arg_type == "bool":
return arg_expr
zero_literal = "0.0" if arg_type in {"float", "double", "half"} else "0"
return f"({arg_expr} != {zero_literal})"
return f"({arg_expr} as {rust_type})"
def scalar_constructor_type(self, func_name):
scalar_types = {
"bool": "bool",
"char": "char",
"short": "i16",
"ushort": "u16",
"int": "i32",
"uint": "u32",
"long": "i64",
"ulong": "u64",
"float": "f32",
"double": "f64",
"half": "f16",
"i16": "i16",
"u16": "u16",
"i32": "i32",
"u32": "u32",
"i64": "i64",
"u64": "u64",
"f16": "f16",
"f32": "f32",
"f64": "f64",
}
return scalar_types.get(func_name)
def format_literal(self, value, literal_type=None):
if isinstance(value, bool):
return "true" if value else "false"
if literal_type == "bool" and isinstance(value, str):
lower_value = value.lower()
if lower_value in {"true", "false"}:
return lower_value
if literal_type == "char":
escaped = self.escape_literal(value, quote="'")
return f"'{escaped}'"
if isinstance(value, str):
escaped = self.escape_literal(value, quote='"')
return f'"{escaped}"'
return str(value)
def register_variable_type(self, name, type_name):
if not name or type_name is None:
return
if hasattr(type_name, "name") or hasattr(type_name, "element_type"):
type_name = self.convert_type_node_to_string(type_name)
else:
type_name = str(type_name)
self.variable_types[name] = type_name
def get_expression_name(self, expr):
if isinstance(expr, IdentifierNode):
return expr.name
if isinstance(expr, VariableNode):
return expr.name
if isinstance(expr, str):
return expr
if isinstance(expr, ArrayAccessNode):
array_expr = getattr(expr, "array_expr", getattr(expr, "array", None))
return self.get_expression_name(array_expr)
return None
def expression_result_type(self, expr):
if expr is None:
return None
if isinstance(expr, (IdentifierNode, VariableNode, ArrayAccessNode)):
return self.variable_types.get(self.get_expression_name(expr))
if isinstance(expr, LiteralNode):
literal_type = getattr(getattr(expr, "literal_type", None), "name", None)
if literal_type:
return literal_type
if isinstance(expr.value, bool):
return "bool"
if isinstance(expr.value, float):
return "float"
if isinstance(expr.value, int):
return "int"
return None
if isinstance(expr, FunctionCallNode):
func_expr = getattr(expr, "function", getattr(expr, "name", None))
func_name = getattr(func_expr, "name", func_expr)
if isinstance(func_name, str) and self.vector_type_info(func_name):
return func_name
return None
if isinstance(expr, BinaryOpNode):
left_type = self.expression_result_type(expr.left)
right_type = self.expression_result_type(expr.right)
if self.vector_type_info(left_type):
return left_type
if self.vector_type_info(right_type):
return right_type
return left_type or right_type
if isinstance(expr, UnaryOpNode):
return self.expression_result_type(expr.operand)
if isinstance(expr, TernaryOpNode):
return self.expression_result_type(
expr.true_expr
) or self.expression_result_type(expr.false_expr)
if isinstance(expr, MemberAccessNode):
object_expr = getattr(expr, "object_expr", getattr(expr, "object", None))
object_type = self.expression_result_type(object_expr)
vector_info = self.vector_type_info(object_type)
if not vector_info:
return None
member = getattr(expr, "member", "")
if len(member) == 1:
return vector_info["component_type"]
if all(component in "xyzwrgba" for component in member):
return self.vector_type_for_components(
vector_info["component_type"], len(member)
)
return None
def vector_type_info(self, type_name):
if type_name is None:
return None
if hasattr(type_name, "name") or hasattr(type_name, "element_type"):
type_name = self.convert_type_node_to_string(type_name)
else:
type_name = str(type_name)
mapped_type = self.map_type(type_name)
vector_details = {
"Vec2<f32>": ("float", 2),
"Vec3<f32>": ("float", 3),
"Vec4<f32>": ("float", 4),
"Vec2<f64>": ("double", 2),
"Vec3<f64>": ("double", 3),
"Vec4<f64>": ("double", 4),
"Vec2<i32>": ("int", 2),
"Vec3<i32>": ("int", 3),
"Vec4<i32>": ("int", 4),
"Vec2<u32>": ("uint", 2),
"Vec3<u32>": ("uint", 3),
"Vec4<u32>": ("uint", 4),
"Vec2<bool>": ("bool", 2),
"Vec3<bool>": ("bool", 3),
"Vec4<bool>": ("bool", 4),
}
details = vector_details.get(mapped_type)
if details is None:
return None
component_type, size = details
return {"component_type": component_type, "size": size}
def vector_type_for_components(self, component_type, component_count):
if component_count < 2 or component_count > 4:
return component_type
prefixes = {
"float": "vec",
"double": "dvec",
"int": "ivec",
"uint": "uvec",
"bool": "bvec",
}
prefix = prefixes.get(component_type)
if prefix is None:
return None
return f"{prefix}{component_count}"
def escape_literal(self, value, quote):
text = str(value)
escaped = []
for index, char in enumerate(text):
if char == "\n":
escaped.append("\\n")
elif char == "\r":
escaped.append("\\r")
elif char == "\t":
escaped.append("\\t")
elif char == quote and (index == 0 or text[index - 1] != "\\"):
escaped.append("\\" + char)
else:
escaped.append(char)
return "".join(escaped)
def map_type(self, vtype):
"""Map a CrossGL type name or type node to a Rust type string."""
if vtype is None:
return "f32"
if hasattr(vtype, "name") or hasattr(vtype, "element_type"):
vtype_str = self.convert_type_node_to_string(vtype)
else:
vtype_str = str(vtype)
if "[" in vtype_str and "]" in vtype_str:
base_type, size = parse_array_type(vtype_str)
base_mapped = self.type_mapping.get(base_type, base_type)
if size:
return f"[{base_mapped}; {size}]"
else:
return f"Vec<{base_mapped}>"
return self.type_mapping.get(vtype_str, vtype_str)
def map_operator(self, op):
op_map = {
"PLUS": "+",
"MINUS": "-",
"MULTIPLY": "*",
"DIVIDE": "/",
"BITWISE_XOR": "^",
"BITWISE_OR": "|",
"BITWISE_AND": "&",
"LESS_THAN": "<",
"GREATER_THAN": ">",
"ASSIGN_ADD": "+=",
"ASSIGN_SUB": "-=",
"ASSIGN_MUL": "*=",
"ASSIGN_DIV": "/=",
"ASSIGN_MOD": "%=",
"ASSIGN_XOR": "^=",
"ASSIGN_OR": "|=",
"ASSIGN_AND": "&=",
"LESS_EQUAL": "<=",
"GREATER_EQUAL": ">=",
"EQUAL": "==",
"NOT_EQUAL": "!=",
"AND": "&&",
"OR": "||",
"EQUALS": "=",
"ASSIGN_SHIFT_LEFT": "<<=",
"ASSIGN_SHIFT_RIGHT": ">>=",
"LOGICAL_AND": "&&",
"LOGICAL_OR": "||",
"BITWISE_SHIFT_RIGHT": ">>",
"BITWISE_SHIFT_LEFT": "<<",
"MOD": "%",
"NOT": "!",
}
return op_map.get(op, op)
def map_semantic(self, semantic):
"""Map a CrossGL semantic to the Rust backend attribute name."""
if semantic:
return self.semantic_map.get(semantic, semantic)
return ""