Source code for crosstl.backend.DirectX.preprocessor

"""Preprocessor support for DirectX HLSL source imports."""

import os
import re
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple


[docs] @dataclass class Macro: """Object-like or function-like HLSL preprocessor macro.""" name: str params: Optional[List[str]] = None replacement: str = "" is_variadic: bool = False def is_function_like(self) -> bool: return self.params is not None
[docs] class HLSLPreprocessor: """Small HLSL preprocessor used before lexing imported source files.""" def __init__( self, include_paths: Optional[List[str]] = None, defines: Optional[Dict[str, str]] = None, strict: bool = False, max_expansion_depth: int = 64, ): self.include_paths = include_paths or [] self.macros: Dict[str, Macro] = {} self.strict = strict self.max_expansion_depth = max_expansion_depth if defines: for name, value in defines.items(): self.macros[name] = Macro( name=name, params=None, replacement=str(value) ) def preprocess(self, code: str, file_path: Optional[str] = None) -> str: logical_lines = self._split_logical_lines(code) output_lines = self._process_lines(logical_lines, file_path) return "\n".join(output_lines) def _split_logical_lines(self, code: str) -> List[str]: lines = code.splitlines() logical_lines: List[str] = [] buffer = "" for line in lines: stripped = line.rstrip() if stripped.endswith("\\"): buffer += stripped[:-1] continue buffer += line logical_lines.append(buffer) buffer = "" if buffer: logical_lines.append(buffer) return logical_lines def _process_lines(self, lines: List[str], file_path: Optional[str]) -> List[str]: output: List[str] = [] conditional_stack: List[Dict[str, bool]] = [] current_line = 1 line_override: Optional[int] = None def is_active() -> bool: return all(frame["active"] for frame in conditional_stack) for raw_line in lines: line = raw_line stripped = line.lstrip() active = is_active() if stripped.startswith("#"): directive, rest = self._parse_directive(stripped) if directive in ("if", "ifdef", "ifndef"): condition = self._evaluate_condition(directive, rest, current_line) parent_active = active active_now = parent_active and condition conditional_stack.append( { "parent_active": parent_active, "active": active_now, "branch_taken": condition, } ) elif directive == "elif": if not conditional_stack: raise SyntaxError("#elif without #if") frame = conditional_stack[-1] if frame["parent_active"] and not frame["branch_taken"]: condition = self._evaluate_expression( self._expand_macros(rest, current_line, True) ) frame["active"] = condition frame["branch_taken"] = condition else: frame["active"] = False elif directive == "else": if not conditional_stack: raise SyntaxError("#else without #if") frame = conditional_stack[-1] frame["active"] = ( frame["parent_active"] and not frame["branch_taken"] ) frame["branch_taken"] = True elif directive == "endif": if not conditional_stack: raise SyntaxError("#endif without #if") conditional_stack.pop() elif not active: pass elif directive == "define": self._handle_define(rest) elif directive == "undef": name = rest.strip() if name in self.macros: del self.macros[name] elif directive == "include": included_text = self._handle_include(rest, file_path) if included_text is not None: nested_lines = self._split_logical_lines(included_text) output.extend(self._process_lines(nested_lines, file_path)) elif directive == "line": line_override = self._handle_line_directive(rest) if line_override is not None: current_line = line_override elif directive in ("error", "warning"): if directive == "error" or self.strict: raise SyntaxError(f"#{directive}: {rest.strip()}") else: if active: output.append(line) else: if active: expanded = self._expand_macros(line, current_line, False) output.append(expanded) current_line += 1 if conditional_stack: raise SyntaxError("Unterminated #if block") return output def _parse_directive(self, line: str) -> Tuple[str, str]: match = re.match(r"#\s*([A-Za-z_][A-Za-z0-9_]*)\s*(.*)", line) if not match: return "", "" return match.group(1), match.group(2) def _evaluate_condition(self, directive: str, rest: str, line_num: int) -> bool: if directive == "ifdef": name = rest.strip() return name in self.macros if directive == "ifndef": name = rest.strip() return name not in self.macros return bool( self._evaluate_expression(self._expand_macros(rest, line_num, True)) ) def _evaluate_expression(self, expr: str) -> int: tokenizer = _ExpressionTokenizer(expr) parser = _ExpressionParser(tokenizer) return parser.parse_expression() def _handle_define(self, rest: str): rest = rest.lstrip() name_match = re.match(r"[A-Za-z_][A-Za-z0-9_]*", rest) if not name_match: return name = name_match.group(0) after = rest[name_match.end() :] if after.startswith("("): params, remainder, is_variadic = self._parse_macro_params(after) replacement = remainder.lstrip() self.macros[name] = Macro( name=name, params=params, replacement=replacement, is_variadic=is_variadic, ) else: replacement = after.lstrip() self.macros[name] = Macro(name=name, params=None, replacement=replacement) def _parse_macro_params(self, text: str) -> Tuple[List[str], str, bool]: assert text[0] == "(" depth = 0 params_text = "" i = 0 while i < len(text): ch = text[i] if ch == "(": depth += 1 elif ch == ")": depth -= 1 if depth == 0: params_text = text[1:i] remainder = text[i + 1 :] break i += 1 else: params_text = text[1:] remainder = "" params = [p.strip() for p in params_text.split(",") if p.strip()] is_variadic = False if params and params[-1] == "...": params[-1] = "__VA_ARGS__" is_variadic = True elif params and params[-1].endswith("..."): params[-1] = params[-1].replace("...", "") params.append("__VA_ARGS__") is_variadic = True return params, remainder, is_variadic def _handle_include(self, rest: str, file_path: Optional[str]) -> Optional[str]: match = re.match(r"\s*([<\"])([^>\"]+)[>\"]", rest) if not match: return None delimiter = match.group(1) target = match.group(2) search_paths: List[str] = [] if delimiter == '"' and file_path: search_paths.append(os.path.dirname(file_path)) search_paths.extend(self.include_paths) for base in search_paths: candidate = os.path.join(base, target) if os.path.isfile(candidate): with open(candidate, "r", encoding="utf-8") as handle: return handle.read() if self.strict: raise FileNotFoundError(f"Include not found: {target}") return None def _handle_line_directive(self, rest: str) -> Optional[int]: parts = rest.strip().split() if not parts: return None try: return int(parts[0]) except ValueError: return None def _expand_macros(self, text: str, line_num: int, in_expression: bool) -> str: result = "" i = 0 depth = 0 while i < len(text): ch = text[i] if ch in "\"'": literal, consumed = self._read_string(text, i) result += literal i += consumed continue if text.startswith("//", i): result += text[i:] break if text.startswith("/*", i): end = text.find("*/", i + 2) if end == -1: result += text[i:] break result += text[i : end + 2] i = end + 2 continue if ch.isalpha() or ch == "_": ident, consumed = self._read_identifier(text, i) i += consumed if in_expression and ident == "defined": value, consumed_def = self._parse_defined(text, i) result += "1" if value else "0" i += consumed_def continue macro = self.macros.get(ident) if macro is None: result += "0" if in_expression else ident continue if macro.is_function_like(): j = i while j < len(text) and text[j].isspace(): j += 1 if j < len(text) and text[j] == "(": args, consumed_args = self._parse_macro_args(text, j) i = j + consumed_args replaced = self._expand_function_macro(macro, args) result += self._expand_macros(replaced, line_num, in_expression) continue result += ident continue replaced = macro.replacement if macro.replacement is not None else "" result += self._expand_macros(replaced, line_num, in_expression) continue result += ch i += 1 depth += 1 if depth > 100000: break return result def _read_identifier(self, text: str, start: int) -> Tuple[str, int]: i = start while i < len(text) and (text[i].isalnum() or text[i] == "_"): i += 1 return text[start:i], i - start def _read_string(self, text: str, start: int) -> Tuple[str, int]: quote = text[start] i = start + 1 while i < len(text): if text[i] == "\\": i += 2 continue if text[i] == quote: return text[start : i + 1], i - start + 1 i += 1 return text[start:], len(text) - start def _parse_defined(self, text: str, start: int) -> Tuple[bool, int]: i = start while i < len(text) and text[i].isspace(): i += 1 if i < len(text) and text[i] == "(": i += 1 while i < len(text) and text[i].isspace(): i += 1 ident, consumed = self._read_identifier(text, i) i += consumed while i < len(text) and text[i] != ")": i += 1 if i < len(text) and text[i] == ")": i += 1 return ident in self.macros, i - start ident, consumed = self._read_identifier(text, i) i += consumed return ident in self.macros, i - start def _parse_macro_args(self, text: str, start: int) -> Tuple[List[str], int]: assert text[start] == "(" args: List[str] = [] current = "" depth = 0 i = start while i < len(text): ch = text[i] if ch == "(": depth += 1 if depth > 1: current += ch elif ch == ")": depth -= 1 if depth == 0: args.append(current.strip()) return args, i - start + 1 current += ch elif ch == "," and depth == 1: args.append(current.strip()) current = "" else: current += ch i += 1 return args, i - start def _expand_function_macro(self, macro: Macro, args: List[str]) -> str: params = macro.params or [] param_map: Dict[str, str] = {} if macro.is_variadic and len(args) >= len(params): fixed_count = len(params) - 1 for idx in range(fixed_count): param_map[params[idx]] = args[idx] if idx < len(args) else "" param_map["__VA_ARGS__"] = ", ".join(args[fixed_count:]) else: for idx, name in enumerate(params): param_map[name] = args[idx] if idx < len(args) else "" return self._replace_params(macro.replacement, param_map) def _replace_params(self, replacement: str, param_map: Dict[str, str]) -> str: tokens = self._tokenize_replacement(replacement) output: List[str] = [] i = 0 while i < len(tokens): tok_type, _tok_val = tokens[i] if tok_type == "hash" and i + 1 < len(tokens): next_type, next_val = tokens[i + 1] if next_type == "ident" and next_val in param_map: output.append(self._stringize(param_map[next_val])) i += 2 continue if tok_type == "paste" and output: if i + 1 < len(tokens): next_val = self._token_value(tokens[i + 1], param_map) prev = output.pop() output.append(prev + next_val) i += 2 continue output.append(self._token_value(tokens[i], param_map)) i += 1 return "".join(output) def _stringize(self, value: str) -> str: collapsed = re.sub(r"\s+", " ", value.strip()) escaped = collapsed.replace("\\", "\\\\").replace('"', '\\"') return f'"{escaped}"' def _tokenize_replacement(self, text: str) -> List[Tuple[str, str]]: tokens: List[Tuple[str, str]] = [] i = 0 while i < len(text): if text.startswith("##", i): tokens.append(("paste", "##")) i += 2 continue if text[i] == "#": tokens.append(("hash", "#")) i += 1 continue if text[i].isspace(): start = i while i < len(text) and text[i].isspace(): i += 1 tokens.append(("ws", text[start:i])) continue if text[i].isalpha() or text[i] == "_": ident, consumed = self._read_identifier(text, i) tokens.append(("ident", ident)) i += consumed continue if text[i] in "\"'": literal, consumed = self._read_string(text, i) tokens.append(("literal", literal)) i += consumed continue tokens.append(("sym", text[i])) i += 1 return tokens def _token_value(self, token: Tuple[str, str], param_map: Dict[str, str]) -> str: tok_type, tok_val = token if tok_type == "ident" and tok_val in param_map: return param_map[tok_val] return tok_val
class _ExpressionTokenizer: """Tokenizer for integer expressions inside conditional directives.""" def __init__(self, expr: str): self.expr = expr self.pos = 0 def next_token(self) -> Tuple[str, Optional[int]]: while self.pos < len(self.expr) and self.expr[self.pos].isspace(): self.pos += 1 if self.pos >= len(self.expr): return "EOF", None multi = [ "||", "&&", "==", "!=", "<=", ">=", "<<", ">>", ] for op in multi: if self.expr.startswith(op, self.pos): self.pos += len(op) return op, None ch = self.expr[self.pos] self.pos += 1 if ch.isdigit(): start = self.pos - 1 while self.pos < len(self.expr) and ( self.expr[self.pos].isalnum() or self.expr[self.pos] in "xX" ): self.pos += 1 text = self.expr[start : self.pos] return "NUMBER", self._parse_number(text) if ch.isalpha() or ch == "_": start = self.pos - 1 while self.pos < len(self.expr) and ( self.expr[self.pos].isalnum() or self.expr[self.pos] == "_" ): self.pos += 1 return "NUMBER", 0 return ch, None def _parse_number(self, text: str) -> int: text = re.sub(r"[uUlL]+$", "", text) if text.startswith(("0x", "0X")): return int(text, 16) if text.startswith(("0b", "0B")): return int(text, 2) if text.startswith("0") and len(text) > 1: try: return int(text, 8) except ValueError: return int(text, 10) try: return int(float(text)) except ValueError: return 0 class _ExpressionParser: """Recursive-descent evaluator for preprocessor integer expressions.""" def __init__(self, tokenizer: _ExpressionTokenizer): self.tokenizer = tokenizer self.current = self.tokenizer.next_token() def _eat(self, token_type: str): if self.current[0] != token_type: raise SyntaxError(f"Expected {token_type}, got {self.current[0]}") self.current = self.tokenizer.next_token() def parse_expression(self) -> int: return self._parse_logical_or() def _parse_logical_or(self) -> int: left = self._parse_logical_and() while self.current[0] == "||": self._eat("||") right = self._parse_logical_and() left = 1 if (left or right) else 0 return left def _parse_logical_and(self) -> int: left = self._parse_bitwise_or() while self.current[0] == "&&": self._eat("&&") right = self._parse_bitwise_or() left = 1 if (left and right) else 0 return left def _parse_bitwise_or(self) -> int: left = self._parse_bitwise_xor() while self.current[0] == "|": self._eat("|") right = self._parse_bitwise_xor() left = left | right return left def _parse_bitwise_xor(self) -> int: left = self._parse_bitwise_and() while self.current[0] == "^": self._eat("^") right = self._parse_bitwise_and() left = left ^ right return left def _parse_bitwise_and(self) -> int: left = self._parse_equality() while self.current[0] == "&": self._eat("&") right = self._parse_equality() left = left & right return left def _parse_equality(self) -> int: left = self._parse_relational() while self.current[0] in ("==", "!="): op = self.current[0] self._eat(op) right = self._parse_relational() left = 1 if (left == right if op == "==" else left != right) else 0 return left def _parse_relational(self) -> int: left = self._parse_shift() while self.current[0] in ("<", "<=", ">", ">="): op = self.current[0] self._eat(op) right = self._parse_shift() if op == "<": left = 1 if left < right else 0 elif op == "<=": left = 1 if left <= right else 0 elif op == ">": left = 1 if left > right else 0 else: left = 1 if left >= right else 0 return left def _parse_shift(self) -> int: left = self._parse_additive() while self.current[0] in ("<<", ">>"): op = self.current[0] self._eat(op) right = self._parse_additive() left = left << right if op == "<<" else left >> right return left def _parse_additive(self) -> int: left = self._parse_multiplicative() while self.current[0] in ("+", "-"): op = self.current[0] self._eat(op) right = self._parse_multiplicative() left = left + right if op == "+" else left - right return left def _parse_multiplicative(self) -> int: left = self._parse_unary() while self.current[0] in ("*", "/", "%"): op = self.current[0] self._eat(op) right = self._parse_unary() if op == "*": left = left * right elif op == "/": left = left // right if right != 0 else 0 else: left = left % right if right != 0 else 0 return left def _parse_unary(self) -> int: if self.current[0] in ("+", "-", "!", "~"): op = self.current[0] self._eat(op) value = self._parse_unary() if op == "+": return value if op == "-": return -value if op == "!": return 0 if value else 1 return ~value return self._parse_primary() def _parse_primary(self) -> int: if self.current[0] == "NUMBER": value = self.current[1] or 0 self._eat("NUMBER") return value if self.current[0] == "(": self._eat("(") value = self.parse_expression() self._eat(")") return value return 0