module; export module lsp.test.ast.main; import std; import tree_sitter; import lsp.language.ast; import lsp.test.ast.debug_printer; extern "C" const TSLanguage* tree_sitter_tsf(void); extern "C" bool ts_node_has_error(TSNode self); using namespace lsp::language::ast; // ==================== 文件读取 ==================== std::string ReadFile(const std::string& filepath) { std::ifstream file(filepath); if (!file.is_open()) { throw std::runtime_error("Cannot open file: " + filepath); } std::ostringstream oss; oss << file.rdbuf(); return oss.str(); } namespace { bool IsIdentifierStart(char ch) { unsigned char uch = static_cast(ch); return std::isalpha(uch) || ch == '_'; } bool IsIdentifierChar(char ch) { unsigned char uch = static_cast(ch); return std::isalnum(uch) || ch == '_'; } bool HasWordBoundaryBefore(const std::string& source, std::size_t pos) { return pos == 0 || !IsIdentifierChar(source[pos - 1]); } bool HasWordBoundaryAfter(const std::string& source, std::size_t pos) { return pos >= source.size() || !IsIdentifierChar(source[pos]); } bool MatchKeyword(const std::string& source, std::size_t pos, std::string_view keyword) { if (pos + keyword.size() > source.size()) return false; for (std::size_t i = 0; i < keyword.size(); ++i) { unsigned char lhs = static_cast(source[pos + i]); unsigned char rhs = static_cast(keyword[i]); if (std::tolower(lhs) != std::tolower(rhs)) return false; } return true; } std::string RewriteSingleVariableForIn(const std::string& source) { enum class ScanState { kNormal, kSingleQuote, kDoubleQuote, kLineComment, kBraceComment, kParenStarComment, }; ScanState state = ScanState::kNormal; std::string rewritten; rewritten.reserve(source.size() + 64); bool changed = false; std::size_t i = 0; while (i < source.size()) { char ch = source[i]; if (state == ScanState::kSingleQuote) { rewritten.push_back(ch); if (ch == '\\' && i + 1 < source.size()) { rewritten.push_back(source[i + 1]); i += 2; continue; } if (ch == '\'') state = ScanState::kNormal; ++i; continue; } if (state == ScanState::kDoubleQuote) { rewritten.push_back(ch); if (ch == '\\' && i + 1 < source.size()) { rewritten.push_back(source[i + 1]); i += 2; continue; } if (ch == '"') state = ScanState::kNormal; ++i; continue; } if (state == ScanState::kLineComment) { rewritten.push_back(ch); ++i; if (ch == '\n') state = ScanState::kNormal; continue; } if (state == ScanState::kBraceComment) { rewritten.push_back(ch); ++i; if (ch == '}') state = ScanState::kNormal; continue; } if (state == ScanState::kParenStarComment) { rewritten.push_back(ch); if (ch == '*' && i + 1 < source.size() && source[i + 1] == ')') { rewritten.push_back(')'); i += 2; state = ScanState::kNormal; continue; } ++i; continue; } if (ch == '\'') { state = ScanState::kSingleQuote; rewritten.push_back(ch); ++i; continue; } if (ch == '"') { state = ScanState::kDoubleQuote; rewritten.push_back(ch); ++i; continue; } if (ch == '/' && i + 1 < source.size() && source[i + 1] == '/') { state = ScanState::kLineComment; rewritten.push_back('/'); rewritten.push_back('/'); i += 2; continue; } if (ch == '{') { state = ScanState::kBraceComment; rewritten.push_back(ch); ++i; continue; } if (ch == '(' && i + 1 < source.size() && source[i + 1] == '*') { state = ScanState::kParenStarComment; rewritten.push_back('('); rewritten.push_back('*'); i += 2; continue; } if (MatchKeyword(source, i, "for") && HasWordBoundaryBefore(source, i) && HasWordBoundaryAfter(source, i + 3)) { std::size_t cursor = i + 3; if (cursor < source.size() && std::isspace(static_cast(source[cursor]))) { while (cursor < source.size() && std::isspace(static_cast(source[cursor]))) ++cursor; std::size_t id_start = cursor; if (id_start < source.size() && IsIdentifierStart(source[id_start])) { ++cursor; while (cursor < source.size() && IsIdentifierChar(source[cursor])) ++cursor; std::size_t id_end = cursor; std::size_t ws_start = cursor; while (cursor < source.size() && std::isspace(static_cast(source[cursor]))) ++cursor; if (cursor < source.size() && source[cursor] != ',' && MatchKeyword(source, cursor, "in") && HasWordBoundaryAfter(source, cursor + 2)) { rewritten.append(source, i, id_end - i); rewritten.append(", "); rewritten.append(source, id_start, id_end - id_start); rewritten.append(source, ws_start, cursor - ws_start); rewritten.append(source, cursor, 2); i = cursor + 2; changed = true; continue; } } } } rewritten.push_back(ch); ++i; } return changed ? rewritten : source; } } // ==================== Tree-Sitter 解析 ==================== class TreeSitterParser { public: TreeSitterParser() { parser_ = ts_parser_new(); if (!parser_) { throw std::runtime_error("Failed to create parser"); } // 设置语言 if (!ts_parser_set_language(parser_, tree_sitter_tsf())) { ts_parser_delete(parser_); throw std::runtime_error("Failed to set language"); } } ~TreeSitterParser() { if (tree_) { ts_tree_delete(tree_); } if (parser_) { ts_parser_delete(parser_); } } TSTree* Parse(const std::string& source) { if (tree_) { ts_tree_delete(tree_); tree_ = nullptr; } parsed_source_ = source; tree_ = ts_parser_parse_string( parser_, nullptr, parsed_source_.c_str(), parsed_source_.length()); if (tree_) { TSNode root = ts_tree_root_node(tree_); if (ts_node_has_error(root)) { std::string rewritten = RewriteSingleVariableForIn(parsed_source_); if (rewritten != parsed_source_) { TSTree* fallback = ts_parser_parse_string( parser_, nullptr, rewritten.c_str(), rewritten.length()); if (fallback) { TSNode fallback_root = ts_tree_root_node(fallback); if (!ts_node_has_error(fallback_root)) { ts_tree_delete(tree_); tree_ = fallback; parsed_source_ = std::move(rewritten); } else { ts_tree_delete(fallback); } } } } } if (!tree_) { throw std::runtime_error("Failed to parse source"); } return tree_; } TSNode GetRootNode() { if (!tree_) { throw std::runtime_error("No tree available"); } return ts_tree_root_node(tree_); } const std::string& ParsedSource() const { return parsed_source_; } private: TSParser* parser_ = nullptr; TSTree* tree_ = nullptr; std::string parsed_source_; }; namespace { ParseResult ParseSourceToAst(const std::string& source, bool incremental = false) { TreeSitterParser ts_parser; [[maybe_unused]] TSTree* tree = ts_parser.Parse(source); TSNode root = ts_parser.GetRootNode(); const std::string& parsed_source = ts_parser.ParsedSource(); Deserializer deserializer; if (incremental) { auto inc_result = deserializer.ParseIncremental(root, parsed_source); return std::move(inc_result.result); } return deserializer.Parse(root, parsed_source); } void Require(bool condition, const std::string& message) { if (!condition) { throw std::runtime_error(message); } } template T* RequireCast(Base* ptr, const std::string& message) { auto* casted = dynamic_cast(ptr); Require(casted != nullptr, message); return casted; } Statement* RequireSingleStatement(ParseResult& result, const std::string& case_name) { Require(!result.HasErrors(), case_name + ": parse contains syntax errors"); Require(result.root != nullptr, case_name + ": root is null"); Require(result.root->statements.size() == 1, case_name + ": expected exactly one top-level statement"); return result.root->statements[0].get(); } void TestFunctionParameters() { const std::string source = "function add(var a: integer; b: integer = 1): integer;\n" "begin\n" " return a + b;\n" "end\n"; auto result = ParseSourceToAst(source); auto* stmt = RequireCast( RequireSingleStatement(result, "TestFunctionParameters"), "TestFunctionParameters: top-level statement is not FunctionDefinition"); Require(stmt->parameters.size() == 2, "TestFunctionParameters: expected 2 parameters"); auto* p0 = stmt->parameters[0].get(); auto* p1 = stmt->parameters[1].get(); Require(p0->name == "a", "TestFunctionParameters: first parameter name should be a"); Require(p0->mode == ParameterMode::kVar, "TestFunctionParameters: first parameter mode should be var"); Require(p0->type.has_value() && p0->type->name == "integer", "TestFunctionParameters: first parameter type should be integer"); Require(p1->name == "b", "TestFunctionParameters: second parameter name should be b"); Require(p1->default_value.has_value() && p1->default_value->get() != nullptr, "TestFunctionParameters: second parameter default value should exist"); } void TestMultiDeclarationInitializer() { const std::string source = "var a, b: integer := 3;\n"; auto result = ParseSourceToAst(source); auto* block = RequireCast( RequireSingleStatement(result, "TestMultiDeclarationInitializer"), "TestMultiDeclarationInitializer: top-level statement is not BlockStatement"); Require(block->statements.size() == 2, "TestMultiDeclarationInitializer: expected block with two declarations"); auto* last_decl = RequireCast( block->statements[1].get(), "TestMultiDeclarationInitializer: second statement is not VarDeclaration"); Require(last_decl->initializer.has_value() && last_decl->initializer->get() != nullptr, "TestMultiDeclarationInitializer: initializer should be present on last declaration"); } void TestFunctionDeclarationOverload() { const std::string source = "function f(a: integer); overload;\n"; auto result = ParseSourceToAst(source); auto* stmt = RequireCast( RequireSingleStatement(result, "TestFunctionDeclarationOverload"), "TestFunctionDeclarationOverload: top-level statement is not FunctionDeclaration"); Require(stmt->is_overload, "TestFunctionDeclarationOverload: is_overload should be true"); } void TestConditionalDirectiveType() { const std::string source = "{$define FLAG}\n" "{$undef FLAG}\n"; auto result = ParseSourceToAst(source); Require(!result.HasErrors(), "TestConditionalDirectiveType: parse contains syntax errors"); Require(result.root != nullptr, "TestConditionalDirectiveType: root is null"); Require(result.root->statements.size() == 2, "TestConditionalDirectiveType: expected two directives"); auto* define_stmt = RequireCast( result.root->statements[0].get(), "TestConditionalDirectiveType: first statement is not ConditionalDirective"); auto* undef_stmt = RequireCast( result.root->statements[1].get(), "TestConditionalDirectiveType: second statement is not ConditionalDirective"); Require(define_stmt->type == ConditionalCompilationType::kDefine, "TestConditionalDirectiveType: define directive parsed type mismatch"); Require(undef_stmt->type == ConditionalCompilationType::kUndef, "TestConditionalDirectiveType: undef directive parsed type mismatch"); } void TestExternalMethodStatic() { const std::string source = "class function TFoo.Bar(): integer;\n" "begin\n" " return 1;\n" "end\n"; auto result = ParseSourceToAst(source); auto* stmt = RequireCast( RequireSingleStatement(result, "TestExternalMethodStatic"), "TestExternalMethodStatic: top-level statement is not ExternalMethodDefinition"); Require(stmt->is_static, "TestExternalMethodStatic: class method should be static"); } void TestBinaryOperatorLocation() { const std::string source = "function plus(): integer;\n" "begin\n" " return 1 + 2;\n" "end\n"; auto result = ParseSourceToAst(source); auto* fn = RequireCast( RequireSingleStatement(result, "TestBinaryOperatorLocation"), "TestBinaryOperatorLocation: top-level statement is not FunctionDefinition"); Require(fn->body != nullptr, "TestBinaryOperatorLocation: function body is null"); Require(fn->body->statements.size() == 1, "TestBinaryOperatorLocation: expected one statement in body"); auto* ret = RequireCast( fn->body->statements[0].get(), "TestBinaryOperatorLocation: body statement is not ReturnStatement"); Require(ret->value.has_value() && ret->value->get() != nullptr, "TestBinaryOperatorLocation: return value is missing"); auto* binary = RequireCast( ret->value->get(), "TestBinaryOperatorLocation: return value is not BinaryExpression"); Require(binary->operator_location.end_offset > binary->operator_location.start_offset, "TestBinaryOperatorLocation: operator location should have non-zero width"); } void TestForInSingleVariable() { const std::string source = "function iterate();\n" "begin\n" " for n in arr do begin\n" " echo n;\n" " end;\n" "end\n"; auto result = ParseSourceToAst(source); auto* fn = RequireCast( RequireSingleStatement(result, "TestForInSingleVariable"), "TestForInSingleVariable: top-level statement is not FunctionDefinition"); Require(fn->body != nullptr, "TestForInSingleVariable: function body is null"); Require(fn->body->statements.size() == 1, "TestForInSingleVariable: expected one statement in body"); auto* for_stmt = RequireCast( fn->body->statements[0].get(), "TestForInSingleVariable: body statement is not ForInStatement"); Require(for_stmt->key.empty(), "TestForInSingleVariable: key should be empty for single-variable form"); Require(for_stmt->value == "n", "TestForInSingleVariable: value should be n"); Require(for_stmt->collection != nullptr, "TestForInSingleVariable: collection should exist"); Require(for_stmt->body != nullptr, "TestForInSingleVariable: body should exist"); } int RunSelfTests() { struct TestCase { std::string name; std::function func; }; std::vector tests = { { "function parameters are parsed", TestFunctionParameters }, { "multi declaration keeps initializer", TestMultiDeclarationInitializer }, { "function declaration overload flag", TestFunctionDeclarationOverload }, { "conditional directive type", TestConditionalDirectiveType }, { "external class method static flag", TestExternalMethodStatic }, { "binary operator location", TestBinaryOperatorLocation }, { "for-in single variable form", TestForInSingleVariable }, }; int failed = 0; std::cout << "Running AST self tests (" << tests.size() << " cases)" << std::endl; for (const auto& test : tests) { try { test.func(); std::cout << "[PASS] " << test.name << std::endl; } catch (const std::exception& e) { failed++; std::cout << "[FAIL] " << test.name << "\n" << " " << e.what() << std::endl; } } std::cout << "AST self tests finished: " << (tests.size() - failed) << " passed, " << failed << " failed." << std::endl; return failed == 0 ? 0 : 1; } } // ==================== 主程序 ==================== void PrintUsage(const char* program_name) { std::cout << "Usage: " << program_name << " [options]\n"; std::cout << "\nOptions:\n"; std::cout << " -v, --verbose Show verbose output with source code\n"; std::cout << " -c, --compact Use compact output mode\n"; std::cout << " -s, --show-source Show source code snippets (default: off)\n"; std::cout << " -l, --hide-location Hide location information\n"; std::cout << " -n, --no-colors Disable colored output\n"; std::cout << " -t, --no-tree Disable tree characters\n"; std::cout << " -k, --show-kind Show node kind enums\n"; std::cout << " -i, --incremental Test incremental parsing\n"; std::cout << " --self-test Run AST assertion tests\n"; std::cout << " -h, --help Show this help message\n"; std::cout << "\nPreset Modes:\n"; std::cout << " --verbose Equivalent to: -s -k\n"; std::cout << " --compact Equivalent to: -c -l -t\n"; std::cout << "\nExamples:\n"; std::cout << " " << program_name << " test.tsf\n"; std::cout << " " << program_name << " test.tsf --verbose\n"; std::cout << " " << program_name << " test.tsf --compact\n"; std::cout << " " << program_name << " test.tsf -s -n # Show source without colors\n"; } export int Run(int argc, char* argv[]) { std::string filepath; bool test_incremental = false; bool run_self_test = false; // 默认打印选项 debug::PrintOptions opts = debug::PrintOptions::Default(); // 解析命令行参数 for (int i = 1; i < argc; ++i) { std::string arg = argv[i]; if (arg == "-h" || arg == "--help") { PrintUsage(argv[0]); return 0; } else if (arg == "-v" || arg == "--verbose") { opts = debug::PrintOptions::Verbose(); } else if (arg == "-c" || arg == "--compact") { opts = debug::PrintOptions::Compact(); } else if (arg == "-s" || arg == "--show-source") { opts.show_source_code = true; } else if (arg == "-l" || arg == "--hide-location") { opts.show_location = false; } else if (arg == "-n" || arg == "--no-colors") { opts.use_colors = false; } else if (arg == "-t" || arg == "--no-tree") { opts.use_tree_chars = false; } else if (arg == "-k" || arg == "--show-kind") { opts.show_node_kind = true; } else if (arg == "-i" || arg == "--incremental") { test_incremental = true; } else if (arg == "--self-test") { run_self_test = true; } else if (filepath.empty()) { filepath = arg; } else { std::cerr << "Unknown argument: " << arg << "\n"; return 1; } } if (run_self_test) return RunSelfTests(); if (filepath.empty()) { PrintUsage(argv[0]); return 1; } try { // 读取文件 std::cout << "Reading file: " << filepath << "\n"; std::string source = ReadFile(filepath); std::cout << "File size: " << source.length() << " bytes\n\n"; // 创建 Tree-Sitter 解析器 TreeSitterParser ts_parser; [[maybe_unused]] TSTree* tree = ts_parser.Parse(source); TSNode root = ts_parser.GetRootNode(); const std::string& parsed_source = ts_parser.ParsedSource(); // 创建 AST 反序列化器 Deserializer deserializer; ParseResult result; if (test_incremental) { std::cout << "Using incremental parsing...\n\n"; auto inc_result = deserializer.ParseIncremental(root, parsed_source); result = std::move(inc_result.result); std::cout << "Incremental Parse Statistics:\n"; std::cout << " Nodes parsed: " << inc_result.nodes_parsed << "\n"; std::cout << " Nodes unchanged: " << inc_result.nodes_unchanged << "\n"; std::cout << " Total nodes: " << inc_result.TotalNodes() << "\n"; std::cout << " Change rate: " << (inc_result.ChangeRate() * 100) << "%\n\n"; } else { result = deserializer.Parse(root, parsed_source); } // 打印 AST 结果(带源码) DebugPrint(result, parsed_source, opts); // 打印摘要 std::cout << "\n"; std::cout << debug::Color::BrightBlue << "========================================\n" << debug::Color::Reset; std::cout << debug::Color::BrightCyan << "Summary:\n" << debug::Color::Reset; std::cout << " File: " << filepath << "\n"; std::cout << " Size: " << source.length() << " bytes\n"; if (result.root) { std::cout << " Statements: " << result.root->statements.size() << "\n"; } std::cout << " Errors: "; if (result.HasErrors()) { std::cout << debug::Color::BrightRed << result.errors.size() << " ✗" << debug::Color::Reset << "\n"; std::cout << " Status: " << debug::Color::BrightRed << "FAILED" << debug::Color::Reset << "\n"; return 1; } else { std::cout << debug::Color::BrightGreen << "0 ✓" << debug::Color::Reset << "\n"; std::cout << " Status: " << debug::Color::BrightGreen << "SUCCESS" << debug::Color::Reset << "\n"; return 0; } } catch (const std::exception& e) { std::cerr << debug::Color::BrightRed << "Error: " << e.what() << debug::Color::Reset << "\n"; return 1; } }