Source code for crosstl.backend.Rust.RustAst

"""Rust AST Node definitions"""

from ..common_ast import (
    ASTNode,
    ArrayAccessNode,
    AssignmentNode,
    AttributeNode,
    BinaryOpNode,
    BreakNode,
    CastNode,
    ContinueNode,
    FunctionCallNode,
    IfNode,
    MemberAccessNode,
    ReturnNode,
    TernaryOpNode,
    UnaryOpNode,
    VariableNode,
    VectorConstructorNode,
    WhileNode,
)


[docs] class ShaderNode(ASTNode): """Root node representing a Rust module/program.""" def __init__( self, structs=None, functions=None, global_variables=None, impl_blocks=None, use_statements=None, traits=None, enums=None, type_aliases=None, ): self.structs = structs or [] self.functions = functions or [] self.global_variables = global_variables or [] self.impl_blocks = impl_blocks or [] self.use_statements = use_statements or [] self.traits = traits or [] self.enums = enums or [] self.type_aliases = type_aliases or [] def __repr__(self): return ( "ShaderNode(" f"structs={len(self.structs)}, " f"functions={len(self.functions)}, " f"globals={len(self.global_variables)}, " f"impl_blocks={len(self.impl_blocks)}, " f"use_statements={len(self.use_statements)}, " f"traits={len(self.traits)}, " f"enums={len(self.enums)}, " f"type_aliases={len(self.type_aliases)})" )
[docs] class StructNode(ASTNode): """Node representing a Rust struct with visibility and attributes.""" def __init__(self, name, members, attributes=None, visibility=None, generics=None): self.name = name self.members = members self.attributes = attributes or [] self.visibility = visibility self.generics = generics or [] def __repr__(self): return ( f"StructNode(name={self.name}, members={len(self.members)}, " f"visibility={self.visibility})" )
[docs] class EnumVariantNode(ASTNode): """Node representing one Rust enum variant.""" def __init__(self, name, kind="unit", fields=None, value=None, attributes=None): self.name = name self.kind = kind self.fields = fields or [] self.value = value self.attributes = attributes or [] def __repr__(self): return ( f"EnumVariantNode(name={self.name}, kind={self.kind}, " f"fields={len(self.fields)}, value={self.value})" )
[docs] class EnumNode(ASTNode): """Node representing a Rust enum declaration.""" def __init__( self, name, variants, attributes=None, visibility=None, generics=None, where_clauses=None, ): self.name = name self.variants = variants self.attributes = attributes or [] self.visibility = visibility self.generics = generics or [] self.where_clauses = where_clauses or [] def __repr__(self): return ( f"EnumNode(name={self.name}, variants={len(self.variants)}, " f"visibility={self.visibility})" )
[docs] class FunctionNode(ASTNode): """Node representing a Rust function.""" def __init__( self, return_type, name, params, body, attributes=None, visibility=None, generics=None, where_clauses=None, ): self.return_type = return_type self.name = name self.params = params self.body = body self.attributes = attributes or [] self.visibility = visibility self.generics = generics or [] self.where_clauses = where_clauses or [] def __repr__(self): return ( f"FunctionNode(name={self.name}, return_type={self.return_type}, " f"params={len(self.params)}, visibility={self.visibility})" )
# Rust-specific nodes
[docs] class ImplNode(ASTNode): """Node representing an impl block""" def __init__( self, struct_name, methods, trait_name=None, generics=None, where_clauses=None, type_aliases=None, ): self.struct_name = struct_name self.methods = methods self.functions = methods self.trait_name = trait_name # For trait implementations self.generics = generics or [] self.where_clauses = where_clauses or [] self.type_aliases = type_aliases or [] def __repr__(self): if self.trait_name: return f"ImplNode(trait={self.trait_name}, for={self.struct_name}, methods={len(self.methods)})" return f"ImplNode(for={self.struct_name}, methods={len(self.methods)})"
[docs] class AssociatedTypeNode(ASTNode): """Node representing a Rust trait associated type declaration.""" def __init__(self, name, bounds=None, default_type=None, where_clauses=None): self.name = name self.bounds = bounds or [] self.default_type = default_type self.where_clauses = where_clauses or [] def __repr__(self): return ( f"AssociatedTypeNode(name={self.name}, bounds={self.bounds}, " f"default_type={self.default_type})" )
[docs] class TypeAliasNode(ASTNode): """Node representing a Rust type alias declaration.""" def __init__( self, name, alias_type, generics=None, visibility=None, where_clauses=None, attributes=None, ): self.name = name self.alias_type = alias_type self.generics = generics or [] self.visibility = visibility self.where_clauses = where_clauses or [] self.attributes = attributes or [] def __repr__(self): return ( f"TypeAliasNode(name={self.name}, alias_type={self.alias_type}, " f"generics={self.generics})" )
[docs] class TraitNode(ASTNode): """Node representing a trait definition""" def __init__( self, name, methods, generics=None, visibility=None, where_clauses=None, associated_types=None, *args, **kwargs, ): self.name = name self.methods = methods self.generics = generics or [] self.visibility = visibility self.where_clauses = where_clauses or [] self.associated_types = associated_types or [] # Handle additional arguments for compatibility for key, value in kwargs.items(): if not hasattr(self, key): setattr(self, key, value) def __repr__(self): return f"TraitNode(name={self.name}, methods={len(self.methods)})"
[docs] class LetNode(ASTNode): """Node representing a let binding""" def __init__(self, name, value, var_type=None, is_mutable=False): self.name = name self.value = value self.var_type = var_type self.vtype = var_type self.is_mutable = is_mutable def __repr__(self): mut = "mut " if self.is_mutable else "" return f"LetNode({mut}{self.name}: {self.var_type} = {self.value})"
[docs] class LoopNode(ASTNode): """Node representing an infinite loop""" def __init__(self, body, label=None): self.body = body self.label = label def __repr__(self): return f"LoopNode(label={self.label}, body={self.body})"
[docs] class ForNode(ASTNode): """Node representing a Rust for-in loop""" def __init__(self, pattern, iterable, body): self.pattern = pattern self.iterable = iterable self.body = body def __repr__(self): return f"ForNode(pattern={self.pattern}, iterable={self.iterable}, body={self.body})"
[docs] class MatchNode(ASTNode): """Node representing a match expression""" def __init__(self, expression, arms): self.expression = expression self.arms = arms def __repr__(self): return f"MatchNode(expression={self.expression}, arms={len(self.arms)})"
[docs] class MatchArmNode(ASTNode): """Node representing a match arm""" def __init__(self, pattern, guard, body): self.pattern = pattern self.guard = guard self.body = body def __repr__(self): return f"MatchArmNode(pattern={self.pattern}, guard={self.guard}, body={self.body})"
[docs] class UseNode(ASTNode): """Node representing a use statement""" def __init__(self, path, alias=None, items=None): self.path = path self.alias = alias self.items = items # For use path::{item1, item2} def __repr__(self): return f"UseNode(path={self.path}, alias={self.alias}, items={self.items})"
[docs] class GenericParameterNode(ASTNode): """Node representing a generic parameter""" def __init__(self, name, bounds=None, default=None): self.name = name self.bounds = bounds or [] self.default = default def __repr__(self): return f"GenericParameterNode(name={self.name}, bounds={self.bounds})"
[docs] class RangeNode(ASTNode): """Node representing a range expression""" def __init__(self, start, end, inclusive=False, step=None): self.start = start self.end = end self.inclusive = inclusive self.step = step def __repr__(self): op = "..=" if self.inclusive else ".." return f"RangeNode({self.start}{op}{self.end})"
[docs] class TupleNode(ASTNode): """Node representing a tuple""" def __init__(self, elements): self.elements = elements def __repr__(self): return f"TupleNode(elements={self.elements})"
[docs] class ArrayNode(ASTNode): """Node representing an array literal""" def __init__(self, elements, size=None): self.elements = elements self.size = size def __repr__(self): return f"ArrayNode(elements={len(self.elements)}, size={self.size})"
[docs] class ReferenceNode(ASTNode): """Node representing a reference (&)""" def __init__(self, expression, is_mutable=False): self.expression = expression self.is_mutable = is_mutable def __repr__(self): mut = "mut " if self.is_mutable else "" return f"ReferenceNode(&{mut}{self.expression})"
[docs] class DereferenceNode(ASTNode): """Node representing a dereference (*)""" def __init__(self, expression): self.expression = expression def __repr__(self): return f"DereferenceNode(*{self.expression})"
[docs] class BlockNode(ASTNode): """Node representing a block expression""" def __init__(self, statements, returns_value=False): self.statements = statements self.returns_value = returns_value def __repr__(self): return f"BlockNode(statements={len(self.statements)}, returns_value={self.returns_value})"
[docs] class ConstNode(ASTNode): """Node representing a const declaration""" def __init__(self, name, const_type, value, visibility=None): self.name = name self.const_type = const_type self.vtype = const_type self.value = value self.visibility = visibility def __repr__(self): return f"ConstNode(name={self.name}, const_type={self.const_type}, value={self.value})"
[docs] class StaticNode(ASTNode): """Node representing a static variable""" def __init__(self, name, static_type, value, is_mutable=False, visibility=None): self.name = name self.static_type = static_type self.vtype = static_type self.value = value self.is_mutable = is_mutable self.visibility = visibility def __repr__(self): mut = "mut " if self.is_mutable else "" return f"StaticNode({mut}{self.name}: {self.static_type} = {self.value})"
[docs] class StructInitializationNode(ASTNode): """Node representing struct initialization""" def __init__(self, struct_name, fields): self.struct_name = struct_name self.fields = fields # Dict of field_name: value def __repr__(self): return f"StructInitializationNode(struct_name={self.struct_name}, fields={self.fields})"