From 3274af67d542b68b11c0cfb0026ca90432921f8f Mon Sep 17 00:00:00 2001 From: csh Date: Mon, 17 Nov 2025 17:45:39 +0800 Subject: [PATCH] =?UTF-8?q?:recycle:=20=E9=87=8D=E6=9E=84=E7=AC=A6?= =?UTF-8?q?=E5=8F=B7=E8=A1=A8=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lsp-server/src/language/symbol/builder.cpp | 1628 ++++++++--------- lsp-server/src/language/symbol/builder.hpp | 254 +-- lsp-server/src/language/symbol/graph/call.cpp | 52 + lsp-server/src/language/symbol/graph/call.hpp | 77 +- .../src/language/symbol/graph/inheritance.cpp | 58 + .../src/language/symbol/graph/inheritance.hpp | 91 +- .../src/language/symbol/graph/reference.cpp | 46 + .../src/language/symbol/graph/reference.hpp | 68 +- .../src/language/symbol/index/location.cpp | 68 + .../src/language/symbol/index/location.hpp | 104 +- .../src/language/symbol/index/scope.cpp | 93 + .../src/language/symbol/index/scope.hpp | 160 +- lsp-server/src/language/symbol/interface.hpp | 35 +- .../src/language/symbol/location_index.hpp | 95 - lsp-server/src/language/symbol/scope.hpp | 130 -- lsp-server/src/language/symbol/store.cpp | 59 + lsp-server/src/language/symbol/store.hpp | 89 +- lsp-server/src/language/symbol/table.cpp | 104 ++ lsp-server/src/language/symbol/table.hpp | 163 +- lsp-server/src/language/symbol/types.hpp | 403 ++-- .../src/provider/text_document/completion.cpp | 1512 ++++++++------- lsp-server/src/provider/workspace/symbol.cpp | 442 +++-- .../src/service/detail/symbol/conversion.cpp | 283 ++- .../src/service/detail/symbol/conversion.hpp | 25 +- .../src/service/detail/symbol/utils.cpp | 207 +++ .../src/service/detail/symbol/utils.hpp | 32 + lsp-server/src/service/parser.cpp | 736 ++++---- lsp-server/src/service/symbol.cpp | 547 +++--- lsp-server/src/service/symbol.hpp | 57 +- 29 files changed, 3656 insertions(+), 3962 deletions(-) create mode 100644 lsp-server/src/language/symbol/graph/call.cpp create mode 100644 lsp-server/src/language/symbol/graph/inheritance.cpp create mode 100644 lsp-server/src/language/symbol/graph/reference.cpp create mode 100644 lsp-server/src/language/symbol/index/location.cpp create mode 100644 lsp-server/src/language/symbol/index/scope.cpp delete mode 100644 lsp-server/src/language/symbol/location_index.hpp delete mode 100644 lsp-server/src/language/symbol/scope.hpp create mode 100644 lsp-server/src/language/symbol/store.cpp create mode 100644 lsp-server/src/language/symbol/table.cpp create mode 100644 lsp-server/src/service/detail/symbol/utils.cpp create mode 100644 lsp-server/src/service/detail/symbol/utils.hpp diff --git a/lsp-server/src/language/symbol/builder.cpp b/lsp-server/src/language/symbol/builder.cpp index d0096b1..55ad9b4 100644 --- a/lsp-server/src/language/symbol/builder.cpp +++ b/lsp-server/src/language/symbol/builder.cpp @@ -1,941 +1,759 @@ -#include #include "./builder.hpp" -namespace lsp::language::symbol -{ - Builder::Builder(SymbolTable& table) : table_(table), current_scope_id_(kInvalidScopeId), in_interface_section_(false) - { - } +#include - void Builder::Build(ast::ASTNode& root) - { - root.Accept(*this); - } +namespace lsp::language::symbol { +Builder::Builder(SymbolTable& table) + : table_(table), + current_scope_id_(kInvalidScopeId), + in_interface_section_(false) {} - // ===== Helper Methods ===== +void Builder::Build(ast::ASTNode& root) { root.Accept(*this); } - SymbolId Builder::CreateSymbol(const std::string& name, SymbolKind kind, const ast::Location& location, const std::optional& type_hint) - { - std::optional visibility; - if (in_interface_section_) - visibility = UnitVisibility::kInterface; - - Symbol symbol = [&]() -> Symbol { - switch (kind) - { - case SymbolKind::Class: - { - Class cls; - cls.name = name; - cls.selection_range = location; - cls.range = location; - cls.unit_visibility = visibility; - return Symbol(std::move(cls)); - } - case SymbolKind::Function: - { - Function fn; - fn.name = name; - fn.selection_range = location; - fn.range = location; - fn.declaration_range = location; - fn.return_type = type_hint; - fn.unit_visibility = visibility; - return Symbol(std::move(fn)); - } - case SymbolKind::Method: - { - Method method; - method.name = name; - method.selection_range = location; - method.range = location; - method.declaration_range = location; - method.return_type = type_hint; - return Symbol(std::move(method)); - } - case SymbolKind::Property: - { - Property property; - property.name = name; - property.selection_range = location; - property.range = location; - property.type = type_hint; - return Symbol(std::move(property)); - } - case SymbolKind::Field: - { - Field field; - field.name = name; - field.selection_range = location; - field.range = location; - field.type = type_hint; - return Symbol(std::move(field)); - } - case SymbolKind::Variable: - { - Variable var; - var.name = name; - var.selection_range = location; - var.range = location; - var.type = type_hint; - var.unit_visibility = visibility; - return Symbol(std::move(var)); - } - case SymbolKind::Constant: - { - Constant constant; - constant.name = name; - constant.selection_range = location; - constant.range = location; - constant.type = type_hint; - return Symbol(std::move(constant)); - } - default: - { - Variable fallback; - fallback.name = name; - fallback.selection_range = location; - fallback.range = location; - fallback.type = type_hint; - fallback.unit_visibility = visibility; - return Symbol(std::move(fallback)); - } - } - }(); - - SymbolId id = table_.CreateSymbol(std::move(symbol)); - - if (current_scope_id_ != kInvalidScopeId) - { - table_.AddSymbolToScope(current_scope_id_, name, id); - } - - return id; - } - - SymbolId Builder::CreateSymbol(const std::string& name, SymbolKind kind, const ast::ASTNode& node, const std::optional& type_hint) - { - return CreateSymbol(name, kind, node.span, type_hint); - } - - SymbolId Builder::CreateFunctionSymbol( - const std::string& name, - const ast::Location& location, - const std::vector>& parameters, - const std::optional& return_type) - { - std::optional visibility; - if (in_interface_section_) - visibility = UnitVisibility::kInterface; +SymbolId Builder::CreateSymbol(const std::string& name, SymbolKind kind, + const ast::Location& location, + const std::optional& type_hint) { + std::optional visibility; + if (in_interface_section_) visibility = UnitVisibility::kInterface; + Symbol symbol = [&]() -> Symbol { + switch (kind) { + case SymbolKind::Class: { + Class cls; + cls.name = name; + cls.selection_range = location; + cls.range = location; + cls.unit_visibility = visibility; + return Symbol(std::move(cls)); + } + case SymbolKind::Function: { Function fn; fn.name = name; fn.selection_range = location; fn.range = location; fn.declaration_range = location; - fn.return_type = ExtractTypeName(return_type); + fn.return_type = type_hint; fn.unit_visibility = visibility; - fn.parameters = BuildParameters(parameters); - - Symbol symbol(std::move(fn)); - SymbolId id = table_.CreateSymbol(std::move(symbol)); - - if (current_scope_id_ != kInvalidScopeId) - { - table_.AddSymbolToScope(current_scope_id_, name, id); - } - - return id; - } - - SymbolId Builder::CreateMethodSymbol( - const std::string& name, - const ast::Location& location, - const std::vector>& parameters, - const std::optional& return_type) - { + return Symbol(std::move(fn)); + } + case SymbolKind::Method: { Method method; method.name = name; method.selection_range = location; method.range = location; method.declaration_range = location; - method.return_type = ExtractTypeName(return_type); - method.parameters = BuildParameters(parameters); + method.return_type = type_hint; + return Symbol(std::move(method)); + } + case SymbolKind::Property: { + Property property; + property.name = name; + property.selection_range = location; + property.range = location; + property.type = type_hint; + return Symbol(std::move(property)); + } + case SymbolKind::Field: { + Field field; + field.name = name; + field.selection_range = location; + field.range = location; + field.type = type_hint; + return Symbol(std::move(field)); + } + case SymbolKind::Variable: { + Variable var; + var.name = name; + var.selection_range = location; + var.range = location; + var.type = type_hint; + var.unit_visibility = visibility; + return Symbol(std::move(var)); + } + case SymbolKind::Constant: { + Constant constant; + constant.name = name; + constant.selection_range = location; + constant.range = location; + constant.type = type_hint; + return Symbol(std::move(constant)); + } + default: { + Variable fallback; + fallback.name = name; + fallback.selection_range = location; + fallback.range = location; + fallback.type = type_hint; + fallback.unit_visibility = visibility; + return Symbol(std::move(fallback)); + } + } + }(); - Symbol symbol(std::move(method)); - SymbolId id = table_.CreateSymbol(std::move(symbol)); + SymbolId id = table_.CreateSymbol(std::move(symbol)); - if (current_scope_id_ != kInvalidScopeId) - { - table_.AddSymbolToScope(current_scope_id_, name, id); + if (current_scope_id_ != kInvalidScopeId) { + table_.AddSymbolToScope(current_scope_id_, name, id); + } + + return id; +} + +SymbolId Builder::CreateSymbol(const std::string& name, SymbolKind kind, + const ast::ASTNode& node, + const std::optional& type_hint) { + return CreateSymbol(name, kind, node.span, type_hint); +} + +SymbolId Builder::CreateFunctionSymbol( + const std::string& name, const ast::Location& location, + const std::vector>& parameters, + const std::optional& return_type) { + std::optional visibility; + if (in_interface_section_) visibility = UnitVisibility::kInterface; + + Function fn; + fn.name = name; + fn.selection_range = location; + fn.range = location; + fn.declaration_range = location; + fn.return_type = ExtractTypeName(return_type); + fn.unit_visibility = visibility; + fn.parameters = BuildParameters(parameters); + + Symbol symbol(std::move(fn)); + SymbolId id = table_.CreateSymbol(std::move(symbol)); + + if (current_scope_id_ != kInvalidScopeId) { + table_.AddSymbolToScope(current_scope_id_, name, id); + } + + return id; +} + +SymbolId Builder::CreateMethodSymbol( + const std::string& name, const ast::Location& location, + const std::vector>& parameters, + const std::optional& return_type) { + Method method; + method.name = name; + method.selection_range = location; + method.range = location; + method.declaration_range = location; + method.return_type = ExtractTypeName(return_type); + method.parameters = BuildParameters(parameters); + + Symbol symbol(std::move(method)); + SymbolId id = table_.CreateSymbol(std::move(symbol)); + + if (current_scope_id_ != kInvalidScopeId) { + table_.AddSymbolToScope(current_scope_id_, name, id); + } + + return id; +} + +ScopeId Builder::EnterScopeWithSymbol(ScopeKind kind, SymbolId symbol_id, + const ast::Location& range) { + current_scope_id_ = + table_.CreateScope(kind, range, current_scope_id_, symbol_id); + return current_scope_id_; +} + +ScopeId Builder::EnterScope(ScopeKind kind, const ast::Location& range) { + current_scope_id_ = + table_.CreateScope(kind, range, current_scope_id_, std::nullopt); + return current_scope_id_; +} + +void Builder::ExitScope() { + const auto* scope_info = table_.scopes().scope(current_scope_id_); + if (scope_info && scope_info->parent) { + current_scope_id_ = *scope_info->parent; + } +} + +void Builder::VisitStatements( + const std::vector& statements) { + for (const auto& stmt : statements) { + if (stmt) stmt->Accept(*this); + } +} + +void Builder::VisitExpression(ast::Expression& expr) { expr.Accept(*this); } + +std::optional Builder::ExtractTypeName( + const std::optional& type) const { + if (type) return type->name; + return std::nullopt; +} + +std::vector Builder::BuildParameters( + const std::vector>& parameters) const { + std::vector result; + result.reserve(parameters.size()); + + for (const auto& param : parameters) { + if (!param) continue; + + language::symbol::Parameter p; + p.name = param->name; + if (param->type) p.type = param->type->name; + if (param->default_value) p.default_value = ""; + result.push_back(std::move(p)); + } + + return result; +} + +void Builder::VisitProgram(ast::Program& node) { + current_scope_id_ = table_.CreateScope(ScopeKind::kGlobal, node.span, + std::nullopt, std::nullopt); + VisitStatements(node.statements); +} + +void Builder::VisitUnitDefinition(ast::UnitDefinition& node) { + [[maybe_unused]] auto unit_scope = EnterScope(ScopeKind::kUnit, node.span); + + // Process interface section + in_interface_section_ = true; + VisitStatements(node.interface_statements); + + // Process implementation section + in_interface_section_ = false; + VisitStatements(node.implementation_statements); + + ExitScope(); +} + +void Builder::VisitClassDefinition(ast::ClassDefinition& node) { + auto class_id = CreateSymbol(node.name, SymbolKind::Class, node.location); + + [[maybe_unused]] auto class_scope = + EnterScopeWithSymbol(ScopeKind::kClass, class_id, node.span); + + auto prev_parent = current_parent_symbol_id_; + current_parent_symbol_id_ = class_id; + + for (auto& member : node.members) { + if (member) member->Accept(*this); + } + + current_parent_symbol_id_ = prev_parent; + ExitScope(); +} + +void Builder::VisitFunctionDeclaration(ast::FunctionDeclaration& node) { + auto location = ast::Location{}; + if (!node.name.empty() && !node.parameters.empty()) { + location = node.span; + } + CreateFunctionSymbol(node.name, location, node.parameters, node.return_type); +} + +void Builder::VisitFunctionDefinition(ast::FunctionDefinition& node) { + auto func_id = CreateFunctionSymbol(node.name, node.location, node.parameters, + node.return_type); + + if (node.body) { + [[maybe_unused]] auto func_scope = + EnterScopeWithSymbol(ScopeKind::kFunction, func_id, node.body->span); + + auto prev_function = current_function_id_; + current_function_id_ = func_id; + + for (auto& param : node.parameters) { + if (param) { + CreateSymbol(param->name, SymbolKind::Variable, param->location); + } + } + + VisitStatements(node.body->statements); + + current_function_id_ = prev_function; + ExitScope(); + } +} + +void Builder::VisitMethodDeclaration(ast::MethodDeclaration& node) { + auto method_id = CreateMethodSymbol(node.name, node.location, node.parameters, + node.return_type); + + if (node.body) { + [[maybe_unused]] auto method_scope = + EnterScopeWithSymbol(ScopeKind::kFunction, method_id, node.body->span); + + auto prev_function = current_function_id_; + current_function_id_ = method_id; + + for (auto& param : node.parameters) { + if (param) { + CreateSymbol(param->name, SymbolKind::Variable, param->location); + } + } + + VisitStatements(node.body->statements); + + current_function_id_ = prev_function; + ExitScope(); + } +} + +void Builder::VisitPropertyDeclaration(ast::PropertyDeclaration& node) { + auto type = ExtractTypeName(node.type); + CreateSymbol(node.name, SymbolKind::Property, node.location, type); +} + +void Builder::VisitExternalMethodDefinition( + ast::ExternalMethodDefinition& node) { + std::string method_name = node.name; + size_t dot_pos = method_name.find_last_of('.'); + if (dot_pos != std::string::npos) { + method_name = method_name.substr(dot_pos + 1); + } + + std::optional method_id; + + if (current_parent_symbol_id_) { + const auto* scope_info = table_.scopes().scope(current_scope_id_); + if (scope_info) { + method_id = + table_.scopes().FindSymbolInScope(current_scope_id_, method_name); + } + } + + if (!method_id) { + method_id = CreateMethodSymbol(method_name, node.location, node.parameters, + node.return_type); + } + + if (node.body) { + [[maybe_unused]] auto method_scope = + EnterScopeWithSymbol(ScopeKind::kFunction, *method_id, node.body->span); + + auto prev_function = current_function_id_; + current_function_id_ = *method_id; + + for (auto& param : node.parameters) { + if (param) { + CreateSymbol(param->name, SymbolKind::Variable, param->location); + } + } + + VisitStatements(node.body->statements); + + current_function_id_ = prev_function; + ExitScope(); + } +} + +void Builder::VisitClassMember(ast::ClassMember& node) { + std::visit( + [this](auto& member) { + if (member) member->Accept(*this); + }, + node.member); +} + +void Builder::VisitVarDeclaration(ast::VarDeclaration& node) { + CreateSymbol(node.name, SymbolKind::Variable, node.location, + ExtractTypeName(node.type)); + + if (node.initializer) { + VisitExpression(*node.initializer.value()); + } +} + +void Builder::VisitStaticDeclaration(ast::StaticDeclaration& node) { + CreateSymbol(node.name, SymbolKind::Variable, node.location, + ExtractTypeName(node.type)); + + if (node.initializer) { + VisitExpression(*node.initializer.value()); + } +} + +void Builder::VisitGlobalDeclaration(ast::GlobalDeclaration& node) { + CreateSymbol(node.name, SymbolKind::Variable, node.location, + ExtractTypeName(node.type)); + + if (node.initializer) { + VisitExpression(*node.initializer.value()); + } +} + +void Builder::VisitConstDeclaration(ast::ConstDeclaration& node) { + CreateSymbol(node.name, SymbolKind::Constant, node.location, + ExtractTypeName(node.type)); + + if (node.value) { + VisitExpression(*node.value); + } +} + +void Builder::VisitFieldDeclaration(ast::FieldDeclaration& node) { + CreateSymbol(node.name, SymbolKind::Field, node.location, + ExtractTypeName(node.type)); + + if (node.initializer) { + VisitExpression(*node.initializer.value()); + } +} + +void Builder::VisitUsesStatement(ast::UsesStatement& node) { (void)node; } + +void Builder::VisitIdentifier(ast::Identifier& node) { + auto symbol_id = + table_.scopes().FindSymbolInScopeChain(current_scope_id_, node.name); + if (symbol_id) { + table_.AddReference(*symbol_id, node.location, false, false); + } +} + +void Builder::VisitCallExpression(ast::CallExpression& node) { + if (node.callee) { + if (auto* id = dynamic_cast(node.callee.get())) { + auto symbol_id = + table_.scopes().FindSymbolInScopeChain(current_scope_id_, id->name); + if (symbol_id) { + table_.AddReference(*symbol_id, id->location, false, false); + + if (current_function_id_) { + table_.AddCall(*current_function_id_, *symbol_id, node.span); } - - return id; + } + } else { + node.callee->Accept(*this); } + } - ScopeId Builder::EnterScopeWithSymbol(ScopeKind kind, SymbolId symbol_id, const ast::Location& range) - { - current_scope_id_ = table_.CreateScope(kind, range, current_scope_id_, symbol_id); - return current_scope_id_; - } + for (auto& arg : node.arguments) { + if (arg.value) VisitExpression(*arg.value); + } +} - ScopeId Builder::EnterScope(ScopeKind kind, const ast::Location& range) - { - current_scope_id_ = table_.CreateScope(kind, range, current_scope_id_, std::nullopt); - return current_scope_id_; - } +void Builder::VisitAttributeExpression(ast::AttributeExpression& node) { + if (node.object) { + node.object->Accept(*this); + } - void Builder::ExitScope() - { - const auto* scope_info = table_.scopes().scope(current_scope_id_); - if (scope_info && scope_info->parent) - { - current_scope_id_ = *scope_info->parent; - } - } + if (node.attribute) { + node.attribute->Accept(*this); + } +} - void Builder::VisitStatements(const std::vector& statements) - { - for (const auto& stmt : statements) - { - if (stmt) - stmt->Accept(*this); - } - } +void Builder::VisitAssignmentExpression(ast::AssignmentExpression& node) { + ProcessLValue(node.left, true); - void Builder::VisitExpression(ast::Expression& expr) - { - expr.Accept(*this); - } + if (node.right) { + VisitExpression(*node.right); + } +} - std::optional Builder::ExtractTypeName(const std::optional& type) const - { - if (type) - return type->name; - return std::nullopt; - } - - std::vector Builder::BuildParameters(const std::vector>& parameters) const - { - std::vector result; - result.reserve(parameters.size()); - - for (const auto& param : parameters) - { - if (!param) - continue; - - language::symbol::Parameter p; - p.name = param->name; - if (param->type) - p.type = param->type->name; - if (param->default_value) - p.default_value = ""; - result.push_back(std::move(p)); - } - - return result; - } - - // ===== Node Visitation ===== - - void Builder::VisitProgram(ast::Program& node) - { - current_scope_id_ = table_.CreateScope(ScopeKind::kGlobal, node.span, std::nullopt, std::nullopt); - VisitStatements(node.statements); - } - - void Builder::VisitUnitDefinition(ast::UnitDefinition& node) - { - auto unit_scope = EnterScope(ScopeKind::kUnit, node.span); - - // Process interface section - in_interface_section_ = true; - VisitStatements(node.interface_statements); - - // Process implementation section - in_interface_section_ = false; - VisitStatements(node.implementation_statements); - - ExitScope(); - } - - void Builder::VisitClassDefinition(ast::ClassDefinition& node) - { - auto class_id = CreateSymbol(node.name, SymbolKind::Class, node.location); - - auto class_scope = EnterScopeWithSymbol(ScopeKind::kClass, class_id, node.span); - - auto prev_parent = current_parent_symbol_id_; - current_parent_symbol_id_ = class_id; - - for (auto& member : node.members) - { - if (member) - member->Accept(*this); - } - - current_parent_symbol_id_ = prev_parent; - ExitScope(); - } - - void Builder::VisitFunctionDeclaration(ast::FunctionDeclaration& node) - { - auto location = ast::Location{}; - if (!node.name.empty() && !node.parameters.empty()) - { - location = node.span; - } - CreateFunctionSymbol(node.name, location, node.parameters, node.return_type); - } - - void Builder::VisitFunctionDefinition(ast::FunctionDefinition& node) - { - auto func_id = CreateFunctionSymbol(node.name, node.location, node.parameters, node.return_type); - - if (node.body) - { - auto func_scope = EnterScopeWithSymbol(ScopeKind::kFunction, func_id, node.body->span); - - auto prev_function = current_function_id_; - current_function_id_ = func_id; - - for (auto& param : node.parameters) - { - if (param) - { - CreateSymbol(param->name, SymbolKind::Variable, param->location); - } +void Builder::ProcessLValue(const ast::LValue& lvalue, bool is_write) { + std::visit( + [this, is_write](auto& val) { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + if (val) { + auto symbol_id = table_.scopes().FindSymbolInScopeChain( + current_scope_id_, val->name); + if (symbol_id) { + table_.AddReference(*symbol_id, val->location, false, is_write); } - - VisitStatements(node.body->statements); - - current_function_id_ = prev_function; - ExitScope(); - } - } - - void Builder::VisitMethodDeclaration(ast::MethodDeclaration& node) - { - auto method_id = CreateMethodSymbol(node.name, node.location, node.parameters, node.return_type); - - if (node.body) - { - auto method_scope = EnterScopeWithSymbol(ScopeKind::kFunction, method_id, node.body->span); - - auto prev_function = current_function_id_; - current_function_id_ = method_id; - - for (auto& param : node.parameters) - { - if (param) - { - CreateSymbol(param->name, SymbolKind::Variable, param->location); - } + } + } else if constexpr (std::is_same_v>) { + if (val) { + if (val->object) { + val->object->Accept(*this); } - - VisitStatements(node.body->statements); - - current_function_id_ = prev_function; - ExitScope(); - } - } - - void Builder::VisitPropertyDeclaration(ast::PropertyDeclaration& node) - { - auto type = ExtractTypeName(node.type); - CreateSymbol(node.name, SymbolKind::Property, node.location, type); - } - - void Builder::VisitExternalMethodDefinition(ast::ExternalMethodDefinition& node) - { - std::string method_name = node.name; - size_t dot_pos = method_name.find_last_of('.'); - if (dot_pos != std::string::npos) - { - method_name = method_name.substr(dot_pos + 1); - } - - std::optional method_id; - - if (current_parent_symbol_id_) - { - const auto* scope_info = table_.scopes().scope(current_scope_id_); - if (scope_info) - { - method_id = table_.scopes().FindSymbolInScope(current_scope_id_, method_name); + if (val->attribute) { + val->attribute->Accept(*this); } - } - - if (!method_id) - { - method_id = CreateMethodSymbol(method_name, node.location, node.parameters, node.return_type); - } - - if (node.body) - { - auto method_scope = EnterScopeWithSymbol(ScopeKind::kFunction, *method_id, node.body->span); - - auto prev_function = current_function_id_; - current_function_id_ = *method_id; - - for (auto& param : node.parameters) - { - if (param) - { - CreateSymbol(param->name, SymbolKind::Variable, param->location); - } + } + } else if constexpr (std::is_same_v>) { + if (val) { + if (val->base) { + val->base->Accept(*this); } - - VisitStatements(node.body->statements); - - current_function_id_ = prev_function; - ExitScope(); - } - } - - void Builder::VisitClassMember(ast::ClassMember& node) - { - std::visit([this](auto& member) { - if (member) - member->Accept(*this); - }, - node.member); - } - - void Builder::VisitVarDeclaration(ast::VarDeclaration& node) - { - CreateSymbol(node.name, SymbolKind::Variable, node.location, ExtractTypeName(node.type)); - - if (node.initializer) - { - VisitExpression(*node.initializer.value()); - } - } - - void Builder::VisitStaticDeclaration(ast::StaticDeclaration& node) - { - CreateSymbol(node.name, SymbolKind::Variable, node.location, ExtractTypeName(node.type)); - - if (node.initializer) - { - VisitExpression(*node.initializer.value()); - } - } - - void Builder::VisitGlobalDeclaration(ast::GlobalDeclaration& node) - { - CreateSymbol(node.name, SymbolKind::Variable, node.location, ExtractTypeName(node.type)); - - if (node.initializer) - { - VisitExpression(*node.initializer.value()); - } - } - - void Builder::VisitConstDeclaration(ast::ConstDeclaration& node) - { - CreateSymbol(node.name, SymbolKind::Constant, node.location, ExtractTypeName(node.type)); - - if (node.value) - { - VisitExpression(*node.value); - } - } - - void Builder::VisitFieldDeclaration(ast::FieldDeclaration& node) - { - CreateSymbol(node.name, SymbolKind::Field, node.location, ExtractTypeName(node.type)); - - if (node.initializer) - { - VisitExpression(*node.initializer.value()); - } - } - - void Builder::VisitUsesStatement(ast::UsesStatement& node) - { - // Unit imports would be stored in the Unit symbol - // For now just skip since we don't have a way to store them - } - - // ===== Expression Processing (references) ===== - - void Builder::VisitIdentifier(ast::Identifier& node) - { - auto symbol_id = table_.scopes().FindSymbolInScopeChain(current_scope_id_, node.name); - if (symbol_id) - { - table_.AddReference(*symbol_id, node.location, false, false); - } - } - - void Builder::VisitCallExpression(ast::CallExpression& node) - { - if (node.callee) - { - if (auto* id = dynamic_cast(node.callee.get())) - { - auto symbol_id = table_.scopes().FindSymbolInScopeChain(current_scope_id_, id->name); - if (symbol_id) - { - table_.AddReference(*symbol_id, id->location, false, false); - - if (current_function_id_) - { - table_.AddCall(*current_function_id_, *symbol_id, node.span); - } - } - } - else - { - node.callee->Accept(*this); + for (auto& idx : val->indices) { + if (idx.start) idx.start->Accept(*this); + if (idx.end) idx.end->Accept(*this); + if (idx.step) idx.step->Accept(*this); } + } + } else if constexpr (std::is_same_v< + T, std::unique_ptr>) { + if (val && val->value) { + val->value->Accept(*this); + } } + }, + lvalue); +} - for (auto& arg : node.arguments) - { - if (arg.value) - VisitExpression(*arg.value); - } +void Builder::VisitBlockStatement(ast::BlockStatement& node) { + [[maybe_unused]] auto block_scope = EnterScope(ScopeKind::kBlock, node.span); + VisitStatements(node.statements); + ExitScope(); +} + +void Builder::VisitIfStatement(ast::IfStatement& node) { + for (auto& branch : node.branches) { + if (branch.condition) VisitExpression(*branch.condition); + + if (branch.body) branch.body->Accept(*this); + } + + if (node.else_body && *node.else_body) { + (*node.else_body)->Accept(*this); + } +} + +void Builder::VisitForInStatement(ast::ForInStatement& node) { + [[maybe_unused]] auto for_scope = EnterScope(ScopeKind::kBlock, node.span); + + if (!node.key.empty()) { + CreateSymbol(node.key, SymbolKind::Variable, node.key_location); + } + + if (!node.value.empty()) { + CreateSymbol(node.value, SymbolKind::Variable, node.value_location); + } + + if (node.collection) VisitExpression(*node.collection); + + if (node.body) node.body->Accept(*this); + + ExitScope(); +} + +void Builder::VisitForToStatement(ast::ForToStatement& node) { + [[maybe_unused]] auto for_scope = EnterScope(ScopeKind::kBlock, node.span); + + CreateSymbol(node.counter, SymbolKind::Variable, node.counter_location); + + if (node.start) VisitExpression(*node.start); + if (node.end) VisitExpression(*node.end); + if (node.step) VisitExpression(*node.step); + if (node.body) node.body->Accept(*this); + + ExitScope(); +} + +void Builder::VisitWhileStatement(ast::WhileStatement& node) { + if (node.condition) VisitExpression(*node.condition); + if (node.body) node.body->Accept(*this); +} + +void Builder::VisitRepeatStatement(ast::RepeatStatement& node) { + VisitStatements(node.body); + if (node.condition) VisitExpression(*node.condition); +} + +void Builder::VisitCaseStatement(ast::CaseStatement& node) { + if (node.discriminant) VisitExpression(*node.discriminant); + + for (auto& branch : node.branches) { + for (auto& value : branch.values) { + if (value) VisitExpression(*value); + } + if (branch.body) branch.body->Accept(*this); + } + + if (node.else_body && *node.else_body) { + (*node.else_body)->Accept(*this); + } +} + +void Builder::VisitTryStatement(ast::TryStatement& node) { + if (node.try_body) { + VisitStatements(node.try_body->statements); + } + + if (node.except_body) { + VisitStatements(node.except_body->statements); + } +} + +void Builder::VisitMatrixIterationStatement( + ast::MatrixIterationStatement& node) { + [[maybe_unused]] auto iter_scope = EnterScope(ScopeKind::kBlock, node.span); + + if (node.target) VisitExpression(*node.target); + + if (node.body) VisitStatements(node.body->statements); + + ExitScope(); +} + +void Builder::VisitExpressionStatement(ast::ExpressionStatement& node) { + if (node.expression) VisitExpression(*node.expression); +} + +void Builder::VisitReturnStatement(ast::ReturnStatement& node) { + if (node.value && *node.value) VisitExpression(**node.value); +} + +void Builder::VisitBinaryExpression(ast::BinaryExpression& node) { + if (node.left) VisitExpression(*node.left); + if (node.right) VisitExpression(*node.right); +} + +void Builder::VisitTernaryExpression(ast::TernaryExpression& node) { + if (node.condition) VisitExpression(*node.condition); + if (node.consequence) VisitExpression(*node.consequence); + if (node.alternative) VisitExpression(*node.alternative); +} + +void Builder::VisitSubscriptExpression(ast::SubscriptExpression& node) { + if (node.base) VisitExpression(*node.base); + + for (auto& idx : node.indices) { + if (idx.start) VisitExpression(*idx.start); + if (idx.end) VisitExpression(*idx.end); + if (idx.step) VisitExpression(*idx.step); + } +} + +void Builder::VisitArrayExpression(ast::ArrayExpression& node) { + for (auto& elem : node.elements) { + if (elem.key) VisitExpression(*elem.key.value()); + if (elem.value) VisitExpression(*elem.value); + } +} + +void Builder::VisitAnonymousFunctionExpression( + ast::AnonymousFunctionExpression& node) { + auto func_id = CreateFunctionSymbol("", node.span, node.parameters, + node.return_type); + + if (node.body) { + [[maybe_unused]] auto func_scope = EnterScopeWithSymbol( + ScopeKind::kAnonymousFunction, func_id, node.body->span); + + auto prev_function = current_function_id_; + current_function_id_ = func_id; + + for (auto& param : node.parameters) { + if (param) { + CreateSymbol(param->name, SymbolKind::Variable, param->location); + } } - void Builder::VisitAttributeExpression(ast::AttributeExpression& node) - { - if (node.object) - { - node.object->Accept(*this); - } - - if (node.attribute) - { - node.attribute->Accept(*this); - } - } - - void Builder::VisitAssignmentExpression(ast::AssignmentExpression& node) - { - ProcessLValue(node.left, true); - - if (node.right) - { - VisitExpression(*node.right); - } - } - - void Builder::ProcessLValue(const ast::LValue& lvalue, bool is_write) - { - std::visit([this, is_write](auto& val) { - using T = std::decay_t; - if constexpr (std::is_same_v>) - { - if (val) - { - auto symbol_id = table_.scopes().FindSymbolInScopeChain(current_scope_id_, val->name); - if (symbol_id) - { - table_.AddReference(*symbol_id, val->location, false, is_write); - } - } - } - else if constexpr (std::is_same_v>) - { - if (val) - { - if (val->object) - { - val->object->Accept(*this); - } - if (val->attribute) - { - val->attribute->Accept(*this); - } - } - } - else if constexpr (std::is_same_v>) - { - if (val) - { - if (val->base) - { - val->base->Accept(*this); - } - for (auto& idx : val->indices) - { - if (idx.start) - idx.start->Accept(*this); - if (idx.end) - idx.end->Accept(*this); - if (idx.step) - idx.step->Accept(*this); - } - } - } - else if constexpr (std::is_same_v>) - { - if (val && val->value) - { - val->value->Accept(*this); - } - } - }, - lvalue); - } - - // ===== Statement Processing ===== - - void Builder::VisitBlockStatement(ast::BlockStatement& node) - { - auto block_scope = EnterScope(ScopeKind::kBlock, node.span); - VisitStatements(node.statements); - ExitScope(); - } - - void Builder::VisitIfStatement(ast::IfStatement& node) - { - for (auto& branch : node.branches) - { - if (branch.condition) - VisitExpression(*branch.condition); - - if (branch.body) - branch.body->Accept(*this); - } - - if (node.else_body && *node.else_body) - { - (*node.else_body)->Accept(*this); - } - } - - void Builder::VisitForInStatement(ast::ForInStatement& node) - { - auto for_scope = EnterScope(ScopeKind::kBlock, node.span); - - if (!node.key.empty()) - { - CreateSymbol(node.key, SymbolKind::Variable, node.key_location); - } - - if (!node.value.empty()) - { - CreateSymbol(node.value, SymbolKind::Variable, node.value_location); - } - - if (node.collection) - VisitExpression(*node.collection); - - if (node.body) - node.body->Accept(*this); - - ExitScope(); - } - - void Builder::VisitForToStatement(ast::ForToStatement& node) - { - auto for_scope = EnterScope(ScopeKind::kBlock, node.span); - - CreateSymbol(node.counter, SymbolKind::Variable, node.counter_location); - - if (node.start) - VisitExpression(*node.start); - if (node.end) - VisitExpression(*node.end); - if (node.step) - VisitExpression(*node.step); - if (node.body) - node.body->Accept(*this); - - ExitScope(); - } - - void Builder::VisitWhileStatement(ast::WhileStatement& node) - { - if (node.condition) - VisitExpression(*node.condition); - if (node.body) - node.body->Accept(*this); - } - - void Builder::VisitRepeatStatement(ast::RepeatStatement& node) - { - VisitStatements(node.body); - if (node.condition) - VisitExpression(*node.condition); - } - - void Builder::VisitCaseStatement(ast::CaseStatement& node) - { - if (node.discriminant) - VisitExpression(*node.discriminant); - - for (auto& branch : node.branches) - { - for (auto& value : branch.values) - { - if (value) - VisitExpression(*value); - } - if (branch.body) - branch.body->Accept(*this); - } - - if (node.else_body && *node.else_body) - { - (*node.else_body)->Accept(*this); - } - } - - void Builder::VisitTryStatement(ast::TryStatement& node) - { - if (node.try_body) - { - VisitStatements(node.try_body->statements); - } - - if (node.except_body) - { - VisitStatements(node.except_body->statements); - } - } - - void Builder::VisitMatrixIterationStatement(ast::MatrixIterationStatement& node) - { - auto iter_scope = EnterScope(ScopeKind::kBlock, node.span); - - if (node.target) - VisitExpression(*node.target); - - if (node.body) - VisitStatements(node.body->statements); - - ExitScope(); - } - - void Builder::VisitExpressionStatement(ast::ExpressionStatement& node) - { - if (node.expression) - VisitExpression(*node.expression); - } - - void Builder::VisitReturnStatement(ast::ReturnStatement& node) - { - if (node.value && *node.value) - VisitExpression(**node.value); - } - - // ===== Other Expressions ===== - - void Builder::VisitBinaryExpression(ast::BinaryExpression& node) - { - if (node.left) - VisitExpression(*node.left); - if (node.right) - VisitExpression(*node.right); - } - - void Builder::VisitTernaryExpression(ast::TernaryExpression& node) - { - if (node.condition) - VisitExpression(*node.condition); - if (node.consequence) - VisitExpression(*node.consequence); - if (node.alternative) - VisitExpression(*node.alternative); - } - - void Builder::VisitSubscriptExpression(ast::SubscriptExpression& node) - { - if (node.base) - VisitExpression(*node.base); - - for (auto& idx : node.indices) - { - if (idx.start) - VisitExpression(*idx.start); - if (idx.end) - VisitExpression(*idx.end); - if (idx.step) - VisitExpression(*idx.step); - } - } - - void Builder::VisitArrayExpression(ast::ArrayExpression& node) - { - for (auto& elem : node.elements) - { - if (elem.key) - VisitExpression(*elem.key.value()); - if (elem.value) - VisitExpression(*elem.value); - } - } - - void Builder::VisitAnonymousFunctionExpression(ast::AnonymousFunctionExpression& node) - { - auto func_id = CreateFunctionSymbol("", node.span, node.parameters, node.return_type); - - if (node.body) - { - auto func_scope = EnterScopeWithSymbol(ScopeKind::kAnonymousFunction, func_id, node.body->span); - - auto prev_function = current_function_id_; - current_function_id_ = func_id; - - for (auto& param : node.parameters) - { - if (param) - { - CreateSymbol(param->name, SymbolKind::Variable, param->location); - } - } - - VisitStatements(node.body->statements); - - current_function_id_ = prev_function; - ExitScope(); - } - } - - void Builder::VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitPrefixIncrementExpression(ast::PrefixIncrementExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitPrefixDecrementExpression(ast::PrefixDecrementExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitPostfixIncrementExpression(ast::PostfixIncrementExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitPostfixDecrementExpression(ast::PostfixDecrementExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitLogicalNotExpression(ast::LogicalNotExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitDerivativeExpression(ast::DerivativeExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitMatrixTransposeExpression(ast::MatrixTransposeExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitExprOperatorExpression(ast::ExprOperatorExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitFunctionPointerExpression([[maybe_unused]] ast::FunctionPointerExpression& node) - { - if (node.argument) - VisitExpression(*node.argument); - } - - void Builder::VisitNewExpression(ast::NewExpression& node) - { - if (node.target) - VisitExpression(*node.target); - } - - void Builder::VisitEchoExpression(ast::EchoExpression& node) - { - for (auto& expr : node.expressions) - { - if (expr) - VisitExpression(*expr); - } - } - - void Builder::VisitRaiseExpression(ast::RaiseExpression& node) - { - if (node.exception) - VisitExpression(*node.exception); - } - - void Builder::VisitInheritedExpression(ast::InheritedExpression& node) - { - if (node.call && *node.call) - { - (*node.call)->Accept(*this); - } - } - - void Builder::VisitParenthesizedExpression(ast::ParenthesizedExpression& node) - { - for (auto& elem : node.elements) - { - if (elem.key) - VisitExpression(*elem.key.value()); - if (elem.value) - VisitExpression(*elem.value); - } - } - - void Builder::VisitColumnReference(ast::ColumnReference& node) - { - if (node.value) - { - VisitExpression(*node.value); - } - } - - void Builder::VisitUnpackPattern([[maybe_unused]] ast::UnpackPattern& node) - { - // Unpack pattern processing - would need to handle LValue variants - } - - void Builder::VisitCompilerDirective([[maybe_unused]] ast::CompilerDirective& node) - { - // Compiler directive processing - } - - void Builder::VisitConditionalDirective([[maybe_unused]] ast::ConditionalDirective& node) - { - // Conditional compilation directive - no statements to visit - } - - void Builder::VisitConditionalBlock(ast::ConditionalBlock& node) - { - VisitStatements(node.consequence); - VisitStatements(node.alternative); - } - -} // namespace lsp::language::symbol + VisitStatements(node.body->statements); + + current_function_id_ = prev_function; + ExitScope(); + } +} + +void Builder::VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitPrefixIncrementExpression( + ast::PrefixIncrementExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitPrefixDecrementExpression( + ast::PrefixDecrementExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitPostfixIncrementExpression( + ast::PostfixIncrementExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitPostfixDecrementExpression( + ast::PostfixDecrementExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitLogicalNotExpression(ast::LogicalNotExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitDerivativeExpression(ast::DerivativeExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitMatrixTransposeExpression( + ast::MatrixTransposeExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitExprOperatorExpression(ast::ExprOperatorExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitFunctionPointerExpression( + [[maybe_unused]] ast::FunctionPointerExpression& node) { + if (node.argument) VisitExpression(*node.argument); +} + +void Builder::VisitNewExpression(ast::NewExpression& node) { + if (node.target) VisitExpression(*node.target); +} + +void Builder::VisitEchoExpression(ast::EchoExpression& node) { + for (auto& expr : node.expressions) { + if (expr) VisitExpression(*expr); + } +} + +void Builder::VisitRaiseExpression(ast::RaiseExpression& node) { + if (node.exception) VisitExpression(*node.exception); +} + +void Builder::VisitInheritedExpression(ast::InheritedExpression& node) { + if (node.call && *node.call) { + (*node.call)->Accept(*this); + } +} + +void Builder::VisitParenthesizedExpression(ast::ParenthesizedExpression& node) { + for (auto& elem : node.elements) { + if (elem.key) VisitExpression(*elem.key.value()); + if (elem.value) VisitExpression(*elem.value); + } +} + +void Builder::VisitColumnReference(ast::ColumnReference& node) { + if (node.value) { + VisitExpression(*node.value); + } +} + +void Builder::VisitUnpackPattern([[maybe_unused]] ast::UnpackPattern& node) { + (void)node; +} + +void Builder::VisitCompilerDirective( + [[maybe_unused]] ast::CompilerDirective& node) { + (void)node; +} + +void Builder::VisitConditionalDirective( + [[maybe_unused]] ast::ConditionalDirective& node) { + (void)node; +} + +void Builder::VisitConditionalBlock(ast::ConditionalBlock& node) { + VisitStatements(node.consequence); + VisitStatements(node.alternative); +} + +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/builder.hpp b/lsp-server/src/language/symbol/builder.hpp index b3209d6..74f0753 100644 --- a/lsp-server/src/language/symbol/builder.hpp +++ b/lsp-server/src/language/symbol/builder.hpp @@ -2,155 +2,157 @@ #include "./table.hpp" -namespace lsp::language::symbol -{ - // Symbol table builder - responsible for stable traversal logic only - class Builder : public ast::ASTVisitor - { - public: - explicit Builder(SymbolTable& table); +namespace lsp::language::symbol { +class Builder : public ast::ASTVisitor { + public: + explicit Builder(SymbolTable& table); - void Build(ast::ASTNode& root); + void Build(ast::ASTNode& root); - // Visit statements and declarations - void VisitProgram(ast::Program& node) override; - void VisitUnitDefinition(ast::UnitDefinition& node) override; - void VisitClassDefinition(ast::ClassDefinition& node) override; - void VisitFunctionDefinition(ast::FunctionDefinition& node) override; - void VisitFunctionDeclaration(ast::FunctionDeclaration& node) override; - void VisitMethodDeclaration(ast::MethodDeclaration& node) override; - void VisitPropertyDeclaration(ast::PropertyDeclaration& node) override; - void VisitExternalMethodDefinition(ast::ExternalMethodDefinition& node) override; - void VisitClassMember(ast::ClassMember& node) override; + // Visit statements and declarations + void VisitProgram(ast::Program& node) override; + void VisitUnitDefinition(ast::UnitDefinition& node) override; + void VisitClassDefinition(ast::ClassDefinition& node) override; + void VisitFunctionDefinition(ast::FunctionDefinition& node) override; + void VisitFunctionDeclaration(ast::FunctionDeclaration& node) override; + void VisitMethodDeclaration(ast::MethodDeclaration& node) override; + void VisitPropertyDeclaration(ast::PropertyDeclaration& node) override; + void VisitExternalMethodDefinition( + ast::ExternalMethodDefinition& node) override; + void VisitClassMember(ast::ClassMember& node) override; - // Declarations - void VisitVarDeclaration(ast::VarDeclaration& node) override; - void VisitStaticDeclaration(ast::StaticDeclaration& node) override; - void VisitGlobalDeclaration(ast::GlobalDeclaration& node) override; - void VisitConstDeclaration(ast::ConstDeclaration& node) override; - void VisitFieldDeclaration(ast::FieldDeclaration& node) override; + // Declarations + void VisitVarDeclaration(ast::VarDeclaration& node) override; + void VisitStaticDeclaration(ast::StaticDeclaration& node) override; + void VisitGlobalDeclaration(ast::GlobalDeclaration& node) override; + void VisitConstDeclaration(ast::ConstDeclaration& node) override; + void VisitFieldDeclaration(ast::FieldDeclaration& node) override; - void VisitBlockStatement(ast::BlockStatement& node) override; - void VisitIfStatement(ast::IfStatement& node) override; - void VisitForInStatement(ast::ForInStatement& node) override; - void VisitForToStatement(ast::ForToStatement& node) override; - void VisitWhileStatement(ast::WhileStatement& node) override; - void VisitRepeatStatement(ast::RepeatStatement& node) override; - void VisitCaseStatement(ast::CaseStatement& node) override; - void VisitTryStatement(ast::TryStatement& node) override; + void VisitBlockStatement(ast::BlockStatement& node) override; + void VisitIfStatement(ast::IfStatement& node) override; + void VisitForInStatement(ast::ForInStatement& node) override; + void VisitForToStatement(ast::ForToStatement& node) override; + void VisitWhileStatement(ast::WhileStatement& node) override; + void VisitRepeatStatement(ast::RepeatStatement& node) override; + void VisitCaseStatement(ast::CaseStatement& node) override; + void VisitTryStatement(ast::TryStatement& node) override; - // Uses statement handling - void VisitUsesStatement(ast::UsesStatement& node) override; + // Uses statement handling + void VisitUsesStatement(ast::UsesStatement& node) override; - // Visit expressions (collect references) - void VisitIdentifier(ast::Identifier& node) override; - void VisitCallExpression(ast::CallExpression& node) override; - void VisitAttributeExpression(ast::AttributeExpression& node) override; - void VisitAssignmentExpression(ast::AssignmentExpression& node) override; + // Visit expressions (collect references) + void VisitIdentifier(ast::Identifier& node) override; + void VisitCallExpression(ast::CallExpression& node) override; + void VisitAttributeExpression(ast::AttributeExpression& node) override; + void VisitAssignmentExpression(ast::AssignmentExpression& node) override; - // Other node visits - void VisitLiteral([[maybe_unused]] ast::Literal& node) override {} - void VisitBinaryExpression(ast::BinaryExpression& node) override; - void VisitTernaryExpression(ast::TernaryExpression& node) override; - void VisitSubscriptExpression(ast::SubscriptExpression& node) override; - void VisitArrayExpression(ast::ArrayExpression& node) override; - void VisitAnonymousFunctionExpression(ast::AnonymousFunctionExpression& node) override; + // Other node visits + void VisitLiteral([[maybe_unused]] ast::Literal& node) override {} + void VisitBinaryExpression(ast::BinaryExpression& node) override; + void VisitTernaryExpression(ast::TernaryExpression& node) override; + void VisitSubscriptExpression(ast::SubscriptExpression& node) override; + void VisitArrayExpression(ast::ArrayExpression& node) override; + void VisitAnonymousFunctionExpression( + ast::AnonymousFunctionExpression& node) override; - // Unary expressions (refined types) - void VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) override; - void VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) override; - void VisitPrefixIncrementExpression(ast::PrefixIncrementExpression& node) override; - void VisitPrefixDecrementExpression(ast::PrefixDecrementExpression& node) override; - void VisitPostfixIncrementExpression(ast::PostfixIncrementExpression& node) override; - void VisitPostfixDecrementExpression(ast::PostfixDecrementExpression& node) override; - void VisitLogicalNotExpression(ast::LogicalNotExpression& node) override; - void VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) override; - void VisitDerivativeExpression(ast::DerivativeExpression& node) override; - void VisitMatrixTransposeExpression(ast::MatrixTransposeExpression& node) override; - void VisitExprOperatorExpression(ast::ExprOperatorExpression& node) override; + // Unary expressions (refined types) + void VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) override; + void VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) override; + void VisitPrefixIncrementExpression( + ast::PrefixIncrementExpression& node) override; + void VisitPrefixDecrementExpression( + ast::PrefixDecrementExpression& node) override; + void VisitPostfixIncrementExpression( + ast::PostfixIncrementExpression& node) override; + void VisitPostfixDecrementExpression( + ast::PostfixDecrementExpression& node) override; + void VisitLogicalNotExpression(ast::LogicalNotExpression& node) override; + void VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) override; + void VisitDerivativeExpression(ast::DerivativeExpression& node) override; + void VisitMatrixTransposeExpression( + ast::MatrixTransposeExpression& node) override; + void VisitExprOperatorExpression(ast::ExprOperatorExpression& node) override; - void VisitFunctionPointerExpression([[maybe_unused]] ast::FunctionPointerExpression& node) override; + void VisitFunctionPointerExpression( + [[maybe_unused]] ast::FunctionPointerExpression& node) override; - // Other expressions - void VisitNewExpression(ast::NewExpression& node) override; - void VisitEchoExpression(ast::EchoExpression& node) override; - void VisitRaiseExpression(ast::RaiseExpression& node) override; - void VisitInheritedExpression(ast::InheritedExpression& node) override; - void VisitParenthesizedExpression(ast::ParenthesizedExpression& node) override; + // Other expressions + void VisitNewExpression(ast::NewExpression& node) override; + void VisitEchoExpression(ast::EchoExpression& node) override; + void VisitRaiseExpression(ast::RaiseExpression& node) override; + void VisitInheritedExpression(ast::InheritedExpression& node) override; + void VisitParenthesizedExpression( + ast::ParenthesizedExpression& node) override; - void VisitExpressionStatement(ast::ExpressionStatement& node) override; - void VisitBreakStatement([[maybe_unused]] ast::BreakStatement& node) override {} - void VisitContinueStatement([[maybe_unused]] ast::ContinueStatement& node) override {} - void VisitReturnStatement(ast::ReturnStatement& node) override; - void VisitTSSQLExpression([[maybe_unused]] ast::TSSQLExpression& node) override {} - void VisitColumnReference(ast::ColumnReference& node) override; - void VisitUnpackPattern([[maybe_unused]] ast::UnpackPattern& node) override; - void VisitMatrixIterationStatement(ast::MatrixIterationStatement& node) override; + void VisitExpressionStatement(ast::ExpressionStatement& node) override; + void VisitBreakStatement( + [[maybe_unused]] ast::BreakStatement& node) override {} + void VisitContinueStatement( + [[maybe_unused]] ast::ContinueStatement& node) override {} + void VisitReturnStatement(ast::ReturnStatement& node) override; + void VisitTSSQLExpression( + [[maybe_unused]] ast::TSSQLExpression& node) override {} + void VisitColumnReference(ast::ColumnReference& node) override; + void VisitUnpackPattern([[maybe_unused]] ast::UnpackPattern& node) override; + void VisitMatrixIterationStatement( + ast::MatrixIterationStatement& node) override; - // Compiler directives (correct names without Statement suffix) - void VisitCompilerDirective([[maybe_unused]] ast::CompilerDirective& node) override; - void VisitConditionalDirective(ast::ConditionalDirective& node) override; - void VisitConditionalBlock(ast::ConditionalBlock& node) override; - void VisitTSLXBlock([[maybe_unused]] ast::TSLXBlock& node) override {} + // Compiler directives (correct names without Statement suffix) + void VisitCompilerDirective( + [[maybe_unused]] ast::CompilerDirective& node) override; + void VisitConditionalDirective(ast::ConditionalDirective& node) override; + void VisitConditionalBlock(ast::ConditionalBlock& node) override; + void VisitTSLXBlock([[maybe_unused]] ast::TSLXBlock& node) override {} - void VisitParameter([[maybe_unused]] ast::Parameter& node) override {} + void VisitParameter([[maybe_unused]] ast::Parameter& node) override {} - private: - // Symbol creation helpers - now returns symbol with parameters pre-set - SymbolId CreateFunctionSymbol( - const std::string& name, - const ast::Location& location, - const std::vector>& parameters, - const std::optional& return_type); + private: + SymbolId CreateFunctionSymbol( + const std::string& name, const ast::Location& location, + const std::vector>& parameters, + const std::optional& return_type); - SymbolId CreateMethodSymbol( - const std::string& name, - const ast::Location& location, - const std::vector>& parameters, - const std::optional& return_type); + SymbolId CreateMethodSymbol( + const std::string& name, const ast::Location& location, + const std::vector>& parameters, + const std::optional& return_type); - SymbolId CreateSymbol( - const std::string& name, - SymbolKind kind, - const ast::Location& location, - const std::optional& type_hint = std::nullopt); + SymbolId CreateSymbol( + const std::string& name, SymbolKind kind, const ast::Location& location, + const std::optional& type_hint = std::nullopt); - SymbolId CreateSymbol( - const std::string& name, - SymbolKind kind, - const ast::ASTNode& node, - const std::optional& type_hint = std::nullopt); + SymbolId CreateSymbol( + const std::string& name, SymbolKind kind, const ast::ASTNode& node, + const std::optional& type_hint = std::nullopt); - // Scope management - ScopeId EnterScopeWithSymbol(ScopeKind kind, SymbolId symbol_id, const ast::Location& range); - ScopeId EnterScope(ScopeKind kind, const ast::Location& range); - void ExitScope(); + // Scope management + ScopeId EnterScopeWithSymbol(ScopeKind kind, SymbolId symbol_id, + const ast::Location& range); + ScopeId EnterScope(ScopeKind kind, const ast::Location& range); + void ExitScope(); - // Traversal helpers - void VisitStatements(const std::vector& statements); - void VisitExpression(ast::Expression& expr); + // Traversal helpers + void VisitStatements(const std::vector& statements); + void VisitExpression(ast::Expression& expr); - // LValue processing for assignments - void ProcessLValue(const ast::LValue& lvalue, bool is_write); + // LValue processing for assignments + void ProcessLValue(const ast::LValue& lvalue, bool is_write); - // Type and parameter extraction - std::optional ExtractTypeName( - const std::optional& type) const; + // Type and parameter extraction + std::optional ExtractTypeName( + const std::optional& type) const; - std::vector BuildParameters( - const std::vector>& parameters) const; + std::vector BuildParameters( + const std::vector>& parameters) const; - private: - // Symbol table reference - SymbolTable& table_; + private: + SymbolTable& table_; - // Current traversal state - ScopeId current_scope_id_; - std::optional current_parent_symbol_id_; - std::optional current_function_id_; + ScopeId current_scope_id_; + std::optional current_parent_symbol_id_; + std::optional current_function_id_; - // Interface/implementation section tracking - bool in_interface_section_; - }; + bool in_interface_section_; +}; -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/graph/call.cpp b/lsp-server/src/language/symbol/graph/call.cpp new file mode 100644 index 0000000..6b13a54 --- /dev/null +++ b/lsp-server/src/language/symbol/graph/call.cpp @@ -0,0 +1,52 @@ +#include "call.hpp" + +#include + +namespace lsp::language::symbol::graph { + +void Call::OnSymbolRemoved(SymbolId id) { + callers_map_.erase(id); + callees_map_.erase(id); + + for (auto& [_, calls] : callers_map_) { + calls.erase(std::remove_if(calls.begin(), calls.end(), + [id](const symbol::Call& call) { + return call.caller == id; + }), + calls.end()); + } + + for (auto& [_, calls] : callees_map_) { + calls.erase(std::remove_if(calls.begin(), calls.end(), + [id](const symbol::Call& call) { + return call.callee == id; + }), + calls.end()); + } +} + +void Call::Clear() { + callers_map_.clear(); + callees_map_.clear(); +} + +void Call::AddCall(SymbolId caller, SymbolId callee, + const ast::Location& location) { + symbol::Call call{caller, callee, location}; + callers_map_[callee].push_back(call); + callees_map_[caller].push_back(call); +} + +const std::vector& Call::callers(SymbolId id) const { + static const std::vector kEmpty; + auto it = callers_map_.find(id); + return it != callers_map_.end() ? it->second : kEmpty; +} + +const std::vector& Call::callees(SymbolId id) const { + static const std::vector kEmpty; + auto it = callees_map_.find(id); + return it != callees_map_.end() ? it->second : kEmpty; +} + +} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/graph/call.hpp b/lsp-server/src/language/symbol/graph/call.hpp index 09fa4e5..ffbbe8e 100644 --- a/lsp-server/src/language/symbol/graph/call.hpp +++ b/lsp-server/src/language/symbol/graph/call.hpp @@ -1,77 +1,26 @@ #pragma once -#include #include #include #include "../interface.hpp" #include "../types.hpp" -namespace lsp::language::symbol::graph -{ +namespace lsp::language::symbol::graph { - class Call : public ISymbolGraph - { - public: - void OnSymbolRemoved(SymbolId id) override - { - callers_map_.erase(id); - callees_map_.erase(id); +class Call : public ISymbolGraph { + public: + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; - // Remove all calls where this symbol is the caller - for (auto& [symbol_id, calls] : callers_map_) - { - calls.erase(std::remove_if(calls.begin(), calls.end(), [id](const symbol::Call& call) { - return call.caller == id; - }), - calls.end()); - } + void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location); - // Remove all calls where this symbol is the callee - for (auto& [symbol_id, calls] : callees_map_) - { - calls.erase(std::remove_if(calls.begin(), calls.end(), [id](const symbol::Call& call) { - return call.callee == id; - }), - calls.end()); - } - } + const std::vector& callers(SymbolId id) const; + const std::vector& callees(SymbolId id) const; - void Clear() override - { - callers_map_.clear(); - callees_map_.clear(); - } + private: + std::unordered_map> callers_map_; + std::unordered_map> callees_map_; +}; - void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location) - { - symbol::Call call{ caller, callee, location }; - callers_map_[callee].push_back(call); // Who calls this symbol - callees_map_[caller].push_back(call); // What this symbol calls - } - - // Accessor (snake_case) - const std::vector& callers(SymbolId id) const - { - static const std::vector kEmpty; - auto it = callers_map_.find(id); - return it != callers_map_.end() ? it->second : kEmpty; - } - - // Accessor (snake_case) - const std::vector& callees(SymbolId id) const - { - static const std::vector kEmpty; - auto it = callees_map_.find(id); - return it != callees_map_.end() ? it->second : kEmpty; - } - - private: - // Map: symbol -> who calls it (incoming calls) - std::unordered_map> callers_map_; - - // Map: symbol -> what it calls (outgoing calls) - std::unordered_map> callees_map_; - }; - -} // namespace lsp::language::symbol::graph +} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/graph/inheritance.cpp b/lsp-server/src/language/symbol/graph/inheritance.cpp new file mode 100644 index 0000000..19cfc74 --- /dev/null +++ b/lsp-server/src/language/symbol/graph/inheritance.cpp @@ -0,0 +1,58 @@ +#include "inheritance.hpp" + +#include + +namespace lsp::language::symbol::graph { + +void Inheritance::OnSymbolRemoved(SymbolId id) { + base_classes_.erase(id); + derived_classes_.erase(id); + + for (auto& [_, bases] : base_classes_) { + bases.erase(std::remove(bases.begin(), bases.end(), id), bases.end()); + } + + for (auto& [_, derived] : derived_classes_) { + derived.erase(std::remove(derived.begin(), derived.end(), id), + derived.end()); + } +} + +void Inheritance::Clear() { + base_classes_.clear(); + derived_classes_.clear(); +} + +void Inheritance::AddInheritance(SymbolId derived, SymbolId base) { + base_classes_[derived].push_back(base); + derived_classes_[base].push_back(derived); +} + +const std::vector& Inheritance::base_classes(SymbolId id) const { + static const std::vector kEmpty; + auto it = base_classes_.find(id); + return it != base_classes_.end() ? it->second : kEmpty; +} + +const std::vector& Inheritance::derived_classes(SymbolId id) const { + static const std::vector kEmpty; + auto it = derived_classes_.find(id); + return it != derived_classes_.end() ? it->second : kEmpty; +} + +bool Inheritance::IsSubclassOf(SymbolId derived, SymbolId base) const { + auto it = base_classes_.find(derived); + if (it == base_classes_.end()) { + return false; + } + + for (SymbolId parent : it->second) { + if (parent == base || IsSubclassOf(parent, base)) { + return true; + } + } + + return false; +} + +} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/graph/inheritance.hpp b/lsp-server/src/language/symbol/graph/inheritance.hpp index 27d150f..826af2b 100644 --- a/lsp-server/src/language/symbol/graph/inheritance.hpp +++ b/lsp-server/src/language/symbol/graph/inheritance.hpp @@ -1,90 +1,27 @@ #pragma once -#include #include #include #include "../interface.hpp" #include "../types.hpp" -namespace lsp::language::symbol::graph -{ +namespace lsp::language::symbol::graph { - class Inheritance : public ISymbolGraph - { - public: - void OnSymbolRemoved(SymbolId id) override - { - base_classes_.erase(id); - derived_classes_.erase(id); +class Inheritance : public ISymbolGraph { + public: + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; - // Remove from all base class lists - for (auto& [derived_id, bases] : base_classes_) - { - bases.erase(std::remove(bases.begin(), bases.end(), id), bases.end()); - } + void AddInheritance(SymbolId derived, SymbolId base); - // Remove from all derived class lists - for (auto& [base_id, derived] : derived_classes_) - { - derived.erase(std::remove(derived.begin(), derived.end(), id), - derived.end()); - } - } + const std::vector& base_classes(SymbolId id) const; + const std::vector& derived_classes(SymbolId id) const; + bool IsSubclassOf(SymbolId derived, SymbolId base) const; - void Clear() override - { - base_classes_.clear(); - derived_classes_.clear(); - } + private: + std::unordered_map> base_classes_; + std::unordered_map> derived_classes_; +}; - void AddInheritance(SymbolId derived, SymbolId base) - { - base_classes_[derived].push_back(base); - derived_classes_[base].push_back(derived); - } - - // Accessor (snake_case) - const std::vector& base_classes(SymbolId id) const - { - static const std::vector kEmpty; - auto it = base_classes_.find(id); - return it != base_classes_.end() ? it->second : kEmpty; - } - - // Accessor (snake_case) - const std::vector& derived_classes(SymbolId id) const - { - static const std::vector kEmpty; - auto it = derived_classes_.find(id); - return it != derived_classes_.end() ? it->second : kEmpty; - } - - bool IsSubclassOf(SymbolId derived, SymbolId base) const - { - auto it = base_classes_.find(derived); - if (it == base_classes_.end()) - { - return false; - } - - for (SymbolId parent : it->second) - { - if (parent == base || IsSubclassOf(parent, base)) - { - return true; - } - } - - return false; - } - - private: - // Map from derived class to its base classes - std::unordered_map> base_classes_; - - // Map from base class to its derived classes - std::unordered_map> derived_classes_; - }; - -} // namespace lsp::language::symbol::graph +} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/graph/reference.cpp b/lsp-server/src/language/symbol/graph/reference.cpp new file mode 100644 index 0000000..d5cc7b8 --- /dev/null +++ b/lsp-server/src/language/symbol/graph/reference.cpp @@ -0,0 +1,46 @@ +#include "reference.hpp" + +#include + +namespace lsp::language::symbol::graph { + +void Reference::OnSymbolRemoved(SymbolId id) { + references_.erase(id); + + for (auto& [_, refs] : references_) { + refs.erase(std::remove_if(refs.begin(), refs.end(), + [id](const symbol::Reference& ref) { + return ref.symbol_id == id; + }), + refs.end()); + } +} + +void Reference::Clear() { references_.clear(); } + +void Reference::AddReference(SymbolId symbol_id, const ast::Location& location, + bool is_definition, bool is_write) { + references_[symbol_id].push_back( + {location, symbol_id, is_definition, is_write}); +} + +const std::vector& Reference::references(SymbolId id) const { + static const std::vector kEmpty; + auto it = references_.find(id); + return it != references_.end() ? it->second : kEmpty; +} + +std::optional Reference::FindDefinitionLocation( + SymbolId id) const { + auto it = references_.find(id); + if (it != references_.end()) { + for (const auto& ref : it->second) { + if (ref.is_definition) { + return ref.location; + } + } + } + return std::nullopt; +} + +} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/graph/reference.hpp b/lsp-server/src/language/symbol/graph/reference.hpp index c448087..79b0371 100644 --- a/lsp-server/src/language/symbol/graph/reference.hpp +++ b/lsp-server/src/language/symbol/graph/reference.hpp @@ -1,67 +1,27 @@ #pragma once -#include +#include #include #include #include "../interface.hpp" #include "../types.hpp" -namespace lsp::language::symbol::graph -{ +namespace lsp::language::symbol::graph { - class Reference : public ISymbolGraph - { - public: - void OnSymbolRemoved(SymbolId id) override - { - references_.erase(id); +class Reference : public ISymbolGraph { + public: + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; - // Remove all references to the removed symbol - for (auto& [symbol_id, refs] : references_) - { - refs.erase(std::remove_if(refs.begin(), refs.end(), [id](const symbol::Reference& ref) { - return ref.symbol_id == id; - }), - refs.end()); - } - } + void AddReference(SymbolId symbol_id, const ast::Location& location, + bool is_definition = false, bool is_write = false); - void Clear() override { references_.clear(); } + const std::vector& references(SymbolId id) const; + std::optional FindDefinitionLocation(SymbolId id) const; - void AddReference(SymbolId symbol_id, const ast::Location& location, bool is_definition = false, bool is_write = false) - { - references_[symbol_id].push_back( - { location, symbol_id, is_definition, is_write }); - } + private: + std::unordered_map> references_; +}; - // Accessor (snake_case) - const std::vector& references(SymbolId id) const - { - static const std::vector kEmpty; - auto it = references_.find(id); - return it != references_.end() ? it->second : kEmpty; - } - - std::optional FindDefinitionLocation(SymbolId id) const - { - auto it = references_.find(id); - if (it != references_.end()) - { - for (const auto& ref : it->second) - { - if (ref.is_definition) - { - return ref.location; - } - } - } - return std::nullopt; - } - - private: - // Map from symbol to all its references - std::unordered_map> references_; - }; - -} // namespace lsp::language::symbol::graph +} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/index/location.cpp b/lsp-server/src/language/symbol/index/location.cpp new file mode 100644 index 0000000..061d9a4 --- /dev/null +++ b/lsp-server/src/language/symbol/index/location.cpp @@ -0,0 +1,68 @@ +#include "location.hpp" + +#include + +namespace lsp::language::symbol::index { + +void Location::OnSymbolAdded(const Symbol& symbol) { + const auto& loc = symbol.selection_range(); + entries_.push_back({loc.start_offset, loc.end_offset, symbol.id()}); + needs_sort_ = true; +} + +void Location::OnSymbolRemoved(SymbolId id) { + entries_.erase( + std::remove_if(entries_.begin(), entries_.end(), + [id](const Entry& e) { return e.symbol_id == id; }), + entries_.end()); +} + +void Location::Clear() { + entries_.clear(); + needs_sort_ = false; +} + +std::optional Location::FindSymbolAt( + const ast::Location& location) const { + EnsureSorted(); + uint32_t pos = location.start_offset; + + auto it = + std::lower_bound(entries_.begin(), entries_.end(), pos, + [](const Entry& e, uint32_t p) { return e.start < p; }); + + if (it != entries_.begin()) { + --it; + } + + std::optional result; + uint32_t min_span = UINT32_MAX; + + for (; it != entries_.end() && it->start <= pos; ++it) { + if (pos >= it->start && pos < it->end) { + uint32_t span = it->end - it->start; + if (span < min_span) { + min_span = span; + result = it->symbol_id; + } + } + } + + return result; +} + +bool Location::Entry::operator<(const Entry& other) const { + if (start != other.start) { + return start < other.start; + } + return end > other.end; +} + +void Location::EnsureSorted() const { + if (needs_sort_) { + std::sort(entries_.begin(), entries_.end()); + needs_sort_ = false; + } +} + +} // namespace lsp::language::symbol::index diff --git a/lsp-server/src/language/symbol/index/location.hpp b/lsp-server/src/language/symbol/index/location.hpp index 78175cc..9f5cfe9 100644 --- a/lsp-server/src/language/symbol/index/location.hpp +++ b/lsp-server/src/language/symbol/index/location.hpp @@ -1,100 +1,34 @@ #pragma once -#include #include #include #include "../interface.hpp" #include "../types.hpp" -namespace lsp::language::symbol::index -{ +namespace lsp::language::symbol::index { - class Location : public ISymbolIndex - { - public: - void OnSymbolAdded(const Symbol& symbol) override - { - const auto& loc = symbol.selection_range(); - entries_.push_back({ loc.start_offset, loc.end_offset, symbol.id() }); - needs_sort_ = true; - } +class Location : public ISymbolIndex { + public: + void OnSymbolAdded(const Symbol& symbol) override; + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; - void OnSymbolRemoved(SymbolId id) override - { - entries_.erase(std::remove_if(entries_.begin(), entries_.end(), [id](const Entry& e) { - return e.symbol_id == id; - }), - entries_.end()); - } + std::optional FindSymbolAt(const ast::Location& location) const; - void Clear() override - { - entries_.clear(); - needs_sort_ = false; - } + private: + struct Entry { + uint32_t start; + uint32_t end; + SymbolId symbol_id; - std::optional FindSymbolAt(const ast::Location& location) const - { - EnsureSorted(); - uint32_t pos = location.start_offset; + bool operator<(const Entry& other) const; + }; - auto it = std::lower_bound(entries_.begin(), entries_.end(), pos, [](const Entry& e, uint32_t p) { - return e.start < p; - }); + void EnsureSorted() const; - if (it != entries_.begin()) - { - --it; - } + mutable std::vector entries_; + mutable bool needs_sort_ = false; +}; - std::optional result; - uint32_t min_span = UINT32_MAX; - - for (; it != entries_.end() && it->start <= pos; ++it) - { - if (pos >= it->start && pos < it->end) - { - uint32_t span = it->end - it->start; - if (span < min_span) - { - min_span = span; - result = it->symbol_id; - } - } - } - - return result; - } - - private: - struct Entry - { - uint32_t start; - uint32_t end; - SymbolId symbol_id; - - bool operator<(const Entry& other) const - { - if (start != other.start) - { - return start < other.start; - } - return end > other.end; - } - }; - - void EnsureSorted() const - { - if (needs_sort_) - { - std::sort(entries_.begin(), entries_.end()); - needs_sort_ = false; - } - } - - mutable std::vector entries_; - mutable bool needs_sort_ = false; - }; - -} // namespace lsp::language::symbol::index +} // namespace lsp::language::symbol::index diff --git a/lsp-server/src/language/symbol/index/scope.cpp b/lsp-server/src/language/symbol/index/scope.cpp new file mode 100644 index 0000000..c71cc80 --- /dev/null +++ b/lsp-server/src/language/symbol/index/scope.cpp @@ -0,0 +1,93 @@ +#include "scope.hpp" + +#include + +namespace lsp::language::symbol::index { + +void Scope::OnSymbolAdded(const Symbol&) {} + +void Scope::OnSymbolRemoved(SymbolId id) { + for (auto& [_, scope] : scopes_) { + for (auto it = scope.symbols.begin(); it != scope.symbols.end();) { + if (it->second == id) { + it = scope.symbols.erase(it); + } else { + ++it; + } + } + } +} + +void Scope::Clear() { + scopes_.clear(); + next_scope_id_ = 1; + global_scope_ = kInvalidScopeId; +} + +ScopeId Scope::CreateScope(ScopeKind kind, const ast::Location& range, + std::optional parent, + std::optional owner) { + ScopeId id = next_scope_id_++; + scopes_[id] = {id, kind, range, parent, owner, {}}; + + if (kind == ScopeKind::kGlobal) { + global_scope_ = id; + } + + return id; +} + +void Scope::AddSymbol(ScopeId scope_id, const std::string& name, + SymbolId symbol_id) { + auto it = scopes_.find(scope_id); + if (it != scopes_.end()) { + it->second.symbols[ToLower(name)] = symbol_id; + } +} + +std::optional Scope::FindSymbolInScope( + ScopeId scope_id, const std::string& name) const { + auto it = scopes_.find(scope_id); + if (it == scopes_.end()) { + return std::nullopt; + } + + auto sym_it = it->second.symbols.find(ToLower(name)); + return sym_it != it->second.symbols.end() ? std::optional(sym_it->second) + : std::nullopt; +} + +std::optional Scope::FindSymbolInScopeChain( + ScopeId scope_id, const std::string& name) const { + std::optional current = scope_id; + + while (current) { + if (auto result = FindSymbolInScope(*current, name)) { + return result; + } + + auto it = scopes_.find(*current); + current = it != scopes_.end() ? it->second.parent : std::nullopt; + } + + return std::nullopt; +} + +const symbol::Scope* Scope::scope(ScopeId id) const { + auto it = scopes_.find(id); + return it != scopes_.end() ? &it->second : nullptr; +} + +ScopeId Scope::global_scope() const { return global_scope_; } + +const std::unordered_map& Scope::all_scopes() const { + return scopes_; +} + +std::string Scope::ToLower(const std::string& s) { + std::string result = s; + std::transform(result.begin(), result.end(), result.begin(), ::tolower); + return result; +} + +} // namespace lsp::language::symbol::index diff --git a/lsp-server/src/language/symbol/index/scope.hpp b/lsp-server/src/language/symbol/index/scope.hpp index f4e1855..cbdae52 100644 --- a/lsp-server/src/language/symbol/index/scope.hpp +++ b/lsp-server/src/language/symbol/index/scope.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -8,137 +7,46 @@ #include "../interface.hpp" #include "../types.hpp" -namespace lsp::language::symbol -{ +namespace lsp::language::symbol { - struct Scope - { - ScopeId id; - ScopeKind kind; - ast::Location range; - std::optional parent; - std::optional owner; - std::unordered_map symbols; - }; +struct Scope { + ScopeId id; + ScopeKind kind; + ast::Location range; + std::optional parent; + std::optional owner; + std::unordered_map symbols; +}; -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol -namespace lsp::language::symbol::index -{ +namespace lsp::language::symbol::index { - class Scope : public ISymbolIndex - { - public: - void OnSymbolAdded(const Symbol&) override {} +class Scope : public ISymbolIndex { + public: + void OnSymbolAdded(const Symbol&) override; + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; - void OnSymbolRemoved(SymbolId id) override - { - for (auto& [_, scope] : scopes_) - { - for (auto it = scope.symbols.begin(); it != scope.symbols.end();) - { - if (it->second == id) - { - it = scope.symbols.erase(it); - } - else - { - ++it; - } - } - } - } + ScopeId CreateScope(ScopeKind kind, const ast::Location& range, + std::optional parent = std::nullopt, + std::optional owner = std::nullopt); + void AddSymbol(ScopeId scope_id, const std::string& name, SymbolId symbol_id); + std::optional FindSymbolInScope(ScopeId scope_id, + const std::string& name) const; + std::optional FindSymbolInScopeChain(ScopeId scope_id, + const std::string& name) const; - void Clear() override - { - scopes_.clear(); - next_scope_id_ = 1; - global_scope_ = kInvalidScopeId; - } + const symbol::Scope* scope(ScopeId id) const; + ScopeId global_scope() const; + const std::unordered_map& all_scopes() const; - ScopeId CreateScope(ScopeKind kind, const ast::Location& range, std::optional parent = std::nullopt, std::optional owner = std::nullopt) - { - ScopeId id = next_scope_id_++; - scopes_[id] = { id, kind, range, parent, owner, {} }; + private: + static std::string ToLower(const std::string& s); - if (kind == ScopeKind::kGlobal) - { - global_scope_ = id; - } + ScopeId next_scope_id_ = 1; + ScopeId global_scope_ = kInvalidScopeId; + std::unordered_map scopes_; +}; - return id; - } - - void AddSymbol(ScopeId scope_id, const std::string& name, SymbolId symbol_id) - { - auto it = scopes_.find(scope_id); - if (it != scopes_.end()) - { - it->second.symbols[ToLower(name)] = symbol_id; - } - } - - std::optional FindSymbolInScope( - ScopeId scope_id, - const std::string& name) const - { - auto it = scopes_.find(scope_id); - if (it == scopes_.end()) - { - return std::nullopt; - } - - auto sym_it = it->second.symbols.find(ToLower(name)); - return sym_it != it->second.symbols.end() ? std::optional(sym_it->second) : std::nullopt; - } - - std::optional FindSymbolInScopeChain( - ScopeId scope_id, - const std::string& name) const - { - std::optional current = scope_id; - - while (current) - { - if (auto result = FindSymbolInScope(*current, name)) - { - return result; - } - - auto it = scopes_.find(*current); - current = it != scopes_.end() ? it->second.parent : std::nullopt; - } - - return std::nullopt; - } - - // Accessor (snake_case) - const symbol::Scope* scope(ScopeId id) const - { - auto it = scopes_.find(id); - return it != scopes_.end() ? &it->second : nullptr; - } - - // Accessor (snake_case) - ScopeId global_scope() const { return global_scope_; } - - // Accessor (snake_case) - const std::unordered_map& all_scopes() const - { - return scopes_; - } - - private: - static std::string ToLower(const std::string& s) - { - std::string result = s; - std::transform(result.begin(), result.end(), result.begin(), ::tolower); - return result; - } - - ScopeId next_scope_id_ = 1; - ScopeId global_scope_ = kInvalidScopeId; - std::unordered_map scopes_; - }; - -} // namespace lsp::language::symbol::index +} // namespace lsp::language::symbol::index diff --git a/lsp-server/src/language/symbol/interface.hpp b/lsp-server/src/language/symbol/interface.hpp index f051382..5554f1f 100644 --- a/lsp-server/src/language/symbol/interface.hpp +++ b/lsp-server/src/language/symbol/interface.hpp @@ -2,26 +2,21 @@ #include "./types.hpp" -namespace lsp::language::symbol -{ +namespace lsp::language::symbol { - // 索引接口:用于通过各种方式查找符号 - class ISymbolIndex - { - public: - virtual ~ISymbolIndex() = default; - virtual void OnSymbolAdded(const Symbol& symbol) = 0; - virtual void OnSymbolRemoved(SymbolId id) = 0; - virtual void Clear() = 0; - }; +class ISymbolIndex { + public: + virtual ~ISymbolIndex() = default; + virtual void OnSymbolAdded(const Symbol& symbol) = 0; + virtual void OnSymbolRemoved(SymbolId id) = 0; + virtual void Clear() = 0; +}; - // 关系接口:用于管理符号之间的关系 - class ISymbolGraph - { - public: - virtual ~ISymbolGraph() = default; - virtual void OnSymbolRemoved(SymbolId id) = 0; - virtual void Clear() = 0; - }; +class ISymbolGraph { + public: + virtual ~ISymbolGraph() = default; + virtual void OnSymbolRemoved(SymbolId id) = 0; + virtual void Clear() = 0; +}; -} +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/location_index.hpp b/lsp-server/src/language/symbol/location_index.hpp deleted file mode 100644 index 0ef83c9..0000000 --- a/lsp-server/src/language/symbol/location_index.hpp +++ /dev/null @@ -1,95 +0,0 @@ -#pragma once - -#include -#include -#include -#include "./interface.hpp" -#include "./types.hpp" - -namespace lsp::language::symbol -{ - - class LocationIndex : public ISymbolIndex - { - public: - void OnSymbolAdded(const Symbol& symbol) override - { - const auto& range = symbol.range(); - entries_.push_back({ range.start_offset, - range.end_offset, - symbol.id() }); - needs_sort_ = true; - } - - void OnSymbolRemoved(SymbolId id) override - { - entries_.erase( - std::remove_if(entries_.begin(), entries_.end(), [id](const Entry& e) { return e.symbol_id == id; }), - entries_.end()); - } - - void Clear() override - { - entries_.clear(); - needs_sort_ = false; - } - - std::optional FindAt(const ast::Location& location) const - { - EnsureSorted(); - uint32_t pos = location.start_offset; - - auto it = std::lower_bound( - entries_.begin(), entries_.end(), pos, [](const Entry& e, uint32_t p) { return e.start < p; }); - - if (it != entries_.begin()) - --it; - - std::optional result; - uint32_t min_span = UINT32_MAX; - - for (; it != entries_.end() && it->start <= pos; ++it) - { - if (pos >= it->start && pos < it->end) - { - uint32_t span = it->end - it->start; - if (span < min_span) - { - min_span = span; - result = it->symbol_id; - } - } - } - - return result; - } - - private: - struct Entry - { - uint32_t start; - uint32_t end; - SymbolId symbol_id; - - bool operator<(const Entry& other) const - { - if (start != other.start) - return start < other.start; - return end > other.end; - } - }; - - void EnsureSorted() const - { - if (needs_sort_) - { - std::sort(entries_.begin(), entries_.end()); - needs_sort_ = false; - } - } - - mutable std::vector entries_; - mutable bool needs_sort_ = false; - }; - -} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/scope.hpp b/lsp-server/src/language/symbol/scope.hpp deleted file mode 100644 index a8a1748..0000000 --- a/lsp-server/src/language/symbol/scope.hpp +++ /dev/null @@ -1,130 +0,0 @@ -#pragma once - -#include -#include -#include -#include "./interface.hpp" -#include "./types.hpp" - -namespace lsp::language::symbol -{ - - struct Scope - { - ScopeId id; - ScopeKind kind; - ast::Location range; - std::optional parent; - std::optional owner; - std::unordered_map symbols; - }; - - class ScopeIndex : public ISymbolIndex - { - public: - void OnSymbolAdded(const Symbol&) override {} - - void OnSymbolRemoved(SymbolId id) override - { - for (auto& [_, scope] : scopes_) - { - for (auto it = scope.symbols.begin(); it != scope.symbols.end();) - { - if (it->second == id) - { - it = scope.symbols.erase(it); - } - else - { - ++it; - } - } - } - } - - void Clear() override - { - scopes_.clear(); - next_scope_id_ = 1; - global_scope_ = kInvalidScopeId; - } - - ScopeId CreateScope(ScopeKind kind, const ast::Location& range, std::optional parent = std::nullopt, std::optional owner = std::nullopt) - { - ScopeId id = next_scope_id_++; - scopes_[id] = { id, kind, range, parent, owner, {} }; - - if (kind == ScopeKind::kGlobal) - { - global_scope_ = id; - } - - return id; - } - - void AddSymbol(ScopeId scope_id, const std::string& name, SymbolId symbol_id) - { - auto it = scopes_.find(scope_id); - if (it != scopes_.end()) - { - it->second.symbols[ToLower(name)] = symbol_id; - } - } - - std::optional FindInScope(ScopeId scope_id, const std::string& name) const - { - auto it = scopes_.find(scope_id); - if (it == scopes_.end()) - return std::nullopt; - - auto sym_it = it->second.symbols.find(ToLower(name)); - return sym_it != it->second.symbols.end() ? - std::optional(sym_it->second) : - std::nullopt; - } - - std::optional FindInScopeChain(ScopeId scope_id, const std::string& name) const - { - std::optional current = scope_id; - - while (current) - { - if (auto result = FindInScope(*current, name)) - { - return result; - } - - auto it = scopes_.find(*current); - current = it != scopes_.end() ? it->second.parent : std::nullopt; - } - - return std::nullopt; - } - - const Scope* GetScope(ScopeId id) const - { - auto it = scopes_.find(id); - return it != scopes_.end() ? &it->second : nullptr; - } - - ScopeId GetGlobalScope() const { return global_scope_; } - - const std::unordered_map& GetAllScopes() const - { - return scopes_; - } - - private: - static std::string ToLower(const std::string& s) - { - std::string result = s; - std::transform(result.begin(), result.end(), result.begin(), ::tolower); - return result; - } - - ScopeId next_scope_id_ = 1; - ScopeId global_scope_ = kInvalidScopeId; - std::unordered_map scopes_; - }; - -} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/store.cpp b/lsp-server/src/language/symbol/store.cpp new file mode 100644 index 0000000..13d7ee0 --- /dev/null +++ b/lsp-server/src/language/symbol/store.cpp @@ -0,0 +1,59 @@ +#include "store.hpp" + +#include + +namespace lsp::language::symbol { + +SymbolId SymbolStore::Add(Symbol def) { + SymbolId id = next_id_++; + std::visit([id](auto& s) { s.id = id; }, def.mutable_data()); + + auto [it, _] = definitions_.emplace(id, std::move(def)); + const auto& stored = it->second; + by_name_[stored.name()].push_back(id); + return id; +} + +bool SymbolStore::Remove(SymbolId id) { + auto it = definitions_.find(id); + if (it == definitions_.end()) { + return false; + } + + const std::string& name = it->second.name(); + auto& ids = by_name_[name]; + ids.erase(std::remove(ids.begin(), ids.end(), id), ids.end()); + if (ids.empty()) { + by_name_.erase(name); + } + + definitions_.erase(it); + return true; +} + +void SymbolStore::Clear() { + definitions_.clear(); + by_name_.clear(); + next_id_ = 1; +} + +const Symbol* SymbolStore::Get(SymbolId id) const { + auto it = definitions_.find(id); + return it != definitions_.end() ? &it->second : nullptr; +} + +std::vector> SymbolStore::GetAll() const { + std::vector> result; + result.reserve(definitions_.size()); + for (const auto& [_, def] : definitions_) { + result.push_back(std::cref(def)); + } + return result; +} + +std::vector SymbolStore::FindByName(const std::string& name) const { + auto it = by_name_.find(name); + return it != by_name_.end() ? it->second : std::vector(); +} + +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/store.hpp b/lsp-server/src/language/symbol/store.hpp index d375eb0..d023cc2 100644 --- a/lsp-server/src/language/symbol/store.hpp +++ b/lsp-server/src/language/symbol/store.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -8,80 +7,22 @@ #include "./types.hpp" -namespace lsp::language::symbol -{ +namespace lsp::language::symbol { - class SymbolStore - { - public: - SymbolId Add(Symbol def) - { - SymbolId id = next_id_++; - // Update the symbol's ID - std::visit([id](auto& s) { s.id = id; }, def.mutable_data()); +class SymbolStore { + public: + SymbolId Add(Symbol def); + bool Remove(SymbolId id); + void Clear(); - auto [it, inserted] = definitions_.emplace(id, std::move(def)); - const auto& stored = it->second; - by_name_[stored.name()].push_back(id); - return id; - } + const Symbol* Get(SymbolId id) const; + std::vector> GetAll() const; + std::vector FindByName(const std::string& name) const; - bool Remove(SymbolId id) - { - auto it = definitions_.find(id); - if (it == definitions_.end()) - { - return false; - } + private: + SymbolId next_id_ = 1; + std::unordered_map definitions_; + std::unordered_map> by_name_; +}; - const std::string& name = it->second.name(); - auto& ids = by_name_[name]; - ids.erase(std::remove(ids.begin(), ids.end(), id), ids.end()); - if (ids.empty()) - { - by_name_.erase(name); - } - - definitions_.erase(it); - return true; - } - - void Clear() - { - definitions_.clear(); - by_name_.clear(); - next_id_ = 1; - } - - // Accessor (snake_case) - 返回指针是合理的,因为可能不存在 - const Symbol* Get(SymbolId id) const - { - auto it = definitions_.find(id); - return it != definitions_.end() ? &it->second : nullptr; - } - - // Accessor (snake_case) - 使用reference_wrapper替代指针 - std::vector> GetAll() const - { - std::vector> result; - result.reserve(definitions_.size()); - for (const auto& [id, def] : definitions_) - { - result.push_back(std::cref(def)); - } - return result; - } - - std::vector FindByName(const std::string& name) const - { - auto it = by_name_.find(name); - return it != by_name_.end() ? it->second : std::vector(); - } - - private: - SymbolId next_id_ = 1; - std::unordered_map definitions_; - std::unordered_map> by_name_; - }; - -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/table.cpp b/lsp-server/src/language/symbol/table.cpp new file mode 100644 index 0000000..e5fe8e1 --- /dev/null +++ b/lsp-server/src/language/symbol/table.cpp @@ -0,0 +1,104 @@ +#include "table.hpp" + +#include + +namespace lsp::language::symbol { + +SymbolId SymbolTable::CreateSymbol(Symbol symbol) { + auto def = Symbol(std::move(symbol)); + auto id = store_.Add(def); + + location_index_.OnSymbolAdded(def); + scope_index_.OnSymbolAdded(def); + + return id; +} + +bool SymbolTable::RemoveSymbol(SymbolId id) { + location_index_.OnSymbolRemoved(id); + scope_index_.OnSymbolRemoved(id); + reference_graph_.OnSymbolRemoved(id); + inheritance_graph_.OnSymbolRemoved(id); + call_graph_.OnSymbolRemoved(id); + + return store_.Remove(id); +} + +void SymbolTable::Clear() { + store_.Clear(); + location_index_.Clear(); + scope_index_.Clear(); + reference_graph_.Clear(); + inheritance_graph_.Clear(); + call_graph_.Clear(); +} + +std::vector SymbolTable::FindSymbolsByName( + const std::string& name) const { + return store_.FindByName(name); +} + +std::optional SymbolTable::FindSymbolAt( + const ast::Location& location) const { + return location_index_.FindSymbolAt(location); +} + +const Symbol* SymbolTable::definition(SymbolId id) const { + return store_.Get(id); +} + +std::vector> SymbolTable::all_definitions() + const { + return store_.GetAll(); +} + +index::Location& SymbolTable::locations() { return location_index_; } +index::Scope& SymbolTable::scopes() { return scope_index_; } + +const index::Location& SymbolTable::locations() const { + return location_index_; +} + +const index::Scope& SymbolTable::scopes() const { return scope_index_; } + +graph::Reference& SymbolTable::references() { return reference_graph_; } +graph::Inheritance& SymbolTable::inheritance() { return inheritance_graph_; } +graph::Call& SymbolTable::calls() { return call_graph_; } + +const graph::Reference& SymbolTable::references() const { + return reference_graph_; +} + +const graph::Inheritance& SymbolTable::inheritance() const { + return inheritance_graph_; +} + +const graph::Call& SymbolTable::calls() const { return call_graph_; } + +ScopeId SymbolTable::CreateScope(ScopeKind kind, const ast::Location& range, + std::optional parent, + std::optional owner) { + return scope_index_.CreateScope(kind, range, parent, owner); +} + +void SymbolTable::AddSymbolToScope(ScopeId scope_id, const std::string& name, + SymbolId symbol_id) { + scope_index_.AddSymbol(scope_id, name, symbol_id); +} + +void SymbolTable::AddReference(SymbolId symbol_id, + const ast::Location& location, + bool is_definition, bool is_write) { + reference_graph_.AddReference(symbol_id, location, is_definition, is_write); +} + +void SymbolTable::AddInheritance(SymbolId derived, SymbolId base) { + inheritance_graph_.AddInheritance(derived, base); +} + +void SymbolTable::AddCall(SymbolId caller, SymbolId callee, + const ast::Location& location) { + call_graph_.AddCall(caller, callee, location); +} + +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/table.hpp b/lsp-server/src/language/symbol/table.hpp index e469684..7df6c64 100644 --- a/lsp-server/src/language/symbol/table.hpp +++ b/lsp-server/src/language/symbol/table.hpp @@ -11,138 +11,53 @@ #include "./index/scope.hpp" #include "./store.hpp" -namespace lsp::language::symbol -{ +namespace lsp::language::symbol { - class SymbolTable - { - public: - SymbolTable() = default; +class SymbolTable { + public: + SymbolTable() = default; - // ===== Symbol Operations ===== + SymbolId CreateSymbol(Symbol symbol); + bool RemoveSymbol(SymbolId id); + void Clear(); - SymbolId CreateSymbol(Symbol symbol) - { - auto def = Symbol(std::move(symbol)); - auto id = store_.Add(def); + std::vector FindSymbolsByName(const std::string& name) const; + std::optional FindSymbolAt(const ast::Location& location) const; - // Notify all indexes - location_index_.OnSymbolAdded(def); - scope_index_.OnSymbolAdded(def); + const Symbol* definition(SymbolId id) const; + std::vector> all_definitions() const; - return id; - } + index::Location& locations(); + index::Scope& scopes(); - bool RemoveSymbol(SymbolId id) - { - // Notify all components - location_index_.OnSymbolRemoved(id); - scope_index_.OnSymbolRemoved(id); - reference_graph_.OnSymbolRemoved(id); - inheritance_graph_.OnSymbolRemoved(id); - call_graph_.OnSymbolRemoved(id); + const index::Location& locations() const; + const index::Scope& scopes() const; - return store_.Remove(id); - } + graph::Reference& references(); + graph::Inheritance& inheritance(); + graph::Call& calls(); - void Clear() - { - store_.Clear(); - location_index_.Clear(); - scope_index_.Clear(); - reference_graph_.Clear(); - inheritance_graph_.Clear(); - call_graph_.Clear(); - } + const graph::Reference& references() const; + const graph::Inheritance& inheritance() const; + const graph::Call& calls() const; - // ===== Basic Queries ===== + ScopeId CreateScope(ScopeKind kind, const ast::Location& range, + std::optional parent = std::nullopt, + std::optional owner = std::nullopt); + void AddSymbolToScope(ScopeId scope_id, const std::string& name, + SymbolId symbol_id); + void AddReference(SymbolId symbol_id, const ast::Location& location, + bool is_definition = false, bool is_write = false); + void AddInheritance(SymbolId derived, SymbolId base); + void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location); - std::vector FindSymbolsByName(const std::string& name) const - { - return store_.FindByName(name); - } + private: + SymbolStore store_; + index::Location location_index_; + index::Scope scope_index_; + graph::Reference reference_graph_; + graph::Inheritance inheritance_graph_; + graph::Call call_graph_; +}; - std::optional FindSymbolAt(const ast::Location& location) const - { - return location_index_.FindSymbolAt(location); - } - - // ===== Accessors (snake_case) ===== - - // Get single definition - 返回指针是合理的,因为可能不存在 - const Symbol* definition(SymbolId id) const - { - return store_.Get(id); - } - - // Get all definitions - 使用reference_wrapper替代指针 - std::vector> all_definitions() const - { - return store_.GetAll(); - } - - // Access indexes (non-const) - index::Location& locations() { return location_index_; } - index::Scope& scopes() { return scope_index_; } - - // Access indexes (const) - const index::Location& locations() const { return location_index_; } - const index::Scope& scopes() const { return scope_index_; } - - // Access graphs (non-const) - graph::Reference& references() { return reference_graph_; } - graph::Inheritance& inheritance() { return inheritance_graph_; } - graph::Call& calls() { return call_graph_; } - - // Access graphs (const) - const graph::Reference& references() const { return reference_graph_; } - const graph::Inheritance& inheritance() const { return inheritance_graph_; } - const graph::Call& calls() const { return call_graph_; } - - // ===== Convenience Methods (shortcuts for common operations) ===== - - // Create scope - ScopeId CreateScope(ScopeKind kind, const ast::Location& range, std::optional parent = std::nullopt, std::optional owner = std::nullopt) - { - return scope_index_.CreateScope(kind, range, parent, owner); - } - - // Add symbol to scope - void AddSymbolToScope(ScopeId scope_id, const std::string& name, SymbolId symbol_id) - { - scope_index_.AddSymbol(scope_id, name, symbol_id); - } - - // Add reference - void AddReference(SymbolId symbol_id, const ast::Location& location, bool is_definition = false, bool is_write = false) - { - reference_graph_.AddReference(symbol_id, location, is_definition, is_write); - } - - // Add inheritance relationship - void AddInheritance(SymbolId derived, SymbolId base) - { - inheritance_graph_.AddInheritance(derived, base); - } - - // Add call relationship - void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location) - { - call_graph_.AddCall(caller, callee, location); - } - - private: - // Storage layer - SymbolStore store_; - - // Index layer - index::Location location_index_; - index::Scope scope_index_; - - // Graph layer - graph::Reference reference_graph_; - graph::Inheritance inheritance_graph_; - graph::Call call_graph_; - }; - -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/types.hpp b/lsp-server/src/language/symbol/types.hpp index 490a243..d5ffa18 100644 --- a/lsp-server/src/language/symbol/types.hpp +++ b/lsp-server/src/language/symbol/types.hpp @@ -5,285 +5,246 @@ #include #include -#include "../ast/types.hpp" #include "../../protocol/protocol.hpp" +#include "../ast/types.hpp" -namespace lsp::language::symbol -{ +namespace lsp::language::symbol { - // ===== Basic Types ===== +// ===== Basic Types ===== - using SymbolId = uint64_t; - constexpr SymbolId kInvalidSymbolId = 0; +using SymbolId = uint64_t; +constexpr SymbolId kInvalidSymbolId = 0; - using ScopeId = uint64_t; - constexpr ScopeId kInvalidScopeId = 0; +using ScopeId = uint64_t; +constexpr ScopeId kInvalidScopeId = 0; - using SymbolKind = protocol::SymbolKind; +using SymbolKind = protocol::SymbolKind; - enum class ScopeKind - { - kGlobal, - kUnit, - kClass, - kFunction, - kAnonymousFunction, - kBlock - }; +enum class ScopeKind { + kGlobal, + kUnit, + kClass, + kFunction, + kAnonymousFunction, + kBlock +}; - enum class VariableScope - { - kAutomatic, - kStatic, - kGlobal, - kParameter, - kField - }; +enum class VariableScope { kAutomatic, kStatic, kGlobal, kParameter, kField }; - enum class UnitVisibility - { - kInterface, - kImplementation - }; +enum class UnitVisibility { kInterface, kImplementation }; - // ===== Parameters ===== +// ===== Parameters ===== - struct Parameter - { - std::string name; - std::optional type; - std::optional default_value; - }; +struct Parameter { + std::string name; + std::optional type; + std::optional default_value; +}; - struct UnitImport - { - std::string unit_name; - ast::Location location; - }; +struct UnitImport { + std::string unit_name; + ast::Location location; +}; - // ===== Symbol Types (all fields directly expanded) ===== +// ===== Symbol Types (all fields directly expanded) ===== - struct Function - { - static constexpr SymbolKind kind = SymbolKind::Function; +struct Function { + static constexpr SymbolKind kind = SymbolKind::Function; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; // Name identifier location (for cursor) - ast::Location range; // Full range (including modifiers, body) + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; // Name identifier location (for cursor) + ast::Location range; // Full range (including modifiers, body) - ast::Location declaration_range; - std::optional implementation_range; - std::vector parameters; - std::optional return_type; + ast::Location declaration_range; + std::optional implementation_range; + std::vector parameters; + std::optional return_type; - std::vector imports; - std::optional unit_visibility; + std::vector imports; + std::optional unit_visibility; - bool HasImplementation() const - { - return implementation_range.has_value(); - } - }; + bool HasImplementation() const { return implementation_range.has_value(); } +}; - struct Class - { - static constexpr SymbolKind kind = SymbolKind::Class; +struct Class { + static constexpr SymbolKind kind = SymbolKind::Class; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - std::optional unit_visibility; + std::optional unit_visibility; - std::vector base_classes; - std::vector members; + std::vector base_classes; + std::vector members; - std::vector imports; - }; + std::vector imports; +}; - struct Method - { - static constexpr SymbolKind kind = SymbolKind::Method; +struct Method { + static constexpr SymbolKind kind = SymbolKind::Method; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - // Location information - ast::Location declaration_range; - std::optional implementation_range; + // Location information + ast::Location declaration_range; + std::optional implementation_range; - // Method-specific - ast::MethodKind method_kind = ast::MethodKind::kOrdinary; - ast::AccessModifier access = ast::AccessModifier::kPublic; - std::optional method_modifier = - ast::MethodModifier::kNone; - bool is_static = false; - std::vector parameters; - std::optional return_type; + // Method-specific + ast::MethodKind method_kind = ast::MethodKind::kOrdinary; + ast::AccessModifier access = ast::AccessModifier::kPublic; + std::optional method_modifier = + ast::MethodModifier::kNone; + bool is_static = false; + std::vector parameters; + std::optional return_type; - std::vector imports; + std::vector imports; - bool HasImplementation() const - { - return implementation_range.has_value(); - } - }; + bool HasImplementation() const { return implementation_range.has_value(); } +}; - struct Property - { - static constexpr SymbolKind kind = SymbolKind::Property; +struct Property { + static constexpr SymbolKind kind = SymbolKind::Property; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - // Property-specific - ast::AccessModifier access = ast::AccessModifier::kPublic; - std::optional type; - std::optional getter; - std::optional setter; - }; + // Property-specific + ast::AccessModifier access = ast::AccessModifier::kPublic; + std::optional type; + std::optional getter; + std::optional setter; +}; - struct Field - { - static constexpr SymbolKind kind = SymbolKind::Field; +struct Field { + static constexpr SymbolKind kind = SymbolKind::Field; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - ast::AccessModifier access = ast::AccessModifier::kPublic; - std::optional reference_modifier; - std::optional type; - bool is_static = false; - }; + ast::AccessModifier access = ast::AccessModifier::kPublic; + std::optional reference_modifier; + std::optional type; + bool is_static = false; +}; - struct Variable - { - static constexpr SymbolKind kind = SymbolKind::Variable; +struct Variable { + static constexpr SymbolKind kind = SymbolKind::Variable; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - std::optional type; - std::optional reference_modifier; - VariableScope storage = VariableScope::kAutomatic; - std::optional unit_visibility; - bool has_initializer = false; - }; + std::optional type; + std::optional reference_modifier; + VariableScope storage = VariableScope::kAutomatic; + std::optional unit_visibility; + bool has_initializer = false; +}; - struct Constant - { - static constexpr SymbolKind kind = SymbolKind::Constant; +struct Constant { + static constexpr SymbolKind kind = SymbolKind::Constant; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - std::optional type; - std::string value; - }; + std::optional type; + std::string value; +}; - struct Unit - { - static constexpr SymbolKind kind = SymbolKind::Namespace; +struct Unit { + static constexpr SymbolKind kind = SymbolKind::Namespace; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - std::vector interface_imports; - std::vector implementation_imports; - }; + std::vector interface_imports; + std::vector implementation_imports; +}; - struct Reference - { - ast::Location location; - SymbolId symbol_id; - bool is_definition; - bool is_write; - }; +struct Reference { + ast::Location location; + SymbolId symbol_id; + bool is_definition; + bool is_write; +}; - struct Call - { - SymbolId caller; - SymbolId callee; - ast::Location call_site; - }; +struct Call { + SymbolId caller; + SymbolId callee; + ast::Location call_site; +}; - // ===== Symbol Data Variant ===== +// ===== Symbol Data Variant ===== - using SymbolData = std::variant; +using SymbolData = std::variant; - // ===== Symbol ===== +// ===== Symbol ===== - class Symbol - { - public: - explicit Symbol(SymbolData data) : data_(std::move(data)) {} +class Symbol { + public: + explicit Symbol(SymbolData data) : data_(std::move(data)) {} - // Type checking and conversion - template - bool Is() const - { - return std::holds_alternative(data_); - } + // Type checking and conversion + template + bool Is() const { + return std::holds_alternative(data_); + } - template - const T* As() const - { - return std::get_if(&data_); - } + template + const T* As() const { + return std::get_if(&data_); + } - template - T* As() - { - return std::get_if(&data_); - } + template + T* As() { + return std::get_if(&data_); + } - // Accessors (snake_case per Google style) - const SymbolData& data() const { return data_; } - SymbolData& mutable_data() { return data_; } + // Accessors (snake_case per Google style) + const SymbolData& data() const { return data_; } + SymbolData& mutable_data() { return data_; } - // Common accessors (all symbol types have these) - SymbolId id() const - { - return std::visit([](const auto& s) { return s.id; }, data_); - } + // Common accessors (all symbol types have these) + SymbolId id() const { + return std::visit([](const auto& s) { return s.id; }, data_); + } - const std::string& name() const - { - return std::visit([](const auto& s) -> const auto& { return s.name; }, - data_); - } + const std::string& name() const { + return std::visit([](const auto& s) -> const auto& { return s.name; }, + data_); + } - ast::Location selection_range() const - { - return std::visit([](const auto& s) { return s.selection_range; }, data_); - } + ast::Location selection_range() const { + return std::visit([](const auto& s) { return s.selection_range; }, data_); + } - ast::Location range() const - { - return std::visit([](const auto& s) { return s.range; }, data_); - } + ast::Location range() const { + return std::visit([](const auto& s) { return s.range; }, data_); + } - SymbolKind kind() const - { - return std::visit([](const auto& s) { return s.kind; }, data_); - } + SymbolKind kind() const { + return std::visit([](const auto& s) { return s.kind; }, data_); + } - private: - SymbolData data_; - }; + private: + SymbolData data_; +}; -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol diff --git a/lsp-server/src/provider/text_document/completion.cpp b/lsp-server/src/provider/text_document/completion.cpp index e90612b..b8e5480 100644 --- a/lsp-server/src/provider/text_document/completion.cpp +++ b/lsp-server/src/provider/text_document/completion.cpp @@ -1,872 +1,822 @@ +#include "./completion.hpp" + #include + #include #include + #include "../../language/keyword/repo.hpp" +#include "../../protocol/transform/facade.hpp" +#include "../../service/detail/symbol/utils.hpp" #include "../../service/document.hpp" #include "../../service/symbol.hpp" #include "../../utils/string.hpp" -#include "../../protocol/transform/facade.hpp" -#include "./completion.hpp" -namespace lsp::provider::text_document -{ - namespace context_analyzer - { - CompletionContext Analyze(const protocol::CompletionParams& params, const std::optional& document_content) - { - CompletionContext result; - result.uri = params.textDocument.uri; - result.position = params.position; +namespace lsp::provider::text_document { +namespace context_analyzer { +CompletionContext Analyze(const protocol::CompletionParams& params, + const std::optional& document_content) { + CompletionContext result; + result.uri = params.textDocument.uri; + result.position = params.position; - if (!document_content.has_value()) - { - spdlog::warn("context_analyzer: Document content not available for URI: {}", result.uri); - return result; - } + if (!document_content.has_value()) { + spdlog::warn("context_analyzer: Document content not available for URI: {}", + result.uri); + return result; + } - const std::string& content = *document_content; + const std::string& content = *document_content; - // 定位到光标所在行的起始位置 - size_t line_start = 0; - size_t current_line = 0; - for (size_t i = 0; i < content.length(); ++i) - { - if (current_line == result.position.line) - { - line_start = i; - break; - } - if (content[i] == '\n') - { - ++current_line; - } - } + // 定位到光标所在行的起始位置 + size_t line_start = 0; + size_t current_line = 0; + for (size_t i = 0; i < content.length(); ++i) { + if (current_line == result.position.line) { + line_start = i; + break; + } + if (content[i] == '\n') { + ++current_line; + } + } - size_t cursor_pos = line_start + result.position.character; - cursor_pos = std::min(cursor_pos, content.length()); + size_t cursor_pos = line_start + result.position.character; + cursor_pos = std::min(cursor_pos, content.length()); - // 提取当前行直到光标位置的内容 - result.line_content = content.substr(line_start, cursor_pos - line_start); + // 提取当前行直到光标位置的内容 + result.line_content = content.substr(line_start, cursor_pos - line_start); - std::string class_name, method_prefix; - result.is_class_method_context = IsClassMethodContext(result.line_content, class_name, method_prefix); - if (result.is_class_method_context) - { - result.class_name = class_name; - result.prefix = method_prefix; - } - else - { - // 分析上下文类型(按优先级检查) - result.is_new_context = IsNewContext(result.line_content); - result.is_unit_context = IsUnitContext(result.line_content); - result.is_class_context = IsClassContext(result.line_content); + std::string class_name, method_prefix; + result.is_class_method_context = + IsClassMethodContext(result.line_content, class_name, method_prefix); + if (result.is_class_method_context) { + result.class_name = class_name; + result.prefix = method_prefix; + } else { + // 分析上下文类型(按优先级检查) + result.is_new_context = IsNewContext(result.line_content); + result.is_unit_context = IsUnitContext(result.line_content); + result.is_class_context = IsClassContext(result.line_content); - // 根据上下文类型提取前缀 - if (result.is_new_context) - result.prefix = ExtractNewPrefix(result.line_content); - else if (result.is_unit_context) - result.prefix = ExtractUnitPrefix(result.line_content); - else if (result.is_class_context) - result.prefix = ExtractClassPrefix(result.line_content); - else - result.prefix = ExtractPrefix(cursor_pos, line_start, content); - } - return result; - } + // 根据上下文类型提取前缀 + if (result.is_new_context) + result.prefix = ExtractNewPrefix(result.line_content); + else if (result.is_unit_context) + result.prefix = ExtractUnitPrefix(result.line_content); + else if (result.is_class_context) + result.prefix = ExtractClassPrefix(result.line_content); + else + result.prefix = ExtractPrefix(cursor_pos, line_start, content); + } + return result; +} - bool IsNewContext(const std::string& line) - { - std::string line_lower = utils::ToLower(line); - size_t new_pos = line_lower.rfind("new "); +bool IsNewContext(const std::string& line) { + std::string line_lower = utils::ToLower(line); + size_t new_pos = line_lower.rfind("new "); - if (new_pos == std::string::npos) - return false; + if (new_pos == std::string::npos) return false; - std::string after_new = utils::Trim(line.substr(new_pos + 4)); + std::string after_new = utils::Trim(line.substr(new_pos + 4)); - // 如果 new 后面已经有左括号,说明已经开始构造调用,不再是类名补全上下文 - if (after_new.find('(') != std::string::npos) - return false; + // 如果 new 后面已经有左括号,说明已经开始构造调用,不再是类名补全上下文 + if (after_new.find('(') != std::string::npos) return false; - return true; - } + return true; +} - bool IsClassMethodContext(const std::string& line, std::string& out_class_name, std::string& out_prefix) - { - std::string line_lower = utils::ToLower(line); +bool IsClassMethodContext(const std::string& line, std::string& out_class_name, + std::string& out_prefix) { + std::string line_lower = utils::ToLower(line); - // 查找 class( 模式 - size_t class_pos = line_lower.rfind("class("); - if (class_pos == std::string::npos) - return false; + // 查找 class( 模式 + size_t class_pos = line_lower.rfind("class("); + if (class_pos == std::string::npos) return false; - // 提取 class( 之后的内容 - std::string after_class = line.substr(class_pos + 6); + // 提取 class( 之后的内容 + std::string after_class = line.substr(class_pos + 6); - // 查找右括号和点号 - size_t close_paren = after_class.find(')'); - if (close_paren == std::string::npos) - return false; + // 查找右括号和点号 + size_t close_paren = after_class.find(')'); + if (close_paren == std::string::npos) return false; - size_t dot_pos = after_class.find('.', close_paren); - if (dot_pos == std::string::npos) - return false; + size_t dot_pos = after_class.find('.', close_paren); + if (dot_pos == std::string::npos) return false; - // 提取类名(括号内的内容) - out_class_name = utils::Trim(after_class.substr(0, close_paren)); - if (out_class_name.empty()) - return false; + // 提取类名(括号内的内容) + out_class_name = utils::Trim(after_class.substr(0, close_paren)); + if (out_class_name.empty()) return false; - // 提取点号后的前缀 - out_prefix = utils::Trim(after_class.substr(dot_pos + 1)); + // 提取点号后的前缀 + out_prefix = utils::Trim(after_class.substr(dot_pos + 1)); - return true; - } + return true; +} - bool IsUnitContext(const std::string& line) - { - std::string line_lower = utils::ToLower(line); +bool IsUnitContext(const std::string& line) { + std::string line_lower = utils::ToLower(line); - // 检查 "unit(" 模式 - size_t unit_pos = line_lower.rfind("unit("); - if (unit_pos == std::string::npos) - return false; + // 检查 "unit(" 模式 + size_t unit_pos = line_lower.rfind("unit("); + if (unit_pos == std::string::npos) return false; - // 确保 unit( 后面没有闭合括号 - size_t after_paren = unit_pos + 5; - std::string after = line.substr(after_paren); + // 确保 unit( 后面没有闭合括号 + size_t after_paren = unit_pos + 5; + std::string after = line.substr(after_paren); - if (after.find(')') != std::string::npos) - return false; + if (after.find(')') != std::string::npos) return false; - return true; - } + return true; +} - bool IsClassContext(const std::string& line) - { - std::string line_lower = utils::ToLower(line); +bool IsClassContext(const std::string& line) { + std::string line_lower = utils::ToLower(line); - // 检查 "class(" 模式 - size_t class_pos = line_lower.rfind("class("); - if (class_pos == std::string::npos) - return false; + // 检查 "class(" 模式 + size_t class_pos = line_lower.rfind("class("); + if (class_pos == std::string::npos) return false; - // 确保 class( 后面没有闭合括号 - size_t after_paren = class_pos + 6; - std::string after = line.substr(after_paren); + // 确保 class( 后面没有闭合括号 + size_t after_paren = class_pos + 6; + std::string after = line.substr(after_paren); - if (after.find(')') != std::string::npos) - return false; + if (after.find(')') != std::string::npos) return false; - return true; - } + return true; +} - std::string ExtractPrefix(size_t cursor_pos, size_t line_start, const std::string& content) - { - size_t prefix_start = cursor_pos; +std::string ExtractPrefix(size_t cursor_pos, size_t line_start, + const std::string& content) { + size_t prefix_start = cursor_pos; - while (prefix_start > line_start) - { - char ch = content[prefix_start - 1]; - if (!std::isalnum(ch) && ch != '_') - { - break; - } - --prefix_start; - } + while (prefix_start > line_start) { + char ch = content[prefix_start - 1]; + if (!std::isalnum(ch) && ch != '_') { + break; + } + --prefix_start; + } - return content.substr(prefix_start, cursor_pos - prefix_start); - } + return content.substr(prefix_start, cursor_pos - prefix_start); +} - std::string ExtractNewPrefix(const std::string& line) - { - std::string line_lower = utils::ToLower(line); - size_t new_pos = line_lower.rfind("new "); +std::string ExtractNewPrefix(const std::string& line) { + std::string line_lower = utils::ToLower(line); + size_t new_pos = line_lower.rfind("new "); - if (new_pos == std::string::npos) - return ""; + if (new_pos == std::string::npos) return ""; - return utils::Trim(line.substr(new_pos + 4)); - } + return utils::Trim(line.substr(new_pos + 4)); +} - std::string ExtractUnitPrefix(const std::string& line) - { - std::string line_lower = utils::ToLower(line); - size_t unit_pos = line_lower.rfind("unit("); +std::string ExtractUnitPrefix(const std::string& line) { + std::string line_lower = utils::ToLower(line); + size_t unit_pos = line_lower.rfind("unit("); - if (unit_pos == std::string::npos) - return ""; + if (unit_pos == std::string::npos) return ""; - return utils::Trim(line.substr(unit_pos + 5)); - } + return utils::Trim(line.substr(unit_pos + 5)); +} - std::string ExtractClassPrefix(const std::string& line) - { - std::string line_lower = utils::ToLower(line); - size_t class_pos = line_lower.rfind("class("); +std::string ExtractClassPrefix(const std::string& line) { + std::string line_lower = utils::ToLower(line); + size_t class_pos = line_lower.rfind("class("); - if (class_pos == std::string::npos) - return ""; + if (class_pos == std::string::npos) return ""; - return utils::Trim(line.substr(class_pos + 6)); - } + return utils::Trim(line.substr(class_pos + 6)); +} +} // namespace context_analyzer + +std::string Completion::GetMethod() const { return "textDocument/completion"; } + +std::string Completion::GetProviderName() const { + return "TextDocumentCompletion"; +} + +std::string Completion::ProvideResponse(const protocol::RequestMessage& request, + ExecutionContext& execution_context) { + spdlog::debug("{}: Processing completion request", GetProviderName()); + + if (!request.params.has_value()) { + spdlog::warn("{}: Missing params in request", GetProviderName()); + return BuildErrorResponseMessage( + request, protocol::ErrorCodes::InvalidParams, "Missing params"); + } + + auto params = + transform::FromLSPAny.template operator()( + request.params.value()); + protocol::CompletionList completion_list = + BuildCompletionList(params, execution_context); + + protocol::ResponseMessage response; + response.id = request.id; + response.result = transform::ToLSPAny(completion_list); + + std::optional json = transform::Serialize(response); + if (!json.has_value()) + return BuildErrorResponseMessage(request, + protocol::ErrorCodes::InternalError, + "Failed to serialize response"); + + return json.value(); +} + +protocol::CompletionList Completion::BuildCompletionList( + const protocol::CompletionParams& params, + ExecutionContext& execution_context) { + spdlog::trace("{}: URI='{}', Position=({}, {})", GetProviderName(), + params.textDocument.uri, params.position.line, + params.position.character); + + auto document_service = execution_context.GetService(); + auto symbol_service = execution_context.GetService(); + auto document_content = document_service->GetContent(params.textDocument.uri); + CompletionContext comp_context = + context_analyzer::Analyze(params, document_content); + + spdlog::debug( + "{}: Completion context - prefix ='{}', is_new_context = {}, " + "is_unit_context = {}, is_class_context = {}, is_class_method_context = " + "{}", + GetProviderName(), comp_context.prefix, comp_context.is_new_context, + comp_context.is_unit_context, comp_context.is_class_context, + comp_context.is_class_method_context); + + std::vector collected_items; + + // 根据上下文类型决定收集哪些补全项 + if (comp_context.is_class_method_context) { + // class(xxx). 上下文:提供指定类的类方法 + auto class_method_items = + CollectClassMethods(comp_context, *symbol_service); + spdlog::trace("{}: Found {} class method completions for class '{}'", + GetProviderName(), class_method_items.size(), + comp_context.class_name); + collected_items.insert(collected_items.end(), class_method_items.begin(), + class_method_items.end()); + } else if (comp_context.is_new_context) { + // new 上下文:提供类名(标记为 new 上下文) + auto class_items = CollectClassNames(comp_context, *symbol_service, true); + spdlog::trace("{}: Found {} class completions for new context", + GetProviderName(), class_items.size()); + collected_items.insert(collected_items.end(), class_items.begin(), + class_items.end()); + } else if (comp_context.is_unit_context) { + // unit( 上下文:只提供 Unit 名称 + auto unit_items = CollectUnitNames(comp_context, *symbol_service); + spdlog::trace("{}: Found {} unit completions", GetProviderName(), + unit_items.size()); + collected_items.insert(collected_items.end(), unit_items.begin(), + unit_items.end()); + } else if (comp_context.is_class_context) { + // class( 上下文:只提供类名 + auto class_items = CollectClassNames(comp_context, *symbol_service, false); + spdlog::trace("{}: Found {} class completions", GetProviderName(), + class_items.size()); + collected_items.insert(collected_items.end(), class_items.begin(), + class_items.end()); + } else { + // 关键字 + auto keyword_items = CollectKeywords(comp_context.prefix); + spdlog::trace("{}: Found {} keyword completions", GetProviderName(), + keyword_items.size()); + collected_items.insert(collected_items.end(), keyword_items.begin(), + keyword_items.end()); + + // 当前编辑文档的所有函数 + auto editing_items = CollectEditingFunctions(comp_context, *symbol_service); + spdlog::trace("{}: Found {} editing functions", GetProviderName(), + editing_items.size()); + collected_items.insert(collected_items.end(), editing_items.begin(), + editing_items.end()); + + // 工作区顶层函数 + auto workspace_items = + CollectWorkspaceFunctions(comp_context, *symbol_service); + spdlog::trace("{}: Found {} workspace functions", GetProviderName(), + workspace_items.size()); + collected_items.insert(collected_items.end(), workspace_items.begin(), + workspace_items.end()); + + // 系统库顶层函数 + auto system_items = CollectSystemFunctions(comp_context, *symbol_service); + spdlog::trace("{}: Found {} system functions", GetProviderName(), + system_items.size()); + collected_items.insert(collected_items.end(), system_items.begin(), + system_items.end()); + } + + std::vector sorted_items = + FilterAndSort(collected_items, comp_context.prefix); + + protocol::CompletionList result; + result.isIncomplete = false; + result.items = std::move(sorted_items); + spdlog::info("{}: Provided {} completion items for prefix '{}'", + GetProviderName(), result.items.size(), comp_context.prefix); + return result; +} + +// ==================== 符号收集函数 ==================== + +std::vector Completion::CollectKeywords( + const std::string& prefix) { + std::vector items; + auto keywords_item = + language::keyword::Repo::Instance().FindByPrefix(utils::ToLower(prefix)); + for (const auto& keyword : keywords_item) { + protocol::CompletionItem item; + item.label = keyword.keyword; + item.kind = protocol::CompletionItemKind::Keyword; + item.labelDetails = protocol::CompletionItemLabelDetails{ + .detail = "", .description = "[K]"}; + items.push_back({item, CompletionSource::kKeyword}); + } + return items; +} + +std::vector Completion::CollectClassMethods( + const CompletionContext& comp_context, + const service::Symbol& symbol_service) { + std::vector items; + + spdlog::debug("CollectClassMethods - Searching for class: '{}', prefix: '{}'", + comp_context.class_name, comp_context.prefix); + + // 从工作区查找指定的类 + const auto& workspace_repo = symbol_service.WorkspaceRepo(); + auto class_info = workspace_repo.FindClass(comp_context.class_name); + + if (!class_info.has_value()) { + spdlog::warn("CollectClassMethods - Class '{}' not found in workspace", + comp_context.class_name); + return items; + } + + spdlog::debug( + "CollectClassMethods - Found class '{}' with {} exported symbols", + comp_context.class_name, class_info->exported_symbols.size()); + + // 遍历该类的导出符号,收集类方法 + size_t total_members = 0; + size_t class_methods = 0; + size_t filtered_by_prefix = 0; + + for (const auto& member : class_info->exported_symbols) { + total_members++; + + spdlog::trace( + "CollectClassMethods - Member: name='{}', kind={}, is_class_method={}", + member.name, static_cast(member.kind), member.is_class_method); + + // 只收集类方法(is_class_method 为 true) + if (!member.is_class_method) { + spdlog::trace("CollectClassMethods - Skipped '{}' (not a class method)", + member.name); + continue; } - std::string Completion::GetMethod() const - { - return "textDocument/completion"; + class_methods++; + + // 前缀过滤 + if (!comp_context.prefix.empty() && + !utils::IStartsWith(member.name, comp_context.prefix)) { + spdlog::trace("CollectClassMethods - Filtered '{}' by prefix", + member.name); + filtered_by_prefix++; + continue; } - std::string Completion::GetProviderName() const - { - return "TextDocumentCompletion"; + protocol::CompletionItem item; + item.label = member.name; + item.kind = ToCompletionItemKind(member.kind); + + // 构建方法签名 + std::string detail = "("; + for (size_t i = 0; i < member.parameters.size(); ++i) { + if (i > 0) detail += ", "; + detail += member.parameters[i].name; + if (!member.parameters[i].type.empty()) + detail += ": " + member.parameters[i].type; } + detail += ")"; + if (member.return_type.has_value()) detail += ": " + *member.return_type; - std::string Completion::ProvideResponse(const protocol::RequestMessage& request, ExecutionContext& execution_context) - { - spdlog::debug("{}: Processing completion request", GetProviderName()); + item.labelDetails = protocol::CompletionItemLabelDetails{ + .detail = detail, .description = "[W]"}; - if (!request.params.has_value()) - { - spdlog::warn("{}: Missing params in request", GetProviderName()); - return BuildErrorResponseMessage(request, protocol::ErrorCodes::InvalidParams, "Missing params"); - } + CompletionInfo ci = {.type = "class_method_context", + .class_name = comp_context.class_name}; + item.data = transform::LSPAnyConverter::ToLSPAny(ci); - auto params = transform::FromLSPAny.template operator()(request.params.value()); - protocol::CompletionList completion_list = BuildCompletionList(params, execution_context); + spdlog::debug("CollectClassMethods - Added class method: {} {}", + member.name, detail); + items.push_back({item, CompletionSource::kWorkspace}); + } - protocol::ResponseMessage response; - response.id = request.id; - response.result = transform::ToLSPAny(completion_list); + spdlog::info( + "CollectClassMethods - Class '{}': total_members={}, class_methods={}, " + "filtered_by_prefix={}, result_count={}", + comp_context.class_name, total_members, class_methods, filtered_by_prefix, + items.size()); - std::optional json = transform::Serialize(response); - if (!json.has_value()) - return BuildErrorResponseMessage(request, protocol::ErrorCodes::InternalError, "Failed to serialize response"); + return items; +} - return json.value(); +std::vector Completion::CollectUnitNames( + const CompletionContext& comp_context, + const service::Symbol& symbol_service) { + std::vector items; + + // 从工作区收集 Unit 名称 + const auto& workspace_repo = symbol_service.WorkspaceRepo(); + auto units = workspace_repo.Units(); + + for (const auto& file_index : units) { + if (comp_context.prefix.empty() || + utils::IStartsWith(file_index.primary_symbol, comp_context.prefix)) { + protocol::CompletionItem item; + item.label = file_index.primary_symbol; + item.kind = protocol::CompletionItemKind::Module; + item.labelDetails = protocol::CompletionItemLabelDetails{ + .detail = "", .description = "[W]"}; + // 添加上下文标记 + item.data = protocol::LSPAny("unit_context"); + items.push_back({item, CompletionSource::kWorkspace}); } + } - protocol::CompletionList Completion::BuildCompletionList(const protocol::CompletionParams& params, ExecutionContext& execution_context) - { - spdlog::trace("{}: URI='{}', Position=({}, {})", GetProviderName(), params.textDocument.uri, params.position.line, params.position.character); + // 从系统库收集 Unit 名称 + // const auto& system_repo = symbol_service.SystemRepo(); + // auto system_units = system_repo.Units(); - auto document_service = execution_context.GetService(); - auto symbol_service = execution_context.GetService(); - auto document_content = document_service->GetContent(params.textDocument.uri); - CompletionContext comp_context = context_analyzer::Analyze(params, document_content); + // for (const auto& sys_index : system_units) + // { + // if (comp_context.prefix.empty() || + // utils::IStartsWith(sys_index.primary_symbol, comp_context.prefix)) + // { + // protocol::CompletionItem item; + // item.label = sys_index.primary_symbol; + // item.kind = protocol::CompletionItemKind::Module; + // item.labelDetails = protocol::CompletionItemLabelDetails{ + // .detail = "", + // .description = "[S]" + // }; + // // 添加上下文标记 + // item.data = protocol::LSPAny("unit_context"); + // items.push_back({ item, CompletionSource::kSystem }); + // } + // } - spdlog::debug("{}: Completion context - prefix ='{}', is_new_context = {}, is_unit_context = {}, is_class_context = {}, is_class_method_context = {}", - GetProviderName(), - comp_context.prefix, - comp_context.is_new_context, - comp_context.is_unit_context, - comp_context.is_class_context, - comp_context.is_class_method_context - ); + return items; +} - std::vector collected_items; +std::vector Completion::CollectClassNames( + const CompletionContext& comp_context, + const service::Symbol& symbol_service, bool is_new_context) { + std::vector items; - // 根据上下文类型决定收集哪些补全项 - if (comp_context.is_class_method_context) - { - // class(xxx). 上下文:提供指定类的类方法 - auto class_method_items = CollectClassMethods(comp_context, *symbol_service); - spdlog::trace("{}: Found {} class method completions for class '{}'", GetProviderName(), class_method_items.size(), comp_context.class_name); - collected_items.insert(collected_items.end(), class_method_items.begin(), class_method_items.end()); - } - else if (comp_context.is_new_context) - { - // new 上下文:提供类名(标记为 new 上下文) - auto class_items = CollectClassNames(comp_context, *symbol_service, true); - spdlog::trace("{}: Found {} class completions for new context", GetProviderName(), class_items.size()); - collected_items.insert(collected_items.end(), class_items.begin(), class_items.end()); - } - else if (comp_context.is_unit_context) - { - // unit( 上下文:只提供 Unit 名称 - auto unit_items = CollectUnitNames(comp_context, *symbol_service); - spdlog::trace("{}: Found {} unit completions", GetProviderName(), unit_items.size()); - collected_items.insert(collected_items.end(), unit_items.begin(), unit_items.end()); - } - else if (comp_context.is_class_context) - { - // class( 上下文:只提供类名 - auto class_items = CollectClassNames(comp_context, *symbol_service, false); - spdlog::trace("{}: Found {} class completions", GetProviderName(), class_items.size()); - collected_items.insert(collected_items.end(), class_items.begin(), class_items.end()); - } - else - { - // 关键字 - auto keyword_items = CollectKeywords(comp_context.prefix); - spdlog::trace("{}: Found {} keyword completions", GetProviderName(), keyword_items.size()); - collected_items.insert(collected_items.end(), keyword_items.begin(), keyword_items.end()); + // 从工作区收集类名 + const auto& workspace_repo = symbol_service.WorkspaceRepo(); + auto classes = workspace_repo.Classes(); - // 当前编辑文档的所有函数 - auto editing_items = CollectEditingFunctions(comp_context, *symbol_service); - spdlog::trace("{}: Found {} editing functions", GetProviderName(), editing_items.size()); - collected_items.insert(collected_items.end(), editing_items.begin(), editing_items.end()); + for (const auto& file_index : classes) { + if (comp_context.prefix.empty() || + utils::IStartsWith(file_index.primary_symbol, comp_context.prefix)) { + protocol::CompletionItem item; + item.label = file_index.primary_symbol; + item.kind = protocol::CompletionItemKind::Class; + item.labelDetails = protocol::CompletionItemLabelDetails{ + .detail = "", .description = "[W]"}; - // 工作区顶层函数 - auto workspace_items = CollectWorkspaceFunctions(comp_context, *symbol_service); - spdlog::trace("{}: Found {} workspace functions", GetProviderName(), workspace_items.size()); - collected_items.insert(collected_items.end(), workspace_items.begin(), workspace_items.end()); + if (is_new_context) + item.data = protocol::LSPAny("new_context"); + else + item.data = protocol::LSPAny("class_context"); - // 系统库顶层函数 - auto system_items = CollectSystemFunctions(comp_context, *symbol_service); - spdlog::trace("{}: Found {} system functions", GetProviderName(), system_items.size()); - collected_items.insert(collected_items.end(), system_items.begin(), system_items.end()); - } - - std::vector sorted_items = FilterAndSort(collected_items, comp_context.prefix); - - protocol::CompletionList result; - result.isIncomplete = false; - result.items = std::move(sorted_items); - spdlog::info("{}: Provided {} completion items for prefix '{}'", GetProviderName(), result.items.size(), comp_context.prefix); - return result; + items.push_back({item, CompletionSource::kWorkspace}); } + } - // ==================== 符号收集函数 ==================== + // 从系统库收集类名 + // const auto& system_repo = symbol_service.SystemRepo(); + // auto system_classes = system_repo.Classes(); - std::vector Completion::CollectKeywords(const std::string& prefix) - { - std::vector items; - auto keywords_item = language::keyword::Repo::Instance().FindByPrefix(utils::ToLower(prefix)); - for (const auto& keyword : keywords_item) - { - protocol::CompletionItem item; - item.label = keyword.keyword; - item.kind = protocol::CompletionItemKind::Keyword; - item.labelDetails = protocol::CompletionItemLabelDetails{ - .detail = "", - .description = "[K]" - }; - items.push_back({ item, CompletionSource::kKeyword }); + // for (const auto& sys_index : system_classes) + // { + // if (comp_context.prefix.empty() || + // utils::IStartsWith(sys_index.primary_symbol, comp_context.prefix)) + // { + // protocol::CompletionItem item; + // item.label = sys_index.primary_symbol; + // item.kind = protocol::CompletionItemKind::Class; + // item.labelDetails = protocol::CompletionItemLabelDetails{ + // .detail = "", + // .description = "[S]" + // }; + + // if (is_new_context) + // item.data = protocol::LSPAny("new_context"); + // else + // item.data = protocol::LSPAny("class_context"); + + // items.push_back({ item, CompletionSource::kSystem }); + // } + // } + + return items; +} + +std::vector Completion::CollectEditingFunctions( + const CompletionContext& comp_context, + const service::Symbol& symbol_service) { + std::vector items; + + spdlog::trace("CollectEditingFunctions - URI: {}, Prefix: '{}'", + comp_context.uri, comp_context.prefix); + + const auto& editing_repo = symbol_service.EditingRepo(); + const auto* symbol_table = editing_repo.GetSymbolTable(comp_context.uri); + + if (!symbol_table) { + spdlog::trace("CollectEditingFunctions - No symbol table for URI: {}", + comp_context.uri); + return items; + } + + auto top_level_symbols = + service::symbol::utils::GetTopLevelSymbols(*symbol_table); + spdlog::trace("CollectEditingFunctions - Found {} top-level symbols", + top_level_symbols.size()); + + for (const auto* symbol : top_level_symbols) { + if (symbol->kind() != protocol::SymbolKind::Function) continue; + + if (!comp_context.prefix.empty() && + !utils::IStartsWith(symbol->name(), comp_context.prefix)) + continue; + + protocol::CompletionItem item; + item.label = symbol->name(); + item.kind = ToCompletionItemKind(symbol->kind()); + + auto detail = service::symbol::utils::BuildSymbolDetail(*symbol); + if (detail.empty()) detail = "()"; + + item.labelDetails = protocol::CompletionItemLabelDetails{ + .detail = detail, .description = "[E]"}; + + spdlog::trace("CollectEditingFunctions - Added: {} {}", symbol->name(), + detail); + items.push_back({item, CompletionSource::kEditing}); + } + + spdlog::debug("CollectEditingFunctions - Collected {} functions", + items.size()); + return items; +} + +std::vector Completion::CollectWorkspaceFunctions( + const CompletionContext& comp_context, + const service::Symbol& symbol_service) { + std::vector items; + + const auto& workspace_repo = symbol_service.WorkspaceRepo(); + + // 获取所有工作区函数 + auto functions = workspace_repo.Functions(); + + for (const auto& file_index : functions) { + if (comp_context.prefix.empty() || + utils::IStartsWith(file_index.primary_symbol, comp_context.prefix)) { + protocol::CompletionItem item; + item.label = file_index.primary_symbol; + item.kind = ToCompletionItemKind(file_index.primary_kind); + + std::string detail = ""; + // 查找与 primary_symbol 匹配的导出符号以获取详细信息 + for (const auto& exported_sym : file_index.exported_symbols) { + if (exported_sym.name == file_index.primary_symbol) { + // 构建参数列表 + detail = "("; + for (size_t i = 0; i < exported_sym.parameters.size(); ++i) { + if (i > 0) detail += ", "; + detail += exported_sym.parameters[i].name; + if (!exported_sym.parameters[i].type.empty()) + detail += ": " + exported_sym.parameters[i].type; + } + detail += ")"; + + // 添加返回类型 + if (exported_sym.return_type.has_value()) + detail += ": " + *exported_sym.return_type; + + break; } - return items; + } + + item.labelDetails = protocol::CompletionItemLabelDetails{ + .detail = detail, .description = "[W]"}; + items.push_back({item, CompletionSource::kWorkspace}); } + } - std::vector Completion::CollectClassMethods(const CompletionContext& comp_context, const service::Symbol& symbol_service) - { - std::vector items; + return items; +} - spdlog::debug("CollectClassMethods - Searching for class: '{}', prefix: '{}'", comp_context.class_name, comp_context.prefix); +std::vector Completion::CollectSystemFunctions( + const CompletionContext& comp_context, + const service::Symbol& symbol_service) { + std::vector items; - // 从工作区查找指定的类 - const auto& workspace_repo = symbol_service.WorkspaceRepo(); - auto class_info = workspace_repo.FindClass(comp_context.class_name); + const auto& system_repo = symbol_service.SystemRepo(); + auto functions = system_repo.Functions(); - if (!class_info.has_value()) - { - spdlog::warn("CollectClassMethods - Class '{}' not found in workspace", - comp_context.class_name); - return items; + for (const auto& file_index : functions) { + if (comp_context.prefix.empty() || + utils::IStartsWith(file_index.primary_symbol, comp_context.prefix)) { + protocol::CompletionItem item; + item.label = file_index.primary_symbol; + item.kind = ToCompletionItemKind(file_index.primary_kind); + + std::string detail = ""; + // 查找与 primary_symbol 匹配的导出符号以获取详细信息 + for (const auto& exported_sym : file_index.exported_symbols) { + if (exported_sym.name == file_index.primary_symbol) { + // 构建参数列表 + detail = "("; + for (size_t i = 0; i < exported_sym.parameters.size(); ++i) { + if (i > 0) detail += ", "; + detail += exported_sym.parameters[i].name; + if (!exported_sym.parameters[i].type.empty()) + detail += ": " + exported_sym.parameters[i].type; + } + detail += ")"; + + // 添加返回类型 + if (exported_sym.return_type.has_value()) + detail += ": " + *exported_sym.return_type; + + break; } + } - spdlog::debug("CollectClassMethods - Found class '{}' with {} exported symbols", - comp_context.class_name, - class_info->exported_symbols.size()); - - // 遍历该类的导出符号,收集类方法 - size_t total_members = 0; - size_t class_methods = 0; - size_t filtered_by_prefix = 0; - - for (const auto& member : class_info->exported_symbols) - { - total_members++; - - spdlog::trace("CollectClassMethods - Member: name='{}', kind={}, is_class_method={}", - member.name, - static_cast(member.kind), - member.is_class_method); - - // 只收集类方法(is_class_method 为 true) - if (!member.is_class_method) - { - spdlog::trace("CollectClassMethods - Skipped '{}' (not a class method)", member.name); - continue; - } - - class_methods++; - - // 前缀过滤 - if (!comp_context.prefix.empty() && - !utils::IStartsWith(member.name, comp_context.prefix)) - { - spdlog::trace("CollectClassMethods - Filtered '{}' by prefix", member.name); - filtered_by_prefix++; - continue; - } - - protocol::CompletionItem item; - item.label = member.name; - item.kind = ToCompletionItemKind(member.kind); - - // 构建方法签名 - std::string detail = "("; - for (size_t i = 0; i < member.parameters.size(); ++i) - { - if (i > 0) - detail += ", "; - detail += member.parameters[i].name; - if (!member.parameters[i].type.empty()) - detail += ": " + member.parameters[i].type; - } - detail += ")"; - if (member.return_type.has_value()) - detail += ": " + *member.return_type; - - item.labelDetails = protocol::CompletionItemLabelDetails{ - .detail = detail, - .description = "[W]" - }; - - CompletionInfo ci = {.type = "class_method_context", .class_name = comp_context.class_name}; - item.data = transform::LSPAnyConverter::ToLSPAny(ci); - - spdlog::debug("CollectClassMethods - Added class method: {} {}", member.name, detail); - items.push_back({ item, CompletionSource::kWorkspace }); - } - - spdlog::info("CollectClassMethods - Class '{}': total_members={}, class_methods={}, filtered_by_prefix={}, result_count={}", - comp_context.class_name, - total_members, - class_methods, - filtered_by_prefix, - items.size()); - - return items; + item.labelDetails = protocol::CompletionItemLabelDetails{ + .detail = detail, .description = "[W]"}; + items.push_back({item, CompletionSource::kWorkspace}); } - - std::vector Completion::CollectUnitNames(const CompletionContext& comp_context, const service::Symbol& symbol_service) - { - std::vector items; - - // 从工作区收集 Unit 名称 - const auto& workspace_repo = symbol_service.WorkspaceRepo(); - auto units = workspace_repo.Units(); - - for (const auto& file_index : units) - { - if (comp_context.prefix.empty() || - utils::IStartsWith(file_index.primary_symbol, comp_context.prefix)) - { - protocol::CompletionItem item; - item.label = file_index.primary_symbol; - item.kind = protocol::CompletionItemKind::Module; - item.labelDetails = protocol::CompletionItemLabelDetails{ - .detail = "", - .description = "[W]" - }; - // 添加上下文标记 - item.data = protocol::LSPAny("unit_context"); - items.push_back({ item, CompletionSource::kWorkspace }); - } - } - - // 从系统库收集 Unit 名称 - // const auto& system_repo = symbol_service.SystemRepo(); - // auto system_units = system_repo.Units(); - - // for (const auto& sys_index : system_units) - // { - // if (comp_context.prefix.empty() || - // utils::IStartsWith(sys_index.primary_symbol, comp_context.prefix)) - // { - // protocol::CompletionItem item; - // item.label = sys_index.primary_symbol; - // item.kind = protocol::CompletionItemKind::Module; - // item.labelDetails = protocol::CompletionItemLabelDetails{ - // .detail = "", - // .description = "[S]" - // }; - // // 添加上下文标记 - // item.data = protocol::LSPAny("unit_context"); - // items.push_back({ item, CompletionSource::kSystem }); - // } - // } - - return items; - } - - std::vector Completion::CollectClassNames(const CompletionContext& comp_context, const service::Symbol& symbol_service, bool is_new_context) - { - std::vector items; - - // 从工作区收集类名 - const auto& workspace_repo = symbol_service.WorkspaceRepo(); - auto classes = workspace_repo.Classes(); - - for (const auto& file_index : classes) - { - if (comp_context.prefix.empty() || - utils::IStartsWith(file_index.primary_symbol, comp_context.prefix)) - { - protocol::CompletionItem item; - item.label = file_index.primary_symbol; - item.kind = protocol::CompletionItemKind::Class; - item.labelDetails = protocol::CompletionItemLabelDetails{ - .detail = "", - .description = "[W]" - }; - - if (is_new_context) - item.data = protocol::LSPAny("new_context"); - else - item.data = protocol::LSPAny("class_context"); - - items.push_back({ item, CompletionSource::kWorkspace }); - } - } - - // 从系统库收集类名 - // const auto& system_repo = symbol_service.SystemRepo(); - // auto system_classes = system_repo.Classes(); - - // for (const auto& sys_index : system_classes) - // { - // if (comp_context.prefix.empty() || - // utils::IStartsWith(sys_index.primary_symbol, comp_context.prefix)) - // { - // protocol::CompletionItem item; - // item.label = sys_index.primary_symbol; - // item.kind = protocol::CompletionItemKind::Class; - // item.labelDetails = protocol::CompletionItemLabelDetails{ - // .detail = "", - // .description = "[S]" - // }; - - // if (is_new_context) - // item.data = protocol::LSPAny("new_context"); - // else - // item.data = protocol::LSPAny("class_context"); - - // items.push_back({ item, CompletionSource::kSystem }); - // } - // } - - return items; - } - - std::vector Completion::CollectEditingFunctions(const CompletionContext& comp_context, const service::Symbol& symbol_service) - { - std::vector items; - - spdlog::trace("CollectEditingFunctions - URI: {}, Prefix: '{}'", - comp_context.uri, - comp_context.prefix); - - const auto& editing_repo = symbol_service.EditingRepo(); - const auto* symbol_table = editing_repo.GetSymbolTable(comp_context.uri); - - if (!symbol_table) - { - spdlog::trace("CollectEditingFunctions - No symbol table for URI: {}", comp_context.uri); - return items; - } - - // 获取所有顶层符号 - auto top_level_symbols = symbol_table->GetDocumentSymbols(); - - spdlog::trace("CollectEditingFunctions - Found {} top-level symbols", top_level_symbols.size()); - - for (const auto* def : top_level_symbols) - { - // 跳过全局命名空间 - if (def->name == "::") - continue; - - // 只收集函数 - if (def->kind != protocol::SymbolKind::Function) - continue; - - // 前缀过滤 - if (!comp_context.prefix.empty() && - !utils::IStartsWith(def->name, comp_context.prefix)) - continue; - - protocol::CompletionItem item; - item.label = def->name; - item.kind = ToCompletionItemKind(def->kind); - - // 使用已格式化的签名 - std::string detail = def->detail.empty() ? "()" : def->detail; - - item.labelDetails = protocol::CompletionItemLabelDetails{ - .detail = detail, - .description = "[E]" // Editing document - }; - - spdlog::trace("CollectEditingFunctions - Added: {} {}", def->name, detail); - items.push_back({ item, CompletionSource::kEditing }); - } - - spdlog::debug("CollectEditingFunctions - Collected {} functions", items.size()); - return items; - } - - std::vector Completion::CollectWorkspaceFunctions(const CompletionContext& comp_context, const service::Symbol& symbol_service) - { - std::vector items; - - const auto& workspace_repo = symbol_service.WorkspaceRepo(); - - // 获取所有工作区函数 - auto functions = workspace_repo.Functions(); - - for (const auto& file_index : functions) - { - if (comp_context.prefix.empty() || utils::IStartsWith(file_index.primary_symbol, comp_context.prefix)) - { - protocol::CompletionItem item; - item.label = file_index.primary_symbol; - item.kind = ToCompletionItemKind(file_index.primary_kind); - - std::string detail = ""; - // 查找与 primary_symbol 匹配的导出符号以获取详细信息 - for (const auto& exported_sym : file_index.exported_symbols) - { - if (exported_sym.name == file_index.primary_symbol) - { - // 构建参数列表 - detail = "("; - for (size_t i = 0; i < exported_sym.parameters.size(); ++i) - { - if (i > 0) - detail += ", "; - detail += exported_sym.parameters[i].name; - if (!exported_sym.parameters[i].type.empty()) - detail += ": " + exported_sym.parameters[i].type; - } - detail += ")"; - - // 添加返回类型 - if (exported_sym.return_type.has_value()) - detail += ": " + *exported_sym.return_type; - - break; - } - } - - item.labelDetails = protocol::CompletionItemLabelDetails{ - .detail = detail, - .description = "[W]" - }; - items.push_back({ item, CompletionSource::kWorkspace }); - } - } - - return items; - } - - std::vector Completion::CollectSystemFunctions(const CompletionContext& comp_context, const service::Symbol& symbol_service) - { - std::vector items; - - const auto& system_repo = symbol_service.SystemRepo(); - auto functions = system_repo.Functions(); - - for (const auto& file_index : functions) - { - if (comp_context.prefix.empty() || - utils::IStartsWith(file_index.primary_symbol, comp_context.prefix)) - { - protocol::CompletionItem item; - item.label = file_index.primary_symbol; - item.kind = ToCompletionItemKind(file_index.primary_kind); - - std::string detail = ""; - // 查找与 primary_symbol 匹配的导出符号以获取详细信息 - for (const auto& exported_sym : file_index.exported_symbols) - { - if (exported_sym.name == file_index.primary_symbol) - { - // 构建参数列表 - detail = "("; - for (size_t i = 0; i < exported_sym.parameters.size(); ++i) - { - if (i > 0) - detail += ", "; - detail += exported_sym.parameters[i].name; - if (!exported_sym.parameters[i].type.empty()) - detail += ": " + exported_sym.parameters[i].type; - } - detail += ")"; - - // 添加返回类型 - if (exported_sym.return_type.has_value()) - detail += ": " + *exported_sym.return_type; - - break; - } - } - - item.labelDetails = protocol::CompletionItemLabelDetails{ - .detail = detail, - .description = "[W]" - }; - items.push_back({ item, CompletionSource::kWorkspace }); - } - } - - return items; - } - - // ==================== 过滤和排序 ==================== - - std::vector Completion::FilterAndSort(const std::vector& items, const std::string& prefix) - { - // 过滤重复项和不匹配项 - std::vector unique_items; - std::set seen_labels; - - for (const auto& item : items) - { - // 去重 - if (seen_labels.count(item.item.label) > 0) - continue; - seen_labels.insert(item.item.label); - - // 前缀过滤(大小写不敏感) - if (!prefix.empty() && - !utils::IStartsWith(item.item.label, prefix)) - continue; - - unique_items.push_back(&item); - } - - // 排序 - std::sort(unique_items.begin(), unique_items.end(), [&prefix](const SourcedCompletionItem* a, const SourcedCompletionItem* b) -> bool { - // 1. 完全匹配最优先 - std::string a_lower = utils::ToLower(a->item.label); - std::string b_lower = utils::ToLower(b->item.label); - std::string prefix_lower = utils::ToLower(prefix); - - bool a_exact = (a_lower == prefix_lower); - bool b_exact = (b_lower == prefix_lower); - - if (a_exact != b_exact) - { - return a_exact; - } - - // 2. 匹配分数(前缀匹配程度) - if (!prefix.empty()) - { - int a_score = GetMatchScore(a->item.label, prefix); - int b_score = GetMatchScore(b->item.label, prefix); - - if (a_score != b_score) - { - return a_score > b_score; - } - } - - // 3. 来源优先级:Editing > Workspace > System > Keyword - if (a->source != b->source) - { - static auto get_source_priority = [](CompletionSource source) -> int { - switch (source) - { - case CompletionSource::kEditing: - return 0; - case CompletionSource::kWorkspace: - return 1; - case CompletionSource::kSystem: - return 2; - case CompletionSource::kKeyword: - return 3; - default: - return 4; - } - }; - - return get_source_priority(a->source) < get_source_priority(b->source); - } - - // 4. 字母顺序 - return a_lower < b_lower; - }); - - // 构建最终结果 - std::vector result; - result.reserve(unique_items.size()); - - for (const auto* sourced_item : unique_items) - { - result.push_back(sourced_item->item); - } - - spdlog::trace("Filtered {} items to {} unique items", items.size(), result.size()); - - return result; - } - - int Completion::GetMatchScore(const std::string& label, const std::string& prefix) - { - if (prefix.empty()) - { - return 0; - } - - std::string label_lower = utils::ToLower(label); + } + + return items; +} + +// ==================== 过滤和排序 ==================== + +std::vector Completion::FilterAndSort( + const std::vector& items, + const std::string& prefix) { + // 过滤重复项和不匹配项 + std::vector unique_items; + std::set seen_labels; + + for (const auto& item : items) { + // 去重 + if (seen_labels.count(item.item.label) > 0) continue; + seen_labels.insert(item.item.label); + + // 前缀过滤(大小写不敏感) + if (!prefix.empty() && !utils::IStartsWith(item.item.label, prefix)) + continue; + + unique_items.push_back(&item); + } + + // 排序 + std::sort( + unique_items.begin(), unique_items.end(), + [&prefix](const SourcedCompletionItem* a, + const SourcedCompletionItem* b) -> bool { + // 1. 完全匹配最优先 + std::string a_lower = utils::ToLower(a->item.label); + std::string b_lower = utils::ToLower(b->item.label); std::string prefix_lower = utils::ToLower(prefix); - // 如果不是前缀匹配,返回0分 - if (!utils::StartsWith(label_lower, prefix_lower)) - return 0; + bool a_exact = (a_lower == prefix_lower); + bool b_exact = (b_lower == prefix_lower); - // 基础分数:前缀占标签长度的比例(0-100) - int base_score = (static_cast(prefix.length()) * 100) / static_cast(label.length()); - - // 大小写完全匹配额外加分 - bool case_match = label.substr(0, prefix.length()) == prefix; - int case_bonus = case_match ? 10 : 0; - - // 长度匹配加分(标签越短越好) - int length_bonus = 0; - if (label.length() == prefix.length()) - { - length_bonus = 20; // 完全匹配长度 - } - else if (label.length() <= prefix.length() + 3) - { - length_bonus = 10; // 接近匹配长度 + if (a_exact != b_exact) { + return a_exact; } - return base_score + case_bonus + length_bonus; - } + // 2. 匹配分数(前缀匹配程度) + if (!prefix.empty()) { + int a_score = GetMatchScore(a->item.label, prefix); + int b_score = GetMatchScore(b->item.label, prefix); - // ==================== 辅助函数 ==================== - - protocol::CompletionItemKind Completion::ToCompletionItemKind(language::symbol::SymbolKind kind) - { - using language::symbol::SymbolKind; - - switch (kind) - { - case SymbolKind::Function: - case SymbolKind::Method: - return protocol::CompletionItemKind::Function; - case SymbolKind::Constructor: - return protocol::CompletionItemKind::Constructor; - case SymbolKind::Class: - return protocol::CompletionItemKind::Class; - case SymbolKind::Property: - return protocol::CompletionItemKind::Property; - case SymbolKind::Variable: - case SymbolKind::TypeParameter: - return protocol::CompletionItemKind::Variable; - case SymbolKind::Constant: - return protocol::CompletionItemKind::Constant; - case SymbolKind::Module: - return protocol::CompletionItemKind::Module; - default: - return protocol::CompletionItemKind::Text; + if (a_score != b_score) { + return a_score > b_score; + } } - } + // 3. 来源优先级:Editing > Workspace > System > Keyword + if (a->source != b->source) { + static auto get_source_priority = [](CompletionSource source) -> int { + switch (source) { + case CompletionSource::kEditing: + return 0; + case CompletionSource::kWorkspace: + return 1; + case CompletionSource::kSystem: + return 2; + case CompletionSource::kKeyword: + return 3; + default: + return 4; + } + }; + + return get_source_priority(a->source) < + get_source_priority(b->source); + } + + // 4. 字母顺序 + return a_lower < b_lower; + }); + + // 构建最终结果 + std::vector result; + result.reserve(unique_items.size()); + + for (const auto* sourced_item : unique_items) { + result.push_back(sourced_item->item); + } + + spdlog::trace("Filtered {} items to {} unique items", items.size(), + result.size()); + + return result; } + +int Completion::GetMatchScore(const std::string& label, + const std::string& prefix) { + if (prefix.empty()) { + return 0; + } + + std::string label_lower = utils::ToLower(label); + std::string prefix_lower = utils::ToLower(prefix); + + // 如果不是前缀匹配,返回0分 + if (!utils::StartsWith(label_lower, prefix_lower)) return 0; + + // 基础分数:前缀占标签长度的比例(0-100) + int base_score = (static_cast(prefix.length()) * 100) / + static_cast(label.length()); + + // 大小写完全匹配额外加分 + bool case_match = label.substr(0, prefix.length()) == prefix; + int case_bonus = case_match ? 10 : 0; + + // 长度匹配加分(标签越短越好) + int length_bonus = 0; + if (label.length() == prefix.length()) { + length_bonus = 20; // 完全匹配长度 + } else if (label.length() <= prefix.length() + 3) { + length_bonus = 10; // 接近匹配长度 + } + + return base_score + case_bonus + length_bonus; +} + +// ==================== 辅助函数 ==================== + +protocol::CompletionItemKind Completion::ToCompletionItemKind( + language::symbol::SymbolKind kind) { + using language::symbol::SymbolKind; + + switch (kind) { + case SymbolKind::Function: + case SymbolKind::Method: + return protocol::CompletionItemKind::Function; + case SymbolKind::Constructor: + return protocol::CompletionItemKind::Constructor; + case SymbolKind::Class: + return protocol::CompletionItemKind::Class; + case SymbolKind::Property: + return protocol::CompletionItemKind::Property; + case SymbolKind::Variable: + case SymbolKind::TypeParameter: + return protocol::CompletionItemKind::Variable; + case SymbolKind::Constant: + return protocol::CompletionItemKind::Constant; + case SymbolKind::Module: + return protocol::CompletionItemKind::Module; + default: + return protocol::CompletionItemKind::Text; + } +} + +} // namespace lsp::provider::text_document diff --git a/lsp-server/src/provider/workspace/symbol.cpp b/lsp-server/src/provider/workspace/symbol.cpp index d9443dc..5a77ae7 100644 --- a/lsp-server/src/provider/workspace/symbol.cpp +++ b/lsp-server/src/provider/workspace/symbol.cpp @@ -1,242 +1,230 @@ +#include "./symbol.hpp" + #include + #include #include -#include "./symbol.hpp" + #include "../../protocol/transform/facade.hpp" #include "../../service/document.hpp" #include "../../service/symbol.hpp" -namespace lsp::provider::workspace -{ - std::string Symbol::GetMethod() const - { - return "workspace/symbol"; - } +namespace lsp::provider::workspace { +std::string Symbol::GetMethod() const { return "workspace/symbol"; } - std::string Symbol::GetProviderName() const - { - return "WorkSpaceSymbol"; - } +std::string Symbol::GetProviderName() const { return "WorkSpaceSymbol"; } - std::string Symbol::ProvideResponse(const protocol::RequestMessage& request, ExecutionContext& context) - { - spdlog::debug("WorkspaceSymbolProvider: Providing response for method {}", request.method); +std::string Symbol::ProvideResponse(const protocol::RequestMessage& request, + ExecutionContext& context) { + spdlog::debug("WorkspaceSymbolProvider: Providing response for method {}", + request.method); - if (!request.params.has_value()) - { - spdlog::warn("{}: Missing params in request", GetProviderName()); - return BuildErrorResponseMessage(request, protocol::ErrorCodes::InvalidParams, "Missing params"); - } + if (!request.params.has_value()) { + spdlog::warn("{}: Missing params in request", GetProviderName()); + return BuildErrorResponseMessage( + request, protocol::ErrorCodes::InvalidParams, "Missing params"); + } - protocol::WorkspaceSymbolParams params = transform::FromLSPAny.template operator()(request.params.value()); + protocol::WorkspaceSymbolParams params = + transform::FromLSPAny.template + operator()(request.params.value()); - auto symbols = BuildSymbolResponse(params, context); + auto symbols = BuildSymbolResponse(params, context); - protocol::ResponseMessage response; - response.id = request.id; - response.result = transform::ToLSPAny(symbols); + protocol::ResponseMessage response; + response.id = request.id; + response.result = transform::ToLSPAny(symbols); - std::optional json = transform::Serialize(response); - if (!json.has_value()) - return BuildErrorResponseMessage(request, protocol::ErrorCodes::InternalError, "Failed to serialize response"); - return json.value(); - } - - std::vector Symbol::BuildSymbolResponse(const protocol::WorkspaceSymbolParams& params, ExecutionContext& context) - { - spdlog::trace("{}: Searching for symbols matching '{}'", GetProviderName(), params.query); - - auto symbols = SearchSymbols(params.query, context); - - // 按匹配分数排序 - std::sort(symbols.begin(), symbols.end(), [¶ms](const protocol::SymbolInformation& a, const protocol::SymbolInformation& b) { - // 可以实现更复杂的排序逻辑 - return a.name < b.name; - }); - - // 限制返回数量(避免返回太多结果) - const size_t max_results = 100; - if (symbols.size() > max_results) - { - symbols.resize(max_results); - spdlog::debug("{}: Limited results to {} symbols", GetProviderName(), max_results); - } - - spdlog::info("{}: Found {} symbols matching '{}'", GetProviderName(), symbols.size(), params.query); - - return symbols; - } - - std::vector Symbol::SearchSymbols(const std::string& query, ExecutionContext& context) - { - std::vector results; - - // 从容器获取服务 - const auto& document_service = context.GetService(); - const auto& symbol_service = context.GetService(); - - // 获取所有打开的文档 - auto document_uris = document_service.GetAllDocumentUris(); - - for (const auto& uri : document_uris) - { - // 获取文档符号 - auto doc_symbols = symbol_service.GetDocumentSymbols(uri); - - // 递归转换并过滤符号 - for (const auto& symbol : doc_symbols) - { - ConvertToSymbolInformation(symbol, uri, query, results); - } - } - - return results; - } - - void Symbol::ConvertToSymbolInformation( - const protocol::DocumentSymbol& doc_symbol, - const protocol::DocumentUri& uri, - const std::string& query, - std::vector& results, - const std::string& container_name) - { - // 检查是否匹配查询 - if (MatchesQuery(doc_symbol.name, query)) - { - protocol::SymbolInformation info; - info.name = doc_symbol.name; - info.kind = doc_symbol.kind; - info.location.uri = uri; - info.location.range = doc_symbol.range; - - // 设置容器名称 - if (!container_name.empty()) - { - info.containerName = container_name; - } - - results.push_back(info); - } - - // 递归处理子符号 - std::string new_container = container_name.empty() ? doc_symbol.name : container_name + "." + doc_symbol.name; - - for (const auto& child : doc_symbol.children.value()) - { - ConvertToSymbolInformation(child, uri, query, results, new_container); - } - } - - bool Symbol::MatchesQuery(const std::string& symbol_name, const std::string& query) - { - // 空查询匹配所有 - if (query.empty()) - return true; - - // 转换为小写进行不区分大小写的匹配 - std::string lower_symbol = symbol_name; - std::string lower_query = query; - std::transform(lower_symbol.begin(), lower_symbol.end(), lower_symbol.begin(), ::tolower); - std::transform(lower_query.begin(), lower_query.end(), lower_query.begin(), ::tolower); - - // 1. 精确匹配 - if (lower_symbol == lower_query) - return true; - - // 2. 前缀匹配 - if (lower_symbol.find(lower_query) == 0) - return true; - - // 3. 包含匹配 - if (lower_symbol.find(lower_query) != std::string::npos) - return true; - - // 4. 模糊匹配(驼峰匹配) - if (FuzzyMatch(query, symbol_name)) - return true; - - return false; - } - - bool Symbol::FuzzyMatch(const std::string& pattern, const std::string& text) - { - // 实现驼峰匹配 - // 例如 "gS" 匹配 "getString" - size_t pattern_idx = 0; - size_t text_idx = 0; - - while (pattern_idx < pattern.length() && text_idx < text.length()) - { - char p = pattern[pattern_idx]; - - // 查找下一个匹配字符 - bool found = false; - while (text_idx < text.length()) - { - char t = text[text_idx]; - - // 不区分大小写匹配 - if (std::tolower(p) == std::tolower(t)) - { - found = true; - text_idx++; - break; - } - - // 如果模式字符是大写,只在大写字母位置匹配 - if (std::isupper(p) && !std::isupper(t)) - { - text_idx++; - continue; - } - - text_idx++; - } - - if (!found) - return false; - - pattern_idx++; - } - - return pattern_idx == pattern.length(); - } - - int Symbol::CalculateScore(const std::string& symbol_name, const std::string& query) - { - int score = 0; - - // 精确匹配得分最高 - if (symbol_name == query) - return 1000; - - // 前缀匹配得分次高 - if (symbol_name.find(query) == 0) - return 900; - - // 不区分大小写的前缀匹配 - std::string lower_symbol = symbol_name; - std::string lower_query = query; - std::transform(lower_symbol.begin(), lower_symbol.end(), lower_symbol.begin(), ::tolower); - std::transform(lower_query.begin(), lower_query.end(), lower_query.begin(), ::tolower); - - if (lower_symbol.find(lower_query) == 0) - return 800; - - // 包含匹配 - size_t pos = lower_symbol.find(lower_query); - if (pos != std::string::npos) - { - // 越靠前得分越高 - score = 700 - static_cast(pos * 10); - } - - // 模糊匹配得分最低 - if (FuzzyMatch(query, symbol_name)) - { - score = std::max(score, 500); - } - - return score; - } + std::optional json = transform::Serialize(response); + if (!json.has_value()) + return BuildErrorResponseMessage(request, + protocol::ErrorCodes::InternalError, + "Failed to serialize response"); + return json.value(); } + +std::vector Symbol::BuildSymbolResponse( + const protocol::WorkspaceSymbolParams& params, ExecutionContext& context) { + spdlog::trace("{}: Searching for symbols matching '{}'", GetProviderName(), + params.query); + + auto symbols = SearchSymbols(params.query, context); + + // 按匹配分数排序 + std::sort(symbols.begin(), symbols.end(), + [¶ms](const protocol::SymbolInformation& a, + const protocol::SymbolInformation& b) { + // 可以实现更复杂的排序逻辑 + return a.name < b.name; + }); + + // 限制返回数量(避免返回太多结果) + const size_t max_results = 100; + if (symbols.size() > max_results) { + symbols.resize(max_results); + spdlog::debug("{}: Limited results to {} symbols", GetProviderName(), + max_results); + } + + spdlog::info("{}: Found {} symbols matching '{}'", GetProviderName(), + symbols.size(), params.query); + + return symbols; +} + +std::vector Symbol::SearchSymbols( + const std::string& query, ExecutionContext& context) { + std::vector results; + + // 从容器获取服务 + auto document_service = context.GetService(); + auto symbol_service = context.GetService(); + + // 获取所有打开的文档 + auto document_uris = document_service->GetAllDocumentUris(); + + for (const auto& uri : document_uris) { + // 获取文档符号 + auto doc_symbols = symbol_service->GetDocumentSymbols(uri); + + // 递归转换并过滤符号 + for (const auto& symbol : doc_symbols) { + ConvertToSymbolInformation(symbol, uri, query, results); + } + } + + return results; +} + +void Symbol::ConvertToSymbolInformation( + const protocol::DocumentSymbol& doc_symbol, + const protocol::DocumentUri& uri, const std::string& query, + std::vector& results, + const std::string& container_name) { + // 检查是否匹配查询 + if (MatchesQuery(doc_symbol.name, query)) { + protocol::SymbolInformation info; + info.name = doc_symbol.name; + info.kind = doc_symbol.kind; + info.location.uri = uri; + info.location.range = doc_symbol.range; + + // 设置容器名称 + if (!container_name.empty()) { + info.containerName = container_name; + } + + results.push_back(info); + } + + // 递归处理子符号 + std::string new_container = container_name.empty() + ? doc_symbol.name + : container_name + "." + doc_symbol.name; + + for (const auto& child : doc_symbol.children.value()) { + ConvertToSymbolInformation(child, uri, query, results, new_container); + } +} + +bool Symbol::MatchesQuery(const std::string& symbol_name, + const std::string& query) { + // 空查询匹配所有 + if (query.empty()) return true; + + // 转换为小写进行不区分大小写的匹配 + std::string lower_symbol = symbol_name; + std::string lower_query = query; + std::transform(lower_symbol.begin(), lower_symbol.end(), lower_symbol.begin(), + ::tolower); + std::transform(lower_query.begin(), lower_query.end(), lower_query.begin(), + ::tolower); + + // 1. 精确匹配 + if (lower_symbol == lower_query) return true; + + // 2. 前缀匹配 + if (lower_symbol.find(lower_query) == 0) return true; + + // 3. 包含匹配 + if (lower_symbol.find(lower_query) != std::string::npos) return true; + + // 4. 模糊匹配(驼峰匹配) + if (FuzzyMatch(query, symbol_name)) return true; + + return false; +} + +bool Symbol::FuzzyMatch(const std::string& pattern, const std::string& text) { + // 实现驼峰匹配 + // 例如 "gS" 匹配 "getString" + size_t pattern_idx = 0; + size_t text_idx = 0; + + while (pattern_idx < pattern.length() && text_idx < text.length()) { + char p = pattern[pattern_idx]; + + // 查找下一个匹配字符 + bool found = false; + while (text_idx < text.length()) { + char t = text[text_idx]; + + // 不区分大小写匹配 + if (std::tolower(p) == std::tolower(t)) { + found = true; + text_idx++; + break; + } + + // 如果模式字符是大写,只在大写字母位置匹配 + if (std::isupper(p) && !std::isupper(t)) { + text_idx++; + continue; + } + + text_idx++; + } + + if (!found) return false; + + pattern_idx++; + } + + return pattern_idx == pattern.length(); +} + +int Symbol::CalculateScore(const std::string& symbol_name, + const std::string& query) { + int score = 0; + + // 精确匹配得分最高 + if (symbol_name == query) return 1000; + + // 前缀匹配得分次高 + if (symbol_name.find(query) == 0) return 900; + + // 不区分大小写的前缀匹配 + std::string lower_symbol = symbol_name; + std::string lower_query = query; + std::transform(lower_symbol.begin(), lower_symbol.end(), lower_symbol.begin(), + ::tolower); + std::transform(lower_query.begin(), lower_query.end(), lower_query.begin(), + ::tolower); + + if (lower_symbol.find(lower_query) == 0) return 800; + + // 包含匹配 + size_t pos = lower_symbol.find(lower_query); + if (pos != std::string::npos) { + // 越靠前得分越高 + score = 700 - static_cast(pos * 10); + } + + // 模糊匹配得分最低 + if (FuzzyMatch(query, symbol_name)) { + score = std::max(score, 500); + } + + return score; +} +} // namespace lsp::provider::workspace diff --git a/lsp-server/src/service/detail/symbol/conversion.cpp b/lsp-server/src/service/detail/symbol/conversion.cpp index b412239..e400cca 100644 --- a/lsp-server/src/service/detail/symbol/conversion.cpp +++ b/lsp-server/src/service/detail/symbol/conversion.cpp @@ -1,188 +1,111 @@ -#include #include "./conversion.hpp" -namespace lsp::service::symbol -{ - FileSymbolIndex BuildFromEditingTable(const EditingSymbolTable& editing_table) - { - FileSymbolIndex index; +#include - index.uri = editing_table.uri; - index.file_type = editing_table.file_type; - index.file_path = editing_table.file_path; +#include "./utils.hpp" - if (editing_table.symbol_table) - ExtractSymbols(index, *editing_table.symbol_table); +namespace lsp::service::symbol { +FileSymbolIndex BuildFromEditingTable(const EditingSymbolTable& editing_table) { + FileSymbolIndex index; - return index; - } + index.uri = editing_table.uri; + index.file_type = editing_table.file_type; + index.file_path = editing_table.file_path; - void ExtractSymbols(FileSymbolIndex& index, const language::symbol::SymbolTable& table) - { - auto all_defs = table.GetAllDefinitions(); - if (all_defs.empty()) - return; - - // 根据文件类型提取不同的符号 - switch (index.file_type) - { - case TsfFileType::kUnit: - ExtractUnitSymbols(index, table, all_defs); - break; - - case TsfFileType::kClass: - ExtractClassSymbols(index, table, all_defs); - break; - - case TsfFileType::kFunction: - ExtractFunctionSymbols(index, table, all_defs); - break; - - case TsfFileType::kScript: - ExtractScriptSymbols(index, table, all_defs); - break; - } - } - - void ExtractUnitSymbols(FileSymbolIndex& index, const language::symbol::SymbolTable& table, const std::vector& all_defs) - { - for (const auto* def : all_defs) - { - if (def->kind == language::symbol::SymbolKind::Module) - { - index.primary_symbol = def->name; - index.primary_kind = def->kind; - break; - } - } - - // 提取顶层函数(没有父符号的 Function) - for (const auto* def : all_defs) - { - if (def->kind == language::symbol::SymbolKind::Function && !def->parent_id) - { - MemberSymbol member; - member.name = def->name; - member.kind = def->kind; - member.return_type = def->type_hint; - - index.exported_symbols.push_back(std::move(member)); - } - } - } - - void ExtractClassSymbols(FileSymbolIndex& index, const language::symbol::SymbolTable& table, const std::vector& all_defs) - { - // 查找 Class 定义 - const language::symbol::SymbolDefinition* class_def = nullptr; - for (const auto* def : all_defs) - { - if (def->kind == language::symbol::SymbolKind::Class) - { - index.primary_symbol = def->name; - index.primary_kind = def->kind; - class_def = def; - break; - } - } - - if (!class_def) - return; - - // 提取类的所有成员 - auto children = table.GetChildren(class_def->id); - for (auto child_id : children) - { - const auto* member_def = table.GetDefinition(child_id); - if (!member_def) - continue; - - // 只提取 public 成员 - if (member_def->access_modifier && *member_def->access_modifier != language::ast::AccessModifier::kPublic) - continue; - - // 过滤掉私有成员 - if (member_def->kind == language::symbol::SymbolKind::Method || - member_def->kind == language::symbol::SymbolKind::Property || - member_def->kind == language::symbol::SymbolKind::Field || - member_def->kind == language::symbol::SymbolKind::Constructor) - { - MemberSymbol member; - member.name = member_def->name; - member.kind = member_def->kind; - member.return_type = member_def->type_hint; - member.access = member_def->access_modifier; - member.is_class_method = member_def->is_class_method; - member.parameters = ExtractFunctionParameters(table, member_def->id); - - index.exported_symbols.push_back(std::move(member)); - } - } - } - - void ExtractFunctionSymbols(FileSymbolIndex& index, const language::symbol::SymbolTable& table, const std::vector& all_defs) - { - // Function 文件的主符号就是函数本身 - for (const auto* def : all_defs) - { - if (def->kind == language::symbol::SymbolKind::Function && !def->parent_id) - { - index.primary_symbol = def->name; - index.primary_kind = def->kind; - - MemberSymbol member; - member.name = def->name; - member.kind = def->kind; - member.return_type = def->type_hint; - member.parameters = ExtractFunctionParameters(table, def->id); - index.exported_symbols.push_back(std::move(member)); - break; - } - } - } - - void ExtractScriptSymbols(FileSymbolIndex& index, const language::symbol::SymbolTable& table, const std::vector& all_defs) - { - index.primary_symbol = "Script"; - index.primary_kind = language::symbol::SymbolKind::Module; - - for (const auto* def : all_defs) - { - if (def->kind == language::symbol::SymbolKind::Function && !def->parent_id) - { - MemberSymbol member; - member.name = def->name; - member.kind = def->kind; - member.return_type = def->type_hint; - member.parameters = ExtractFunctionParameters(table, def->id); - - index.exported_symbols.push_back(std::move(member)); - } - } - } - - std::vector ExtractFunctionParameters(const language::symbol::SymbolTable& table, language::symbol::SymbolId function_id) - { - std::vector params; - - auto children = table.GetChildren(function_id); - for (auto child_id : children) - { - const auto* child_def = table.GetDefinition(child_id); - if (!child_def) - continue; - - if (child_def->kind == language::symbol::SymbolKind::Variable) - { - symbol::Parameter param; - param.name = child_def->name; - param.type = child_def->type_hint.value_or(""); - param.has_default = false; - params.push_back(std::move(param)); - } - } - - return params; - } + if (editing_table.symbol_table) + ExtractSymbols(index, *editing_table.symbol_table); + return index; } + +void ExtractSymbols(FileSymbolIndex& index, + const language::symbol::SymbolTable& table) { + switch (index.file_type) { + case TsfFileType::kUnit: + ExtractUnitSymbols(index, table); + break; + case TsfFileType::kClass: + ExtractClassSymbols(index, table); + break; + case TsfFileType::kFunction: + ExtractFunctionSymbols(index, table); + break; + case TsfFileType::kScript: + ExtractScriptSymbols(index, table); + break; + } +} + +void ExtractUnitSymbols(FileSymbolIndex& index, + const language::symbol::SymbolTable& table) { + auto top_level = utils::GetTopLevelSymbols(table); + for (const auto* symbol : top_level) { + if (symbol->kind() == language::symbol::SymbolKind::Namespace) { + index.primary_symbol = symbol->name(); + index.primary_kind = symbol->kind(); + break; + } + } + + for (const auto* symbol : top_level) { + if (symbol->kind() == language::symbol::SymbolKind::Function) { + index.exported_symbols.push_back(utils::BuildMemberSymbol(*symbol)); + } + } +} + +void ExtractClassSymbols(FileSymbolIndex& index, + const language::symbol::SymbolTable& table) { + auto top_level = utils::GetTopLevelSymbols(table); + const language::symbol::Symbol* class_symbol = nullptr; + + for (const auto* symbol : top_level) { + if (symbol->kind() == language::symbol::SymbolKind::Class) { + class_symbol = symbol; + index.primary_symbol = symbol->name(); + index.primary_kind = symbol->kind(); + break; + } + } + + if (!class_symbol) return; + + auto members = utils::GetChildSymbols(table, class_symbol->id()); + for (const auto* member : members) { + if (!utils::IsPublicMember(*member)) continue; + + if (member->kind() == language::symbol::SymbolKind::Method || + member->kind() == language::symbol::SymbolKind::Property || + member->kind() == language::symbol::SymbolKind::Field) { + index.exported_symbols.push_back(utils::BuildMemberSymbol(*member)); + } + } +} + +void ExtractFunctionSymbols(FileSymbolIndex& index, + const language::symbol::SymbolTable& table) { + auto top_level = utils::GetTopLevelSymbols(table); + for (const auto* symbol : top_level) { + if (symbol->kind() == language::symbol::SymbolKind::Function) { + index.primary_symbol = symbol->name(); + index.primary_kind = symbol->kind(); + index.exported_symbols.push_back(utils::BuildMemberSymbol(*symbol)); + break; + } + } +} + +void ExtractScriptSymbols(FileSymbolIndex& index, + const language::symbol::SymbolTable& table) { + index.primary_symbol = "Script"; + index.primary_kind = language::symbol::SymbolKind::Namespace; + + auto top_level = utils::GetTopLevelSymbols(table); + for (const auto* symbol : top_level) { + if (symbol->kind() == language::symbol::SymbolKind::Function) { + index.exported_symbols.push_back(utils::BuildMemberSymbol(*symbol)); + } + } +} + +} // namespace lsp::service::symbol diff --git a/lsp-server/src/service/detail/symbol/conversion.hpp b/lsp-server/src/service/detail/symbol/conversion.hpp index 5d094a9..632b9da 100644 --- a/lsp-server/src/service/detail/symbol/conversion.hpp +++ b/lsp-server/src/service/detail/symbol/conversion.hpp @@ -1,14 +1,17 @@ #pragma once -#include "./types.hpp" #include "../../../language/symbol/table.hpp" +#include "./types.hpp" -namespace lsp::service::symbol -{ - FileSymbolIndex BuildFromEditingTable(const EditingSymbolTable& editing_table); - void ExtractSymbols(FileSymbolIndex& index, const language::symbol::SymbolTable& table); - void ExtractUnitSymbols(FileSymbolIndex& index, const language::symbol::SymbolTable& table, const std::vector& all_defs); - void ExtractClassSymbols(FileSymbolIndex& index, const language::symbol::SymbolTable& table, const std::vector& all_defs); - void ExtractFunctionSymbols(FileSymbolIndex& index, const language::symbol::SymbolTable& table, const std::vector& all_defs); - void ExtractScriptSymbols(FileSymbolIndex& index, const language::symbol::SymbolTable& table, const std::vector& all_defs); - std::vector ExtractFunctionParameters(const language::symbol::SymbolTable& table, language::symbol::SymbolId function_id); -} +namespace lsp::service::symbol { +FileSymbolIndex BuildFromEditingTable(const EditingSymbolTable& editing_table); +void ExtractSymbols(FileSymbolIndex& index, + const language::symbol::SymbolTable& table); +void ExtractUnitSymbols(FileSymbolIndex& index, + const language::symbol::SymbolTable& table); +void ExtractClassSymbols(FileSymbolIndex& index, + const language::symbol::SymbolTable& table); +void ExtractFunctionSymbols(FileSymbolIndex& index, + const language::symbol::SymbolTable& table); +void ExtractScriptSymbols(FileSymbolIndex& index, + const language::symbol::SymbolTable& table); +} // namespace lsp::service::symbol diff --git a/lsp-server/src/service/detail/symbol/utils.cpp b/lsp-server/src/service/detail/symbol/utils.cpp new file mode 100644 index 0000000..38ad065 --- /dev/null +++ b/lsp-server/src/service/detail/symbol/utils.cpp @@ -0,0 +1,207 @@ +#include "utils.hpp" + +#include + +namespace lsp::service::symbol::utils { +namespace { + +std::vector CollectSymbolsInScope( + const language::symbol::SymbolTable& table, + language::symbol::ScopeId scope_id) { + std::vector result; + const auto* scope = table.scopes().scope(scope_id); + if (!scope) { + return result; + } + + result.reserve(scope->symbols.size()); + for (const auto& [_, symbol_id] : scope->symbols) { + if (const auto* symbol = table.definition(symbol_id)) { + result.push_back(symbol); + } + } + + std::sort(result.begin(), result.end(), + [](const language::symbol::Symbol* lhs, + const language::symbol::Symbol* rhs) { + return lhs->range().start_offset < rhs->range().start_offset; + }); + return result; +} + +language::symbol::ScopeId FindScopeOwnedBy( + const language::symbol::SymbolTable& table, + language::symbol::SymbolId owner) { + const auto& scopes = table.scopes().all_scopes(); + for (const auto& [scope_id, info] : scopes) { + if (info.owner && *info.owner == owner) { + return scope_id; + } + } + return language::symbol::kInvalidScopeId; +} + +std::string BuildParameterList( + const std::vector& params) { + std::string detail = "("; + for (size_t i = 0; i < params.size(); ++i) { + if (i > 0) { + detail += ", "; + } + + detail += params[i].name; + if (params[i].type && !params[i].type->empty()) { + detail += ": " + *params[i].type; + } + } + detail += ")"; + return detail; +} + +} // namespace + +std::vector GetTopLevelSymbols( + const language::symbol::SymbolTable& table) { + auto global_scope = table.scopes().global_scope(); + if (global_scope == language::symbol::kInvalidScopeId) { + return {}; + } + return CollectSymbolsInScope(table, global_scope); +} + +std::vector GetChildSymbols( + const language::symbol::SymbolTable& table, + language::symbol::SymbolId parent_id) { + auto scope_id = FindScopeOwnedBy(table, parent_id); + if (scope_id == language::symbol::kInvalidScopeId) { + return {}; + } + return CollectSymbolsInScope(table, scope_id); +} + +std::vector ConvertParameters( + const std::vector& params) { + std::vector converted; + converted.reserve(params.size()); + + for (const auto& param : params) { + service::symbol::Parameter converted_param; + converted_param.name = param.name; + converted_param.type = param.type.value_or(""); + converted_param.has_default = param.default_value.has_value(); + converted.push_back(std::move(converted_param)); + } + + return converted; +} + +std::optional GetReturnType( + const language::symbol::Symbol& symbol) { + if (const auto* fn = symbol.As()) { + return fn->return_type; + } + if (const auto* method = symbol.As()) { + return method->return_type; + } + if (const auto* property = symbol.As()) { + return property->type; + } + if (const auto* field = symbol.As()) { + return field->type; + } + if (const auto* variable = symbol.As()) { + if (variable->type) { + return *variable->type; + } + } + if (const auto* constant = symbol.As()) { + if (constant->type && !constant->type->empty()) { + return constant->type; + } + } + return std::nullopt; +} + +std::string BuildSymbolDetail(const language::symbol::Symbol& symbol) { + if (const auto* fn = symbol.As()) { + std::string detail = BuildParameterList(fn->parameters); + if (fn->return_type && !fn->return_type->empty()) { + detail += ": " + *fn->return_type; + } + return detail; + } + + if (const auto* method = symbol.As()) { + std::string detail = BuildParameterList(method->parameters); + if (method->return_type && !method->return_type->empty()) { + detail += ": " + *method->return_type; + } + return detail; + } + + if (const auto* property = symbol.As()) { + return property->type ? ": " + *property->type : ""; + } + + if (const auto* field = symbol.As()) { + return field->type ? ": " + *field->type : ""; + } + + if (const auto* variable = symbol.As()) { + return variable->type ? ": " + *variable->type : ""; + } + + if (const auto* constant = symbol.As()) { + if (constant->type && !constant->type->empty()) { + return ": " + *constant->type; + } + } + + return {}; +} + +bool IsPublicMember(const language::symbol::Symbol& symbol) { + if (const auto* method = symbol.As()) { + return method->access == language::ast::AccessModifier::kPublic; + } + if (const auto* property = symbol.As()) { + return property->access == language::ast::AccessModifier::kPublic; + } + if (const auto* field = symbol.As()) { + return field->access == language::ast::AccessModifier::kPublic; + } + return true; +} + +service::symbol::MemberSymbol BuildMemberSymbol( + const language::symbol::Symbol& symbol) { + service::symbol::MemberSymbol member; + member.name = symbol.name(); + member.kind = symbol.kind(); + member.is_class_method = false; + + if (const auto* fn = symbol.As()) { + member.parameters = ConvertParameters(fn->parameters); + member.return_type = fn->return_type; + } else if (const auto* method = symbol.As()) { + member.parameters = ConvertParameters(method->parameters); + member.return_type = method->return_type; + member.access = method->access; + member.is_class_method = method->is_static; + } else if (const auto* property = symbol.As()) { + member.return_type = property->type; + member.access = property->access; + } else if (const auto* field = symbol.As()) { + member.return_type = field->type; + member.access = field->access; + member.is_class_method = field->is_static; + } else if (const auto* variable = symbol.As()) { + member.return_type = variable->type; + } else if (const auto* constant = symbol.As()) { + member.return_type = constant->type; + } + + return member; +} + +} // namespace lsp::service::symbol::utils diff --git a/lsp-server/src/service/detail/symbol/utils.hpp b/lsp-server/src/service/detail/symbol/utils.hpp new file mode 100644 index 0000000..50e6a39 --- /dev/null +++ b/lsp-server/src/service/detail/symbol/utils.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include + +#include "../../../language/symbol/table.hpp" +#include "./types.hpp" + +namespace lsp::service::symbol::utils { + +std::vector GetTopLevelSymbols( + const language::symbol::SymbolTable& table); + +std::vector GetChildSymbols( + const language::symbol::SymbolTable& table, + language::symbol::SymbolId parent_id); + +std::vector ConvertParameters( + const std::vector& params); + +std::string BuildSymbolDetail(const language::symbol::Symbol& symbol); + +service::symbol::MemberSymbol BuildMemberSymbol( + const language::symbol::Symbol& symbol); + +bool IsPublicMember(const language::symbol::Symbol& symbol); + +std::optional GetReturnType( + const language::symbol::Symbol& symbol); + +} // namespace lsp::service::symbol::utils diff --git a/lsp-server/src/service/parser.cpp b/lsp-server/src/service/parser.cpp index 7832d5a..9f4493a 100644 --- a/lsp-server/src/service/parser.cpp +++ b/lsp-server/src/service/parser.cpp @@ -1,390 +1,360 @@ +#include "./parser.hpp" + +#include + +#include #include #include #include -#include -#include -#include "./utils/text_coordinates.hpp" + #include "../language/ast/deserializer.hpp" -#include "./parser.hpp" +#include "../language/symbol/builder.hpp" +#include "./detail/symbol/utils.hpp" +#include "./utils/text_coordinates.hpp" -namespace lsp::service -{ - // ============= TreeSitter ================== // - - TreeSitter::TreeSitter() - { - parser_ = ts_parser_new(); - if (!parser_) - throw std::runtime_error("Failed to create tree-sitter parser"); - } - - TreeSitter::~TreeSitter() - { - if (parser_) - ts_parser_delete(parser_); - } - - TreeSitter::TreeSitter(TreeSitter&& other) noexcept : - parser_(other.parser_) - { - other.parser_ = nullptr; - } - - bool TreeSitter::SetLanguage(const TSLanguage* language) - { - if (!parser_) - return false; - return ts_parser_set_language(parser_, language); - } - - TSTree* TreeSitter::Parse(const char* content, size_t length, TSTree* old_tree) - { - if (!parser_) - return nullptr; - return ts_parser_parse_string(parser_, old_tree, content, length); - } - - TSParser* TreeSitter::GetRawParser() const - { - return parser_; - }; - - // ============= SyntaxTree ================== // - - SyntaxTree::SyntaxTree(TSTree* tree) : - tree_(tree, ts_tree_delete) {} - - SyntaxTree::~SyntaxTree() = default; - - TSTree* SyntaxTree::Get() const - { - return tree_.get(); - } - - void SyntaxTree::ApplyEdit(const protocol::TextDocumentContentChangeEvent& change, const protocol::string& content) - { - if (!tree_) - return; - protocol::uinteger start_offset = utils::text_coordinates::ToOffset(change.range.start, content); - protocol::uinteger end_offset = utils::text_coordinates::ToOffset(change.range.end, content); - - TSInputEdit edit = {}; - edit.start_byte = start_offset; - edit.old_end_byte = end_offset; - edit.new_end_byte = start_offset + change.text.length(); - - edit.start_point = utils::text_coordinates::ToPoint(change.range.start); - edit.old_end_point = utils::text_coordinates::ToPoint(change.range.end); - edit.new_end_point = utils::text_coordinates::CalculateEndPoint(change.text, edit.start_point); - - ts_tree_edit(tree_.get(), &edit); - } - - TSNode SyntaxTree::GetRootNode() const - { - return tree_ ? ts_tree_root_node(tree_.get()) : TSNode{}; - } - - // ============= SyntaxTreeManager ================== // - - void SyntaxTreeManager::StoreTree(const protocol::DocumentUri& uri, std::unique_ptr tree) - { - std::unique_lock lock(mutex_); - trees_[uri] = std::move(tree); - } - - void SyntaxTreeManager::RemoveTree(const protocol::DocumentUri& uri) - { - std::unique_lock lock(mutex_); - trees_.erase(uri); - } - - SyntaxTree* SyntaxTreeManager::GetTree(const protocol::DocumentUri& uri) - { - std::shared_lock lock(mutex_); - auto it = trees_.find(uri); - return (it != trees_.end()) ? it->second.get() : nullptr; - } - - const SyntaxTree* SyntaxTreeManager::GetTree(const protocol::DocumentUri& uri) const - { - std::shared_lock lock(mutex_); - auto it = trees_.find(uri); - return (it != trees_.end()) ? it->second.get() : nullptr; - } - - size_t SyntaxTreeManager::GetTreeCount() const - { - std::shared_lock lock(mutex_); - return trees_.size(); - } - - void SyntaxTreeManager::Clear() - { - std::unique_lock lock(mutex_); - trees_.clear(); - } - - // ============= Parser ================== // - - Parser::Parser(std::shared_ptr event_bus) : - event_bus_(std::move(event_bus)) - { - if (parser_.SetLanguage(tree_sitter_tsf())) - spdlog::info("Set tree-sitter-tsf successfully"); - else - spdlog::error("Failed to set tree-sitter language"); - - event_bus_->Subscribe( - [this](const auto& e) { OnDocumentOpened(e); }); - event_bus_->Subscribe( - [this](const auto& e) { OnDocumentChanged(e); }); - event_bus_->Subscribe( - [this](const auto& e) { OnDocumentClosed(e); }); - } - - Parser::~Parser() = default; - - TSParser* Parser::GetRawParser() const - { - return parser_.GetRawParser(); - }; - - TSTree* Parser::GetTree(const protocol::DocumentUri& uri) const - { - auto* syntax_tree = syntax_tree_manager_.GetTree(uri); - return syntax_tree ? syntax_tree->Get() : nullptr; - } - - std::optional Parser::ParseTsfFile(const std::string& file_path) - { - std::string extension = std::filesystem::path(file_path).extension().string(); - std::transform(extension.begin(), extension.end(), extension.begin(), ::tolower); - if (extension != ".tsf" && extension != ".tsl") - return std::nullopt; - spdlog::debug("Parse tsf file: {}", file_path); - - // 读取文件 - std::ifstream file(file_path); - if (!file.is_open()) - { - spdlog::trace("Cannot open file: {}", file_path); - return std::nullopt; - } - - std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); - file.close(); - - // 创建局部的 parser 和 extractor(线程安全) - TreeSitter local_parser; - if (!local_parser.SetLanguage(tree_sitter_tsf())) - { - spdlog::error("Failed to set tree-sitter language in ParseTsfFile"); - return std::nullopt; - } - - // 解析 - TSTree* tree = local_parser.Parse(content.c_str(), content.length()); - if (!tree) - { - spdlog::trace("Failed to parse file: {}", file_path); - return std::nullopt; - } - // 解析 - auto tree_deleter = [](TSTree* t) { ts_tree_delete(t); }; - std::unique_ptr tree_guard(tree, tree_deleter); - - symbol::EditingSymbolTable table; - table.uri = "file://" + std::filesystem::absolute(file_path).string(); - table.version = 0; - table.last_parsed = std::chrono::system_clock::now(); - table.is_dirty = false; - table.deserializer = std::make_unique(); - table.file_path = std::filesystem::absolute(file_path).string(); - - auto ast_result = table.deserializer->Parse(ts_tree_root_node(tree), content); - if (ast_result.HasErrors()) - { - for (const auto& it : ast_result.errors) - spdlog::warn("Parse Error = {}", it.message); - } - if (!ast_result.root) - { - spdlog::error("AST root is null for: {}", file_path); - return std::nullopt; - } - table.ast_root = std::move(ast_result.root); - table.symbol_table = std::make_unique(); - - try - { - // 使用 Builder 构建符号表 - table.symbol_table->Build(*table.ast_root); - spdlog::debug("Successfully built symbol table for: {}", file_path); - } - catch (const std::exception& e) - { - spdlog::error("Failed to build symbol table for {}: {}", file_path, e.what()); - return std::nullopt; - } - - auto metadata = InferFileMetadata(*table.symbol_table); - if (!metadata) - { - spdlog::warn("Cannot infer file metadata from symbol table: {}", file_path); - return std::nullopt; - } - table.file_type = metadata->file_type; - - // 验证TSF文件 - std::string file_stem = std::filesystem::path(file_path).stem().string(); - auto normalize_name = [](std::string s) { - std::transform(s.begin(), s.end(), s.begin(), ::tolower); - size_t at_pos = s.find("@"); - if (at_pos != std::string::npos) - s = s.substr(0, at_pos); - return s; - }; - if (normalize_name(file_stem) != normalize_name(metadata->primary_symbol)) - spdlog::warn("File name '{}' doesn't match primary symbol '{}' in: {}", file_stem, metadata->primary_symbol, file_path); - - spdlog::debug("Successfully parsed file: {} (type: {}, symbol: {})", - file_path, - static_cast(metadata->file_type), - metadata->primary_symbol); - return table; - } - - std::optional Parser::InferFileMetadata(const language::symbol::SymbolTable& table) - { - auto top_level_symbols = table.GetDocumentSymbols(); - - if (top_level_symbols.empty()) - { - spdlog::warn("No top-level symbol definitions found in symbol table"); - return std::nullopt; - } - - // 过滤掉特殊的全局命名空间符号 "::" - std::vector filtered_symbols; - for (const auto* sym : top_level_symbols) - { - if (sym->name != "::") - { - filtered_symbols.push_back(sym); - } - } - - if (filtered_symbols.empty()) - { - spdlog::warn("No valid top-level symbols found (only global namespace)"); - return std::nullopt; - } - - // 按位置排序 - std::sort(filtered_symbols.begin(), filtered_symbols.end(), [](const language::symbol::SymbolDefinition* a, const language::symbol::SymbolDefinition* b) { - if (a->location.start_byte != b->location.start_byte) - { - return a->location.start_byte < b->location.start_byte; - } - if (a->location.start_line != b->location.start_line) - { - return a->location.start_line < b->location.start_line; - } - return a->location.start_column < b->location.start_column; - }); - - // 取第一个真实的顶层符号 - const language::symbol::SymbolDefinition* first_top_level = filtered_symbols[0]; - - FileMetadata metadata; - metadata.primary_symbol = first_top_level->name; - metadata.primary_kind = first_top_level->kind; - - // 根据第一个符号推断文件类型 - switch (first_top_level->kind) - { - case protocol::SymbolKind::Module: - metadata.file_type = symbol::TsfFileType::kUnit; - spdlog::trace("File type inferred as Unit from first symbol: '{}'", first_top_level->name); - break; - - case protocol::SymbolKind::Class: - metadata.file_type = symbol::TsfFileType::kClass; - spdlog::trace("File type inferred as Class from first symbol: '{}'", first_top_level->name); - break; - - case protocol::SymbolKind::Function: - metadata.file_type = symbol::TsfFileType::kFunction; - spdlog::trace("File type inferred as Function from first symbol: '{}'", first_top_level->name); - break; - - default: - metadata.file_type = symbol::TsfFileType::kScript; - spdlog::trace("File type inferred as Script (default) from first symbol: '{}'", first_top_level->name); - break; - } - - spdlog::debug("Inferred file metadata: type={}, symbol='{}', kind={}, location={}:{}", - static_cast(metadata.file_type), - metadata.primary_symbol, - static_cast(metadata.primary_kind), - first_top_level->location.start_line, - first_top_level->location.start_column); - - return metadata; - } - - void Parser::OnDocumentOpened(const events::DocumentOpend& event) - { - TSTree* tree = parser_.Parse(event.textDocument.text.c_str(), event.textDocument.text.length()); - if (tree) - { - syntax_tree_manager_.StoreTree(event.textDocument.uri, std::make_unique(tree)); - event_bus_->Publish(events::DocumentParsed{ - .item = event.textDocument, - .tree = tree }); - spdlog::debug("Successfully parsed document: {}", event.textDocument.uri); - } - else - { - spdlog::error("Failed to parsed document: {}", event.textDocument.uri); - } - } - - void Parser::OnDocumentChanged(const events::DocumentChanged& event) - { - SyntaxTree* syntax_tree = syntax_tree_manager_.GetTree(event.uri); - TSTree* old_tree = syntax_tree ? syntax_tree->Get() : nullptr; - - // 应用增量编辑 - if (syntax_tree && old_tree) - for (const auto& change : event.changes) - syntax_tree->ApplyEdit(change, event.content); - - // 增量解析 - TSTree* tree = parser_.Parse(event.content.c_str(), event.content.length(), old_tree); - - if (tree) - { - syntax_tree_manager_.StoreTree(event.uri, std::make_unique(tree)); - event_bus_->Publish(events::DocumentReparsed{ - .item{ .uri = event.uri, .languageId = "", .version = event.version, .text = std::move(event.content) }, - .tree = tree, - }); - - spdlog::debug("Document reparsed successfully: {}", event.uri); - } - else - { - spdlog::error("Failed to reparse document: {}", event.uri); - } - } - - void Parser::OnDocumentClosed(const events::DocumentClosed& event) - { - syntax_tree_manager_.RemoveTree(event.textDocument.uri); - spdlog::debug("Removed syntax tree for: {}", event.textDocument.uri); - } +namespace lsp::service { +// ============= TreeSitter ================== // +TreeSitter::TreeSitter() { + parser_ = ts_parser_new(); + if (!parser_) throw std::runtime_error("Failed to create tree-sitter parser"); } + +TreeSitter::~TreeSitter() { + if (parser_) ts_parser_delete(parser_); +} + +TreeSitter::TreeSitter(TreeSitter&& other) noexcept : parser_(other.parser_) { + other.parser_ = nullptr; +} + +bool TreeSitter::SetLanguage(const TSLanguage* language) { + if (!parser_) return false; + return ts_parser_set_language(parser_, language); +} + +TSTree* TreeSitter::Parse(const char* content, size_t length, + TSTree* old_tree) { + if (!parser_) return nullptr; + return ts_parser_parse_string(parser_, old_tree, content, length); +} + +TSParser* TreeSitter::GetRawParser() const { return parser_; }; + +// ============= SyntaxTree ================== // + +SyntaxTree::SyntaxTree(TSTree* tree) : tree_(tree, ts_tree_delete) {} + +SyntaxTree::~SyntaxTree() = default; + +TSTree* SyntaxTree::Get() const { return tree_.get(); } + +void SyntaxTree::ApplyEdit( + const protocol::TextDocumentContentChangeEvent& change, + const protocol::string& content) { + if (!tree_) return; + protocol::uinteger start_offset = + utils::text_coordinates::ToOffset(change.range.start, content); + protocol::uinteger end_offset = + utils::text_coordinates::ToOffset(change.range.end, content); + + TSInputEdit edit = {}; + edit.start_byte = start_offset; + edit.old_end_byte = end_offset; + edit.new_end_byte = start_offset + change.text.length(); + + edit.start_point = utils::text_coordinates::ToPoint(change.range.start); + edit.old_end_point = utils::text_coordinates::ToPoint(change.range.end); + edit.new_end_point = + utils::text_coordinates::CalculateEndPoint(change.text, edit.start_point); + + ts_tree_edit(tree_.get(), &edit); +} + +TSNode SyntaxTree::GetRootNode() const { + return tree_ ? ts_tree_root_node(tree_.get()) : TSNode{}; +} + +// ============= SyntaxTreeManager ================== // + +void SyntaxTreeManager::StoreTree(const protocol::DocumentUri& uri, + std::unique_ptr tree) { + std::unique_lock lock(mutex_); + trees_[uri] = std::move(tree); +} + +void SyntaxTreeManager::RemoveTree(const protocol::DocumentUri& uri) { + std::unique_lock lock(mutex_); + trees_.erase(uri); +} + +SyntaxTree* SyntaxTreeManager::GetTree(const protocol::DocumentUri& uri) { + std::shared_lock lock(mutex_); + auto it = trees_.find(uri); + return (it != trees_.end()) ? it->second.get() : nullptr; +} + +const SyntaxTree* SyntaxTreeManager::GetTree( + const protocol::DocumentUri& uri) const { + std::shared_lock lock(mutex_); + auto it = trees_.find(uri); + return (it != trees_.end()) ? it->second.get() : nullptr; +} + +size_t SyntaxTreeManager::GetTreeCount() const { + std::shared_lock lock(mutex_); + return trees_.size(); +} + +void SyntaxTreeManager::Clear() { + std::unique_lock lock(mutex_); + trees_.clear(); +} + +// ============= Parser ================== // + +Parser::Parser(std::shared_ptr event_bus) + : event_bus_(std::move(event_bus)) { + if (parser_.SetLanguage(tree_sitter_tsf())) + spdlog::info("Set tree-sitter-tsf successfully"); + else + spdlog::error("Failed to set tree-sitter language"); + + event_bus_->Subscribe( + [this](const auto& e) { OnDocumentOpened(e); }); + event_bus_->Subscribe( + [this](const auto& e) { OnDocumentChanged(e); }); + event_bus_->Subscribe( + [this](const auto& e) { OnDocumentClosed(e); }); +} + +Parser::~Parser() = default; + +TSParser* Parser::GetRawParser() const { return parser_.GetRawParser(); }; + +TSTree* Parser::GetTree(const protocol::DocumentUri& uri) const { + auto* syntax_tree = syntax_tree_manager_.GetTree(uri); + return syntax_tree ? syntax_tree->Get() : nullptr; +} + +std::optional Parser::ParseTsfFile( + const std::string& file_path) { + std::string extension = std::filesystem::path(file_path).extension().string(); + std::transform(extension.begin(), extension.end(), extension.begin(), + ::tolower); + if (extension != ".tsf" && extension != ".tsl") return std::nullopt; + spdlog::debug("Parse tsf file: {}", file_path); + + // 读取文件 + std::ifstream file(file_path); + if (!file.is_open()) { + spdlog::trace("Cannot open file: {}", file_path); + return std::nullopt; + } + + std::string content((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + file.close(); + + // 创建局部的 parser 和 extractor(线程安全) + TreeSitter local_parser; + if (!local_parser.SetLanguage(tree_sitter_tsf())) { + spdlog::error("Failed to set tree-sitter language in ParseTsfFile"); + return std::nullopt; + } + + // 解析 + TSTree* tree = local_parser.Parse(content.c_str(), content.length()); + if (!tree) { + spdlog::trace("Failed to parse file: {}", file_path); + return std::nullopt; + } + // 解析 + auto tree_deleter = [](TSTree* t) { ts_tree_delete(t); }; + std::unique_ptr tree_guard(tree, + tree_deleter); + + symbol::EditingSymbolTable table; + table.uri = "file://" + std::filesystem::absolute(file_path).string(); + table.version = 0; + table.last_parsed = std::chrono::system_clock::now(); + table.is_dirty = false; + table.deserializer = std::make_unique(); + table.file_path = std::filesystem::absolute(file_path).string(); + + auto ast_result = table.deserializer->Parse(ts_tree_root_node(tree), content); + if (ast_result.HasErrors()) { + for (const auto& it : ast_result.errors) + spdlog::warn("Parse Error = {}", it.message); + } + if (!ast_result.root) { + spdlog::error("AST root is null for: {}", file_path); + return std::nullopt; + } + table.ast_root = std::move(ast_result.root); + table.symbol_table = std::make_unique(); + + try { + language::symbol::Builder builder(*table.symbol_table); + builder.Build(*table.ast_root); + spdlog::debug("Successfully built symbol table for: {}", file_path); + } catch (const std::exception& e) { + spdlog::error("Failed to build symbol table for {}: {}", file_path, + e.what()); + return std::nullopt; + } + + auto metadata = InferFileMetadata(*table.symbol_table); + if (!metadata) { + spdlog::warn("Cannot infer file metadata from symbol table: {}", file_path); + return std::nullopt; + } + table.file_type = metadata->file_type; + + // 验证TSF文件 + std::string file_stem = std::filesystem::path(file_path).stem().string(); + auto normalize_name = [](std::string s) { + std::transform(s.begin(), s.end(), s.begin(), ::tolower); + size_t at_pos = s.find("@"); + if (at_pos != std::string::npos) s = s.substr(0, at_pos); + return s; + }; + if (normalize_name(file_stem) != normalize_name(metadata->primary_symbol)) + spdlog::warn("File name '{}' doesn't match primary symbol '{}' in: {}", + file_stem, metadata->primary_symbol, file_path); + + spdlog::debug("Successfully parsed file: {} (type: {}, symbol: {})", + file_path, static_cast(metadata->file_type), + metadata->primary_symbol); + return table; +} + +std::optional Parser::InferFileMetadata( + const language::symbol::SymbolTable& table) { + auto top_level_symbols = symbol::utils::GetTopLevelSymbols(table); + + if (top_level_symbols.empty()) { + spdlog::warn("No top-level symbol definitions found in symbol table"); + return std::nullopt; + } + + std::vector filtered_symbols; + filtered_symbols.reserve(top_level_symbols.size()); + for (const auto* sym : top_level_symbols) { + if (sym->name() != "::") { + filtered_symbols.push_back(sym); + } + } + + if (filtered_symbols.empty()) { + spdlog::warn("No valid top-level symbols found (only global namespace)"); + return std::nullopt; + } + + std::sort( + filtered_symbols.begin(), filtered_symbols.end(), + [](const language::symbol::Symbol* a, const language::symbol::Symbol* b) { + const auto& loc_a = a->range(); + const auto& loc_b = b->range(); + + if (loc_a.start_offset != loc_b.start_offset) + return loc_a.start_offset < loc_b.start_offset; + if (loc_a.start_line != loc_b.start_line) + return loc_a.start_line < loc_b.start_line; + return loc_a.start_column < loc_b.start_column; + }); + + const language::symbol::Symbol* first_top_level = filtered_symbols[0]; + + FileMetadata metadata; + metadata.primary_symbol = first_top_level->name(); + metadata.primary_kind = first_top_level->kind(); + + // 根据第一个符号推断文件类型 + switch (first_top_level->kind()) { + case protocol::SymbolKind::Module: + metadata.file_type = symbol::TsfFileType::kUnit; + spdlog::trace("File type inferred as Unit from first symbol: '{}'", + first_top_level->name()); + break; + + case protocol::SymbolKind::Class: + metadata.file_type = symbol::TsfFileType::kClass; + spdlog::trace("File type inferred as Class from first symbol: '{}'", + first_top_level->name()); + break; + + case protocol::SymbolKind::Function: + metadata.file_type = symbol::TsfFileType::kFunction; + spdlog::trace("File type inferred as Function from first symbol: '{}'", + first_top_level->name()); + break; + + default: + metadata.file_type = symbol::TsfFileType::kScript; + spdlog::trace( + "File type inferred as Script (default) from first symbol: '{}'", + first_top_level->name()); + break; + } + + spdlog::debug( + "Inferred file metadata: type={}, symbol='{}', kind={}, location={}:{}", + static_cast(metadata.file_type), metadata.primary_symbol, + static_cast(metadata.primary_kind), + first_top_level->range().start_line, + first_top_level->range().start_column); + + return metadata; +} + +void Parser::OnDocumentOpened(const events::DocumentOpend& event) { + TSTree* tree = parser_.Parse(event.textDocument.text.c_str(), + event.textDocument.text.length()); + if (tree) { + syntax_tree_manager_.StoreTree(event.textDocument.uri, + std::make_unique(tree)); + event_bus_->Publish( + events::DocumentParsed{.item = event.textDocument, .tree = tree}); + spdlog::debug("Successfully parsed document: {}", event.textDocument.uri); + } else { + spdlog::error("Failed to parsed document: {}", event.textDocument.uri); + } +} + +void Parser::OnDocumentChanged(const events::DocumentChanged& event) { + SyntaxTree* syntax_tree = syntax_tree_manager_.GetTree(event.uri); + TSTree* old_tree = syntax_tree ? syntax_tree->Get() : nullptr; + + // 应用增量编辑 + if (syntax_tree && old_tree) + for (const auto& change : event.changes) + syntax_tree->ApplyEdit(change, event.content); + + // 增量解析 + TSTree* tree = + parser_.Parse(event.content.c_str(), event.content.length(), old_tree); + + if (tree) { + syntax_tree_manager_.StoreTree(event.uri, + std::make_unique(tree)); + event_bus_->Publish(events::DocumentReparsed{ + .item{.uri = event.uri, + .languageId = "", + .version = event.version, + .text = std::move(event.content)}, + .tree = tree, + }); + + spdlog::debug("Document reparsed successfully: {}", event.uri); + } else { + spdlog::error("Failed to reparse document: {}", event.uri); + } +} + +void Parser::OnDocumentClosed(const events::DocumentClosed& event) { + syntax_tree_manager_.RemoveTree(event.textDocument.uri); + spdlog::debug("Removed syntax tree for: {}", event.textDocument.uri); +} + +} // namespace lsp::service diff --git a/lsp-server/src/service/symbol.cpp b/lsp-server/src/service/symbol.cpp index 85432f6..a7d7ec0 100644 --- a/lsp-server/src/service/symbol.cpp +++ b/lsp-server/src/service/symbol.cpp @@ -1,259 +1,302 @@ -#include -#include -#include -#include "../language/ast/deserializer.hpp" -#include "./detail/symbol/conversion.hpp" -#include "./parser.hpp" #include "./symbol.hpp" -namespace lsp::service -{ - Symbol::Symbol(std::shared_ptr event_bus) : - event_bus_(std::move(event_bus)) - { - // 订阅文档事件 - event_bus_->Subscribe( - [this](const auto& e) { OnDocumentParsed(e); }); +#include - event_bus_->Subscribe( - [this](const auto& e) { OnDocumentReparsed(e); }); +#include +#include +#include - event_bus_->Subscribe( - [this](const auto& e) { OnDocumentClosed(e); }); - } +#include "../language/ast/deserializer.hpp" +#include "../language/symbol/builder.hpp" +#include "./detail/symbol/conversion.hpp" +#include "./detail/symbol/utils.hpp" +#include "./parser.hpp" - Symbol::~Symbol() = default; +namespace lsp::service { +namespace { - // ==================== 符号加载 ==================== - - void Symbol::LoadSystemSymbols(const std::string& folder) - { - spdlog::info("Loading system symbols from: {}", folder); - auto start = std::chrono::steady_clock::now(); - - if (!std::filesystem::exists(folder)) - { - spdlog::warn("System folder does not exist: {}", folder); - return; - } - - size_t loaded = 0; - size_t failed = 0; - - auto options = std::filesystem::directory_options::follow_directory_symlink | - std::filesystem::directory_options::skip_permission_denied; - - for (const auto& entry : std::filesystem::recursive_directory_iterator(folder, options)) - { - try - { - // 使用 Parser 解析文件 - auto table_opt = Parser::ParseTsfFile(entry.path().string()); - if (!table_opt) - { - failed++; - continue; - } - - auto& table = *table_opt; - - // 确保符号表存在 - if (!table.symbol_table) - { - spdlog::warn("Symbol table is null for: {}", entry.path().string()); - failed++; - continue; - } - - // 提取系统索引 - symbol::SystemSymbolIndex system_index = symbol::BuildFromEditingTable(table); - - system_repo_.Add(std::move(system_index)); - loaded++; - - if (loaded % 100 == 0) - { - spdlog::debug("Loaded {} system symbols", loaded); - } - } - catch (const std::exception& e) - { - spdlog::error("Exception loading system symbol {}: {}", entry.path().string(), e.what()); - failed++; - } - } - - auto duration = std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); - spdlog::info("System symbols loaded: {} files, {} failed, {}ms", loaded, failed, duration); - } - - void Symbol::LoadWorkspaceSymbols(const std::string& folder) - { - spdlog::info("Loading workspace symbols from: {}", folder); - auto start = std::chrono::steady_clock::now(); - - if (!std::filesystem::exists(folder)) - { - spdlog::warn("Workspace folder does not exist: {}", folder); - return; - } - - size_t loaded = 0; - size_t failed = 0; - - auto options = std::filesystem::directory_options::follow_directory_symlink | - std::filesystem::directory_options::skip_permission_denied; - - for (const auto& entry : std::filesystem::recursive_directory_iterator(folder, options)) - { - if (!entry.is_regular_file() || entry.path().extension() != ".tsf") - continue; - - try - { - // 使用 Parser 解析文件 - auto table_opt = Parser::ParseTsfFile(entry.path().string()); - if (!table_opt) - { - failed++; - continue; - } - - auto& table = *table_opt; - - // 确保符号表存在 - if (!table.symbol_table) - { - spdlog::warn("Symbol table is null for: {}", entry.path().string()); - failed++; - continue; - } - - // 提取文件索引 - symbol::FileSymbolIndex file_index = symbol::BuildFromEditingTable(table); - - workspace_repo_.Add(std::move(file_index)); - loaded++; - - if (loaded % 100 == 0) - spdlog::debug("Loaded {} workspace symbols", loaded); - } - catch (const std::exception& e) - { - spdlog::error("Exception loading workspace symbol {}: {}", entry.path().string(), e.what()); - failed++; - } - } - - auto duration = std::chrono::duration_cast( - std::chrono::steady_clock::now() - start) - .count(); - - spdlog::info("Workspace symbols loaded: {} files, {} failed, {}ms", loaded, failed, duration); - } - - void Symbol::ReloadWorkspaceSymbols(const std::string& folder) - { - spdlog::info("Reloading workspace symbols"); - - // 清空现有符号 - workspace_repo_.Clear(); - - // 重新加载 - LoadWorkspaceSymbols(folder); - } - - const symbol::repository::System& Symbol::SystemRepo() const - { - return system_repo_; - } - - const symbol::repository::Workspace& Symbol::WorkspaceRepo() const - { - return workspace_repo_; - } - - const symbol::repository::Editing& Symbol::EditingRepo() const - { - return editing_repo_; - } - - // ==================== 事件处理 ==================== - - void Symbol::OnDocumentParsed(const events::DocumentParsed& event) - { - if (!event.tree) - { - spdlog::warn("Received null tree for document: {}", event.item.uri); - return; - } - - try - { - // 创建 EditingSymbolTable - symbol::EditingSymbolTable table; - table.uri = event.item.uri; - table.version = event.item.version; - table.last_parsed = std::chrono::system_clock::now(); - table.is_dirty = false; - table.deserializer = std::make_unique(); - - // 使用 AST 反序列化器解析文档 - auto ast_result = table.deserializer->Parse(ts_tree_root_node(event.tree), event.item.text); - if (!ast_result.IsSuccess()) - { - spdlog::error("Failed to deserialize AST for: {}", event.item.uri); - return; - } - - table.ast_root = std::move(ast_result.root); - - // 确定文件类型 (根据 AST root 的 Program 内容判断) - symbol::TsfFileType file_type = symbol::TsfFileType::kScript; - if (table.ast_root) - { - auto& program = *table.ast_root; - if (!program.statements.empty()) - { - auto& first_stmt = program.statements[0]; - if (auto* unit_def = dynamic_cast(first_stmt.get())) - { - file_type = symbol::TsfFileType::kUnit; - } - else if (auto* class_def = dynamic_cast(first_stmt.get())) - { - file_type = symbol::TsfFileType::kClass; - } - else if (auto* func_def = dynamic_cast(first_stmt.get())) - { - file_type = symbol::TsfFileType::kFunction; - } - } - } - table.file_type = file_type; - - // 创建符号表并构建 - table.symbol_table = std::make_unique(); - table.symbol_table->Build(*table.ast_root); - // 添加到编辑仓库 - editing_repo_.AddOrUpdate(std::move(table)); - - spdlog::debug("Document parsed and added to editing repo: {}", event.item.uri); - } - catch (const std::exception& e) - { - spdlog::error("Exception loading editing symbol {}: {}", event.item.uri, e.what()); - } - } - - void Symbol::OnDocumentReparsed(const events::DocumentReparsed& event) - { - spdlog::debug("OnDocumentReparsed"); - } - - void Symbol::OnDocumentClosed(const events::DocumentClosed& event) - { - editing_repo_.Remove(event.textDocument.uri); - spdlog::debug("Document closed and removed from editing repo: {}", event.textDocument.uri); - } +lsp::protocol::Range ToRange(const language::ast::Location& location) { + lsp::protocol::Range range; + range.start.line = location.start_line; + range.start.character = location.start_column; + range.end.line = location.end_line; + range.end.character = location.end_column; + return range; } + +lsp::protocol::DocumentSymbol ToDocumentSymbol( + const language::symbol::SymbolTable& table, + const language::symbol::Symbol& symbol) { + lsp::protocol::DocumentSymbol doc_symbol; + doc_symbol.name = symbol.name(); + doc_symbol.kind = symbol.kind(); + doc_symbol.range = ToRange(symbol.range()); + doc_symbol.selectionRange = ToRange(symbol.selection_range()); + + auto detail = lsp::service::symbol::utils::BuildSymbolDetail(symbol); + if (!detail.empty()) { + doc_symbol.detail = detail; + } + + auto children = + lsp::service::symbol::utils::GetChildSymbols(table, symbol.id()); + if (!children.empty()) { + std::vector child_symbols; + child_symbols.reserve(children.size()); + for (const auto* child : children) { + child_symbols.push_back(ToDocumentSymbol(table, *child)); + } + doc_symbol.children = std::move(child_symbols); + } + + return doc_symbol; +} + +} // namespace +Symbol::Symbol(std::shared_ptr event_bus) + : event_bus_(std::move(event_bus)) { + // 订阅文档事件 + event_bus_->Subscribe( + [this](const auto& e) { OnDocumentParsed(e); }); + + event_bus_->Subscribe( + [this](const auto& e) { OnDocumentReparsed(e); }); + + event_bus_->Subscribe( + [this](const auto& e) { OnDocumentClosed(e); }); +} + +Symbol::~Symbol() = default; + +// ==================== 符号加载 ==================== + +void Symbol::LoadSystemSymbols(const std::string& folder) { + spdlog::info("Loading system symbols from: {}", folder); + auto start = std::chrono::steady_clock::now(); + + if (!std::filesystem::exists(folder)) { + spdlog::warn("System folder does not exist: {}", folder); + return; + } + + size_t loaded = 0; + size_t failed = 0; + + auto options = std::filesystem::directory_options::follow_directory_symlink | + std::filesystem::directory_options::skip_permission_denied; + + for (const auto& entry : + std::filesystem::recursive_directory_iterator(folder, options)) { + try { + // 使用 Parser 解析文件 + auto table_opt = Parser::ParseTsfFile(entry.path().string()); + if (!table_opt) { + failed++; + continue; + } + + auto& table = *table_opt; + + // 确保符号表存在 + if (!table.symbol_table) { + spdlog::warn("Symbol table is null for: {}", entry.path().string()); + failed++; + continue; + } + + // 提取系统索引 + symbol::SystemSymbolIndex system_index = + symbol::BuildFromEditingTable(table); + + system_repo_.Add(std::move(system_index)); + loaded++; + + if (loaded % 100 == 0) { + spdlog::debug("Loaded {} system symbols", loaded); + } + } catch (const std::exception& e) { + spdlog::error("Exception loading system symbol {}: {}", + entry.path().string(), e.what()); + failed++; + } + } + + auto duration = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + spdlog::info("System symbols loaded: {} files, {} failed, {}ms", loaded, + failed, duration); +} + +void Symbol::LoadWorkspaceSymbols(const std::string& folder) { + spdlog::info("Loading workspace symbols from: {}", folder); + auto start = std::chrono::steady_clock::now(); + + if (!std::filesystem::exists(folder)) { + spdlog::warn("Workspace folder does not exist: {}", folder); + return; + } + + size_t loaded = 0; + size_t failed = 0; + + auto options = std::filesystem::directory_options::follow_directory_symlink | + std::filesystem::directory_options::skip_permission_denied; + + for (const auto& entry : + std::filesystem::recursive_directory_iterator(folder, options)) { + if (!entry.is_regular_file() || entry.path().extension() != ".tsf") + continue; + + try { + // 使用 Parser 解析文件 + auto table_opt = Parser::ParseTsfFile(entry.path().string()); + if (!table_opt) { + failed++; + continue; + } + + auto& table = *table_opt; + + // 确保符号表存在 + if (!table.symbol_table) { + spdlog::warn("Symbol table is null for: {}", entry.path().string()); + failed++; + continue; + } + + // 提取文件索引 + symbol::FileSymbolIndex file_index = symbol::BuildFromEditingTable(table); + + workspace_repo_.Add(std::move(file_index)); + loaded++; + + if (loaded % 100 == 0) + spdlog::debug("Loaded {} workspace symbols", loaded); + } catch (const std::exception& e) { + spdlog::error("Exception loading workspace symbol {}: {}", + entry.path().string(), e.what()); + failed++; + } + } + + auto duration = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start) + .count(); + + spdlog::info("Workspace symbols loaded: {} files, {} failed, {}ms", loaded, + failed, duration); +} + +void Symbol::ReloadWorkspaceSymbols(const std::string& folder) { + spdlog::info("Reloading workspace symbols"); + + // 清空现有符号 + workspace_repo_.Clear(); + + // 重新加载 + LoadWorkspaceSymbols(folder); +} + +const symbol::repository::System& Symbol::SystemRepo() const { + return system_repo_; +} + +const symbol::repository::Workspace& Symbol::WorkspaceRepo() const { + return workspace_repo_; +} + +const symbol::repository::Editing& Symbol::EditingRepo() const { + return editing_repo_; +} + +std::vector Symbol::GetDocumentSymbols( + const protocol::DocumentUri& uri) const { + std::vector result; + const auto* table = editing_repo_.GetSymbolTable(uri); + if (!table) return result; + + auto top_level = symbol::utils::GetTopLevelSymbols(*table); + result.reserve(top_level.size()); + for (const auto* symbol : top_level) { + if (symbol->name() == "::") continue; + result.push_back(ToDocumentSymbol(*table, *symbol)); + } + + return result; +} + +// ==================== 事件处理 ==================== + +void Symbol::OnDocumentParsed(const events::DocumentParsed& event) { + if (!event.tree) { + spdlog::warn("Received null tree for document: {}", event.item.uri); + return; + } + + try { + // 创建 EditingSymbolTable + symbol::EditingSymbolTable table; + table.uri = event.item.uri; + table.version = event.item.version; + table.last_parsed = std::chrono::system_clock::now(); + table.is_dirty = false; + table.deserializer = std::make_unique(); + + // 使用 AST 反序列化器解析文档 + auto ast_result = table.deserializer->Parse(ts_tree_root_node(event.tree), + event.item.text); + if (!ast_result.IsSuccess()) { + spdlog::error("Failed to deserialize AST for: {}", event.item.uri); + return; + } + + table.ast_root = std::move(ast_result.root); + + // 确定文件类型 (根据 AST root 的 Program 内容判断) + symbol::TsfFileType file_type = symbol::TsfFileType::kScript; + if (table.ast_root) { + auto& program = *table.ast_root; + if (!program.statements.empty()) { + auto& first_stmt = program.statements[0]; + if (auto* unit_def = dynamic_cast( + first_stmt.get())) { + file_type = symbol::TsfFileType::kUnit; + } else if (auto* class_def = + dynamic_cast( + first_stmt.get())) { + file_type = symbol::TsfFileType::kClass; + } else if (auto* func_def = + dynamic_cast( + first_stmt.get())) { + file_type = symbol::TsfFileType::kFunction; + } + } + } + table.file_type = file_type; + + // 创建符号表并构建 + table.symbol_table = std::make_unique(); + language::symbol::Builder builder(*table.symbol_table); + builder.Build(*table.ast_root); + // 添加到编辑仓库 + editing_repo_.AddOrUpdate(std::move(table)); + + spdlog::debug("Document parsed and added to editing repo: {}", + event.item.uri); + } catch (const std::exception& e) { + spdlog::error("Exception loading editing symbol {}: {}", event.item.uri, + e.what()); + } +} + +void Symbol::OnDocumentReparsed(const events::DocumentReparsed& event) { + spdlog::debug("OnDocumentReparsed"); +} + +void Symbol::OnDocumentClosed(const events::DocumentClosed& event) { + editing_repo_.Remove(event.textDocument.uri); + spdlog::debug("Document closed and removed from editing repo: {}", + event.textDocument.uri); +} +} // namespace lsp::service diff --git a/lsp-server/src/service/symbol.hpp b/lsp-server/src/service/symbol.hpp index 8d1cec2..b13c261 100644 --- a/lsp-server/src/service/symbol.hpp +++ b/lsp-server/src/service/symbol.hpp @@ -1,38 +1,41 @@ #pragma once #include -#include "./base/events.hpp" +#include + +#include "../protocol/protocol.hpp" #include "./base/event_bus.hpp" +#include "./base/events.hpp" +#include "./detail/symbol/repository/editing.hpp" #include "./detail/symbol/repository/system.hpp" #include "./detail/symbol/repository/workspace.hpp" -#include "./detail/symbol/repository/editing.hpp" -namespace lsp::service -{ - class Symbol - { - public: - explicit Symbol(std::shared_ptr event_bus); - ~Symbol(); +namespace lsp::service { +class Symbol { + public: + explicit Symbol(std::shared_ptr event_bus); + ~Symbol(); - // === 加载符号 === - void LoadSystemSymbols(const std::string& folder); - void LoadWorkspaceSymbols(const std::string& folder); - void ReloadWorkspaceSymbols(const std::string& folder); + // === 加载符号 === + void LoadSystemSymbols(const std::string& folder); + void LoadWorkspaceSymbols(const std::string& folder); + void ReloadWorkspaceSymbols(const std::string& folder); - const symbol::repository::System& SystemRepo() const; - const symbol::repository::Workspace& WorkspaceRepo() const; - const symbol::repository::Editing& EditingRepo() const; + const symbol::repository::System& SystemRepo() const; + const symbol::repository::Workspace& WorkspaceRepo() const; + const symbol::repository::Editing& EditingRepo() const; + std::vector GetDocumentSymbols( + const protocol::DocumentUri& uri) const; - private: - void OnDocumentParsed(const events::DocumentParsed& event); - void OnDocumentReparsed(const events::DocumentReparsed& event); - void OnDocumentClosed(const events::DocumentClosed& event); + private: + void OnDocumentParsed(const events::DocumentParsed& event); + void OnDocumentReparsed(const events::DocumentReparsed& event); + void OnDocumentClosed(const events::DocumentClosed& event); - private: - symbol::repository::System system_repo_; - symbol::repository::Workspace workspace_repo_; - symbol::repository::Editing editing_repo_; + private: + symbol::repository::System system_repo_; + symbol::repository::Workspace workspace_repo_; + symbol::repository::Editing editing_repo_; - std::shared_ptr event_bus_; - }; -} + std::shared_ptr event_bus_; +}; +} // namespace lsp::service