"""Registry for source-language parsers and reverse code generators."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Dict, Optional, Sequence, Tuple, Any
import os
from .lexer import Lexer as CglLexer
from .parser import Parser as CglParser
def _normalize_source_name(name: str) -> str:
if not isinstance(name, str):
raise TypeError(f"Source name must be a string, got {type(name)}")
return name.strip().lower()
def _normalize_extension(ext: str) -> str:
ext = ext.strip().lower()
if not ext:
return ext
return ext if ext.startswith(".") else f".{ext}"
def _extract_tokens(lexer) -> Any:
if hasattr(lexer, "tokens") and lexer.tokens:
return lexer.tokens
if hasattr(lexer, "tokenize"):
result = lexer.tokenize()
if result is not None:
return result
if hasattr(lexer, "tokens") and lexer.tokens:
return lexer.tokens
if hasattr(lexer, "get_tokens"):
result = lexer.get_tokens()
if result is not None:
return result
if hasattr(lexer, "token_generator"):
return list(lexer.token_generator())
raise ValueError(f"Unsupported lexer interface: {type(lexer)}")
[docs]
@dataclass(frozen=True)
class SourceSpec:
"""Descriptor for a source language frontend.
A source spec connects file extensions and aliases to a lazily imported
lexer/parser pair. Specs can also provide a reverse code generator factory
when the source language can be converted back into CrossGL.
"""
name: str
extensions: Sequence[str]
load_lexer_parser: Callable[[], Tuple[type, type]]
reverse_codegen_factory: Optional[Callable[[], Any]] = None
aliases: Sequence[str] = ()
[docs]
def parse(self, code: str):
"""Parse source code into that source backend's AST."""
lexer_cls, parser_cls = self.load_lexer_parser()
lexer = lexer_cls(code)
tokens = _extract_tokens(lexer)
parser = parser_cls(tokens)
return parser.parse()
[docs]
class SourceRegistry:
"""Lookup table for source parsers by name, alias, and extension."""
def __init__(self) -> None:
self._by_name: Dict[str, SourceSpec] = {}
self._by_alias: Dict[str, str] = {}
self._by_extension: Dict[str, str] = {}
[docs]
def register(self, spec: SourceSpec, *, overwrite: bool = False) -> SourceSpec:
"""Register a source spec and all of its aliases/extensions."""
name = _normalize_source_name(spec.name)
if name in self._by_name and not overwrite:
existing = self._by_name[name]
if existing.load_lexer_parser is spec.load_lexer_parser:
return existing
raise ValueError(f"Source '{name}' already registered")
self._by_name[name] = spec
for alias in spec.aliases:
alias_key = _normalize_source_name(alias)
if alias_key in self._by_alias and not overwrite:
if self._by_alias[alias_key] == name:
continue
raise ValueError(f"Source alias '{alias_key}' already registered")
self._by_alias[alias_key] = name
for ext in spec.extensions:
ext_key = _normalize_extension(ext)
if not ext_key:
continue
if ext_key in self._by_extension and not overwrite:
if self._by_extension[ext_key] == name:
continue
raise ValueError(f"Extension '{ext_key}' already registered")
self._by_extension[ext_key] = name
return spec
[docs]
def resolve_name(self, name: str) -> Optional[str]:
"""Resolve a source name or alias to its canonical registry name."""
if not name:
return None
key = _normalize_source_name(name)
if key in self._by_name:
return key
return self._by_alias.get(key)
[docs]
def get(self, name: str) -> Optional[SourceSpec]:
"""Return the source spec registered for a name or alias."""
resolved = self.resolve_name(name)
if not resolved:
return None
return self._by_name.get(resolved)
[docs]
def get_by_extension(self, path_or_ext: str) -> Optional[SourceSpec]:
"""Return the source spec registered for a file path or extension."""
ext = path_or_ext
if path_or_ext:
looks_like_path = os.path.basename(path_or_ext) != path_or_ext
looks_like_filename = not path_or_ext.startswith(".") and "." in path_or_ext
if looks_like_path or looks_like_filename:
_, ext = os.path.splitext(path_or_ext)
ext_key = _normalize_extension(ext or "")
name = self._by_extension.get(ext_key)
if not name:
return None
return self._by_name.get(name)
[docs]
def names(self) -> Sequence[str]:
"""Return registered canonical source names in sorted order."""
return sorted(self._by_name.keys())
[docs]
def extensions(self) -> Sequence[str]:
"""Return registered source file extensions in sorted order."""
return sorted(self._by_extension.keys())
SOURCE_REGISTRY = SourceRegistry()
def _load_cgl():
return CglLexer, CglParser
def _load_directx():
from crosstl.backend.DirectX import HLSLLexer, HLSLParser
return HLSLLexer, HLSLParser
def _load_metal():
from crosstl.backend.Metal import MetalLexer, MetalParser
return MetalLexer, MetalParser
def _load_glsl():
from crosstl.backend.GLSL import GLSLLexer, GLSLParser
return GLSLLexer, GLSLParser
def _load_slang():
from crosstl.backend.slang import SlangLexer, SlangParser
return SlangLexer, SlangParser
def _load_spirv():
from crosstl.backend.SPIRV import VulkanLexer, VulkanParser
return VulkanLexer, VulkanParser
def _load_mojo():
from crosstl.backend.Mojo import MojoLexer, MojoParser
return MojoLexer, MojoParser
def _load_rust():
from crosstl.backend.Rust import RustLexer, RustParser
return RustLexer, RustParser
def _load_cuda():
from crosstl.backend.CUDA import CudaLexer, CudaParser
return CudaLexer, CudaParser
def _load_hip():
from crosstl.backend.HIP import HipLexer, HipParser
return HipLexer, HipParser
def _reverse_directx():
from crosstl.backend.DirectX.DirectxCrossGLCodeGen import HLSLToCrossGLConverter
return HLSLToCrossGLConverter()
def _reverse_metal():
from crosstl.backend.Metal.MetalCrossGLCodeGen import MetalToCrossGLConverter
return MetalToCrossGLConverter()
def _reverse_glsl():
from crosstl.backend.GLSL.openglCrossglCodegen import GLSLToCrossGLConverter
return GLSLToCrossGLConverter()
def _reverse_slang():
from crosstl.backend.slang.SlangCrossGLCodeGen import SlangToCrossGLConverter
return SlangToCrossGLConverter()
def _reverse_spirv():
from crosstl.backend.SPIRV.VulkanCrossGLCodeGen import VulkanToCrossGLConverter
return VulkanToCrossGLConverter()
def _reverse_mojo():
from crosstl.backend.Mojo.MojoCrossGLCodeGen import MojoToCrossGLConverter
return MojoToCrossGLConverter()
def _reverse_rust():
from crosstl.backend.Rust.RustCrossGLCodeGen import RustToCrossGLConverter
return RustToCrossGLConverter()
def _reverse_cuda():
from crosstl.backend.CUDA.CudaCrossGLCodeGen import CudaToCrossGLConverter
return CudaToCrossGLConverter()
def _reverse_hip():
from crosstl.backend.HIP.HipCrossGLCodeGen import HipToCrossGLConverter
return HipToCrossGLConverter()
[docs]
def register_default_sources() -> None:
"""Register the built-in CrossGL and native source frontends."""
def _register(spec: SourceSpec) -> None:
try:
SOURCE_REGISTRY.register(spec)
except ValueError:
return
_register(
SourceSpec(
name="cgl",
extensions=(".cgl",),
load_lexer_parser=_load_cgl,
aliases=("crossgl",),
)
)
_register(
SourceSpec(
name="directx",
extensions=(".hlsl",),
load_lexer_parser=_load_directx,
reverse_codegen_factory=_reverse_directx,
aliases=("hlsl", "dx"),
)
)
_register(
SourceSpec(
name="metal",
extensions=(".metal",),
load_lexer_parser=_load_metal,
reverse_codegen_factory=_reverse_metal,
aliases=("metal",),
)
)
_register(
SourceSpec(
name="opengl",
extensions=(".glsl",),
load_lexer_parser=_load_glsl,
reverse_codegen_factory=_reverse_glsl,
aliases=("glsl", "ogl"),
)
)
_register(
SourceSpec(
name="slang",
extensions=(".slang",),
load_lexer_parser=_load_slang,
reverse_codegen_factory=_reverse_slang,
aliases=("slang",),
)
)
_register(
SourceSpec(
name="vulkan",
extensions=(".spv", ".spirv"),
load_lexer_parser=_load_spirv,
reverse_codegen_factory=_reverse_spirv,
aliases=("spirv", "spv"),
)
)
_register(
SourceSpec(
name="mojo",
extensions=(".mojo",),
load_lexer_parser=_load_mojo,
reverse_codegen_factory=_reverse_mojo,
aliases=("mojo",),
)
)
_register(
SourceSpec(
name="rust",
extensions=(".rs", ".rust"),
load_lexer_parser=_load_rust,
reverse_codegen_factory=_reverse_rust,
aliases=("rust", "rs"),
)
)
_register(
SourceSpec(
name="cuda",
extensions=(".cu", ".cuh", ".cuda"),
load_lexer_parser=_load_cuda,
reverse_codegen_factory=_reverse_cuda,
aliases=("cuda", "cu"),
)
)
_register(
SourceSpec(
name="hip",
extensions=(".hip",),
load_lexer_parser=_load_hip,
reverse_codegen_factory=_reverse_hip,
aliases=("hip",),
)
)