Source code for crosstl.translator.codegen.ast_utils

"""AST utilities for CrossGL code generators."""

from typing import Optional, List
from ..ast import (
    TypeNode,
    PrimitiveType,
    VectorType,
    MatrixType,
    ArrayType,
    PointerType,
    ReferenceType,
    FunctionType,
    GenericType,
    NamedType,
    StructMemberNode,
    VariableNode,
    ParameterNode,
    AttributeNode,
    FunctionNode,
    StructNode,
    ExpressionNode,
    StatementNode,
    BlockNode,
)


[docs] class ASTUtils: """Utilities for AST processing and type conversion."""
[docs] @staticmethod def get_type_string(type_node: TypeNode, backend: str = "generic") -> str: """Convert a TypeNode to a backend-specific string.""" if isinstance(type_node, PrimitiveType): return ASTUtils._map_primitive_type(type_node.name, backend) elif isinstance(type_node, VectorType): element_type = ASTUtils.get_type_string(type_node.element_type, backend) return ASTUtils._map_vector_type(element_type, type_node.size, backend) elif isinstance(type_node, MatrixType): element_type = ASTUtils.get_type_string(type_node.element_type, backend) return ASTUtils._map_matrix_type( element_type, type_node.rows, type_node.cols, backend ) elif isinstance(type_node, ArrayType): element_type = ASTUtils.get_type_string(type_node.element_type, backend) if type_node.size is not None: if isinstance(type_node.size, int): return f"{element_type}[{type_node.size}]" else: return f"{element_type}[{ASTUtils.expression_to_string(type_node.size)}]" else: return f"{element_type}[]" elif isinstance(type_node, PointerType): pointee_type = ASTUtils.get_type_string(type_node.pointee_type, backend) return ASTUtils._map_pointer_type( pointee_type, type_node.is_mutable, backend ) elif isinstance(type_node, ReferenceType): referenced_type = ASTUtils.get_type_string( type_node.referenced_type, backend ) return ASTUtils._map_reference_type( referenced_type, type_node.is_mutable, backend ) elif isinstance(type_node, FunctionType): return_type = ASTUtils.get_type_string(type_node.return_type, backend) param_types = [ ASTUtils.get_type_string(pt, backend) for pt in type_node.param_types ] return ASTUtils._map_function_type(return_type, param_types, backend) elif isinstance(type_node, GenericType): return type_node.name elif isinstance(type_node, NamedType): if type_node.generic_args: args = [ ASTUtils.get_type_string(arg, backend) for arg in type_node.generic_args ] return f"{type_node.name}<{', '.join(args)}>" return type_node.name else: return str(type_node)
@staticmethod def _map_primitive_type(type_name: str, backend: str) -> str: """Map primitive type names to backend-specific types.""" type_mappings = { "generic": { "void": "void", "bool": "bool", "int": "int", "float": "float", "double": "double", "char": "char", "uint": "uint", }, "metal": { "void": "void", "bool": "bool", "int": "int", "float": "float", "double": "double", "char": "int", "uint": "uint", "half": "half", }, "directx": { "void": "void", "bool": "bool", "int": "int", "float": "float", "double": "double", "char": "int", "uint": "uint", }, "opengl": { "void": "void", "bool": "bool", "int": "int", "float": "float", "double": "double", "char": "int", "uint": "uint", }, "vulkan": { "void": "void", "bool": "bool", "int": "int", "float": "float", "double": "double", "char": "int", "uint": "uint", }, "rust": { "void": "()", "bool": "bool", "int": "i32", "float": "f32", "double": "f64", "char": "i8", "uint": "u32", }, "cuda": { "void": "void", "bool": "bool", "int": "int", "float": "float", "double": "double", "char": "char", "uint": "unsigned int", }, "hip": { "void": "void", "bool": "bool", "int": "int", "float": "float", "double": "double", "char": "char", "uint": "unsigned int", }, "mojo": { "void": "None", "bool": "Bool", "int": "Int32", "float": "Float32", "double": "Float64", "char": "Int8", "uint": "UInt32", }, } mapping = type_mappings.get(backend, type_mappings["generic"]) return mapping.get(type_name, type_name) @staticmethod def _map_vector_type(element_type: str, size: int, backend: str) -> str: """Map vector types to backend-specific representations.""" if backend == "metal": return f"{element_type}{size}" elif backend == "directx": return f"{element_type}{size}" elif backend in ["opengl", "vulkan"]: if element_type == "float": return f"vec{size}" elif element_type == "int": return f"ivec{size}" elif element_type == "uint": return f"uvec{size}" elif element_type == "bool": return f"bvec{size}" elif backend == "rust": return f"Vec{size}<{element_type}>" elif backend in ["cuda", "hip"]: return f"{element_type}{size}" elif backend == "mojo": return f"SIMD[DType.{element_type.lower()}, {size}]" return f"{element_type}{size}" @staticmethod def _map_matrix_type(element_type: str, rows: int, cols: int, backend: str) -> str: """Map matrix types to backend-specific representations.""" if backend == "metal": return f"{element_type}{cols}x{rows}" elif backend == "directx": return f"{element_type}{rows}x{cols}" elif backend in ["opengl", "vulkan"]: if rows == cols: return f"mat{rows}" else: return f"mat{cols}x{rows}" elif backend == "rust": return f"Mat{rows}x{cols}<{element_type}>" elif backend in ["cuda", "hip"]: return f"{element_type}{rows}x{cols}" elif backend == "mojo": return f"Matrix[DType.{element_type.lower()}, {rows}, {cols}]" return f"{element_type}{rows}x{cols}" @staticmethod def _map_pointer_type(pointee_type: str, is_mutable: bool, backend: str) -> str: """Map pointer types to backend-specific representations.""" if backend == "rust": return f"*{'mut' if is_mutable else 'const'} {pointee_type}" elif backend in ["cuda", "hip"]: return f"{pointee_type}*" else: return f"{pointee_type}*" @staticmethod def _map_reference_type( referenced_type: str, is_mutable: bool, backend: str ) -> str: """Map reference types to backend-specific representations.""" if backend == "rust": return f"&{'mut ' if is_mutable else ''}{referenced_type}" else: return f"{referenced_type}&" @staticmethod def _map_function_type( return_type: str, param_types: List[str], backend: str ) -> str: """Map function types to backend-specific representations.""" params = ", ".join(param_types) if backend == "rust": return f"fn({params}) -> {return_type}" elif backend == "mojo": return f"fn({params}) -> {return_type}" else: return f"{return_type}({params})"
[docs] @staticmethod def get_semantic_from_attributes(attributes: List[AttributeNode]) -> Optional[str]: """Return the first recognized semantic attribute name.""" semantic_attrs = [ "position", "color", "texcoord", "normal", "tangent", "binormal", "POSITION", "COLOR", "TEXCOORD", "NORMAL", "TANGENT", "BINORMAL", "TEXCOORD0", "TEXCOORD1", "TEXCOORD2", "TEXCOORD3", "TEXCOORD4", "TEXCOORD5", "TEXCOORD6", "TEXCOORD7", "COLOR0", "COLOR1", ] for attr in attributes: if attr.name in semantic_attrs: return attr.name return None
[docs] @staticmethod def get_member_info(member: StructMemberNode, backend: str = "generic"): """Return normalized metadata for a struct member.""" return { "name": member.name, "type": ASTUtils.get_type_string(member.member_type, backend), "semantic": ASTUtils.get_semantic_from_attributes(member.attributes), "attributes": member.attributes, "visibility": member.visibility, "default_value": member.default_value, }
[docs] @staticmethod def get_variable_info(variable: VariableNode, backend: str = "generic"): """Return normalized metadata for a variable declaration.""" return { "name": variable.name, "type": ASTUtils.get_type_string(variable.var_type, backend), "semantic": ASTUtils.get_semantic_from_attributes(variable.attributes), "attributes": variable.attributes, "qualifiers": variable.qualifiers, "is_mutable": variable.is_mutable, "initial_value": variable.initial_value, "visibility": variable.visibility, }
[docs] @staticmethod def get_parameter_info(parameter: ParameterNode, backend: str = "generic"): """Return normalized metadata for a function parameter.""" return { "name": parameter.name, "type": ASTUtils.get_type_string(parameter.param_type, backend), "semantic": ASTUtils.get_semantic_from_attributes(parameter.attributes), "attributes": parameter.attributes, "is_mutable": parameter.is_mutable, "default_value": parameter.default_value, }
[docs] @staticmethod def get_function_info(function: FunctionNode, backend: str = "generic"): """Return normalized metadata for a function declaration.""" return { "name": function.name, "return_type": ASTUtils.get_type_string(function.return_type, backend), "parameters": [ ASTUtils.get_parameter_info(p, backend) for p in function.parameters ], "qualifiers": function.qualifiers, "attributes": function.attributes, "visibility": function.visibility, "is_unsafe": function.is_unsafe, "is_async": function.is_async, "body": function.body, }
[docs] @staticmethod def expression_to_string(expr: ExpressionNode) -> str: """Render a simple expression node to a string for declarations.""" # Simplified — a full implementation would use an expression visitor. if hasattr(expr, "value"): return str(expr.value) elif hasattr(expr, "name"): return str(expr.name) else: return str(expr)
[docs] @staticmethod def is_legacy_ast_node(node) -> bool: """Return whether a node uses the older ``vtype`` string shape.""" return hasattr(node, "vtype") and isinstance(getattr(node, "vtype", None), str)
[docs] @staticmethod def get_legacy_compatible_type(node, backend: str = "generic") -> str: """Return a type string for either legacy or current AST nodes.""" if ASTUtils.is_legacy_ast_node(node): return getattr(node, "vtype", "float") else: if hasattr(node, "var_type"): return ASTUtils.get_type_string(node.var_type, backend) elif hasattr(node, "member_type"): return ASTUtils.get_type_string(node.member_type, backend) elif hasattr(node, "param_type"): return ASTUtils.get_type_string(node.param_type, backend) else: return "float"
[docs] @staticmethod def get_legacy_compatible_semantic(node) -> Optional[str]: """Return semantic metadata from either legacy or current AST nodes.""" if ASTUtils.is_legacy_ast_node(node): return getattr(node, "semantic", None) else: if hasattr(node, "attributes"): return ASTUtils.get_semantic_from_attributes(node.attributes) else: return None
[docs] @staticmethod def safe_get_body_statements(body): """Extract statements from function body, handling both old and new AST.""" if body is None: return [] elif isinstance(body, BlockNode): return body.statements elif isinstance(body, list): return body else: return [body]
[docs] @staticmethod def safe_get_function_qualifier(function: FunctionNode) -> Optional[str]: """Get function qualifier, handling both old and new AST.""" if hasattr(function, "qualifiers") and function.qualifiers: return function.qualifiers[0] elif hasattr(function, "qualifier"): return function.qualifier else: return None