"""HIP to CrossGL Code Generator"""
from .HipAst import (
ASTNode,
ShaderNode,
FunctionNode,
KernelNode,
KernelLaunchNode,
StructNode,
VariableNode,
AssignmentNode,
BinaryOpNode,
UnaryOpNode,
FunctionCallNode,
AtomicOperationNode,
CaseNode,
CastNode,
DesignatedInitializerNode,
DoWhileNode,
InitializerListNode,
SyncNode,
MemberAccessNode,
ArrayAccessNode,
IfNode,
ForNode,
WhileNode,
ReturnNode,
BreakNode,
ContinueNode,
PreprocessorNode,
SwitchNode,
TernaryOpNode,
TypeAliasNode,
HipBuiltinNode,
)
[docs]
class HipToCrossGLConverter:
"""Serialize HIP backend AST nodes back into CrossGL source."""
VECTOR_TYPE_MAPPING = {
"float2": "vec2<f32>",
"float3": "vec3<f32>",
"float4": "vec4<f32>",
"double2": "vec2<f64>",
"double3": "vec3<f64>",
"double4": "vec4<f64>",
"int2": "vec2<i32>",
"int3": "vec3<i32>",
"int4": "vec4<i32>",
"uint2": "vec2<u32>",
"uint3": "vec3<u32>",
"uint4": "vec4<u32>",
"char2": "vec2<i8>",
"char3": "vec3<i8>",
"char4": "vec4<i8>",
"uchar2": "vec2<u8>",
"uchar3": "vec3<u8>",
"uchar4": "vec4<u8>",
"short2": "vec2<i16>",
"short3": "vec3<i16>",
"short4": "vec4<i16>",
"ushort2": "vec2<u16>",
"ushort3": "vec3<u16>",
"ushort4": "vec4<u16>",
"long2": "vec2<i64>",
"long3": "vec3<i64>",
"long4": "vec4<i64>",
"ulong2": "vec2<u64>",
"ulong3": "vec3<u64>",
"ulong4": "vec4<u64>",
"longlong2": "vec2<i64>",
"longlong3": "vec3<i64>",
"longlong4": "vec4<i64>",
"ulonglong2": "vec2<u64>",
"ulonglong3": "vec3<u64>",
"ulonglong4": "vec4<u64>",
}
VECTOR_CONSTRUCTOR_MAPPING = {
**VECTOR_TYPE_MAPPING,
**{f"make_{name}": mapped for name, mapped in VECTOR_TYPE_MAPPING.items()},
}
def __init__(self):
"""Initialize HIP-to-CrossGL visitor state."""
self.indent_level = 0
self.output = []
self.packed_argument_scopes = []
self.unique_ptr_scopes = [set()]
self.type_alias_scopes = [{}]
[docs]
def generate(self, ast_node):
"""Generate complete CrossGL source from a parsed HIP AST."""
self.output = []
self.indent_level = 0
self.packed_argument_scopes = []
self.unique_ptr_scopes = [set()]
self.type_alias_scopes = [{}]
self.visit(ast_node)
return "\n".join(self.output)
[docs]
def visit(self, node):
"""Dispatch a HIP backend AST node to its converter method."""
method_name = f"visit_{type(node).__name__}"
visitor = getattr(self, method_name, self.generic_visit)
return visitor(node)
[docs]
def generic_visit(self, node):
"""Fallback converter for primitive values, lists, and unknown nodes."""
if isinstance(node, str):
return node
elif isinstance(node, list):
return [self.visit(item) for item in node]
else:
return str(node)
[docs]
def emit(self, code):
"""Append a line of CrossGL output using the current indentation level."""
if code.strip():
self.output.append(" " * self.indent_level + code)
else:
self.output.append("")
[docs]
def emit_statement(self, stmt):
"""Render and append one converted statement."""
if isinstance(stmt, list):
for item in stmt:
self.emit_statement(item)
return
if self.emit_hip_runtime_call_statement(stmt):
return
result = self.visit(stmt)
if isinstance(result, str) and result.strip():
self.emit(f"{result};")
def emit_hip_runtime_call_statement(self, stmt):
if not isinstance(stmt, FunctionCallNode):
return False
comments = self.format_hip_runtime_call(stmt)
if comments is None:
return False
for comment in comments:
self.emit(comment)
return True
def format_hip_runtime_call(self, node):
args = [self.visit(arg) for arg in node.args]
name = node.name
if name in {"hipMalloc", "hipMallocManaged", "hipHostMalloc"}:
if len(node.args) >= 2:
target = self.format_runtime_pointer_target(node.args[0])
size = self.visit(node.args[1])
return [f"// HIP memory allocate: {target}, bytes: {size}"]
elif name in {"hipFree", "hipHostFree"}:
if args:
return [f"// HIP memory free: {args[0]}"]
elif name in {"hipMemcpy", "hipMemcpyAsync"}:
if len(args) >= 4:
comment = (
f"// HIP memory copy: {args[1]} -> {args[0]}, "
f"bytes: {args[2]}, kind: {args[3]}"
)
if len(args) >= 5:
comment += f", stream: {args[4]}"
return [comment]
elif name == "hipMemset":
if len(args) >= 3:
return [
f"// HIP memory set: {args[0]}, value: {args[1]}, "
f"bytes: {args[2]}"
]
elif name in {"hipStreamSynchronize"}:
if args:
return [f"// HIP synchronize: {args[0]}"]
elif name in {"hipStreamCreate", "hipStreamDestroy"}:
if args:
action = "create" if name == "hipStreamCreate" else "destroy"
stream = (
self.format_runtime_pointer_target(node.args[0])
if action == "create"
else args[0]
)
return [f"// HIP stream {action}: {stream}"]
elif name in {"hipEventCreate", "hipEventCreateWithFlags"}:
if args:
event = self.format_runtime_pointer_target(node.args[0])
comment = f"// HIP event create: {event}"
if len(args) >= 2:
comment += f", flags: {args[1]}"
return [comment]
elif name == "hipEventRecord":
if args:
comment = f"// HIP event record: {args[0]}"
if len(args) >= 2:
comment += f", stream: {args[1]}"
return [comment]
elif name == "hipEventSynchronize":
if args:
return [f"// HIP event synchronize: {args[0]}"]
elif name == "hipEventElapsedTime":
if len(node.args) >= 3:
output = self.format_runtime_pointer_target(node.args[0])
return [
f"// HIP event elapsed time: {args[1]} -> {args[2]}, "
f"output: {output}"
]
elif name == "hipEventDestroy":
if args:
return [f"// HIP event destroy: {args[0]}"]
elif name == "hipEventQuery":
if args:
return [f"// HIP event query: {args[0]}"]
elif name == "hipStreamWaitEvent":
if len(args) >= 2:
comment = f"// HIP stream wait event: {args[0]} waits for {args[1]}"
if len(args) >= 3:
comment += f", flags: {args[2]}"
return [comment]
return None
def format_runtime_pointer_target(self, arg):
if isinstance(arg, CastNode):
return self.format_runtime_pointer_target(arg.expression)
if isinstance(arg, UnaryOpNode) and arg.op == "&":
return self.visit(arg.operand)
return self.visit(arg)
def format_statement_fragment(self, stmt):
if stmt is None:
return ""
if isinstance(stmt, VariableNode):
var_type = self.convert_hip_type_to_crossgl(getattr(stmt, "vtype", "int"))
if hasattr(stmt, "value") and stmt.value:
value = self.visit(stmt.value)
return f"var {stmt.name}: {var_type} = {value}"
return f"var {stmt.name}: {var_type}"
if isinstance(stmt, AssignmentNode):
left = self.visit(stmt.left)
right = self.visit(stmt.right)
operator = getattr(stmt, "operator", "=")
return f"{left} {operator} {right}"
result = self.visit(stmt)
return result if isinstance(result, str) else ""
[docs]
def visit_HipProgramNode(self, node):
"""Render a HIP program AST as a CrossGL shader block."""
self.emit("// HIP to CrossGL conversion")
for stmt in node.statements:
if isinstance(stmt, FunctionNode):
if hasattr(stmt, "qualifiers") and "__global__" in getattr(
stmt, "qualifiers", []
):
self.emit(f"// Kernel: {stmt.name}")
self.visit_kernel_as_compute_shader(stmt)
else:
self.emit(f"// Function: {stmt.name}")
self.visit(stmt)
self.emit("")
elif isinstance(stmt, StructNode):
self.visit(stmt)
self.emit("")
elif isinstance(stmt, VariableNode):
self.visit(stmt)
self.emit("")
elif isinstance(stmt, TypeAliasNode):
self.visit(stmt)
self.emit("")
else:
self.visit(stmt)
[docs]
def visit_FunctionNode(self, node):
"""Render a HIP function node as a CrossGL function."""
# Skip device functions in CrossGL (they become inline)
if hasattr(node, "qualifiers") and "__device__" in getattr(
node, "qualifiers", []
):
return
return_type = self.convert_hip_type_to_crossgl(
node.return_type if hasattr(node, "return_type") else "void"
)
params = []
if hasattr(node, "params") and node.params:
for param in node.params:
if isinstance(param, dict):
param_type = self.convert_hip_type_to_crossgl(
param.get("type", "int")
)
param_name = param.get("name", "param")
params.append(f"{param_type} {param_name}")
else:
param_type = self.convert_hip_type_to_crossgl(
getattr(param, "vtype", "int")
)
param_name = getattr(param, "name", "param")
params.append(f"{param_type} {param_name}")
param_str = ", ".join(params)
self.emit(f"{return_type} {node.name}({param_str}) {{")
self.indent_level += 1
self.push_packed_argument_scope()
self.push_type_alias_scope()
self.push_unique_ptr_scope()
if hasattr(node, "params") and node.params:
for param in node.params:
self.register_unique_ptr_parameter(param)
if hasattr(node, "body") and node.body:
try:
if isinstance(node.body, list):
for stmt in node.body:
self.emit_statement(stmt)
else:
self.emit_statement(node.body)
finally:
self.pop_unique_ptr_scope()
self.pop_type_alias_scope()
self.pop_packed_argument_scope()
self.indent_level -= 1
else:
self.pop_unique_ptr_scope()
self.pop_type_alias_scope()
self.pop_packed_argument_scope()
self.indent_level -= 1
self.emit("}")
[docs]
def visit_kernel_as_compute_shader(self, kernel):
"""Render a HIP kernel as a CrossGL compute shader block."""
self.emit("@compute")
self.emit("@workgroup_size(1, 1, 1) // Default workgroup size")
params = []
if hasattr(kernel, "params") and kernel.params:
for param in kernel.params:
if isinstance(param, dict):
raw_type = param.get("type", "int")
param_name = param.get("name", "param")
# Add storage buffer qualifiers for pointer parameters
if "*" in raw_type:
element_type = self.convert_hip_pointer_element_type(raw_type)
params.append(
f"@group(0) @binding({len(params)}) var<storage, read_write> {param_name}: array<{element_type}>"
)
else:
param_type = self.convert_hip_type_to_crossgl(raw_type)
params.append(f"{param_type} {param_name}")
else:
raw_type = getattr(param, "vtype", "int")
param_name = getattr(param, "name", "param")
if "*" in raw_type:
element_type = self.convert_hip_pointer_element_type(raw_type)
params.append(
f"@group(0) @binding({len(params)}) var<storage, read_write> {param_name}: array<{element_type}>"
)
else:
param_type = self.convert_hip_type_to_crossgl(raw_type)
params.append(f"{param_type} {param_name}")
self.emit(f"fn {kernel.name}(")
self.indent_level += 1
for i, param in enumerate(params):
if i == len(params) - 1:
self.emit(f"{param}")
else:
self.emit(f"{param},")
self.indent_level -= 1
self.emit(") {")
# Add built-in variable declarations
self.indent_level += 1
self.emit("let thread_id = gl_GlobalInvocationID;")
self.emit("let block_id = gl_WorkGroupID;")
self.emit("let thread_local_id = gl_LocalInvocationID;")
self.emit("let block_dim = gl_WorkGroupSize;")
self.emit("")
if hasattr(kernel, "body") and kernel.body:
self.push_packed_argument_scope()
self.push_type_alias_scope()
self.push_unique_ptr_scope()
if hasattr(kernel, "params") and kernel.params:
for param in kernel.params:
self.register_unique_ptr_parameter(param)
try:
if isinstance(kernel.body, list):
for stmt in kernel.body:
self.emit_statement(stmt)
else:
self.emit_statement(kernel.body)
finally:
self.pop_unique_ptr_scope()
self.pop_type_alias_scope()
self.pop_packed_argument_scope()
self.indent_level -= 1
self.emit("}")
def visit_StructNode(self, node):
self.emit(f"struct {node.name} {{")
self.indent_level += 1
if hasattr(node, "members") and node.members:
for member in node.members:
if isinstance(member, VariableNode):
member_type = self.convert_hip_type_to_crossgl(
getattr(member, "vtype", "int")
)
self.emit(f"{member_type} {member.name};")
self.indent_level -= 1
self.emit("};")
def visit_VariableNode(self, node):
var_type = self.convert_hip_type_to_crossgl(getattr(node, "vtype", "int"))
self.register_packed_argument_list(node)
self.register_unique_ptr_name(node.name, getattr(node, "vtype", "int"))
if hasattr(node, "value") and node.value:
value = self.visit(node.value)
self.emit(f"var {node.name}: {var_type} = {value};")
else:
self.emit(f"var {node.name}: {var_type};")
def visit_KernelLaunchNode(self, node):
kernel_name = self.visit(node.kernel_name)
config = [self.visit(node.blocks), self.visit(node.threads)]
if node.shared_mem is not None:
config.append(self.visit(node.shared_mem))
if node.stream is not None:
config.append(self.visit(node.stream))
self.emit(f"// Kernel launch: {kernel_name}<<<{', '.join(config)}>>>()")
if node.args:
args = self.resolve_packed_launch_args(node.args)
args_str = ", ".join([self.format_kernel_launch_arg(arg) for arg in args])
self.emit(f"// Arguments: {args_str}")
def push_packed_argument_scope(self):
self.packed_argument_scopes.append({})
def pop_packed_argument_scope(self):
if self.packed_argument_scopes:
self.packed_argument_scopes.pop()
def push_unique_ptr_scope(self):
self.unique_ptr_scopes.append(set())
def pop_unique_ptr_scope(self):
if len(self.unique_ptr_scopes) > 1:
self.unique_ptr_scopes.pop()
def push_type_alias_scope(self):
self.type_alias_scopes.append({})
def pop_type_alias_scope(self):
if len(self.type_alias_scopes) > 1:
self.type_alias_scopes.pop()
def register_type_alias(self, name, alias_type):
self.type_alias_scopes[-1][name] = alias_type
def resolve_type_alias(self, type_name):
type_name = self.strip_type_qualifiers(type_name)
for scope in reversed(self.type_alias_scopes):
if type_name in scope:
return scope[type_name]
return type_name
def register_unique_ptr_parameter(self, param):
if isinstance(param, dict):
self.register_unique_ptr_name(param.get("name", ""), param.get("type", ""))
else:
self.register_unique_ptr_name(
getattr(param, "name", ""), getattr(param, "vtype", "")
)
def register_unique_ptr_name(self, name, type_name):
if self.is_unique_ptr_type_name(type_name):
self.unique_ptr_scopes[-1].add(name)
def is_unique_ptr_expression(self, expr):
if not isinstance(expr, str):
return False
return any(expr in scope for scope in reversed(self.unique_ptr_scopes))
def register_packed_argument_list(self, node):
if not self.packed_argument_scopes:
return
if self.is_packed_argument_list(node):
self.packed_argument_scopes[-1][node.name] = (
self.get_initializer_list_elements(node.value)
)
def is_packed_argument_list(self, node):
if self.get_initializer_list_elements(getattr(node, "value", None)) is None:
return False
compact_type = getattr(node, "vtype", "").replace(" ", "")
return compact_type in {"void*[]", "void**"}
def get_initializer_list_elements(self, value):
if isinstance(value, InitializerListNode):
return value.elements
if isinstance(value, CastNode) and isinstance(
value.expression, InitializerListNode
):
return value.expression.elements
return None
def resolve_packed_launch_args(self, args):
if len(args) != 1:
return args
compound_elements = self.get_packed_compound_literal_elements(args[0])
if compound_elements is not None:
return compound_elements
packed_arg_name = self.get_packed_argument_name(args[0])
if packed_arg_name is None:
return args
for scope in reversed(self.packed_argument_scopes):
if packed_arg_name in scope:
return scope[packed_arg_name]
return args
def get_packed_argument_name(self, arg):
if isinstance(arg, str):
return arg
if isinstance(arg, CastNode):
return self.get_packed_argument_name(arg.expression)
return None
def get_packed_compound_literal_elements(self, arg):
if not isinstance(arg, CastNode):
return None
compact_type = arg.target_type.replace(" ", "")
if compact_type not in {"void*[]", "void**"}:
return None
return self.get_initializer_list_elements(arg.expression)
def format_kernel_launch_arg(self, arg):
if isinstance(arg, UnaryOpNode) and arg.op == "&":
return self.visit(arg.operand)
return self.visit(arg)
def visit_AssignmentNode(self, node):
left = self.visit(node.left)
right = self.visit(node.right)
operator = getattr(node, "operator", "=")
self.emit(f"{left} {operator} {right};")
def visit_BinaryOpNode(self, node):
left = self.visit(node.left)
right = self.visit(node.right)
return f"({left} {node.op} {right})"
def visit_UnaryOpNode(self, node):
operand = self.visit(node.operand)
if isinstance(node.op, str) and node.op.endswith("_POST"):
return f"({operand}{node.op[:-5]})"
elif hasattr(node, "postfix") and node.postfix:
return f"({operand}{node.op})"
else:
return f"({node.op}{operand})"
def visit_FunctionCallNode(self, node):
if self.is_get_method_call(node):
return self.visit(node.name.object)
args = []
if hasattr(node, "args") and node.args:
args = [self.visit(arg) for arg in node.args]
elif hasattr(node, "arguments") and node.arguments:
args = [self.visit(arg) for arg in node.arguments]
args_str = ", ".join(args)
if hasattr(node, "name"):
func_name = node.name
else:
func_name = str(node.function) if hasattr(node, "function") else "unknown"
if not isinstance(func_name, str):
func_name = self.visit(func_name)
make_unique = self.format_make_unique_call(func_name, args)
if make_unique is not None:
return make_unique
unique_ptr_init = self.format_unique_ptr_constructor_call(func_name, args)
if unique_ptr_init is not None:
return unique_ptr_init
# Convert HIP built-in functions
crossgl_func = self.convert_hip_builtin_function(func_name)
return f"{crossgl_func}({args_str})"
def is_get_method_call(self, node):
return (
isinstance(getattr(node, "name", None), MemberAccessNode)
and node.name.member == "get"
and not getattr(node, "args", [])
and self.is_unique_ptr_expression(node.name.object)
)
def format_make_unique_call(self, function_name, args):
base_name, template_args = self.parse_cpp_template(function_name)
if base_name.split("::")[-1] != "make_unique" or not template_args:
return None
target_type, is_array = self.unwrap_array_template_type(template_args[0])
target_type = self.convert_hip_type_to_crossgl(target_type)
args_str = ", ".join(args)
if is_array:
return f"new_array<{target_type}>({args_str})"
return f"new<{target_type}>({args_str})"
def format_unique_ptr_constructor_call(self, function_name, args):
base_name, _ = self.parse_cpp_template(function_name)
if len(args) != 1:
return None
if base_name.split("::")[
-1
] != "unique_ptr" and not self.is_unique_ptr_type_name(function_name):
return None
return args[0]
def visit_NewNode(self, node):
target_type = self.convert_hip_type_to_crossgl(node.target_type)
if node.is_array:
size = self.visit(node.size) if node.size is not None else ""
return f"new_array<{target_type}>({size})"
args = ", ".join(self.visit(arg) for arg in node.args)
return f"new<{target_type}>({args})"
def visit_DeleteNode(self, node):
target = self.visit(node.expression)
if node.is_array:
self.emit(f"// delete array: {target}")
else:
self.emit(f"// delete: {target}")
def visit_TypeAliasNode(self, node):
self.register_type_alias(node.name, node.alias_type)
alias_type = self.convert_hip_type_to_crossgl(node.alias_type)
self.emit(f"typedef {alias_type} {node.name};")
def visit_MemberAccessNode(self, node):
obj = self.visit(node.object)
return f"{obj}.{node.member}"
def visit_ArrayAccessNode(self, node):
array = self.visit(node.array)
index = self.visit(node.index)
return f"{array}[{index}]"
def visit_InitializerListNode(self, node):
elements = ", ".join(self.visit(element) for element in node.elements)
return f"{{{elements}}}"
def visit_DesignatedInitializerNode(self, node):
designators = []
for kind, target in node.designators:
if kind == "index":
designators.append(f"[{self.visit(target)}]")
else:
designators.append(f".{target}")
value = self.visit(node.value)
return f"{''.join(designators)} = {value}"
def visit_SyncNode(self, node):
if node.sync_type == "__syncthreads":
self.emit("workgroupBarrier();")
elif node.sync_type == "hipDeviceSynchronize":
self.emit("// HIP device synchronize")
elif node.sync_type == "__syncwarp":
self.emit("// Warp sync not directly supported in CrossGL")
else:
self.emit(f"// {node.sync_type}();")
def visit_HipBuiltinNode(self, node):
builtin_map = {
"threadIdx": "gl_LocalInvocationID",
"blockIdx": "gl_WorkGroupID",
"gridDim": "gl_NumWorkGroups",
"blockDim": "gl_WorkGroupSize",
}
base_name = builtin_map.get(node.builtin_name, node.builtin_name)
if hasattr(node, "component") and node.component:
return f"{base_name}.{node.component}"
else:
return base_name
def visit_ReturnNode(self, node):
if hasattr(node, "value") and node.value:
value = self.visit(node.value)
self.emit(f"return {value};")
else:
self.emit("return;")
def visit_BreakNode(self, node):
self.emit("break;")
def visit_ContinueNode(self, node):
self.emit("continue;")
def visit_IfNode(self, node):
condition = self.visit(node.condition)
self.emit(f"if ({condition}) {{")
self.indent_level += 1
if hasattr(node, "if_body") and node.if_body:
if isinstance(node.if_body, list):
for stmt in node.if_body:
self.emit_statement(stmt)
else:
self.emit_statement(node.if_body)
self.indent_level -= 1
if hasattr(node, "else_body") and node.else_body:
self.emit("} else {")
self.indent_level += 1
if isinstance(node.else_body, list):
for stmt in node.else_body:
self.emit_statement(stmt)
else:
self.emit_statement(node.else_body)
self.indent_level -= 1
self.emit("}")
def visit_ForNode(self, node):
init_node = node.init if hasattr(node, "init") else None
scoped_init = isinstance(init_node, list)
if scoped_init:
self.emit("{")
self.indent_level += 1
for stmt in init_node:
self.emit_statement(stmt)
init = ""
else:
init = self.format_statement_fragment(init_node)
condition = (
self.visit(node.condition)
if hasattr(node, "condition") and node.condition
else ""
)
update = self.format_statement_fragment(
node.update if hasattr(node, "update") else None
)
self.emit(f"for ({init}; {condition}; {update}) {{")
self.indent_level += 1
if hasattr(node, "body") and node.body:
if isinstance(node.body, list):
for stmt in node.body:
self.emit_statement(stmt)
else:
self.emit_statement(node.body)
self.indent_level -= 1
self.emit("}")
if scoped_init:
self.indent_level -= 1
self.emit("}")
def visit_RangeForNode(self, node):
iterable = self.visit(node.iterable)
self.emit(f"for {node.name} in {iterable} {{")
self.indent_level += 1
if hasattr(node, "body") and node.body:
if isinstance(node.body, list):
for stmt in node.body:
self.emit_statement(stmt)
else:
self.emit_statement(node.body)
self.indent_level -= 1
self.emit("}")
def visit_WhileNode(self, node):
condition = self.visit(node.condition)
self.emit(f"while ({condition}) {{")
self.indent_level += 1
if hasattr(node, "body") and node.body:
if isinstance(node.body, list):
for stmt in node.body:
self.emit_statement(stmt)
else:
self.emit_statement(node.body)
self.indent_level -= 1
self.emit("}")
def visit_DoWhileNode(self, node):
condition = self.visit(node.condition)
self.emit("do {")
self.indent_level += 1
if hasattr(node, "body") and node.body:
if isinstance(node.body, list):
for stmt in node.body:
self.emit_statement(stmt)
else:
self.emit_statement(node.body)
self.indent_level -= 1
self.emit(f"}} while ({condition});")
def visit_SwitchNode(self, node):
expression = self.visit(node.expression)
self.emit(f"switch ({expression}) {{")
self.indent_level += 1
for case in getattr(node, "cases", []):
self.visit(case)
if getattr(node, "default_case", None):
self.emit("default:")
self.indent_level += 1
for stmt in node.default_case:
self.emit_statement(stmt)
self.indent_level -= 1
self.indent_level -= 1
self.emit("}")
def visit_CaseNode(self, node):
value = self.visit(node.value)
self.emit(f"case {value}:")
self.indent_level += 1
for stmt in getattr(node, "body", []):
self.emit_statement(stmt)
self.indent_level -= 1
def visit_TernaryOpNode(self, node):
condition = self.visit(node.condition)
true_expr = self.visit(node.true_expr)
false_expr = self.visit(node.false_expr)
return f"({condition} ? {true_expr} : {false_expr})"
def visit_CastNode(self, node):
target_type = self.convert_hip_type_to_crossgl(node.target_type)
expression = self.visit(node.expression)
return f"{target_type}({expression})"
[docs]
def convert_hip_type_to_crossgl(self, hip_type):
"""Map a HIP type name to the closest CrossGL type name."""
if hip_type is None:
return "void"
if not isinstance(hip_type, str):
hip_type = str(hip_type)
hip_type = self.strip_type_qualifiers(hip_type)
type_mapping = {
# Basic types
"void": "void",
"bool": "bool",
"char": "i8",
"unsigned char": "u8",
"short": "i16",
"unsigned short": "u16",
"int": "i32",
"unsigned int": "u32",
"long": "i64",
"unsigned long": "u64",
"float": "f32",
"double": "f64",
"size_t": "u32",
# HIP vector types
**self.VECTOR_TYPE_MAPPING,
"dim3": "vec3<u32>",
}
unique_ptr_type = self.convert_unique_ptr_type(hip_type)
if unique_ptr_type is not None:
return unique_ptr_type
# Handle arrays
if self.has_array_suffix(hip_type):
return self.convert_hip_array_type(hip_type, type_mapping)
# Handle pointers
if "*" in hip_type:
return self.convert_hip_pointer_type(hip_type)
return type_mapping.get(hip_type, hip_type)
def convert_unique_ptr_type(self, hip_type):
base_name, template_args = self.parse_cpp_template(hip_type)
if not self.is_unique_ptr_base_name(base_name) or not template_args:
return None
target_type, _ = self.unwrap_array_template_type(template_args[0])
return f"ptr<{self.convert_hip_type_to_crossgl(target_type)}>"
def is_unique_ptr_type_name(self, type_name):
type_name = self.strip_type_qualifiers(type_name)
type_name = self.resolve_type_alias(type_name)
base_name, template_args = self.parse_cpp_template(type_name)
return self.is_unique_ptr_base_name(base_name) and bool(template_args)
def is_unique_ptr_base_name(self, base_name):
return base_name.split("::")[-1] == "unique_ptr"
def has_array_suffix(self, type_name):
depth = 0
for char in str(type_name):
if char == "<":
depth += 1
elif char == ">":
depth -= 1
elif char == "[" and depth == 0:
return True
return False
def unwrap_array_template_type(self, type_name):
type_name = type_name.strip()
if type_name.endswith("[]"):
return type_name[:-2].strip(), True
return type_name, False
def parse_cpp_template(self, text):
if not isinstance(text, str):
return str(text), []
start = text.find("<")
if start == -1 or not text.endswith(">"):
return text, []
base_name = text[:start].strip()
args = self.split_cpp_template_args(text[start + 1 : -1])
return base_name, args
def split_cpp_template_args(self, args_text):
args = []
depth = 0
start = 0
for index, char in enumerate(args_text):
if char == "<":
depth += 1
elif char == ">":
depth -= 1
elif char == "," and depth == 0:
args.append(args_text[start:index].strip())
start = index + 1
tail = args_text[start:].strip()
if tail:
args.append(tail)
return args
[docs]
def convert_hip_pointer_type(self, hip_type):
"""Convert a HIP pointer type into nested CrossGL pointer syntax."""
pointer_depth = hip_type.count("*")
base_type = hip_type.replace("*", "").strip()
mapped_type = self.convert_hip_type_to_crossgl(base_type)
for _ in range(pointer_depth):
mapped_type = f"ptr<{mapped_type}>"
return mapped_type
def convert_hip_pointer_element_type(self, hip_type):
pointer_depth = hip_type.count("*")
base_type = hip_type.replace("*", "").strip()
mapped_type = self.convert_hip_type_to_crossgl(base_type)
for _ in range(max(0, pointer_depth - 1)):
mapped_type = f"ptr<{mapped_type}>"
return mapped_type
def strip_type_qualifiers(self, type_name):
qualifiers = {"const", "volatile", "__restrict__", "restrict", "&", "&&"}
return " ".join(
part for part in str(type_name).split() if part not in qualifiers
)
def convert_hip_array_type(self, hip_type, type_mapping):
base_type = hip_type.split("[", 1)[0].strip()
dimensions = []
remainder = hip_type[len(base_type) :].strip()
while remainder.startswith("["):
close_index = remainder.find("]")
if close_index == -1:
break
dimensions.append(remainder[1:close_index].strip())
remainder = remainder[close_index + 1 :].strip()
mapped_type = type_mapping.get(base_type)
if mapped_type is None:
mapped_type = self.convert_hip_type_to_crossgl(base_type)
for size in reversed(dimensions):
if size:
mapped_type = f"array<{mapped_type}, {size}>"
else:
mapped_type = f"array<{mapped_type}>"
return mapped_type
[docs]
def convert_hip_builtin_function(self, func_name):
"""Convert HIP built-in functions to CrossGL equivalents."""
function_mapping = {
# Math functions
"sqrtf": "sqrt",
"powf": "pow",
"sinf": "sin",
"cosf": "cos",
"tanf": "tan",
"logf": "log",
"expf": "exp",
"fabsf": "abs",
"fminf": "min",
"fmaxf": "max",
"floorf": "floor",
"ceilf": "ceil",
# Double precision variants
"sqrt": "sqrt",
"pow": "pow",
"sin": "sin",
"cos": "cos",
"tan": "tan",
"log": "log",
"exp": "exp",
"fabs": "abs",
"fmin": "min",
"fmax": "max",
"floor": "floor",
"ceil": "ceil",
"bool": "bool",
"char": "i8",
"short": "i16",
"int": "i32",
"long": "i64",
"float": "f32",
"double": "f64",
"size_t": "u32",
# Vector functions
**self.VECTOR_CONSTRUCTOR_MAPPING,
"dim3": "vec3<u32>",
# Sync functions
"__syncthreads": "workgroupBarrier",
"__threadfence": "memoryBarrier",
}
return function_mapping.get(func_name, func_name)
def visit_EnumNode(self, node):
self.emit(f"enum {node.name} {{")
self.indent_level += 1
if hasattr(node, "variants") and node.variants:
for i, variant in enumerate(node.variants):
if hasattr(variant, "value") and variant.value:
value = self.visit(variant.value)
self.emit(f"{variant.name} = {value},")
else:
self.emit(f"{variant.name},")
self.indent_level -= 1
self.emit("}")
# Legacy method for backwards compatibility
[docs]
def convert(self, node):
"""Legacy convert method for compatibility"""
return self.generate(node)
[docs]
def hip_to_crossgl(hip_ast) -> str:
"""Convert HIP AST to CrossGL code string"""
converter = HipToCrossGLConverter()
return converter.generate(hip_ast)