Source code for crosstl.translator.codegen.stage_utils
"""Utilities for normalizing and matching shader stage qualifiers."""
STAGE_QUALIFIER_NAMES = frozenset(
{
"vertex",
"fragment",
"compute",
"geometry",
"tessellation_control",
"tessellation_evaluation",
"mesh",
"task",
"amplification",
"object",
"ray_generation",
"ray_intersection",
"ray_any_hit",
"ray_closest_hit",
"ray_miss",
"ray_callable",
"intersection",
"anyhit",
"closesthit",
"miss",
"callable",
}
)
[docs]
def normalize_stage_name(stage):
"""Normalize a shader stage enum or string into a lowercase name."""
if stage is None:
return None
return str(stage).split(".")[-1].lower()
[docs]
def stage_matches(target_stage, stage):
"""Return whether a stage should be emitted for the target stage."""
target_stage = normalize_stage_name(target_stage)
stage = normalize_stage_name(stage)
return target_stage is None or stage == target_stage
[docs]
def should_emit_qualified_function(target_stage, qualifier):
"""Return whether a qualified function belongs in the target output."""
target_stage = normalize_stage_name(target_stage)
qualifier = normalize_stage_name(qualifier)
return not (
target_stage is not None
and qualifier in STAGE_QUALIFIER_NAMES
and qualifier != target_stage
)
[docs]
def compute_local_size(execution_config=None):
"""Return a three-component workgroup size from execution metadata."""
config = execution_config or {}
for key in ("local_size", "workgroup_size", "numthreads"):
value = config.get(key)
if isinstance(value, (list, tuple)) and len(value) >= 3:
return tuple(compute_local_size_value(item) for item in value[:3])
return (
compute_local_size_value(config.get("local_size_x", 1)),
compute_local_size_value(config.get("local_size_y", 1)),
compute_local_size_value(config.get("local_size_z", 1)),
)
[docs]
def compute_local_size_value(value):
"""Return a string representation for a local-size dimension value."""
if hasattr(value, "value"):
return str(value.value)
return str(value)