Source code for crosstl.backend.GLSL.preprocessor

"""Preprocessor support for GLSL source imports."""

import re
from typing import Dict, List, Optional

from crosstl.backend.DirectX.preprocessor import HLSLPreprocessor


[docs] class GLSLPreprocessor: """GLSL preprocessor wrapper that preserves #version placement rules.""" def __init__( self, include_paths: Optional[List[str]] = None, defines: Optional[Dict[str, str]] = None, strict: bool = True, max_expansion_depth: int = 64, ): self.strict = strict self._preprocessor = HLSLPreprocessor( include_paths=include_paths, defines=defines, strict=strict, max_expansion_depth=max_expansion_depth, ) def preprocess(self, code: str, file_path: Optional[str] = None) -> str: self._ensure_version_first(code) processed = self._preprocessor.preprocess(code, file_path=file_path) self._ensure_version_first(processed) return processed def _ensure_version_first(self, code: str): version_index = self._find_version_index(code) if version_index is None: return if self._has_tokens_before(code, version_index): if self.strict: raise SyntaxError("#version must appear before any other tokens") def _find_version_index(self, code: str) -> Optional[int]: i = 0 while i < len(code): ch = code[i] if ch.isspace(): i += 1 continue if code.startswith("//", i): i = self._skip_line_comment(code, i) continue if code.startswith("/*", i): i = self._skip_block_comment(code, i) continue if ch == "#": directive = self._read_directive(code, i) if directive == "version": return i i = self._skip_line_comment(code, i) continue return None return None def _has_tokens_before(self, code: str, version_index: int) -> bool: i = 0 while i < version_index: ch = code[i] if ch.isspace(): i += 1 continue if code.startswith("//", i): i = self._skip_line_comment(code, i) continue if code.startswith("/*", i): i = self._skip_block_comment(code, i) continue return True return False def _read_directive(self, code: str, start: int) -> str: match = re.match(r"#\s*([A-Za-z_][A-Za-z0-9_]*)", code[start:]) if not match: return "" return match.group(1) def _skip_line_comment(self, code: str, start: int) -> int: end = code.find("\n", start) if end == -1: return len(code) return end + 1 def _skip_block_comment(self, code: str, start: int) -> int: end = code.find("*/", start + 2) if end == -1: return len(code) return end + 2