"""HIP Parser for converting HIP tokens to AST"""
from typing import List
from .HipLexer import HipLexer, Token
from .HipAst import (
ASTNode,
ShaderNode,
FunctionNode,
KernelNode,
StructNode,
VariableNode,
AssignmentNode,
BinaryOpNode,
UnaryOpNode,
FunctionCallNode,
AtomicOperationNode,
KernelLaunchNode,
CastNode,
CaseNode,
DesignatedInitializerNode,
DeleteNode,
DoWhileNode,
InitializerListNode,
SyncNode,
MemberAccessNode,
NewNode,
ArrayAccessNode,
IfNode,
ForNode,
RangeForNode,
WhileNode,
ReturnNode,
BreakNode,
ContinueNode,
PreprocessorNode,
SwitchNode,
TernaryOpNode,
TypeAliasNode,
HipBuiltinNode,
)
[docs]
class HipProgramNode(ASTNode):
"""Root node representing a complete HIP program"""
def __init__(self, statements=None):
"""Initialize the program node with top-level statements."""
self.statements = statements or []
def __repr__(self):
"""Return a developer-readable program representation."""
return f"HipProgramNode(statements={self.statements})"
[docs]
class HipParser:
"""Parse HIP tokens into the HIP backend AST."""
FUNCTION_SPECIFIER_TOKENS = {"STATIC", "INLINE", "EXTERN"}
TYPE_QUALIFIER_TOKENS = {"CONST", "VOLATILE", "UNSIGNED", "SIGNED", "__RESTRICT__"}
POSTFIX_TYPE_QUALIFIER_TOKENS = {"__RESTRICT__"}
TYPE_REFERENCE_TOKENS = {"AMPERSAND", "AND"}
CPP_NAMED_CASTS = {"static_cast", "reinterpret_cast", "const_cast", "dynamic_cast"}
BUILTIN_TYPE_TOKENS = {
"INT",
"FLOAT",
"DOUBLE",
"BOOL",
"VOID",
"CHAR",
"SHORT",
"LONG",
"HIPERROR",
"SIZE_T",
}
VECTOR_TYPE_TOKENS = {
"FLOAT2",
"FLOAT3",
"FLOAT4",
"INT2",
"INT3",
"INT4",
"DOUBLE2",
"DOUBLE3",
"DOUBLE4",
"UINT2",
"UINT3",
"UINT4",
"CHAR2",
"CHAR3",
"CHAR4",
"UCHAR2",
"UCHAR3",
"UCHAR4",
"SHORT2",
"SHORT3",
"SHORT4",
"USHORT2",
"USHORT3",
"USHORT4",
"LONG2",
"LONG3",
"LONG4",
"ULONG2",
"ULONG3",
"ULONG4",
"LONGLONG2",
"LONGLONG3",
"LONGLONG4",
"ULONGLONG2",
"ULONGLONG3",
"ULONGLONG4",
}
def __init__(self, tokens: List[Token]):
"""Initialize the parser with a token stream from ``HipLexer``."""
self.tokens = tokens
self.pos = 0
self.current_token = self.tokens[0] if tokens else None
self.block_depth = 0
self.type_aliases = set()
[docs]
def error(self, message: str):
"""Raise a syntax error annotated with the current token."""
token_info = (
f"at token '{self.current_token.value}'"
if self.current_token
else "at end of input"
)
raise SyntaxError(f"Parse error {token_info}: {message}")
[docs]
def advance(self):
"""Advance to the next token or mark the parser as finished."""
if self.pos < len(self.tokens) - 1:
self.pos += 1
self.current_token = self.tokens[self.pos]
else:
self.current_token = None
[docs]
def peek(self, offset: int = 1):
"""Look ahead at the next token without advancing"""
peek_pos = self.pos + offset
if peek_pos < len(self.tokens):
return self.tokens[peek_pos]
return None
[docs]
def consume(self, expected_type: str):
"""Consume and return the current token when its type matches."""
if not self.current_token:
self.error(f"Expected {expected_type} but reached end of input")
if self.current_token.type != expected_type:
self.error(f"Expected {expected_type}, got {self.current_token.type}")
token = self.current_token
self.advance()
return token
[docs]
def match(self, *token_types: str) -> bool:
"""Return whether the current token type is one of ``token_types``."""
if not self.current_token:
return False
return self.current_token.type in token_types
[docs]
def skip_newlines(self):
"""Advance past newline tokens between declarations/statements."""
while self.match("NEWLINE"):
self.advance()
def is_builtin_type_token(self, token=None):
token = token or self.current_token
if not token:
return False
if token.type == "FLOAT":
return token.value == "float"
if token.type == "CHAR":
return token.value == "char"
return token.type in self.BUILTIN_TYPE_TOKENS
def is_type_token(self, token=None, allow_identifier=True):
token = token or self.current_token
if not token:
return False
if self.is_builtin_type_token(token):
return True
if token.type in self.VECTOR_TYPE_TOKENS:
return True
return allow_identifier and token.type == "IDENTIFIER"
[docs]
def parse(self):
"""Parse the entire HIP program into a ``HipProgramNode``."""
statements = []
while self.current_token:
if self.match("NEWLINE", "SEMICOLON"):
self.advance()
continue
stmt = self.parse_statement()
if stmt:
if isinstance(stmt, list):
statements.extend(stmt)
else:
statements.append(stmt)
return HipProgramNode(statements)
[docs]
def parse_statement(self):
"""Parse a single statement"""
if not self.current_token:
return None
# Skip newlines and semicolons
if self.match("NEWLINE", "SEMICOLON"):
self.advance()
return None
# Parse preprocessor directives
if self.match("HASH"):
return self.parse_preprocessor()
# Parse device/host/global qualifiers
if self.match("__DEVICE__", "__HOST__", "__GLOBAL__"):
return self.parse_function_with_qualifier()
# Parse struct definitions
if self.match("STRUCT"):
return self.parse_struct()
# Parse class definitions
if self.match("CLASS"):
return self.parse_class()
if self.is_type_alias_start():
return self.parse_type_alias()
# Parse return statements
if self.match("RETURN"):
return self.parse_return_statement()
# Parse control flow statements
if self.match("IF"):
return self.parse_if_statement()
elif self.match("FOR"):
return self.parse_for_statement()
elif self.match("WHILE"):
return self.parse_while_statement()
elif self.match("DO"):
return self.parse_do_while_statement()
elif self.match("SWITCH"):
return self.parse_switch_statement()
if self.match("BREAK"):
self.advance()
if self.match("SEMICOLON"):
self.advance()
return BreakNode()
if self.match("CONTINUE"):
self.advance()
if self.match("SEMICOLON"):
self.advance()
return ContinueNode()
if self.match("SYNCTHREADS", "SYNCWARP"):
return self.parse_sync_statement()
if self.match("LBRACE"):
return self.parse_block()
if self.is_identifier_value("delete"):
return self.parse_delete_statement()
# Try to parse function or variable declaration
if self.block_depth > 0 and self.is_variable_declaration():
declarations = self.parse_variable_declaration_list()
return declarations if len(declarations) > 1 else declarations[0]
elif self.is_function_declaration():
return self.parse_simple_function()
elif self.is_variable_declaration():
declarations = self.parse_variable_declaration_list()
return declarations if len(declarations) > 1 else declarations[0]
else:
# Parse expression statement
return self.parse_expression_statement()
def is_identifier_value(self, value):
return self.match("IDENTIFIER") and self.current_token.value == value
def is_type_alias_start(self):
return self.match("TYPEDEF") or self.is_identifier_value("using")
def parse_type_alias(self):
if self.match("TYPEDEF"):
return self.parse_typedef_alias()
return self.parse_using_alias()
def parse_typedef_alias(self):
self.consume("TYPEDEF")
first_type = self.parse_type()
base_type = self.strip_declarator_markers(first_type)
aliases = [self.parse_type_alias_declarator(first_type, allow_prefix=False)]
while self.match("COMMA"):
self.advance()
aliases.append(
self.parse_type_alias_declarator(base_type, allow_prefix=True)
)
if self.match("SEMICOLON"):
self.advance()
return aliases
def parse_type_alias_declarator(self, base_type, allow_prefix):
alias_type = base_type
if allow_prefix:
alias_type = self.parse_declarator_prefix(alias_type)
name = self.consume("IDENTIFIER").value
alias_type += self.parse_array_suffix()
self.type_aliases.add(name)
return TypeAliasNode(alias_type, name)
def parse_using_alias(self):
self.advance()
if self.match("NAMESPACE"):
self.skip_until_semicolon()
return None
name = self.consume("IDENTIFIER").value
self.consume("ASSIGN")
self.skip_newlines()
alias_type = self.parse_type()
if self.match("SEMICOLON"):
self.advance()
self.type_aliases.add(name)
return TypeAliasNode(alias_type, name)
def skip_until_semicolon(self):
while self.current_token and not self.match("SEMICOLON"):
self.advance()
if self.match("SEMICOLON"):
self.advance()
def parse_delete_statement(self):
self.advance()
is_array = False
if self.match("LBRACKET"):
self.consume("LBRACKET")
self.consume("RBRACKET")
is_array = True
expression = self.parse_unary_expression()
if self.match("SEMICOLON"):
self.advance()
return DeleteNode(expression, is_array)
[docs]
def parse_preprocessor(self):
"""Parse preprocessor directives"""
self.consume("HASH")
if not self.current_token:
self.error("Expected preprocessor directive after #")
directive = self.current_token.value
self.advance()
# Parse the rest of the line
content = []
while self.current_token and not self.match("NEWLINE"):
content.append(self.current_token.value)
self.advance()
return PreprocessorNode(directive, " ".join(content))
def parse_function_with_qualifier(self):
qualifiers = []
while self.match(
"__DEVICE__", "__HOST__", "__GLOBAL__", "__FORCEINLINE__", "__NOINLINE__"
):
qualifiers.append(self.current_token.value)
self.advance()
return_type = self.parse_type()
name = self.consume("IDENTIFIER").value
self.consume("LPAREN")
params = self.parse_parameter_list()
self.consume("RPAREN")
body = None
if self.match("LBRACE"):
body = self.parse_block()
elif self.match("SEMICOLON"):
self.advance()
function = FunctionNode(return_type, name, params, body, qualifiers)
# __global__ qualifier marks a kernel
if "__global__" in qualifiers:
return KernelNode(return_type, name, params, body)
return function
def parse_simple_function(self):
qualifiers = []
while self.match(*self.FUNCTION_SPECIFIER_TOKENS):
qualifiers.append(self.current_token.value)
self.advance()
return_type = self.parse_type()
name = self.consume("IDENTIFIER").value
self.consume("LPAREN")
params = self.parse_parameter_list()
self.consume("RPAREN")
body = None
if self.match("LBRACE"):
body = self.parse_block()
elif self.match("SEMICOLON"):
self.advance()
return FunctionNode(return_type, name, params, body, qualifiers)
def parse_struct(self):
self.consume("STRUCT")
name = None
if self.match("IDENTIFIER"):
name = self.current_token.value
self.advance()
members = []
if self.match("LBRACE"):
self.consume("LBRACE")
while self.current_token and not self.match("RBRACE"):
if self.match("NEWLINE", "SEMICOLON"):
self.advance()
continue
member = self.parse_struct_member()
if member:
members.append(member)
self.consume("RBRACE")
if self.match("SEMICOLON"):
self.advance()
return StructNode(name, members)
[docs]
def parse_class(self):
"""Parse class definitions (treat similar to struct for now)"""
self.consume("CLASS")
name = self.consume("IDENTIFIER").value
members = []
if self.match("LBRACE"):
self.consume("LBRACE")
while self.current_token and not self.match("RBRACE"):
if self.match("NEWLINE", "SEMICOLON"):
self.advance()
continue
# Skip access specifiers
if self.match("PUBLIC", "PRIVATE", "PROTECTED"):
self.advance()
if self.match("COLON"):
self.advance()
continue
member = self.parse_struct_member()
if member:
members.append(member)
self.consume("RBRACE")
if self.match("SEMICOLON"):
self.advance()
return StructNode(name, members) # Treat class as struct for simplicity
def parse_struct_member(self):
try:
member_type = self.parse_type()
name = self.consume("IDENTIFIER").value
member_type += self.parse_array_suffix()
if self.match("SEMICOLON"):
self.advance()
return VariableNode(member_type, name)
except Exception:
while self.current_token and not self.match("SEMICOLON", "RBRACE"):
self.advance()
if self.match("SEMICOLON"):
self.advance()
return None
def parse_variable_declaration(self, consume_semicolon=True):
qualifiers = []
while self.match(
"__SHARED__", "__CONSTANT__", "__DEVICE__", "STATIC", "EXTERN"
):
qualifiers.append(self.current_token.value)
self.advance()
var_type = self.parse_type()
name = self.consume("IDENTIFIER").value
var_type += self.parse_array_suffix()
value = None
if self.match("ASSIGN"):
self.advance()
value = self.parse_expression()
elif self.match("LPAREN"):
value = FunctionCallNode(var_type, self.parse_parenthesized_argument_list())
elif self.match("LBRACE"):
value = self.parse_initializer_list()
if consume_semicolon and self.match("SEMICOLON"):
self.advance()
return VariableNode(var_type, name, value, qualifiers)
def parse_variable_declaration_list(self, consume_semicolon=True):
qualifiers = []
while self.match(
"__SHARED__", "__CONSTANT__", "__DEVICE__", "STATIC", "EXTERN"
):
qualifiers.append(self.current_token.value)
self.advance()
first_type = self.parse_type()
base_type = self.strip_declarator_markers(first_type)
declarations = [
self.parse_variable_declarator(first_type, qualifiers, allow_prefix=False)
]
while self.match("COMMA"):
self.advance()
declarations.append(
self.parse_variable_declarator(base_type, qualifiers, allow_prefix=True)
)
if consume_semicolon and self.match("SEMICOLON"):
self.advance()
return declarations
def parse_variable_declarator(self, base_type, qualifiers, allow_prefix):
var_type = base_type
if allow_prefix:
var_type = self.parse_declarator_prefix(var_type)
name = self.consume("IDENTIFIER").value
var_type += self.parse_array_suffix()
value = self.parse_variable_initializer(var_type)
return VariableNode(var_type, name, value, list(qualifiers))
def parse_declarator_prefix(self, base_type):
parts = [base_type] if base_type else []
while self.match(
"ASTERISK",
"STAR",
*self.TYPE_REFERENCE_TOKENS,
*self.POSTFIX_TYPE_QUALIFIER_TOKENS,
):
if self.match(*self.POSTFIX_TYPE_QUALIFIER_TOKENS):
parts.append(self.current_token.value)
self.advance()
continue
parts.append(self.current_token.value)
self.advance()
self.parse_postfix_type_qualifiers(parts)
return " ".join(parts)
def parse_variable_initializer(self, var_type):
if self.match("ASSIGN"):
self.advance()
self.skip_newlines()
return self.parse_expression()
if self.match("LPAREN"):
return FunctionCallNode(var_type, self.parse_parenthesized_argument_list())
if self.match("LBRACE"):
return self.parse_initializer_list()
return None
def strip_declarator_markers(self, var_type):
parts = str(var_type).split()
while parts and parts[-1] in {"*", "&", "&&", "__restrict__", "restrict"}:
parts.pop()
return " ".join(parts)
def parse_array_suffix(self):
suffixes = []
while self.match("LBRACKET"):
self.consume("LBRACKET")
if not self.match("RBRACKET"):
size = self.parse_expression()
suffixes.append(f"[{self.expression_to_text(size)}]")
else:
suffixes.append("[]")
self.consume("RBRACKET")
return "".join(suffixes)
def parse_parenthesized_argument_list(self):
self.consume("LPAREN")
args = self.parse_argument_list()
self.consume("RPAREN")
return args
def parse_type(self):
type_parts = []
while self.match(*self.TYPE_QUALIFIER_TOKENS):
type_parts.append(self.current_token.value)
self.advance()
if (
self.is_builtin_type_token()
or self.match(*self.VECTOR_TYPE_TOKENS)
or self.match("IDENTIFIER")
):
type_parts.append(self.parse_type_name())
else:
type_parts.append("int") # Default type
self.parse_postfix_type_qualifiers(type_parts)
while self.match("ASTERISK", "STAR"):
type_parts.append("*")
self.advance()
self.parse_postfix_type_qualifiers(type_parts)
while self.match(*self.TYPE_REFERENCE_TOKENS):
type_parts.append(self.current_token.value)
self.advance()
self.parse_postfix_type_qualifiers(type_parts)
array_suffix = self.parse_array_suffix()
if array_suffix:
type_parts.append(array_suffix)
return " ".join(type_parts)
def parse_type_without_array_suffix(self):
type_parts = []
while self.match(*self.TYPE_QUALIFIER_TOKENS):
type_parts.append(self.current_token.value)
self.advance()
if (
self.is_builtin_type_token()
or self.match(*self.VECTOR_TYPE_TOKENS)
or self.match("IDENTIFIER")
):
type_parts.append(self.parse_type_name())
else:
type_parts.append("int")
self.parse_postfix_type_qualifiers(type_parts)
while self.match("ASTERISK", "STAR"):
type_parts.append("*")
self.advance()
self.parse_postfix_type_qualifiers(type_parts)
while self.match(*self.TYPE_REFERENCE_TOKENS):
type_parts.append(self.current_token.value)
self.advance()
self.parse_postfix_type_qualifiers(type_parts)
return " ".join(type_parts)
def parse_postfix_type_qualifiers(self, type_parts):
while self.match(*self.POSTFIX_TYPE_QUALIFIER_TOKENS):
type_parts.append(self.current_token.value)
self.advance()
def parse_type_name(self):
type_name = self.current_token.value
self.advance()
while self.match("SCOPE"):
self.consume("SCOPE")
member = self.consume("IDENTIFIER").value
type_name += f"::{member}"
if self.match("LT"):
type_name += self.parse_template_suffix()
return type_name
def parse_parameter_list(self):
params = []
self.skip_newlines()
if self.match("RPAREN"):
return params
while True:
self.skip_newlines()
if self.match("RPAREN"):
break
param_type = self.parse_type()
param_name = ""
if self.match("IDENTIFIER"):
param_name = self.current_token.value
self.advance()
param_type += self.parse_array_suffix()
params.append({"type": param_type, "name": param_name})
if self.match("COMMA"):
self.advance()
self.skip_newlines()
else:
break
return params
def parse_return_statement(self):
self.consume("RETURN")
value = None
if not self.match("SEMICOLON"):
value = self.parse_expression()
if self.match("SEMICOLON"):
self.advance()
return ReturnNode(value)
def parse_if_statement(self):
self.consume("IF")
self.consume("LPAREN")
condition = self.parse_expression()
self.consume("RPAREN")
if_body = None
self.skip_newlines()
if self.match("LBRACE"):
if_body = self.parse_block()
else:
if_body = self.parse_statement()
else_body = None
if self.match("ELSE"):
self.advance()
self.skip_newlines()
if self.match("LBRACE"):
else_body = self.parse_block()
else:
else_body = self.parse_statement()
return IfNode(condition, if_body, else_body)
def parse_for_statement(self):
self.consume("FOR")
self.consume("LPAREN")
if self.is_range_for_statement():
return self.parse_range_for_statement()
init = None
if not self.match("SEMICOLON"):
if self.is_variable_declaration():
init_declarations = self.parse_variable_declaration_list(
consume_semicolon=False
)
init = (
init_declarations
if len(init_declarations) > 1
else init_declarations[0]
)
else:
init = self.parse_expression()
self.consume("SEMICOLON")
condition = None
if not self.match("SEMICOLON"):
condition = self.parse_expression()
self.consume("SEMICOLON")
update = None
if not self.match("RPAREN"):
update = self.parse_expression()
self.consume("RPAREN")
self.skip_newlines()
body = self.parse_statement()
return ForNode(init, condition, update, body)
def is_range_for_statement(self):
index = self.skip_range_for_type_at_pos(self.pos)
if index is None:
return False
if index >= len(self.tokens) or self.tokens[index].type != "IDENTIFIER":
return False
index += 1
index = self.skip_array_suffix_at_pos(index)
return index < len(self.tokens) and self.tokens[index].type == "COLON"
def skip_range_for_type_at_pos(self, index):
while (
index < len(self.tokens)
and self.tokens[index].type in self.TYPE_QUALIFIER_TOKENS
):
index += 1
if index >= len(self.tokens) or not self.is_type_token(self.tokens[index]):
return None
index += 1
while (
index + 1 < len(self.tokens)
and self.tokens[index].type == "SCOPE"
and self.tokens[index + 1].type == "IDENTIFIER"
):
index += 2
if index < len(self.tokens) and self.tokens[index].type == "LT":
index = self.skip_template_at_pos(index)
if index is None:
return None
index = self.skip_postfix_type_qualifiers_at_pos(index)
while index < len(self.tokens) and self.tokens[index].type in {
"ASTERISK",
"STAR",
*self.TYPE_REFERENCE_TOKENS,
}:
index += 1
index = self.skip_postfix_type_qualifiers_at_pos(index)
return self.skip_array_suffix_at_pos(index)
def parse_range_for_statement(self):
vtype = self.parse_type()
name = self.consume("IDENTIFIER").value
self.consume("COLON")
iterable = self.parse_expression()
self.consume("RPAREN")
self.skip_newlines()
body = self.parse_statement()
return RangeForNode(vtype, name, iterable, body)
def parse_while_statement(self):
self.consume("WHILE")
self.consume("LPAREN")
condition = self.parse_expression()
self.consume("RPAREN")
self.skip_newlines()
body = self.parse_statement()
return WhileNode(condition, body)
def parse_do_while_statement(self):
self.consume("DO")
self.skip_newlines()
body = self.parse_statement()
self.skip_newlines()
self.consume("WHILE")
self.consume("LPAREN")
condition = self.parse_expression()
self.consume("RPAREN")
if self.match("SEMICOLON"):
self.advance()
return DoWhileNode(body, condition)
def parse_switch_statement(self):
self.consume("SWITCH")
self.consume("LPAREN")
expression = self.parse_expression()
self.consume("RPAREN")
self.skip_newlines()
self.consume("LBRACE")
cases = []
default_case = None
while self.current_token and not self.match("RBRACE"):
if self.match("NEWLINE", "SEMICOLON"):
self.advance()
continue
if self.match("CASE"):
self.advance()
value = self.parse_expression()
self.consume("COLON")
body = []
while self.current_token and not self.match(
"CASE", "DEFAULT", "RBRACE"
):
stmt = self.parse_statement()
if stmt:
body.append(stmt)
cases.append(CaseNode(value, body))
elif self.match("DEFAULT"):
self.advance()
self.consume("COLON")
default_case = []
while self.current_token and not self.match("CASE", "RBRACE"):
stmt = self.parse_statement()
if stmt:
default_case.append(stmt)
else:
self.advance()
self.consume("RBRACE")
return SwitchNode(expression, cases, default_case)
def parse_sync_statement(self):
sync_type = self.current_token.value
self.advance()
self.consume("LPAREN")
args = []
if not self.match("RPAREN"):
args.append(self.parse_expression())
while self.match("COMMA"):
self.advance()
args.append(self.parse_expression())
self.consume("RPAREN")
if self.match("SEMICOLON"):
self.advance()
return SyncNode(sync_type, args)
def parse_block(self):
self.consume("LBRACE")
statements = []
self.block_depth += 1
try:
while self.current_token and not self.match("RBRACE"):
stmt = self.parse_statement()
if stmt:
if isinstance(stmt, list):
statements.extend(stmt)
else:
statements.append(stmt)
finally:
self.block_depth -= 1
self.consume("RBRACE")
return statements
def parse_expression_statement(self):
expr = self.parse_expression()
if self.match("SEMICOLON"):
self.advance()
return expr
def parse_expression(self):
return self.parse_assignment_expression()
def parse_assignment_expression(self):
left = self.parse_ternary_expression()
if self.match(
"ASSIGN",
"PLUS_ASSIGN",
"MINUS_ASSIGN",
"MULTIPLY_ASSIGN",
"DIVIDE_ASSIGN",
"STAR_ASSIGN",
"SLASH_ASSIGN",
"PERCENT_ASSIGN",
"AND_ASSIGN",
"OR_ASSIGN",
"XOR_ASSIGN",
"LSHIFT_ASSIGN",
"RSHIFT_ASSIGN",
):
op = self.current_token.value
self.advance()
right = self.parse_assignment_expression()
return AssignmentNode(left, right, op)
return left
def parse_ternary_expression(self):
expr = self.parse_logical_or_expression()
if self.match("QUESTION"):
self.advance()
true_expr = self.parse_expression()
self.consume("COLON")
false_expr = self.parse_expression()
return TernaryOpNode(expr, true_expr, false_expr)
return expr
def parse_logical_or_expression(self):
left = self.parse_logical_and_expression()
while self.match("LOGICAL_OR", "OR"):
op = self.current_token.value
self.advance()
right = self.parse_logical_and_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_logical_and_expression(self):
left = self.parse_bitwise_or_expression()
while self.match("LOGICAL_AND", "AND"):
op = self.current_token.value
self.advance()
right = self.parse_bitwise_or_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_bitwise_or_expression(self):
left = self.parse_bitwise_xor_expression()
while self.match("BITWISE_OR", "PIPE"):
op = self.current_token.value
self.advance()
right = self.parse_bitwise_xor_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_bitwise_xor_expression(self):
left = self.parse_bitwise_and_expression()
while self.match("BITWISE_XOR", "XOR"):
op = self.current_token.value
self.advance()
right = self.parse_bitwise_and_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_bitwise_and_expression(self):
left = self.parse_equality_expression()
while self.match("BITWISE_AND", "AMPERSAND"):
op = self.current_token.value
self.advance()
right = self.parse_equality_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_equality_expression(self):
left = self.parse_relational_expression()
while self.match("EQ", "NE"):
op = self.current_token.value
self.advance()
right = self.parse_relational_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_relational_expression(self):
left = self.parse_shift_expression()
while self.match("LT", "LE", "GT", "GE"):
op = self.current_token.value
self.advance()
right = self.parse_shift_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_shift_expression(self):
left = self.parse_additive_expression()
while self.match("SHIFT_LEFT", "SHIFT_RIGHT", "LSHIFT", "RSHIFT"):
op = self.current_token.value
self.advance()
right = self.parse_additive_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_additive_expression(self):
left = self.parse_multiplicative_expression()
while self.match("PLUS", "MINUS"):
op = self.current_token.value
self.advance()
right = self.parse_multiplicative_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_multiplicative_expression(self):
left = self.parse_unary_expression()
while self.match("MULTIPLY", "STAR", "DIVIDE", "SLASH", "MODULO", "PERCENT"):
op = self.current_token.value
self.advance()
right = self.parse_unary_expression()
left = BinaryOpNode(left, op, right)
return left
def parse_unary_expression(self):
if self.match(
"PLUS",
"MINUS",
"NOT",
"BITWISE_NOT",
"INCREMENT",
"DECREMENT",
"STAR",
"AMPERSAND",
):
op = self.current_token.value
self.advance()
operand = self.parse_unary_expression()
return UnaryOpNode(op, operand)
return self.parse_postfix_expression()
def parse_postfix_expression(self):
expr = self.parse_primary_expression()
while True:
if self.match("LBRACKET"):
self.consume("LBRACKET")
index = self.parse_expression()
self.consume("RBRACKET")
expr = ArrayAccessNode(expr, index)
elif self.match("SCOPE"):
self.consume("SCOPE")
member = self.consume("IDENTIFIER").value
expr = self.append_qualified_name(expr, "::", member)
elif self.match("LT") and self.is_template_suffix():
expr = self.append_template_suffix(expr)
elif self.match("DOT"):
self.consume("DOT")
member = self.consume("IDENTIFIER").value
expr = MemberAccessNode(expr, member, False)
elif self.match("ARROW"):
self.consume("ARROW")
member = self.consume("IDENTIFIER").value
expr = MemberAccessNode(expr, member, True)
elif self.match("LPAREN"):
self.consume("LPAREN")
args = self.parse_argument_list()
self.consume("RPAREN")
expr = self.parse_function_call_node(expr, args)
elif self.match("KERNEL_LAUNCH_START"):
expr = self.parse_kernel_launch(expr)
elif self.match("INCREMENT", "DECREMENT"):
op = self.current_token.value
self.advance()
expr = UnaryOpNode(op + "_POST", expr)
else:
break
return expr
def append_qualified_name(self, base, separator, member):
if isinstance(base, str):
return f"{base}{separator}{member}"
return f"{self.expression_to_text(base)}{separator}{member}"
def is_template_suffix(self):
index = self.pos
if self.tokens[index].type != "LT":
return False
depth = 0
while index < len(self.tokens):
token_type = self.tokens[index].type
if token_type == "LT":
depth += 1
elif token_type == "GT":
depth -= 1
if depth == 0:
next_type = (
self.tokens[index + 1].type
if index + 1 < len(self.tokens)
else "EOF"
)
return next_type in {"LPAREN", "SCOPE", "DOT"}
elif token_type == "RSHIFT":
depth -= 2
if depth == 0:
next_type = (
self.tokens[index + 1].type
if index + 1 < len(self.tokens)
else "EOF"
)
return next_type in {"LPAREN", "SCOPE", "DOT"}
if depth < 0:
return False
elif token_type in {"SEMICOLON", "ASSIGN"}:
return False
index += 1
return False
def append_template_suffix(self, base):
suffix = self.parse_template_suffix()
if isinstance(base, str):
return f"{base}{suffix}"
return f"{self.expression_to_text(base)}{suffix}"
def parse_template_suffix(self):
self.consume("LT")
parts = []
depth = 1
while depth > 0:
token_type = self.current_token.type
token_value = self.current_token.value
if token_type == "LT":
depth += 1
parts.append(token_value)
self.consume("LT")
elif token_type == "GT":
depth -= 1
if depth == 0:
self.consume("GT")
break
parts.append(token_value)
self.consume("GT")
elif token_type == "RSHIFT":
self.consume("RSHIFT")
for _ in range(2):
depth -= 1
if depth == 0:
break
parts.append(">")
else:
parts.append(token_value)
self.consume(token_type)
return f"<{self.format_template_parts(parts)}>"
def format_template_parts(self, parts):
formatted = []
previous = None
for part in parts:
if part == ",":
formatted.append(", ")
else:
if self.needs_template_part_space(previous, part):
formatted.append(" ")
formatted.append(part)
previous = part
return "".join(formatted).strip()
def needs_template_part_space(self, previous, current):
if previous is None:
return False
if previous in {"<", "[", "(", "::", ","}:
return False
if current in {">", "]", ")", ",", "::", "<", "[", "(", "*", "&", "&&"}:
return False
if previous in {"*", "&", "&&"}:
return False
return self.is_template_word(previous) and self.is_template_word(current)
def is_template_word(self, part):
return part.replace("_", "").isalnum()
def parse_function_call_node(self, function_name, args):
named_cast = self.parse_cpp_named_cast_call(function_name, args)
if named_cast is not None:
return named_cast
if function_name == "hipLaunchKernelGGL" and len(args) >= 5:
return KernelLaunchNode(
args[0],
args[1],
args[2],
args[3],
args[4],
args[5:],
)
return FunctionCallNode(function_name, args)
def parse_cpp_named_cast_call(self, function_name, args):
if not isinstance(function_name, str) or len(args) != 1:
return None
for cast_name in self.CPP_NAMED_CASTS:
prefix = f"{cast_name}<"
if function_name.startswith(prefix) and function_name.endswith(">"):
return CastNode(function_name[len(prefix) : -1], args[0])
return None
def parse_kernel_launch(self, kernel_name):
self.consume("KERNEL_LAUNCH_START")
self.skip_newlines()
blocks = self.parse_expression()
self.skip_newlines()
self.consume("COMMA")
self.skip_newlines()
threads = self.parse_expression()
self.skip_newlines()
shared_mem = None
stream = None
if self.match("COMMA"):
self.advance()
self.skip_newlines()
shared_mem = self.parse_expression()
self.skip_newlines()
if self.match("COMMA"):
self.advance()
self.skip_newlines()
stream = self.parse_expression()
self.skip_newlines()
self.consume("KERNEL_LAUNCH_END")
self.consume("LPAREN")
args = self.parse_argument_list()
self.consume("RPAREN")
return KernelLaunchNode(kernel_name, blocks, threads, shared_mem, stream, args)
def parse_primary_expression(self):
if self.match("IDENTIFIER"):
if self.current_token.value == "new":
return self.parse_new_expression()
name = self.current_token.value
self.advance()
# Check for HIP built-in variables
if name in ["threadIdx", "blockIdx", "blockDim", "gridDim"]:
component = None
if self.match("DOT"):
self.advance()
if self.match("IDENTIFIER"):
component = self.current_token.value
self.advance()
return HipBuiltinNode(name, component)
return name
elif self.match("THREADIDX", "BLOCKIDX", "BLOCKDIM", "GRIDDIM"):
name = self.current_token.value
self.advance()
component = None
if self.match("DOT"):
self.advance()
if self.match("IDENTIFIER"):
component = self.current_token.value
self.advance()
return HipBuiltinNode(name, component)
elif self.match("INTEGER", "FLOAT_NUM", "FLOAT", "STRING"):
value = self.current_token.value
self.advance()
return value
elif self.match("TRUE", "FALSE", "NULL", "NULLPTR", "HIPSUCCESS"):
value = self.current_token.value
self.advance()
return value
elif self.match("CHAR") and self.current_token.value != "char":
value = self.current_token.value
self.advance()
return value
elif self.match("SYNCTHREADS", "SYNCWARP"):
value = self.current_token.value
self.advance()
return value
elif self.match("LPAREN"):
if self.is_cast_expression():
self.consume("LPAREN")
target_type = self.parse_type()
self.consume("RPAREN")
expr = self.parse_unary_expression()
return CastNode(target_type, expr)
self.consume("LPAREN")
expr = self.parse_expression()
self.consume("RPAREN")
return expr
elif self.match("LBRACE"):
return self.parse_initializer_list()
else:
self.error(
f"Unexpected token in expression: {self.current_token.type if self.current_token else 'EOF'}"
)
def parse_new_expression(self):
self.advance()
target_type = self.parse_type_without_array_suffix()
if self.match("LBRACKET"):
self.consume("LBRACKET")
size = None
if not self.match("RBRACKET"):
size = self.parse_expression()
self.consume("RBRACKET")
return NewNode(target_type, size=size, is_array=True)
args = []
if self.match("LPAREN"):
self.consume("LPAREN")
args = self.parse_argument_list()
self.consume("RPAREN")
return NewNode(target_type, args=args)
def parse_initializer_list(self):
self.consume("LBRACE")
elements = []
self.skip_newlines()
while self.current_token and not self.match("RBRACE"):
elements.append(self.parse_initializer_element())
self.skip_newlines()
if self.match("COMMA"):
self.advance()
self.skip_newlines()
if self.match("RBRACE"):
break
else:
break
self.consume("RBRACE")
return InitializerListNode(elements)
def parse_initializer_element(self):
if self.match("LBRACKET", "DOT"):
return self.parse_designated_initializer()
return self.parse_expression()
def parse_designated_initializer(self):
designators = []
while self.match("LBRACKET", "DOT"):
if self.match("LBRACKET"):
self.consume("LBRACKET")
index = self.parse_expression()
self.consume("RBRACKET")
designators.append(("index", index))
else:
self.consume("DOT")
field = self.consume("IDENTIFIER").value
designators.append(("field", field))
self.skip_newlines()
self.consume("ASSIGN")
self.skip_newlines()
value = self.parse_expression()
return DesignatedInitializerNode(designators, value)
def expression_to_text(self, expr):
if isinstance(expr, str):
return expr
if isinstance(expr, HipBuiltinNode):
if expr.component:
return f"{expr.builtin_name}.{expr.component}"
return expr.builtin_name
if isinstance(expr, BinaryOpNode):
left = self.expression_to_text(expr.left)
right = self.expression_to_text(expr.right)
return f"({left} {expr.op} {right})"
if isinstance(expr, UnaryOpNode):
return f"{expr.op}{self.expression_to_text(expr.operand)}"
return str(expr)
def is_cast_expression(self):
if not self.match("LPAREN"):
return False
saved_pos = self.pos
try:
self.advance()
while self.match(*self.TYPE_QUALIFIER_TOKENS):
self.advance()
if not self.is_type_token(allow_identifier=False):
return False
self.advance()
while self.match("ASTERISK", "STAR"):
self.advance()
while self.match("LBRACKET"):
self.advance()
while self.current_token and not self.match("RBRACKET"):
self.advance()
self.consume("RBRACKET")
return self.match("RPAREN")
finally:
self.pos = saved_pos
self.current_token = (
self.tokens[self.pos] if self.pos < len(self.tokens) else None
)
def parse_argument_list(self):
args = []
self.skip_newlines()
if self.match("RPAREN"):
return args
while True:
self.skip_newlines()
arg = self.parse_expression()
args.append(arg)
self.skip_newlines()
if self.match("COMMA"):
self.advance()
else:
break
return args
def is_function_declaration(self) -> bool:
# Simple heuristic: type followed by identifier followed by (
index = self.pos
while (
index < len(self.tokens)
and self.tokens[index].type in self.FUNCTION_SPECIFIER_TOKENS
):
index += 1
index = self.skip_type_at_pos(index)
if index is not None:
if (
index + 1 < len(self.tokens)
and self.tokens[index].type == "IDENTIFIER"
and self.tokens[index + 1].type == "LPAREN"
):
return True
return False
def is_variable_declaration(self) -> bool:
# Simple heuristic: type followed by identifier not followed by (
index = self.pos
while index < len(self.tokens) and self.tokens[index].type in {
"__SHARED__",
"__CONSTANT__",
"__DEVICE__",
"STATIC",
"EXTERN",
"CONST",
"VOLATILE",
"UNSIGNED",
"SIGNED",
}:
index += 1
index = self.skip_type_at_pos(index)
if index is not None:
if index < len(self.tokens) and self.tokens[index].type == "IDENTIFIER":
index += 1
if index < len(self.tokens) and self.tokens[index].type in {
"SEMICOLON",
"ASSIGN",
"LBRACKET",
"LPAREN",
"LBRACE",
"COMMA",
}:
return True
return False
def skip_type_at_pos(self, index):
while (
index < len(self.tokens)
and self.tokens[index].type in self.TYPE_QUALIFIER_TOKENS
):
index += 1
if index >= len(self.tokens) or not self.is_type_token(self.tokens[index]):
return None
type_token = self.tokens[index].type
type_value = self.tokens[index].value
has_qualified_suffix = False
index += 1
while (
index + 1 < len(self.tokens)
and self.tokens[index].type == "SCOPE"
and self.tokens[index + 1].type == "IDENTIFIER"
):
has_qualified_suffix = True
index += 2
if index < len(self.tokens) and self.tokens[index].type == "LT":
has_qualified_suffix = True
index = self.skip_template_at_pos(index)
if index is None:
return None
index = self.skip_postfix_type_qualifiers_at_pos(index)
can_have_pointer_suffix = (
type_token != "IDENTIFIER"
or has_qualified_suffix
or type_value == "auto"
or type_value in self.type_aliases
)
while (
can_have_pointer_suffix
and index < len(self.tokens)
and self.tokens[index].type
in {
"ASTERISK",
"STAR",
*self.TYPE_REFERENCE_TOKENS,
}
):
index += 1
index = self.skip_postfix_type_qualifiers_at_pos(index)
return self.skip_array_suffix_at_pos(index)
def skip_postfix_type_qualifiers_at_pos(self, index):
while (
index < len(self.tokens)
and self.tokens[index].type in self.POSTFIX_TYPE_QUALIFIER_TOKENS
):
index += 1
return index
def skip_template_at_pos(self, index):
depth = 0
while index < len(self.tokens):
token_type = self.tokens[index].type
if token_type == "LT":
depth += 1
elif token_type == "GT":
depth -= 1
if depth == 0:
return index + 1
elif token_type == "RSHIFT":
depth -= 2
if depth == 0:
return index + 1
if depth < 0:
return None
elif token_type in {"SEMICOLON", "ASSIGN"}:
return None
index += 1
return None
def skip_array_suffix_at_pos(self, index):
while index < len(self.tokens) and self.tokens[index].type == "LBRACKET":
index += 1
while index < len(self.tokens) and self.tokens[index].type != "RBRACKET":
index += 1
if index < len(self.tokens) and self.tokens[index].type == "RBRACKET":
index += 1
else:
break
return index
[docs]
def parse_hip_code(code: str) -> HipProgramNode:
"""Parse HIP source text and return the backend AST."""
lexer = HipLexer(code)
tokens = lexer.tokenize()
parser = HipParser(tokens)
return parser.parse()