"""Helpers for recognizing resource query expressions in CrossGL AST nodes."""
from ..ast import ArrayAccessNode, FunctionCallNode
[docs]
class ResourceQueryMixin:
"""Helpers for resource query metadata propagation and code generation."""
query_function_names = {
"textureSize",
"imageSize",
"textureSamples",
"imageSamples",
"textureQueryLevels",
}
[docs]
def collect_resource_query_requirements(self, node):
"""Collect global and function-parameter resources needing metadata."""
functions = self.query_collect_functions(node)
functions_by_name = {getattr(func, "name", None): func for func in functions}
functions_by_name = {
name: func for name, func in functions_by_name.items() if name
}
param_names = {
func_name: {
getattr(param, "name", None)
for param in getattr(func, "parameters", getattr(func, "params", []))
}
for func_name, func in functions_by_name.items()
}
param_names = {
func_name: {name for name in names if name}
for func_name, names in param_names.items()
}
global_query_names = set()
function_param_query_names = {
func_name: set() for func_name in functions_by_name
}
def mark_resource_name(func_name, resource_name):
"""Record a resource as global or function-parameter metadata."""
if not resource_name:
return False
if resource_name in param_names.get(func_name, set()):
before = len(function_param_query_names[func_name])
function_param_query_names[func_name].add(resource_name)
return len(function_param_query_names[func_name]) != before
before = len(global_query_names)
global_query_names.add(resource_name)
return len(global_query_names) != before
for func_name, func in functions_by_name.items():
for call in self.query_walk_nodes(getattr(func, "body", [])):
if not isinstance(call, FunctionCallNode):
continue
func_call_name = self.raw_function_call_name(call)
raw_args = getattr(call, "arguments", getattr(call, "args", []))
if func_call_name in self.query_function_names and raw_args:
mark_resource_name(func_name, self.get_expression_name(raw_args[0]))
changed = True
while changed:
changed = False
for caller_name, caller in functions_by_name.items():
caller_params = param_names.get(caller_name, set())
for call in self.query_walk_nodes(getattr(caller, "body", [])):
if not isinstance(call, FunctionCallNode):
continue
callee_name = self.raw_function_call_name(call)
callee = functions_by_name.get(callee_name)
if callee is None:
continue
callee_required = function_param_query_names.get(callee_name, set())
if not callee_required:
continue
callee_params = getattr(
callee, "parameters", getattr(callee, "params", [])
)
raw_args = getattr(call, "arguments", getattr(call, "args", []))
for index, param in enumerate(callee_params):
if index >= len(raw_args):
continue
param_name = getattr(param, "name", None)
if param_name not in callee_required:
continue
arg_name = self.get_expression_name(raw_args[index])
if not arg_name:
continue
if arg_name in caller_params:
before = len(function_param_query_names[caller_name])
function_param_query_names[caller_name].add(arg_name)
changed = (
changed
or len(function_param_query_names[caller_name])
!= before
)
else:
before = len(global_query_names)
global_query_names.add(arg_name)
changed = changed or len(global_query_names) != before
return (
global_query_names,
{
func_name: names
for func_name, names in function_param_query_names.items()
if names
},
)
[docs]
def collect_resource_query_names(self, node):
"""Collect resource names used directly in resource query calls."""
query_names = set()
visited = set()
def visit_node(current):
"""Visit one AST value while collecting direct resource queries."""
if current is None or isinstance(current, (str, int, float, bool)):
return
if isinstance(current, (list, tuple, set)):
for item in current:
visit_node(item)
return
if isinstance(current, dict):
for item in current.values():
visit_node(item)
return
current_id = id(current)
if current_id in visited:
return
visited.add(current_id)
if isinstance(current, FunctionCallNode):
func_name = self.raw_function_call_name(current)
raw_args = getattr(current, "arguments", getattr(current, "args", []))
if (
func_name
in {
"textureSize",
"imageSize",
"textureSamples",
"imageSamples",
"textureQueryLevels",
}
and raw_args
):
resource_name = self.get_expression_name(raw_args[0])
if resource_name:
query_names.add(resource_name)
if hasattr(current, "__dict__"):
for key, value in vars(current).items():
if key in {"parent", "annotations"}:
continue
visit_node(value)
visit_node(node)
return query_names
[docs]
def query_collect_functions(self, root):
"""Collect function-like nodes from an AST subtree."""
functions = []
for node in self.query_walk_nodes(root):
if hasattr(node, "body") and hasattr(node, "parameters"):
functions.append(node)
return functions
[docs]
def query_walk_nodes(self, root):
"""Yield AST nodes recursively while avoiding parent/annotation cycles."""
visited = set()
def walk(value):
"""Yield one AST value and recurse into child values."""
if value is None or isinstance(value, (str, int, float, bool)):
return
if isinstance(value, dict):
for item in value.values():
yield from walk(item)
return
if isinstance(value, (list, tuple, set)):
for item in value:
yield from walk(item)
return
value_id = id(value)
if value_id in visited:
return
visited.add(value_id)
yield value
if hasattr(value, "__dict__"):
for key, child in vars(value).items():
if key in {"parent", "annotations"}:
continue
yield from walk(child)
yield from walk(root)
[docs]
def raw_function_call_name(self, node):
"""Return a function call name without backend formatting."""
func_expr = getattr(node, "function", getattr(node, "name", None))
if hasattr(func_expr, "name"):
return func_expr.name
if isinstance(func_expr, str):
return func_expr
return None
[docs]
def get_parameter_type(self, param):
"""Return a parameter type from current or legacy AST shapes."""
if hasattr(param, "param_type"):
return param.param_type
if hasattr(param, "vtype"):
return param.vtype
return "void"
[docs]
def get_variable_node_type(self, node):
"""Return a variable type from current or legacy AST shapes."""
if hasattr(node, "var_type"):
return node.var_type
if hasattr(node, "vtype"):
return node.vtype
return None
[docs]
def query_type_name(self, type_name):
"""Return a string type name suitable for resource query decisions."""
if hasattr(type_name, "name") or hasattr(type_name, "element_type"):
return self.convert_type_node_to_string(type_name)
return str(type_name)
[docs]
def query_array_suffix(self, type_name):
"""Return the array suffix for a resource type, if present."""
type_name = self.query_type_name(type_name)
if "[" not in type_name or "]" not in type_name:
return ""
return type_name[type_name.find("[") :]
[docs]
def require_helper_function(self, name, body):
"""Register a helper function body if it has not already been added."""
self.helper_functions.setdefault(name, body)
[docs]
def query_return_type(self, dimensions):
"""Return an integer vector type for a dimension query result."""
if len(dimensions) == 1:
return "int"
return f"int{len(dimensions)}"
[docs]
def query_constructor(self, return_type, values):
"""Return a constructor expression for a query result."""
if return_type == "int":
return values[0]
return f"make_{return_type}({', '.join(values)})"
[docs]
def query_dimension_expression(self, dimension, mip_arg):
"""Return a metadata field expression for one queried dimension."""
value = f"info.{dimension}"
if mip_arg is not None and dimension in {"width", "height", "depth"}:
return f"cgl_lod_extent({value}, {mip_arg})"
return value
[docs]
def query_helper_prefix(self):
"""Return helper code shared by mip-aware dimension queries."""
return (
"__device__ inline int cgl_lod_extent(int extent, int mipLevel)\n"
"{\n"
" int shifted = extent >> mipLevel;\n"
" return shifted > 1 ? shifted : 1;\n"
"}"
)
[docs]
def dimension_query_spec(self, type_name):
"""Return dimension/mip/sample metadata for a queryable resource type."""
specs = {
"sampler1D": (("width",), True, False),
"sampler2D": (("width", "height"), True, False),
"sampler2DShadow": (("width", "height"), True, False),
"sampler2DArray": (("width", "height", "elements"), True, False),
"sampler2DArrayShadow": (
("width", "height", "elements"),
True,
False,
),
"sampler3D": (("width", "height", "depth"), True, False),
"samplerCube": (("width", "height"), True, False),
"samplerCubeShadow": (("width", "height"), True, False),
"samplerCubeArray": (("width", "height", "elements"), True, False),
"samplerCubeArrayShadow": (
("width", "height", "elements"),
True,
False,
),
"sampler2DMS": (("width", "height"), False, True),
"sampler2DMSArray": (("width", "height", "elements"), False, True),
"image2D": (("width", "height"), False, False),
"iimage2D": (("width", "height"), False, False),
"uimage2D": (("width", "height"), False, False),
"image3D": (("width", "height", "depth"), False, False),
"iimage3D": (("width", "height", "depth"), False, False),
"uimage3D": (("width", "height", "depth"), False, False),
"imageCube": (("width", "height"), False, False),
"image2DArray": (("width", "height", "elements"), False, False),
"iimage2DArray": (("width", "height", "elements"), False, False),
"uimage2DArray": (("width", "height", "elements"), False, False),
"image2DMS": (("width", "height"), False, True),
"iimage2DMS": (("width", "height"), False, True),
"uimage2DMS": (("width", "height"), False, True),
"image2DMSArray": (("width", "height", "elements"), False, True),
"iimage2DMSArray": (("width", "height", "elements"), False, True),
"uimage2DMSArray": (("width", "height", "elements"), False, True),
}
spec = specs.get(type_name)
if spec is None:
return None
dimensions, mip, samples = spec
return {"dimensions": dimensions, "mip": mip, "samples": samples}
[docs]
def build_dimension_query_helper(self, helper_name, spec):
"""Build helper code for a texture/image dimension query."""
return_type = self.query_return_type(spec["dimensions"])
params = "CglResourceQueryInfo info"
mip_arg = None
if spec["mip"]:
params += ", int mipLevel"
mip_arg = "mipLevel"
values = [
self.query_dimension_expression(dimension, mip_arg)
for dimension in spec["dimensions"]
]
return_value = self.query_constructor(return_type, values)
return (
f"__device__ inline {return_type} {helper_name}({params})\n"
"{\n"
f" return {return_value};\n"
"}"
)
[docs]
def build_sample_count_query_helper(self, helper_name):
"""Build helper code for a sample-count query."""
return (
f"__device__ inline int {helper_name}(CglResourceQueryInfo info)\n"
"{\n"
" return info.samples;\n"
"}"
)
[docs]
def build_texture_query_levels_helper(self, helper_name, spec):
"""Build helper code for texture mip-level queries."""
if not spec["mip"]:
return (
f"__device__ inline int {helper_name}(CglResourceQueryInfo info)\n"
"{\n"
" return 1;\n"
"}"
)
return (
f"__device__ inline int {helper_name}(CglResourceQueryInfo info)\n"
"{\n"
" return info.levels > 0 ? info.levels : 1;\n"
"}"
)
[docs]
def ensure_query_prefix_helper(self):
"""Ensure shared query helper code has been registered."""
self.require_helper_function("cgl_lod_extent", self.query_helper_prefix())
[docs]
def is_sampled_resource_type(self, type_name):
"""Return whether a resource type is sampler-backed."""
return isinstance(type_name, str) and type_name.startswith("sampler")
[docs]
def generate_dimension_query(self, func_name, raw_args, args):
"""Generate a helper call for a dimension query, if supported."""
if not raw_args:
return None
resource_type = self.resource_base_type(self.get_expression_type(raw_args[0]))
spec = self.dimension_query_spec(resource_type)
metadata_expr = self.query_metadata_expression(raw_args[0])
if spec is None or metadata_expr is None:
return None
self.resource_query_info_required = True
helper_name = f"cgl_{func_name}_{resource_type}"
if spec["mip"]:
self.ensure_query_prefix_helper()
self.require_helper_function(
helper_name, self.build_dimension_query_helper(helper_name, spec)
)
if spec["mip"]:
lod = args[1] if len(args) > 1 else "0"
return f"{helper_name}({metadata_expr}, {lod})"
return f"{helper_name}({metadata_expr})"
[docs]
def generate_sample_count_query(self, func_name, raw_args, args):
"""Generate a helper call for a sample-count query, if supported."""
if not raw_args:
return None
resource_type = self.resource_base_type(self.get_expression_type(raw_args[0]))
spec = self.dimension_query_spec(resource_type)
metadata_expr = self.query_metadata_expression(raw_args[0])
if spec is None or not spec["samples"] or metadata_expr is None:
return None
self.resource_query_info_required = True
helper_name = f"cgl_{func_name}_{resource_type}"
self.require_helper_function(
helper_name, self.build_sample_count_query_helper(helper_name)
)
return f"{helper_name}({metadata_expr})"
[docs]
def generate_texture_query_levels(self, raw_args):
"""Generate a helper call for ``textureQueryLevels``, if supported."""
if not raw_args:
return None
resource_type = self.resource_base_type(self.get_expression_type(raw_args[0]))
if not self.is_sampled_resource_type(resource_type):
return None
spec = self.dimension_query_spec(resource_type)
metadata_expr = self.query_metadata_expression(raw_args[0])
if spec is None or metadata_expr is None:
return None
self.resource_query_info_required = True
helper_name = f"cgl_textureQueryLevels_{resource_type}"
self.require_helper_function(
helper_name, self.build_texture_query_levels_helper(helper_name, spec)
)
return f"{helper_name}({metadata_expr})"