Source code for crosstl.formatter
"""Formatting helpers for generated shader and compute source code."""
import os
import subprocess
import tempfile
from pathlib import Path
from enum import Enum
import logging
import shutil
logger = logging.getLogger(__name__)
[docs]
class ShaderLanguage(Enum):
"""Supported shader languages for formatting"""
HLSL = "hlsl"
GLSL = "glsl"
METAL = "metal"
SPIRV = "spirv"
SLANG = "slang"
MOJO = "mojo"
RUST = "rust"
CUDA = "cuda"
HIP = "hip"
UNKNOWN = "unknown"
[docs]
class CodeFormatter:
"""Formats shader code using appropriate external tools"""
def __init__(self, clang_format_path=None, spirv_tools_path=None):
"""Discover external formatter tools or use explicit tool paths."""
self.clang_format_path = clang_format_path or shutil.which("clang-format")
self.spirv_as_path = spirv_tools_path or shutil.which("spirv-as")
self.spirv_dis_path = spirv_tools_path or shutil.which("spirv-dis")
self.spirv_val_path = spirv_tools_path or shutil.which("spirv-val")
self.has_clang = bool(self.clang_format_path)
self.has_spirv_tools = bool(self.spirv_as_path and self.spirv_dis_path)
if not self.has_clang:
logger.warning(
"clang-format not found. Install it for C-like shader formatting."
)
if not self.has_spirv_tools:
logger.warning(
"SPIRV-Tools not found. Install them for proper SPIR-V handling."
)
[docs]
def detect_language(self, file_path):
"""Detect shader language from file extension"""
ext = Path(file_path).suffix.lower()
if ext in [".hlsl", ".fx"]:
return ShaderLanguage.HLSL
elif ext in [".glsl", ".vert", ".frag", ".comp", ".geom", ".tese", ".tesc"]:
return ShaderLanguage.GLSL
elif ext in [".metal"]:
return ShaderLanguage.METAL
elif ext in [".spv", ".spirv", ".vulkan"]:
return ShaderLanguage.SPIRV
elif ext in [".slang"]:
return ShaderLanguage.SLANG
elif ext in [".rs", ".rust"]:
return ShaderLanguage.RUST
elif ext in [".cu", ".cuh", ".cuda"]:
return ShaderLanguage.CUDA
elif ext in [".hip"]:
return ShaderLanguage.HIP
else:
return ShaderLanguage.UNKNOWN
[docs]
def format_code(self, code, language=None, file_path=None):
"""Format source text for a shader language, falling back to input text."""
if language is None and file_path:
language = self.detect_language(file_path)
if isinstance(language, str):
try:
language = ShaderLanguage(language)
except ValueError:
language = ShaderLanguage.UNKNOWN
if language in [
ShaderLanguage.HLSL,
ShaderLanguage.GLSL,
ShaderLanguage.METAL,
ShaderLanguage.SLANG,
ShaderLanguage.RUST,
ShaderLanguage.CUDA,
ShaderLanguage.HIP,
]:
return self._format_with_clang(code, language)
elif language == ShaderLanguage.SPIRV:
return self._format_spirv(code)
else:
logger.warning(f"No formatter available for {language}")
return code
def _format_with_clang(self, code, language):
"""Format C-like shader code with clang-format."""
if not self.has_clang:
logger.warning("clang-format not available for code formatting")
return code
style_map = {
ShaderLanguage.HLSL: "Microsoft",
ShaderLanguage.GLSL: "Google",
ShaderLanguage.METAL: "LLVM",
ShaderLanguage.SLANG: "Microsoft",
ShaderLanguage.RUST: "LLVM",
ShaderLanguage.CUDA: "Google",
ShaderLanguage.HIP: "Google",
}
style = style_map.get(language, "LLVM")
try:
ext_map = {
ShaderLanguage.HLSL: ".hlsl",
ShaderLanguage.GLSL: ".glsl",
ShaderLanguage.METAL: ".metal",
ShaderLanguage.SLANG: ".slang",
ShaderLanguage.RUST: ".rs",
ShaderLanguage.CUDA: ".cu",
ShaderLanguage.HIP: ".hip",
}
ext = ext_map.get(language, ".txt")
with tempfile.NamedTemporaryFile(
suffix=ext, mode="w+", delete=False
) as tmp:
tmp_path = tmp.name
tmp.write(code)
cmd = [self.clang_format_path, "-style=" + style, "-i", tmp_path]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.error(f"clang-format failed: {result.stderr}")
return code
with open(tmp_path, "r") as f:
formatted_code = f.read()
return formatted_code
except Exception as e:
logger.error(f"Error formatting with clang-format: {e}")
return code
finally:
if "tmp_path" in locals():
try:
os.unlink(tmp_path)
except Exception:
pass
def _format_spirv(self, code):
"""Format SPIR-V assembly code using spirv-as and spirv-dis."""
if not self.has_spirv_tools:
logger.warning("SPIRV-Tools not available for SPIR-V formatting")
return code
try:
with tempfile.NamedTemporaryFile(
suffix=".spvasm", mode="w+", delete=False
) as tmp_in:
tmp_in_path = tmp_in.name
tmp_in.write(code)
tmp_out_path = tmp_in_path + ".spv"
assemble_cmd = [self.spirv_as_path, tmp_in_path, "-o", tmp_out_path]
result = subprocess.run(assemble_cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.error(f"spirv-as failed: {result.stderr}")
return self._make_spirv_readable(code)
disassemble_cmd = [self.spirv_dis_path, tmp_out_path, "--no-color"]
result = subprocess.run(disassemble_cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.error(f"spirv-dis failed: {result.stderr}")
return self._make_spirv_readable(code)
return result.stdout
except Exception as e:
logger.error(f"Error formatting SPIR-V: {e}")
return self._make_spirv_readable(code)
finally:
for path in [tmp_in_path, tmp_out_path]:
if "path" in locals():
try:
os.unlink(path)
except Exception:
pass
def _make_spirv_readable(self, code):
"""Make SPIR-V code more readable without external tools."""
lines = code.split("\n")
result = []
for line in lines:
line = line.strip()
if not line:
continue
if line == "OpCapability Shader":
result.append(line)
elif line == "OpFunction":
result.append(line)
elif line == "OpLabel":
result.append(" " + line)
elif line == "OpStore":
result.append(" " + line)
elif line == "OpFunctionEnd":
result.append(line)
else:
if any(keyword in line for keyword in ["OpFunction", "OpLabel"]):
result.append(line)
elif "OpFunctionEnd" in line:
result.append(line)
else:
result.append(" " + line)
return "\n".join(result)
[docs]
def validate_spirv(self, code):
"""Validate SPIR-V code."""
if not self.has_spirv_tools:
logger.warning("SPIRV-Tools not available for validation")
return False, "SPIRV-Tools not available"
try:
with tempfile.NamedTemporaryFile(
suffix=".spvasm", mode="w+", delete=False
) as tmp:
tmp_path = tmp.name
tmp.write(code)
assemble_cmd = [self.spirv_as_path, "--target-env", "vulkan1.0", tmp_path]
result = subprocess.run(assemble_cmd, capture_output=True, text=True)
if result.returncode != 0:
return False, f"Assembly failed: {result.stderr}"
validate_cmd = [self.spirv_val_path, tmp_path + ".spv"]
result = subprocess.run(validate_cmd, capture_output=True, text=True)
if result.returncode != 0:
return False, f"Validation failed: {result.stderr}"
return True, "Valid SPIR-V code"
except Exception as e:
return False, f"Error validating SPIR-V: {e}"
finally:
if "tmp_path" in locals():
try:
os.unlink(tmp_path)
os.unlink(tmp_path + ".spv")
except Exception:
pass
[docs]
def format_file(file_path, language=None):
"""Format a shader file in-place."""
formatter = CodeFormatter()
try:
with open(file_path, "r") as f:
code = f.read()
formatted_code = formatter.format_code(code, language, file_path)
with open(file_path, "w") as f:
f.write(formatted_code)
return True
except Exception as e:
logger.error(f"Error formatting file {file_path}: {e}")
return False
# Helper function to be called from _crosstl.py
[docs]
def format_shader_code(code, backend, output_path=None):
"""Format shader code based on backend."""
if not backend:
return code
backend_key = backend.lower()
if backend_key in ["cgl", "crossgl"]:
return code
# Map backend to language
language_map = {
"metal": ShaderLanguage.METAL,
"directx": ShaderLanguage.HLSL,
"hlsl": ShaderLanguage.HLSL,
"opengl": ShaderLanguage.GLSL,
"glsl": ShaderLanguage.GLSL,
"vulkan": ShaderLanguage.SPIRV,
"spirv": ShaderLanguage.SPIRV,
"spv": ShaderLanguage.SPIRV,
"mojo": ShaderLanguage.MOJO,
"rust": ShaderLanguage.RUST,
"cuda": ShaderLanguage.CUDA,
"hip": ShaderLanguage.HIP,
"slang": ShaderLanguage.SLANG,
}
language = language_map.get(backend_key)
formatter = CodeFormatter()
return formatter.format_code(code, language, output_path)
# Function that's being patched in tests but doesn't exist
[docs]
def format_code(code, language=None, file_path=None):
"""Format code using the CodeFormatter
This is a convenience function used by format_shader_code
Args:
code: The shader code to format
language: ShaderLanguage enum or string
file_path: Optional file path for language detection
Returns:
Formatted code string
"""
formatter = CodeFormatter()
return formatter.format_code(code, language, file_path)