Source code for crosstl.backend.CUDA.CudaParser

"""CUDA Parser Implementation"""

from .CudaAst import (
    ArrayAccessNode,
    AssignmentNode,
    AtomicOperationNode,
    BinaryOpNode,
    BreakNode,
    CaseNode,
    CastNode,
    ConstantMemoryNode,
    ContinueNode,
    CudaBuiltinNode,
    DesignatedInitializerNode,
    DeleteNode,
    DoWhileNode,
    ForNode,
    FunctionCallNode,
    FunctionNode,
    IfNode,
    InitializerListNode,
    KernelLaunchNode,
    KernelNode,
    MemberAccessNode,
    NewNode,
    PreprocessorNode,
    RangeForNode,
    ReturnNode,
    ShaderNode,
    SharedMemoryNode,
    StructNode,
    SwitchNode,
    SyncNode,
    TernaryOpNode,
    TypeAliasNode,
    UnaryOpNode,
    VariableNode,
    WhileNode,
)
from .CudaLexer import CudaLexer


[docs] class CudaParser: """Parse CUDA tokens into the CUDA backend shader AST.""" TYPE_QUALIFIER_TOKENS = {"CONST", "VOLATILE", "UNSIGNED", "SIGNED", "RESTRICT"} POSTFIX_TYPE_QUALIFIER_TOKENS = {"RESTRICT"} TYPE_REFERENCE_TOKENS = {"BITWISE_AND", "LOGICAL_AND"} CPP_NAMED_CASTS = {"static_cast", "reinterpret_cast", "const_cast", "dynamic_cast"} DECLARATION_QUALIFIER_TOKENS = { "CONST", "VOLATILE", "STATIC", "EXTERN", "SHARED", "CONSTANT", "DEVICE", "MANAGED", "UNSIGNED", "SIGNED", } FUNCTION_SPECIFIER_TOKENS = { "GLOBAL", "DEVICE", "HOST", "INLINE", "STATIC", "EXTERN", "FORCEINLINE", "NOINLINE", } TYPE_TOKENS = { "VOID", "CHAR", "SHORT", "INT", "LONG", "FLOAT", "DOUBLE", "BOOL", "FLOAT2", "FLOAT3", "FLOAT4", "INT2", "INT3", "INT4", "DOUBLE2", "DOUBLE3", "DOUBLE4", "UINT2", "UINT3", "UINT4", "CHAR2", "CHAR3", "CHAR4", "UCHAR2", "UCHAR3", "UCHAR4", "SHORT2", "SHORT3", "SHORT4", "USHORT2", "USHORT3", "USHORT4", "LONG2", "LONG3", "LONG4", "ULONG2", "ULONG3", "ULONG4", "LONGLONG2", "LONGLONG3", "LONGLONG4", "ULONGLONG2", "ULONGLONG3", "ULONGLONG4", "SIZE_T", "IDENTIFIER", } def __init__(self, tokens): """Initialize the parser with a token stream from ``CudaLexer``.""" self.tokens = tokens self.current_index = 0 self.current_token = tokens[0] if tokens else None self.type_aliases = set()
[docs] def parse(self): """Parse the entire CUDA program into a ``ShaderNode``.""" includes = [] functions = [] structs = [] global_variables = [] kernels = [] typedefs = [] while self.current_token[0] != "EOF": if self.current_token[0] == "PREPROCESSOR": includes.append(self.parse_preprocessor()) elif self.is_type_alias_start(): aliases = self.parse_type_alias() if isinstance(aliases, list): typedefs.extend(aliases) elif aliases is not None: typedefs.append(aliases) elif self.current_token[0] == "STRUCT": structs.append(self.parse_struct()) elif ( self.current_token[0] in ["GLOBAL", "DEVICE", "HOST"] or self.peek_function() ): func = self.parse_function() if isinstance(func, KernelNode): kernels.append(func) else: functions.append(func) elif ( self.current_token[0] in ["CONSTANT", "SHARED"] or self.peek_variable() ): global_variables.append(self.parse_global_variable()) else: # Skip unexpected tokens self.eat(self.current_token[0]) return ShaderNode( includes, functions, structs, global_variables, kernels, typedefs=typedefs )
[docs] def peek_function(self): """Check if the next tokens form a function declaration""" # Look ahead for function pattern: [qualifiers] type name ( saved_index = self.current_index while ( saved_index < len(self.tokens) and self.tokens[saved_index][0] in self.FUNCTION_SPECIFIER_TOKENS ): saved_index += 1 while ( saved_index < len(self.tokens) and self.tokens[saved_index][0] in self.TYPE_QUALIFIER_TOKENS ): saved_index += 1 saved_index = self.skip_type_at_index(saved_index) if saved_index is not None: if ( saved_index < len(self.tokens) - 1 and self.tokens[saved_index][0] == "IDENTIFIER" and self.tokens[saved_index + 1][0] == "LPAREN" ): return True return False
[docs] def peek_variable(self): """Check if the next tokens form a variable declaration""" saved_index = self.current_index while ( saved_index < len(self.tokens) and self.tokens[saved_index][0] in self.DECLARATION_QUALIFIER_TOKENS ): saved_index += 1 saved_index = self.skip_type_at_index(saved_index) if saved_index is not None: if ( saved_index < len(self.tokens) and self.tokens[saved_index][0] == "IDENTIFIER" ): saved_index += 1 saved_index = self.skip_array_suffix_at_index(saved_index) if saved_index < len(self.tokens) and self.tokens[saved_index][0] in [ "SEMICOLON", "ASSIGN", "LPAREN", "LBRACE", "COMMA", ]: return True return False
def skip_type_at_index(self, index): while ( index < len(self.tokens) and self.tokens[index][0] in self.TYPE_QUALIFIER_TOKENS ): index += 1 if index >= len(self.tokens) or self.tokens[index][0] not in self.TYPE_TOKENS: return None type_token = self.tokens[index][0] type_value = self.tokens[index][1] has_qualified_suffix = False index += 1 while ( index + 1 < len(self.tokens) and self.tokens[index][0] == "SCOPE" and self.tokens[index + 1][0] == "IDENTIFIER" ): has_qualified_suffix = True index += 2 if index < len(self.tokens) and self.tokens[index][0] == "LESS_THAN": has_qualified_suffix = True index = self.skip_template_at_index(index) if index is None: return None index = self.skip_postfix_type_qualifiers_at_index(index) can_have_pointer_suffix = ( type_token != "IDENTIFIER" or has_qualified_suffix or type_value == "auto" or type_value in self.type_aliases ) while ( can_have_pointer_suffix and index < len(self.tokens) and self.tokens[index][0] in {"MULTIPLY", *self.TYPE_REFERENCE_TOKENS} ): index += 1 index = self.skip_postfix_type_qualifiers_at_index(index) return self.skip_array_suffix_at_index(index) def skip_postfix_type_qualifiers_at_index(self, index): while ( index < len(self.tokens) and self.tokens[index][0] in self.POSTFIX_TYPE_QUALIFIER_TOKENS ): index += 1 return index def skip_template_at_index(self, index): depth = 0 while index < len(self.tokens): token_type = self.tokens[index][0] if token_type == "LESS_THAN": depth += 1 elif token_type == "GREATER_THAN": depth -= 1 if depth == 0: return index + 1 elif token_type == "SHIFT_RIGHT": depth -= 2 if depth == 0: return index + 1 if depth < 0: return None elif token_type in {"SEMICOLON", "ASSIGN", "EOF"}: return None index += 1 return None def skip_array_suffix_at_index(self, index): while index < len(self.tokens) and self.tokens[index][0] == "LBRACKET": index += 1 while index < len(self.tokens) and self.tokens[index][0] != "RBRACKET": index += 1 if index < len(self.tokens) and self.tokens[index][0] == "RBRACKET": index += 1 else: break return index
[docs] def eat(self, expected_type): """Consume a token of the expected type""" if self.current_token[0] == expected_type: token = self.current_token self.current_index += 1 if self.current_index < len(self.tokens): self.current_token = self.tokens[self.current_index] else: self.current_token = ("EOF", "") return token else: raise SyntaxError(f"Expected {expected_type}, got {self.current_token[0]}")
[docs] def parse_preprocessor(self): """Parse preprocessor directives""" directive_token = self.eat("PREPROCESSOR") directive_text = directive_token[1].strip() if directive_text.startswith("#include"): content = directive_text[8:].strip() return PreprocessorNode("include", content) elif directive_text.startswith("#define"): content = directive_text[7:].strip() return PreprocessorNode("define", content) else: return PreprocessorNode("other", directive_text)
def is_type_alias_start(self): return self.current_token[0] == "TYPEDEF" or self.is_identifier_value("using") def parse_type_alias(self): if self.current_token[0] == "TYPEDEF": return self.parse_typedef_alias() return self.parse_using_alias() def parse_typedef_alias(self): self.eat("TYPEDEF") first_type = self.parse_type() base_type = self.strip_declarator_markers(first_type) aliases = [self.parse_type_alias_declarator(first_type, allow_prefix=False)] while self.current_token[0] == "COMMA": self.eat("COMMA") aliases.append( self.parse_type_alias_declarator(base_type, allow_prefix=True) ) self.eat("SEMICOLON") return aliases def parse_type_alias_declarator(self, base_type, allow_prefix): alias_type = base_type if allow_prefix: alias_type = self.parse_declarator_prefix(alias_type) name = self.eat("IDENTIFIER")[1] alias_type += self.parse_array_suffix() self.type_aliases.add(name) return TypeAliasNode(alias_type, name) def parse_using_alias(self): self.eat("IDENTIFIER") if self.current_token[0] == "NAMESPACE": self.skip_until_semicolon() return None name = self.eat("IDENTIFIER")[1] self.eat("ASSIGN") alias_type = self.parse_type() self.eat("SEMICOLON") self.type_aliases.add(name) return TypeAliasNode(alias_type, name) def skip_until_semicolon(self): while self.current_token[0] not in {"SEMICOLON", "EOF"}: self.eat(self.current_token[0]) if self.current_token[0] == "SEMICOLON": self.eat("SEMICOLON")
[docs] def parse_struct(self): """Parse struct declaration""" self.eat("STRUCT") name = self.eat("IDENTIFIER")[1] self.eat("LBRACE") members = [] while self.current_token[0] != "RBRACE": member = self.parse_variable_declaration() members.append(member) self.eat("SEMICOLON") self.eat("RBRACE") self.eat("SEMICOLON") return StructNode(name, members)
[docs] def parse_function(self): """Parse function declaration including kernels""" qualifiers = [] while self.current_token[0] in self.FUNCTION_SPECIFIER_TOKENS: qualifiers.append(self.current_token[1]) self.eat(self.current_token[0]) return_type = self.parse_type() name = self.eat("IDENTIFIER")[1] params = self.parse_parameters() body = self.parse_block() if "__global__" in qualifiers: return KernelNode(return_type, name, params, body) else: return FunctionNode(return_type, name, params, body, qualifiers)
[docs] def parse_parameters(self): """Parse function parameters""" self.eat("LPAREN") params = [] if self.current_token[0] != "RPAREN": params.append(self.parse_parameter()) while self.current_token[0] == "COMMA": self.eat("COMMA") params.append(self.parse_parameter()) self.eat("RPAREN") return params
[docs] def parse_parameter(self): """Parse a single parameter""" param_type = self.parse_type() param_name = self.eat("IDENTIFIER")[1] param_type += self.parse_array_suffix() return VariableNode(param_type, param_name)
[docs] def parse_type(self): """Parse type specification""" type_parts = [] while self.current_token[0] in self.TYPE_QUALIFIER_TOKENS: type_parts.append(self.current_token[1]) self.eat(self.current_token[0]) if self.current_token[0] in self.TYPE_TOKENS: type_parts.append(self.parse_type_name()) self.parse_postfix_type_qualifiers(type_parts) while self.current_token[0] == "MULTIPLY": type_parts.append("*") self.eat("MULTIPLY") self.parse_postfix_type_qualifiers(type_parts) while self.current_token[0] in self.TYPE_REFERENCE_TOKENS: type_parts.append(self.current_token[1]) self.eat(self.current_token[0]) self.parse_postfix_type_qualifiers(type_parts) while self.current_token[0] == "LBRACKET": self.eat("LBRACKET") if self.current_token[0] == "NUMBER": type_parts.append(f"[{self.current_token[1]}]") self.eat("NUMBER") else: type_parts.append("[]") self.eat("RBRACKET") return " ".join(type_parts)
def parse_type_without_array_suffix(self): type_parts = [] while self.current_token[0] in self.TYPE_QUALIFIER_TOKENS: type_parts.append(self.current_token[1]) self.eat(self.current_token[0]) if self.current_token[0] in self.TYPE_TOKENS: type_parts.append(self.parse_type_name()) self.parse_postfix_type_qualifiers(type_parts) while self.current_token[0] == "MULTIPLY": type_parts.append("*") self.eat("MULTIPLY") self.parse_postfix_type_qualifiers(type_parts) while self.current_token[0] in self.TYPE_REFERENCE_TOKENS: type_parts.append(self.current_token[1]) self.eat(self.current_token[0]) self.parse_postfix_type_qualifiers(type_parts) return " ".join(type_parts) def parse_postfix_type_qualifiers(self, type_parts): while self.current_token[0] in self.POSTFIX_TYPE_QUALIFIER_TOKENS: type_parts.append(self.current_token[1]) self.eat(self.current_token[0]) def parse_type_name(self): type_name = self.current_token[1] self.eat(self.current_token[0]) while self.current_token[0] == "SCOPE": self.eat("SCOPE") member = self.eat("IDENTIFIER")[1] type_name += f"::{member}" if self.current_token[0] == "LESS_THAN": type_name += self.parse_template_suffix() return type_name
[docs] def parse_global_variable(self): """Parse global variable declaration""" qualifiers = [] while self.current_token[0] in ["CONSTANT", "SHARED", "DEVICE", "MANAGED"]: qualifiers.append(self.current_token[1]) self.eat(self.current_token[0]) var = self.parse_variable_declaration() var.qualifiers = qualifiers self.eat("SEMICOLON") if "__constant__" in qualifiers: return ConstantMemoryNode(var.vtype, var.name, var.value) elif "__shared__" in qualifiers: return SharedMemoryNode(var.vtype, var.name) else: return var
[docs] def parse_variable_declaration(self): """Parse variable declaration""" qualifiers = [] while self.current_token[0] in [ "SHARED", "CONSTANT", "STATIC", "EXTERN", "DEVICE", "MANAGED", ]: qualifiers.append(self.current_token[1]) self.eat(self.current_token[0]) vtype = self.parse_type() name = self.eat("IDENTIFIER")[1] vtype += self.parse_array_suffix() value = None if self.current_token[0] == "ASSIGN": self.eat("ASSIGN") value = self.parse_expression() elif self.current_token[0] == "LPAREN": args = self.parse_argument_list() value = FunctionCallNode(vtype, args) elif self.current_token[0] == "LBRACE": value = self.parse_initializer_list() var = VariableNode(vtype, name, value, qualifiers) if "__shared__" in qualifiers: return SharedMemoryNode(vtype, name) elif "__constant__" in qualifiers: return ConstantMemoryNode(vtype, name, value) else: return var
[docs] def parse_variable_declaration_list(self): """Parse one or more comma-separated variable declarations.""" qualifiers = [] while self.current_token[0] in [ "SHARED", "CONSTANT", "STATIC", "EXTERN", "DEVICE", "MANAGED", ]: qualifiers.append(self.current_token[1]) self.eat(self.current_token[0]) first_type = self.parse_type() base_type = self.strip_declarator_markers(first_type) declarations = [ self.parse_variable_declarator(first_type, qualifiers, allow_prefix=False) ] while self.current_token[0] == "COMMA": self.eat("COMMA") declarations.append( self.parse_variable_declarator(base_type, qualifiers, allow_prefix=True) ) return declarations
def parse_variable_declarator(self, base_type, qualifiers, allow_prefix): vtype = base_type if allow_prefix: vtype = self.parse_declarator_prefix(vtype) name = self.eat("IDENTIFIER")[1] vtype += self.parse_array_suffix() value = self.parse_variable_initializer(vtype) if "__shared__" in qualifiers: return SharedMemoryNode(vtype, name) if "__constant__" in qualifiers: return ConstantMemoryNode(vtype, name, value) return VariableNode(vtype, name, value, list(qualifiers)) def parse_declarator_prefix(self, base_type): parts = [base_type] if base_type else [] while self.current_token[0] in { "MULTIPLY", *self.TYPE_REFERENCE_TOKENS, *self.POSTFIX_TYPE_QUALIFIER_TOKENS, }: if self.current_token[0] in self.POSTFIX_TYPE_QUALIFIER_TOKENS: parts.append(self.current_token[1]) self.eat(self.current_token[0]) continue parts.append(self.current_token[1]) self.eat(self.current_token[0]) self.parse_postfix_type_qualifiers(parts) return " ".join(parts) def parse_variable_initializer(self, vtype): if self.current_token[0] == "ASSIGN": self.eat("ASSIGN") return self.parse_expression() if self.current_token[0] == "LPAREN": args = self.parse_argument_list() return FunctionCallNode(vtype, args) if self.current_token[0] == "LBRACE": return self.parse_initializer_list() return None def strip_declarator_markers(self, vtype): parts = str(vtype).split() while parts and parts[-1] in {"*", "&", "&&", "__restrict__", "restrict"}: parts.pop() return " ".join(parts)
[docs] def parse_array_suffix(self): """Parse one or more C-style array declarator suffixes.""" suffixes = [] while self.current_token[0] == "LBRACKET": self.eat("LBRACKET") if self.current_token[0] != "RBRACKET": size = self.parse_expression() suffixes.append(f"[{self.expression_to_text(size)}]") else: suffixes.append("[]") self.eat("RBRACKET") return "".join(suffixes)
def parse_argument_list(self): self.eat("LPAREN") args = [] if self.current_token[0] != "RPAREN": args.append(self.parse_expression()) while self.current_token[0] == "COMMA": self.eat("COMMA") args.append(self.parse_expression()) self.eat("RPAREN") return args
[docs] def parse_block(self): """Parse a block of statements""" self.eat("LBRACE") statements = [] while self.current_token[0] != "RBRACE" and self.current_token[0] != "EOF": stmt = self.parse_statement() if stmt: if isinstance(stmt, list): statements.extend(stmt) else: statements.append(stmt) self.eat("RBRACE") return statements
[docs] def parse_statement(self): """Parse a single statement""" if self.current_token[0] == "IF": return self.parse_if_statement() elif self.current_token[0] == "FOR": return self.parse_for_statement() elif self.current_token[0] == "WHILE": return self.parse_while_statement() elif self.current_token[0] == "DO": return self.parse_do_while_statement() elif self.current_token[0] == "SWITCH": return self.parse_switch_statement() elif self.current_token[0] == "RETURN": return self.parse_return_statement() elif self.current_token[0] == "BREAK": self.eat("BREAK") self.eat("SEMICOLON") return BreakNode() elif self.current_token[0] == "CONTINUE": self.eat("CONTINUE") self.eat("SEMICOLON") return ContinueNode() elif self.current_token[0] in ["SYNCTHREADS", "SYNCWARP"]: return self.parse_sync_statement() elif self.current_token[0] == "LBRACE": return self.parse_block() elif self.is_type_alias_start(): return self.parse_type_alias() elif self.is_identifier_value("delete"): stmt = self.parse_delete_statement() self.eat("SEMICOLON") return stmt elif self.is_variable_declaration(): declarations = self.parse_variable_declaration_list() self.eat("SEMICOLON") return declarations if len(declarations) > 1 else declarations[0] else: # Expression statement or assignment expr = self.parse_assignment_expression() self.eat("SEMICOLON") return expr
def is_identifier_value(self, value): return self.current_token[0] == "IDENTIFIER" and self.current_token[1] == value def parse_delete_statement(self): self.eat("IDENTIFIER") is_array = False if self.current_token[0] == "LBRACKET": self.eat("LBRACKET") self.eat("RBRACKET") is_array = True expression = self.parse_unary_expression() return DeleteNode(expression, is_array)
[docs] def is_variable_declaration(self): """Check if current position is a variable declaration""" saved_index = self.current_index while ( saved_index < len(self.tokens) and self.tokens[saved_index][0] in self.DECLARATION_QUALIFIER_TOKENS ): saved_index += 1 saved_index = self.skip_type_at_index(saved_index) if saved_index is not None: if ( saved_index < len(self.tokens) and self.tokens[saved_index][0] == "IDENTIFIER" ): saved_index += 1 if saved_index >= len(self.tokens): return True return self.tokens[saved_index][0] in [ "SEMICOLON", "ASSIGN", "LBRACKET", "LPAREN", "LBRACE", "COMMA", ] return False
[docs] def parse_if_statement(self): """Parse if statement""" self.eat("IF") self.eat("LPAREN") condition = self.parse_expression() self.eat("RPAREN") if_body = self.parse_statement() else_body = None if self.current_token[0] == "ELSE": self.eat("ELSE") else_body = self.parse_statement() return IfNode(condition, if_body, else_body)
[docs] def parse_for_statement(self): """Parse for loop""" self.eat("FOR") self.eat("LPAREN") if self.is_range_for_statement(): return self.parse_range_for_statement() init = None if self.current_token[0] != "SEMICOLON": if self.is_variable_declaration(): init_declarations = self.parse_variable_declaration_list() init = ( init_declarations if len(init_declarations) > 1 else init_declarations[0] ) else: init = self.parse_expression() self.eat("SEMICOLON") condition = None if self.current_token[0] != "SEMICOLON": condition = self.parse_expression() self.eat("SEMICOLON") update = None if self.current_token[0] != "RPAREN": update = self.parse_expression() self.eat("RPAREN") body = self.parse_statement() return ForNode(init, condition, update, body)
[docs] def is_range_for_statement(self): """Check if the current parenthesized for header is a range-for loop.""" index = self.skip_range_for_type_at_index(self.current_index) if index is None: return False if index >= len(self.tokens) or self.tokens[index][0] != "IDENTIFIER": return False index += 1 index = self.skip_array_suffix_at_index(index) return index < len(self.tokens) and self.tokens[index][0] == "COLON"
def skip_range_for_type_at_index(self, index): while ( index < len(self.tokens) and self.tokens[index][0] in self.TYPE_QUALIFIER_TOKENS ): index += 1 if index >= len(self.tokens) or self.tokens[index][0] not in self.TYPE_TOKENS: return None index += 1 while ( index + 1 < len(self.tokens) and self.tokens[index][0] == "SCOPE" and self.tokens[index + 1][0] == "IDENTIFIER" ): index += 2 if index < len(self.tokens) and self.tokens[index][0] == "LESS_THAN": index = self.skip_template_at_index(index) if index is None: return None index = self.skip_postfix_type_qualifiers_at_index(index) while index < len(self.tokens) and self.tokens[index][0] in { "MULTIPLY", *self.TYPE_REFERENCE_TOKENS, }: index += 1 index = self.skip_postfix_type_qualifiers_at_index(index) return self.skip_array_suffix_at_index(index)
[docs] def parse_range_for_statement(self): """Parse a C++ range-based for loop after 'for (' has been consumed.""" vtype = self.parse_type() name = self.eat("IDENTIFIER")[1] self.eat("COLON") iterable = self.parse_expression() self.eat("RPAREN") body = self.parse_statement() return RangeForNode(vtype, name, iterable, body)
[docs] def parse_while_statement(self): """Parse while loop""" self.eat("WHILE") self.eat("LPAREN") condition = self.parse_expression() self.eat("RPAREN") body = self.parse_statement() return WhileNode(condition, body)
[docs] def parse_do_while_statement(self): """Parse do-while loop""" self.eat("DO") body = self.parse_statement() self.eat("WHILE") self.eat("LPAREN") condition = self.parse_expression() self.eat("RPAREN") self.eat("SEMICOLON") return DoWhileNode(body, condition)
[docs] def parse_switch_statement(self): """Parse switch statement""" self.eat("SWITCH") self.eat("LPAREN") expression = self.parse_expression() self.eat("RPAREN") self.eat("LBRACE") cases = [] default_case = None while self.current_token[0] != "RBRACE": if self.current_token[0] == "CASE": self.eat("CASE") value = self.parse_expression() self.eat("COLON") body = [] while self.current_token[0] not in ["CASE", "DEFAULT", "RBRACE"]: body.append(self.parse_statement()) cases.append(CaseNode(value, body)) elif self.current_token[0] == "DEFAULT": self.eat("DEFAULT") self.eat("COLON") default_case = [] while self.current_token[0] not in ["CASE", "RBRACE"]: default_case.append(self.parse_statement()) else: self.eat(self.current_token[0]) # Skip unexpected tokens self.eat("RBRACE") return SwitchNode(expression, cases, default_case)
[docs] def parse_return_statement(self): """Parse return statement""" self.eat("RETURN") value = None if self.current_token[0] != "SEMICOLON": value = self.parse_expression() self.eat("SEMICOLON") return ReturnNode(value)
[docs] def parse_sync_statement(self): """Parse CUDA synchronization statements""" sync_type = self.current_token[1] self.eat(self.current_token[0]) self.eat("LPAREN") args = [] if self.current_token[0] != "RPAREN": args.append(self.parse_expression()) while self.current_token[0] == "COMMA": self.eat("COMMA") args.append(self.parse_expression()) self.eat("RPAREN") self.eat("SEMICOLON") return SyncNode(sync_type, args)
[docs] def parse_expression(self): """Parse expression with precedence""" return self.parse_ternary_expression()
[docs] def parse_ternary_expression(self): """Parse ternary conditional operator""" expr = self.parse_logical_or_expression() if self.current_token[0] == "QUESTION": self.eat("QUESTION") true_expr = self.parse_expression() self.eat("COLON") false_expr = self.parse_expression() return TernaryOpNode(expr, true_expr, false_expr) return expr
[docs] def parse_logical_or_expression(self): """Parse logical OR expression""" left = self.parse_logical_and_expression() while self.current_token[0] == "LOGICAL_OR": op = self.current_token[1] self.eat("LOGICAL_OR") right = self.parse_logical_and_expression() left = BinaryOpNode(left, op, right) return left
[docs] def parse_logical_and_expression(self): """Parse logical AND expression""" left = self.parse_bitwise_or_expression() while self.current_token[0] == "LOGICAL_AND": op = self.current_token[1] self.eat("LOGICAL_AND") right = self.parse_bitwise_or_expression() left = BinaryOpNode(left, op, right) return left
[docs] def parse_bitwise_or_expression(self): """Parse bitwise OR expression""" left = self.parse_bitwise_xor_expression() while self.current_token[0] == "BITWISE_OR": op = self.current_token[1] self.eat("BITWISE_OR") right = self.parse_bitwise_xor_expression() left = BinaryOpNode(left, op, right) return left
[docs] def parse_bitwise_xor_expression(self): """Parse bitwise XOR expression""" left = self.parse_bitwise_and_expression() while self.current_token[0] == "BITWISE_XOR": op = self.current_token[1] self.eat("BITWISE_XOR") right = self.parse_bitwise_and_expression() left = BinaryOpNode(left, op, right) return left
[docs] def parse_bitwise_and_expression(self): """Parse bitwise AND expression""" left = self.parse_equality_expression() while self.current_token[0] == "BITWISE_AND": op = self.current_token[1] self.eat("BITWISE_AND") right = self.parse_equality_expression() left = BinaryOpNode(left, op, right) return left
[docs] def parse_equality_expression(self): """Parse equality expression""" left = self.parse_relational_expression() while self.current_token[0] in ["EQUAL", "NOT_EQUAL"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_relational_expression() left = BinaryOpNode(left, op, right) return left
[docs] def parse_relational_expression(self): """Parse relational expression""" left = self.parse_shift_expression() while self.current_token[0] in [ "LESS_THAN", "GREATER_THAN", "LESS_EQUAL", "GREATER_EQUAL", ]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_shift_expression() left = BinaryOpNode(left, op, right) return left
[docs] def parse_shift_expression(self): """Parse bit shift expression""" left = self.parse_additive_expression() while self.current_token[0] in ["SHIFT_LEFT", "SHIFT_RIGHT"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_additive_expression() left = BinaryOpNode(left, op, right) return left
[docs] def parse_additive_expression(self): """Parse additive expression""" left = self.parse_multiplicative_expression() while self.current_token[0] in ["PLUS", "MINUS"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_multiplicative_expression() left = BinaryOpNode(left, op, right) return left
[docs] def parse_multiplicative_expression(self): """Parse multiplicative expression""" left = self.parse_unary_expression() while self.current_token[0] in ["MULTIPLY", "DIVIDE", "MODULO"]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_unary_expression() left = BinaryOpNode(left, op, right) return left
[docs] def parse_unary_expression(self): """Parse unary expression""" if self.current_token[0] in [ "PLUS", "MINUS", "LOGICAL_NOT", "BITWISE_NOT", "MULTIPLY", "BITWISE_AND", ]: op = self.current_token[1] self.eat(self.current_token[0]) operand = self.parse_unary_expression() return UnaryOpNode(op, operand) elif self.current_token[0] in ["INCREMENT", "DECREMENT"]: op = self.current_token[1] self.eat(self.current_token[0]) operand = self.parse_postfix_expression() return UnaryOpNode(op, operand) else: return self.parse_postfix_expression()
[docs] def parse_postfix_expression(self): """Parse postfix expression""" left = self.parse_primary_expression() while True: if self.current_token[0] == "DOT": self.eat("DOT") member = self.eat("IDENTIFIER")[1] left = MemberAccessNode(left, member, False) elif self.current_token[0] == "SCOPE": self.eat("SCOPE") member = self.eat("IDENTIFIER")[1] left = self.append_qualified_name(left, "::", member) elif self.current_token[0] == "LESS_THAN" and self.is_template_suffix(): left = self.append_template_suffix(left) elif self.current_token[0] == "ARROW": self.eat("ARROW") member = self.eat("IDENTIFIER")[1] left = MemberAccessNode(left, member, True) elif self.current_token[0] == "LBRACKET": self.eat("LBRACKET") index = self.parse_expression() self.eat("RBRACKET") left = ArrayAccessNode(left, index) elif self.current_token[0] == "LPAREN": self.eat("LPAREN") args = [] if left == "sizeof" and self.is_sizeof_type_operand(): args.append(self.parse_type()) elif self.current_token[0] != "RPAREN": args.append(self.parse_expression()) while self.current_token[0] == "COMMA": self.eat("COMMA") args.append(self.parse_expression()) self.eat("RPAREN") # Check for atomic operations if isinstance(left, str) and left.startswith("atomic"): left = AtomicOperationNode(left, args) else: left = self.parse_function_call_node(left, args) elif self.current_token[0] == "KERNEL_LAUNCH_START": # Kernel launch: kernel<<<blocks, threads>>>(args) return self.parse_kernel_launch(left) elif self.current_token[0] in ["INCREMENT", "DECREMENT"]: op = self.current_token[1] self.eat(self.current_token[0]) left = UnaryOpNode(f"post{op}", left) else: break return left
def append_qualified_name(self, base, separator, member): if isinstance(base, str): return f"{base}{separator}{member}" return f"{self.expression_to_text(base)}{separator}{member}" def is_template_suffix(self): index = self.current_index if self.tokens[index][0] != "LESS_THAN": return False depth = 0 while index < len(self.tokens): token_type = self.tokens[index][0] if token_type == "LESS_THAN": depth += 1 elif token_type == "GREATER_THAN": depth -= 1 if depth == 0: next_type = ( self.tokens[index + 1][0] if index + 1 < len(self.tokens) else "EOF" ) return next_type in {"LPAREN", "SCOPE", "DOT"} elif token_type == "SHIFT_RIGHT": depth -= 2 if depth == 0: next_type = ( self.tokens[index + 1][0] if index + 1 < len(self.tokens) else "EOF" ) return next_type in {"LPAREN", "SCOPE", "DOT"} if depth < 0: return False elif token_type in {"SEMICOLON", "ASSIGN", "EOF"}: return False index += 1 return False def append_template_suffix(self, base): suffix = self.parse_template_suffix() if isinstance(base, str): return f"{base}{suffix}" return f"{self.expression_to_text(base)}{suffix}" def parse_template_suffix(self): self.eat("LESS_THAN") parts = [] depth = 1 while depth > 0: token_type, token_value = self.current_token if token_type == "LESS_THAN": depth += 1 parts.append(token_value) self.eat("LESS_THAN") elif token_type == "GREATER_THAN": depth -= 1 if depth == 0: self.eat("GREATER_THAN") break parts.append(token_value) self.eat("GREATER_THAN") elif token_type == "SHIFT_RIGHT": self.eat("SHIFT_RIGHT") for _ in range(2): depth -= 1 if depth == 0: break parts.append(">") else: parts.append(token_value) self.eat(token_type) return f"<{self.format_template_parts(parts)}>" def format_template_parts(self, parts): formatted = [] previous = None for part in parts: if part == ",": formatted.append(", ") else: if self.needs_template_part_space(previous, part): formatted.append(" ") formatted.append(part) previous = part return "".join(formatted).strip() def needs_template_part_space(self, previous, current): if previous is None: return False if previous in {"<", "[", "(", "::", ","}: return False if current in {">", "]", ")", ",", "::", "<", "[", "(", "*", "&", "&&"}: return False if previous in {"*", "&", "&&"}: return False return self.is_template_word(previous) and self.is_template_word(current) def is_template_word(self, part): return part.replace("_", "").isalnum() def parse_function_call_node(self, function_name, args): named_cast = self.parse_cpp_named_cast_call(function_name, args) if named_cast is not None: return named_cast if function_name == "cudaLaunchKernel" and len(args) == 6: return KernelLaunchNode( self.unwrap_cuda_kernel_function_arg(args[0]), args[1], args[2], args[4], args[5], [args[3]], ) return FunctionCallNode(function_name, args) def parse_cpp_named_cast_call(self, function_name, args): if not isinstance(function_name, str) or len(args) != 1: return None for cast_name in self.CPP_NAMED_CASTS: prefix = f"{cast_name}<" if function_name.startswith(prefix) and function_name.endswith(">"): return CastNode(function_name[len(prefix) : -1], args[0]) return None def unwrap_cuda_kernel_function_arg(self, function_arg): if isinstance(function_arg, CastNode): return function_arg.expression return function_arg def is_sizeof_type_operand(self): saved_index = self.current_index try: while self.current_token[0] in self.TYPE_QUALIFIER_TOKENS: self.eat(self.current_token[0]) if ( self.current_token[0] not in self.TYPE_TOKENS or self.current_token[0] == "IDENTIFIER" ): return False self.eat(self.current_token[0]) while self.current_token[0] == "MULTIPLY": self.eat("MULTIPLY") while self.current_token[0] == "LBRACKET": self.eat("LBRACKET") while self.current_token[0] not in ["RBRACKET", "EOF"]: self.eat(self.current_token[0]) self.eat("RBRACKET") return self.current_token[0] == "RPAREN" finally: self.current_index = saved_index self.current_token = self.tokens[self.current_index]
[docs] def parse_kernel_launch(self, kernel_name): """Parse CUDA kernel launch syntax""" self.eat("KERNEL_LAUNCH_START") blocks = self.parse_expression() self.eat("COMMA") threads = self.parse_expression() shared_mem = None stream = None if self.current_token[0] == "COMMA": self.eat("COMMA") shared_mem = self.parse_expression() if self.current_token[0] == "COMMA": self.eat("COMMA") stream = self.parse_expression() self.eat("KERNEL_LAUNCH_END") self.eat("LPAREN") args = [] if self.current_token[0] != "RPAREN": args.append(self.parse_expression()) while self.current_token[0] == "COMMA": self.eat("COMMA") args.append(self.parse_expression()) self.eat("RPAREN") return KernelLaunchNode(kernel_name, blocks, threads, shared_mem, stream, args)
[docs] def parse_primary_expression(self): """Parse primary expression""" if self.current_token[0] == "NUMBER": value = self.current_token[1] self.eat("NUMBER") return value elif self.current_token[0] == "STRING": value = self.current_token[1] self.eat("STRING") return value elif self.current_token[0] == "CHAR_LIT": value = self.current_token[1] self.eat("CHAR_LIT") return value elif self.current_token[0] in ["TRUE", "FALSE"]: value = self.current_token[1] self.eat(self.current_token[0]) return value elif self.current_token[0] in ["NULL", "NULLPTR"]: value = self.current_token[1] self.eat(self.current_token[0]) return value elif ( self.current_token[0] in self.TYPE_TOKENS and self.current_token[0] != "IDENTIFIER" ): value = self.current_token[1] self.eat(self.current_token[0]) return value elif self.current_token[0] == "IDENTIFIER": if self.current_token[1] == "new": return self.parse_new_expression() name = self.current_token[1] self.eat("IDENTIFIER") return name elif self.current_token[0] in [ "ATOMICADD", "ATOMICSUB", "ATOMICMAX", "ATOMICMIN", "ATOMICEXCH", "ATOMICCAS", ]: name = self.current_token[1] self.eat(self.current_token[0]) return name elif self.current_token[0] in [ "THREADIDX", "BLOCKIDX", "GRIDDIM", "BLOCKDIM", "WARPSIZE", ]: builtin_name = self.current_token[1] self.eat(self.current_token[0]) # Check for component access (.x, .y, .z) if self.current_token[0] == "DOT": self.eat("DOT") component = self.eat("IDENTIFIER")[1] return CudaBuiltinNode(builtin_name, component) else: return CudaBuiltinNode(builtin_name) elif self.current_token[0] == "LPAREN": if self.is_cast_expression(): self.eat("LPAREN") target_type = self.parse_type() self.eat("RPAREN") expr = self.parse_unary_expression() return CastNode(target_type, expr) self.eat("LPAREN") expr = self.parse_expression() self.eat("RPAREN") return expr elif self.current_token[0] == "LBRACE": return self.parse_initializer_list() else: raise SyntaxError( f"Unexpected token in primary expression: {self.current_token}" )
def parse_new_expression(self): self.eat("IDENTIFIER") target_type = self.parse_type_without_array_suffix() if self.current_token[0] == "LBRACKET": self.eat("LBRACKET") size = None if self.current_token[0] != "RBRACKET": size = self.parse_expression() self.eat("RBRACKET") return NewNode(target_type, size=size, is_array=True) args = [] if self.current_token[0] == "LPAREN": args = self.parse_argument_list() return NewNode(target_type, args=args) def parse_initializer_list(self): self.eat("LBRACE") elements = [] while self.current_token[0] != "RBRACE": elements.append(self.parse_initializer_element()) if self.current_token[0] == "COMMA": self.eat("COMMA") if self.current_token[0] == "RBRACE": break else: break self.eat("RBRACE") return InitializerListNode(elements) def parse_initializer_element(self): if self.current_token[0] in ["LBRACKET", "DOT"]: return self.parse_designated_initializer() return self.parse_expression() def parse_designated_initializer(self): designators = [] while self.current_token[0] in ["LBRACKET", "DOT"]: if self.current_token[0] == "LBRACKET": self.eat("LBRACKET") index = self.parse_expression() self.eat("RBRACKET") designators.append(("index", index)) else: self.eat("DOT") field = self.eat("IDENTIFIER")[1] designators.append(("field", field)) self.eat("ASSIGN") value = self.parse_expression() return DesignatedInitializerNode(designators, value) def is_cast_expression(self): if self.current_token[0] != "LPAREN": return False saved_index = self.current_index try: self.eat("LPAREN") while self.current_token[0] in self.TYPE_QUALIFIER_TOKENS: self.eat(self.current_token[0]) if ( self.current_token[0] not in self.TYPE_TOKENS or self.current_token[0] == "IDENTIFIER" ): return False self.eat(self.current_token[0]) while self.current_token[0] == "MULTIPLY": self.eat("MULTIPLY") while self.current_token[0] == "LBRACKET": self.eat("LBRACKET") while self.current_token[0] not in ["RBRACKET", "EOF"]: self.eat(self.current_token[0]) self.eat("RBRACKET") return self.current_token[0] == "RPAREN" finally: self.current_index = saved_index self.current_token = self.tokens[self.current_index] def expression_to_text(self, expr): if isinstance(expr, str): return expr if isinstance(expr, CudaBuiltinNode): if expr.component: return f"{expr.builtin_name}.{expr.component}" return expr.builtin_name if isinstance(expr, BinaryOpNode): left = self.expression_to_text(expr.left) right = self.expression_to_text(expr.right) return f"({left} {expr.op} {right})" if isinstance(expr, UnaryOpNode): return f"{expr.op}{self.expression_to_text(expr.operand)}" return str(expr)
[docs] def parse_assignment_expression(self): """Parse assignment expression""" left = self.parse_expression() if self.current_token[0] in [ "ASSIGN", "PLUS_EQUALS", "MINUS_EQUALS", "MULTIPLY_EQUALS", "DIVIDE_EQUALS", "MODULO_EQUALS", "AND_EQUALS", "OR_EQUALS", "XOR_EQUALS", "SHIFT_LEFT_EQUALS", "SHIFT_RIGHT_EQUALS", ]: op = self.current_token[1] self.eat(self.current_token[0]) right = self.parse_expression() return AssignmentNode(left, right, op) return left