Source code for crosstl.formatter

"""Formatting helpers for generated shader and compute source code."""

import os
import subprocess
import tempfile
from pathlib import Path
from enum import Enum
import logging
import shutil

logger = logging.getLogger(__name__)


[docs] class ShaderLanguage(Enum): """Supported shader languages for formatting""" HLSL = "hlsl" GLSL = "glsl" METAL = "metal" SPIRV = "spirv" SLANG = "slang" MOJO = "mojo" RUST = "rust" CUDA = "cuda" HIP = "hip" UNKNOWN = "unknown"
[docs] class CodeFormatter: """Formats shader code using appropriate external tools""" def __init__(self, clang_format_path=None, spirv_tools_path=None): """Discover external formatter tools or use explicit tool paths.""" self.clang_format_path = clang_format_path or shutil.which("clang-format") self.spirv_as_path = spirv_tools_path or shutil.which("spirv-as") self.spirv_dis_path = spirv_tools_path or shutil.which("spirv-dis") self.spirv_val_path = spirv_tools_path or shutil.which("spirv-val") self.has_clang = bool(self.clang_format_path) self.has_spirv_tools = bool(self.spirv_as_path and self.spirv_dis_path) if not self.has_clang: logger.warning( "clang-format not found. Install it for C-like shader formatting." ) if not self.has_spirv_tools: logger.warning( "SPIRV-Tools not found. Install them for proper SPIR-V handling." )
[docs] def detect_language(self, file_path): """Detect shader language from file extension""" ext = Path(file_path).suffix.lower() if ext in [".hlsl", ".fx"]: return ShaderLanguage.HLSL elif ext in [".glsl", ".vert", ".frag", ".comp", ".geom", ".tese", ".tesc"]: return ShaderLanguage.GLSL elif ext in [".metal"]: return ShaderLanguage.METAL elif ext in [".spv", ".spirv", ".vulkan"]: return ShaderLanguage.SPIRV elif ext in [".slang"]: return ShaderLanguage.SLANG elif ext in [".rs", ".rust"]: return ShaderLanguage.RUST elif ext in [".cu", ".cuh", ".cuda"]: return ShaderLanguage.CUDA elif ext in [".hip"]: return ShaderLanguage.HIP else: return ShaderLanguage.UNKNOWN
[docs] def format_code(self, code, language=None, file_path=None): """Format source text for a shader language, falling back to input text.""" if language is None and file_path: language = self.detect_language(file_path) if isinstance(language, str): try: language = ShaderLanguage(language) except ValueError: language = ShaderLanguage.UNKNOWN if language in [ ShaderLanguage.HLSL, ShaderLanguage.GLSL, ShaderLanguage.METAL, ShaderLanguage.SLANG, ShaderLanguage.RUST, ShaderLanguage.CUDA, ShaderLanguage.HIP, ]: return self._format_with_clang(code, language) elif language == ShaderLanguage.SPIRV: return self._format_spirv(code) else: logger.warning(f"No formatter available for {language}") return code
def _format_with_clang(self, code, language): """Format C-like shader code with clang-format.""" if not self.has_clang: logger.warning("clang-format not available for code formatting") return code style_map = { ShaderLanguage.HLSL: "Microsoft", ShaderLanguage.GLSL: "Google", ShaderLanguage.METAL: "LLVM", ShaderLanguage.SLANG: "Microsoft", ShaderLanguage.RUST: "LLVM", ShaderLanguage.CUDA: "Google", ShaderLanguage.HIP: "Google", } style = style_map.get(language, "LLVM") try: ext_map = { ShaderLanguage.HLSL: ".hlsl", ShaderLanguage.GLSL: ".glsl", ShaderLanguage.METAL: ".metal", ShaderLanguage.SLANG: ".slang", ShaderLanguage.RUST: ".rs", ShaderLanguage.CUDA: ".cu", ShaderLanguage.HIP: ".hip", } ext = ext_map.get(language, ".txt") with tempfile.NamedTemporaryFile( suffix=ext, mode="w+", delete=False ) as tmp: tmp_path = tmp.name tmp.write(code) cmd = [self.clang_format_path, "-style=" + style, "-i", tmp_path] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: logger.error(f"clang-format failed: {result.stderr}") return code with open(tmp_path, "r") as f: formatted_code = f.read() return formatted_code except Exception as e: logger.error(f"Error formatting with clang-format: {e}") return code finally: if "tmp_path" in locals(): try: os.unlink(tmp_path) except Exception: pass def _format_spirv(self, code): """Format SPIR-V assembly code using spirv-as and spirv-dis.""" if not self.has_spirv_tools: logger.warning("SPIRV-Tools not available for SPIR-V formatting") return code try: with tempfile.NamedTemporaryFile( suffix=".spvasm", mode="w+", delete=False ) as tmp_in: tmp_in_path = tmp_in.name tmp_in.write(code) tmp_out_path = tmp_in_path + ".spv" assemble_cmd = [self.spirv_as_path, tmp_in_path, "-o", tmp_out_path] result = subprocess.run(assemble_cmd, capture_output=True, text=True) if result.returncode != 0: logger.error(f"spirv-as failed: {result.stderr}") return self._make_spirv_readable(code) disassemble_cmd = [self.spirv_dis_path, tmp_out_path, "--no-color"] result = subprocess.run(disassemble_cmd, capture_output=True, text=True) if result.returncode != 0: logger.error(f"spirv-dis failed: {result.stderr}") return self._make_spirv_readable(code) return result.stdout except Exception as e: logger.error(f"Error formatting SPIR-V: {e}") return self._make_spirv_readable(code) finally: for path in [tmp_in_path, tmp_out_path]: if "path" in locals(): try: os.unlink(path) except Exception: pass def _make_spirv_readable(self, code): """Make SPIR-V code more readable without external tools.""" lines = code.split("\n") result = [] for line in lines: line = line.strip() if not line: continue if line == "OpCapability Shader": result.append(line) elif line == "OpFunction": result.append(line) elif line == "OpLabel": result.append(" " + line) elif line == "OpStore": result.append(" " + line) elif line == "OpFunctionEnd": result.append(line) else: if any(keyword in line for keyword in ["OpFunction", "OpLabel"]): result.append(line) elif "OpFunctionEnd" in line: result.append(line) else: result.append(" " + line) return "\n".join(result)
[docs] def validate_spirv(self, code): """Validate SPIR-V code.""" if not self.has_spirv_tools: logger.warning("SPIRV-Tools not available for validation") return False, "SPIRV-Tools not available" try: with tempfile.NamedTemporaryFile( suffix=".spvasm", mode="w+", delete=False ) as tmp: tmp_path = tmp.name tmp.write(code) assemble_cmd = [self.spirv_as_path, "--target-env", "vulkan1.0", tmp_path] result = subprocess.run(assemble_cmd, capture_output=True, text=True) if result.returncode != 0: return False, f"Assembly failed: {result.stderr}" validate_cmd = [self.spirv_val_path, tmp_path + ".spv"] result = subprocess.run(validate_cmd, capture_output=True, text=True) if result.returncode != 0: return False, f"Validation failed: {result.stderr}" return True, "Valid SPIR-V code" except Exception as e: return False, f"Error validating SPIR-V: {e}" finally: if "tmp_path" in locals(): try: os.unlink(tmp_path) os.unlink(tmp_path + ".spv") except Exception: pass
[docs] def format_file(file_path, language=None): """Format a shader file in-place.""" formatter = CodeFormatter() try: with open(file_path, "r") as f: code = f.read() formatted_code = formatter.format_code(code, language, file_path) with open(file_path, "w") as f: f.write(formatted_code) return True except Exception as e: logger.error(f"Error formatting file {file_path}: {e}") return False
# Helper function to be called from _crosstl.py
[docs] def format_shader_code(code, backend, output_path=None): """Format shader code based on backend.""" if not backend: return code backend_key = backend.lower() if backend_key in ["cgl", "crossgl"]: return code # Map backend to language language_map = { "metal": ShaderLanguage.METAL, "directx": ShaderLanguage.HLSL, "hlsl": ShaderLanguage.HLSL, "opengl": ShaderLanguage.GLSL, "glsl": ShaderLanguage.GLSL, "vulkan": ShaderLanguage.SPIRV, "spirv": ShaderLanguage.SPIRV, "spv": ShaderLanguage.SPIRV, "mojo": ShaderLanguage.MOJO, "rust": ShaderLanguage.RUST, "cuda": ShaderLanguage.CUDA, "hip": ShaderLanguage.HIP, "slang": ShaderLanguage.SLANG, } language = language_map.get(backend_key) formatter = CodeFormatter() return formatter.format_code(code, language, output_path)
# Function that's being patched in tests but doesn't exist
[docs] def format_code(code, language=None, file_path=None): """Format code using the CodeFormatter This is a convenience function used by format_shader_code Args: code: The shader code to format language: ShaderLanguage enum or string file_path: Optional file path for language detection Returns: Formatted code string """ formatter = CodeFormatter() return formatter.format_code(code, language, file_path)