From 4c2e242920864c5be1f8d40ad4d44b319d4e74d2 Mon Sep 17 00:00:00 2001 From: csh Date: Tue, 18 Nov 2025 23:11:40 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20=E6=96=B0=E5=A2=9E=E8=AF=AD?= =?UTF-8?q?=E4=B9=89=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit :recycle: 重构符号表,职责更清晰单一 :bug: 同步修复`test_symbol` --- lsp-server/src/CMakeLists.txt | 10 +- lsp-server/src/language/semantic/analyzer.cpp | 555 ++++++ lsp-server/src/language/semantic/analyzer.hpp | 162 ++ .../src/language/semantic/graph/call.cpp | 59 + .../src/language/semantic/graph/call.hpp | 45 + .../language/semantic/graph/inheritance.cpp | 70 + .../language/semantic/graph/inheritance.hpp | 50 + .../src/language/semantic/graph/reference.cpp | 55 + .../src/language/semantic/graph/reference.hpp | 46 + .../src/language/semantic/interface.hpp | 31 + .../src/language/semantic/name_resolver.cpp | 443 +++++ .../src/language/semantic/name_resolver.hpp | 226 +++ .../src/language/semantic/semantic_model.cpp | 56 + .../src/language/semantic/semantic_model.hpp | 125 ++ .../src/language/semantic/type_system.cpp | 550 ++++++ .../src/language/semantic/type_system.hpp | 406 +++++ lsp-server/src/language/semantic/types.hpp | 44 + lsp-server/src/language/symbol/builder.cpp | 1588 +++++++++-------- lsp-server/src/language/symbol/builder.hpp | 247 ++- lsp-server/src/language/symbol/graph/call.cpp | 52 - lsp-server/src/language/symbol/graph/call.hpp | 26 - .../src/language/symbol/graph/inheritance.cpp | 58 - .../src/language/symbol/graph/inheritance.hpp | 27 - .../src/language/symbol/graph/reference.cpp | 46 - .../src/language/symbol/graph/reference.hpp | 27 - .../src/language/symbol/index/dispatcher.hpp | 68 + .../src/language/symbol/index/location.cpp | 127 +- .../src/language/symbol/index/location.hpp | 41 +- .../src/language/symbol/index/scope.cpp | 176 +- .../src/language/symbol/index/scope.hpp | 70 +- lsp-server/src/language/symbol/interface.hpp | 45 +- lsp-server/src/language/symbol/store.cpp | 104 +- lsp-server/src/language/symbol/store.hpp | 32 +- lsp-server/src/language/symbol/table.cpp | 164 +- lsp-server/src/language/symbol/table.hpp | 72 +- lsp-server/src/language/symbol/types.hpp | 384 ++-- lsp-server/test/test_symbol/CMakeLists.txt | 3 - lsp-server/test/test_symbol/debug_printer.cpp | 170 -- lsp-server/test/test_symbol/debug_printer.hpp | 17 +- lsp-server/test/test_symbol/test.cpp | 50 - 40 files changed, 4635 insertions(+), 1892 deletions(-) create mode 100644 lsp-server/src/language/semantic/analyzer.cpp create mode 100644 lsp-server/src/language/semantic/analyzer.hpp create mode 100644 lsp-server/src/language/semantic/graph/call.cpp create mode 100644 lsp-server/src/language/semantic/graph/call.hpp create mode 100644 lsp-server/src/language/semantic/graph/inheritance.cpp create mode 100644 lsp-server/src/language/semantic/graph/inheritance.hpp create mode 100644 lsp-server/src/language/semantic/graph/reference.cpp create mode 100644 lsp-server/src/language/semantic/graph/reference.hpp create mode 100644 lsp-server/src/language/semantic/interface.hpp create mode 100644 lsp-server/src/language/semantic/name_resolver.cpp create mode 100644 lsp-server/src/language/semantic/name_resolver.hpp create mode 100644 lsp-server/src/language/semantic/semantic_model.cpp create mode 100644 lsp-server/src/language/semantic/semantic_model.hpp create mode 100644 lsp-server/src/language/semantic/type_system.cpp create mode 100644 lsp-server/src/language/semantic/type_system.hpp create mode 100644 lsp-server/src/language/semantic/types.hpp delete mode 100644 lsp-server/src/language/symbol/graph/call.cpp delete mode 100644 lsp-server/src/language/symbol/graph/call.hpp delete mode 100644 lsp-server/src/language/symbol/graph/inheritance.cpp delete mode 100644 lsp-server/src/language/symbol/graph/inheritance.hpp delete mode 100644 lsp-server/src/language/symbol/graph/reference.cpp delete mode 100644 lsp-server/src/language/symbol/graph/reference.hpp create mode 100644 lsp-server/src/language/symbol/index/dispatcher.hpp diff --git a/lsp-server/src/CMakeLists.txt b/lsp-server/src/CMakeLists.txt index 6a45ff9..8631d4a 100644 --- a/lsp-server/src/CMakeLists.txt +++ b/lsp-server/src/CMakeLists.txt @@ -49,11 +49,15 @@ set(SOURCES language/symbol/builder.cpp language/symbol/index/location.cpp language/symbol/index/scope.cpp - language/symbol/graph/call.cpp - language/symbol/graph/inheritance.cpp - language/symbol/graph/reference.cpp language/symbol/store.cpp language/symbol/table.cpp + language/semantic/graph/call.cpp + language/semantic/graph/inheritance.cpp + language/semantic/graph/reference.cpp + language/semantic/semantic_model.cpp + language/semantic/type_system.cpp + language/semantic/name_resolver.cpp + language/semantic/analyzer.cpp language/keyword/repo.cpp provider/base/bootstrap.cpp provider/base/interface.cpp diff --git a/lsp-server/src/language/semantic/analyzer.cpp b/lsp-server/src/language/semantic/analyzer.cpp new file mode 100644 index 0000000..1b2de2d --- /dev/null +++ b/lsp-server/src/language/semantic/analyzer.cpp @@ -0,0 +1,555 @@ +#include "./analyzer.hpp" + +namespace lsp::language::semantic +{ + + Analyzer::Analyzer(const symbol::SymbolTable& symbol_table, + SemanticModel& semantic_model) + : symbol_table_(symbol_table), + semantic_model_(semantic_model), + current_function_id_(std::nullopt), + current_class_id_(std::nullopt) + { + } + + void Analyzer::Analyze(ast::ASTNode& root) + { + root.Accept(*this); + } + + // ===== Statements ===== + + void Analyzer::VisitProgram(ast::Program& node) + { + VisitStatements(node.statements); + } + + void Analyzer::VisitUnitDefinition(ast::UnitDefinition& node) + { + // TODO: 处理 Unit 导入关系 + VisitStatements(node.interface_section.statements); + VisitStatements(node.implementation_section.statements); + } + + void Analyzer::VisitClassDefinition(ast::ClassDefinition& node) + { + // 查找类符号 + auto class_id = ResolveIdentifier(node.name, node.location); + if (!class_id) + { + return; + } + + // 设置当前类上下文 + auto prev_class = current_class_id_; + current_class_id_ = class_id; + + // 处理继承关系 + for (const auto& base_class : node.base_classes) + { + auto base_id = ResolveIdentifier(base_class, node.location); + if (base_id) + { + semantic_model_.AddInheritance(*class_id, *base_id); + } + } + + // 访问类成员 + for (const auto& member : node.members) + { + member->Accept(*this); + } + + // 恢复上下文 + current_class_id_ = prev_class; + } + + void Analyzer::VisitFunctionDefinition(ast::FunctionDefinition& node) + { + auto func_id = ResolveIdentifier(node.name, node.location); + if (!func_id) + { + return; + } + + auto prev_func = current_function_id_; + current_function_id_ = func_id; + + // 访问函数体 + if (node.body) + { + node.body->Accept(*this); + } + + current_function_id_ = prev_func; + } + + void Analyzer::VisitFunctionDeclaration(ast::FunctionDeclaration& node) + { + // 函数声明不需要额外的语义分析 + } + + void Analyzer::VisitMethodDeclaration(ast::MethodDeclaration& node) + { + auto method_id = ResolveIdentifier(node.name, node.location); + if (!method_id) + { + return; + } + + auto prev_func = current_function_id_; + current_function_id_ = method_id; + + if (node.body) + { + node.body->Accept(*this); + } + + current_function_id_ = prev_func; + } + + void Analyzer::VisitExternalMethodDefinition(ast::ExternalMethodDefinition& node) + { + auto method_id = ResolveIdentifier(node.name, node.location); + if (!method_id) + { + return; + } + + auto prev_func = current_function_id_; + current_function_id_ = method_id; + + if (node.body) + { + node.body->Accept(*this); + } + + current_function_id_ = prev_func; + } + + // ===== Declarations ===== + + void Analyzer::VisitVarDeclaration(ast::VarDeclaration& node) + { + if (node.initializer) + { + VisitExpression(*node.initializer); + } + } + + void Analyzer::VisitStaticDeclaration(ast::StaticDeclaration& node) + { + if (node.initializer) + { + VisitExpression(*node.initializer); + } + } + + void Analyzer::VisitGlobalDeclaration(ast::GlobalDeclaration& node) + { + if (node.initializer) + { + VisitExpression(*node.initializer); + } + } + + void Analyzer::VisitConstDeclaration(ast::ConstDeclaration& node) + { + VisitExpression(*node.value); + } + + void Analyzer::VisitFieldDeclaration(ast::FieldDeclaration& node) + { + // 字段声明不需要额外的语义分析 + } + + void Analyzer::VisitClassMember(ast::ClassMember& node) + { + node.member->Accept(*this); + } + + void Analyzer::VisitPropertyDeclaration(ast::PropertyDeclaration& node) + { + // TODO: 处理 property getter/setter + } + + // ===== Control Flow ===== + + void Analyzer::VisitBlockStatement(ast::BlockStatement& node) + { + VisitStatements(node.statements); + } + + void Analyzer::VisitIfStatement(ast::IfStatement& node) + { + VisitExpression(*node.condition); + node.then_branch->Accept(*this); + if (node.else_branch) + { + node.else_branch->Accept(*this); + } + } + + void Analyzer::VisitForInStatement(ast::ForInStatement& node) + { + VisitExpression(*node.iterable); + node.body->Accept(*this); + } + + void Analyzer::VisitForToStatement(ast::ForToStatement& node) + { + VisitExpression(*node.start); + VisitExpression(*node.end); + if (node.step) + { + VisitExpression(*node.step); + } + node.body->Accept(*this); + } + + void Analyzer::VisitWhileStatement(ast::WhileStatement& node) + { + VisitExpression(*node.condition); + node.body->Accept(*this); + } + + void Analyzer::VisitRepeatStatement(ast::RepeatStatement& node) + { + VisitStatements(node.statements); + VisitExpression(*node.condition); + } + + void Analyzer::VisitCaseStatement(ast::CaseStatement& node) + { + VisitExpression(*node.expression); + for (const auto& branch : node.branches) + { + for (const auto& value : branch.values) + { + VisitExpression(*value); + } + branch.statement->Accept(*this); + } + if (node.else_branch) + { + node.else_branch->Accept(*this); + } + } + + void Analyzer::VisitTryStatement(ast::TryStatement& node) + { + node.try_block->Accept(*this); + if (node.except_block) + { + node.except_block->Accept(*this); + } + if (node.finally_block) + { + node.finally_block->Accept(*this); + } + } + + void Analyzer::VisitUsesStatement(ast::UsesStatement& node) + { + // TODO: 处理 Unit 引用 + } + + // ===== Expressions ===== + + void Analyzer::VisitIdentifier(ast::Identifier& node) + { + auto symbol_id = ResolveIdentifier(node.name, node.location); + if (symbol_id) + { + TrackReference(*symbol_id, node.location, false); + } + } + + void Analyzer::VisitCallExpression(ast::CallExpression& node) + { + // 访问被调用的表达式 + VisitExpression(*node.callee); + + // 访问参数 + for (const auto& arg : node.arguments) + { + VisitExpression(*arg); + } + + // 如果 callee 是标识符,记录调用关系 + if (auto* ident = dynamic_cast(node.callee.get())) + { + auto callee_id = ResolveIdentifier(ident->name, ident->location); + if (callee_id) + { + TrackCall(*callee_id, node.location); + } + } + } + + void Analyzer::VisitAttributeExpression(ast::AttributeExpression& node) + { + VisitExpression(*node.object); + // TODO: 解析成员访问 + } + + void Analyzer::VisitAssignmentExpression(ast::AssignmentExpression& node) + { + // 处理左值(写引用) + ProcessLValue(node.target); + + // 处理右值 + VisitExpression(*node.value); + } + + void Analyzer::VisitLiteral(ast::Literal& node) + { + // 字面量不需要处理 + } + + void Analyzer::VisitBinaryExpression(ast::BinaryExpression& node) + { + VisitExpression(*node.left); + VisitExpression(*node.right); + } + + void Analyzer::VisitTernaryExpression(ast::TernaryExpression& node) + { + VisitExpression(*node.condition); + VisitExpression(*node.true_expr); + VisitExpression(*node.false_expr); + } + + void Analyzer::VisitSubscriptExpression(ast::SubscriptExpression& node) + { + VisitExpression(*node.object); + for (const auto& index : node.indices) + { + VisitExpression(*index); + } + } + + void Analyzer::VisitArrayExpression(ast::ArrayExpression& node) + { + for (const auto& element : node.elements) + { + VisitExpression(*element); + } + } + + void Analyzer::VisitAnonymousFunctionExpression(ast::AnonymousFunctionExpression& node) + { + if (node.body) + { + node.body->Accept(*this); + } + } + + // Unary expressions + void Analyzer::VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitPrefixIncrementExpression(ast::PrefixIncrementExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitPrefixDecrementExpression(ast::PrefixDecrementExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitPostfixIncrementExpression(ast::PostfixIncrementExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitPostfixDecrementExpression(ast::PostfixDecrementExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitLogicalNotExpression(ast::LogicalNotExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitDerivativeExpression(ast::DerivativeExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitMatrixTransposeExpression(ast::MatrixTransposeExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitExprOperatorExpression(ast::ExprOperatorExpression& node) + { + VisitExpression(*node.operand); + } + + void Analyzer::VisitFunctionPointerExpression(ast::FunctionPointerExpression& node) + { + // TODO: 处理函数指针 + } + + void Analyzer::VisitNewExpression(ast::NewExpression& node) + { + // TODO: 处理对象创建 + } + + void Analyzer::VisitEchoExpression(ast::EchoExpression& node) + { + VisitExpression(*node.expression); + } + + void Analyzer::VisitRaiseExpression(ast::RaiseExpression& node) + { + if (node.exception) + { + VisitExpression(*node.exception); + } + } + + void Analyzer::VisitInheritedExpression(ast::InheritedExpression& node) + { + // TODO: 处理 inherited 调用 + } + + void Analyzer::VisitParenthesizedExpression(ast::ParenthesizedExpression& node) + { + VisitExpression(*node.expression); + } + + void Analyzer::VisitExpressionStatement(ast::ExpressionStatement& node) + { + VisitExpression(*node.expression); + } + + void Analyzer::VisitReturnStatement(ast::ReturnStatement& node) + { + if (node.value) + { + VisitExpression(*node.value); + } + } + + void Analyzer::VisitColumnReference(ast::ColumnReference& node) + { + // TODO: 处理列引用 + } + + void Analyzer::VisitUnpackPattern(ast::UnpackPattern& node) + { + // TODO: 处理解包模式 + } + + void Analyzer::VisitMatrixIterationStatement(ast::MatrixIterationStatement& node) + { + VisitExpression(*node.matrix); + node.body->Accept(*this); + } + + void Analyzer::VisitCompilerDirective(ast::CompilerDirective& node) + { + // 编译器指令不需要语义分析 + } + + void Analyzer::VisitConditionalDirective(ast::ConditionalDirective& node) + { + for (const auto& block : node.blocks) + { + block->Accept(*this); + } + } + + void Analyzer::VisitConditionalBlock(ast::ConditionalBlock& node) + { + VisitStatements(node.statements); + } + + // ===== Helper Methods ===== + + std::optional Analyzer::ResolveIdentifier( + const std::string& name, + const ast::Location& location) + { + auto result = semantic_model_.name_resolver().ResolveNameAtLocation(name, location); + return result.IsResolved() ? std::optional(result.symbol_id) : std::nullopt; + } + + void Analyzer::TrackReference(symbol::SymbolId symbol_id, + const ast::Location& location, + bool is_write) + { + semantic_model_.AddReference(symbol_id, location, false, is_write); + } + + void Analyzer::TrackCall(symbol::SymbolId callee, + const ast::Location& location) + { + if (current_function_id_) + { + semantic_model_.AddCall(*current_function_id_, callee, location); + } + } + + std::shared_ptr Analyzer::InferExpressionType(ast::Expression& expr) + { + // TODO: 实现完整的类型推断 + return semantic_model_.type_system().GetUnknownType(); + } + + void Analyzer::VisitStatements(const std::vector& statements) + { + for (const auto& stmt : statements) + { + stmt->Accept(*this); + } + } + + void Analyzer::VisitExpression(ast::Expression& expr) + { + expr.Accept(*this); + } + + void Analyzer::ProcessLValue(const ast::LValue& lvalue) + { + std::visit( + [this](const auto& value) { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + // 简单标识符 + auto symbol_id = ResolveIdentifier(value, ast::Location()); + if (symbol_id) + { + TrackReference(*symbol_id, ast::Location(), true); + } + } + else if constexpr (std::is_same_v>) + { + // 复杂左值(如 a.b, a[i]) + if (value) + { + VisitExpression(*value); + } + } + }, + lvalue); + } + +} // namespace lsp::language::semantic diff --git a/lsp-server/src/language/semantic/analyzer.hpp b/lsp-server/src/language/semantic/analyzer.hpp new file mode 100644 index 0000000..1854754 --- /dev/null +++ b/lsp-server/src/language/semantic/analyzer.hpp @@ -0,0 +1,162 @@ +#pragma once + +#include "../ast/types.hpp" +#include "../symbol/table.hpp" +#include "./semantic_model.hpp" + +namespace lsp::language::semantic +{ + + /** + * SemanticAnalyzer - 语义分析器 + * + * 职责: + * 1. 遍历 AST 构建语义信息 + * 2. 收集引用关系 + * 3. 收集调用关系 + * 4. 收集继承关系 + * 5. 推断表达式类型 + * + * 注意:SemanticAnalyzer 依赖符号表已经构建完成 + */ + class Analyzer : public ast::ASTVisitor + { + public: + explicit Analyzer(const symbol::SymbolTable& symbol_table, SemanticModel& semantic_model); + + void Analyze(ast::ASTNode& root); + + // ===== Visit statements and declarations ===== + void VisitProgram(ast::Program& node) override; + void VisitUnitDefinition(ast::UnitDefinition& node) override; + void VisitClassDefinition(ast::ClassDefinition& node) override; + void VisitFunctionDefinition(ast::FunctionDefinition& node) override; + void VisitFunctionDeclaration(ast::FunctionDeclaration& node) override; + void VisitMethodDeclaration(ast::MethodDeclaration& node) override; + void VisitExternalMethodDefinition(ast::ExternalMethodDefinition& node) override; + + // Declarations + void VisitVarDeclaration(ast::VarDeclaration& node) override; + void VisitStaticDeclaration(ast::StaticDeclaration& node) override; + void VisitGlobalDeclaration(ast::GlobalDeclaration& node) override; + void VisitConstDeclaration(ast::ConstDeclaration& node) override; + void VisitFieldDeclaration(ast::FieldDeclaration& node) override; + void VisitClassMember(ast::ClassMember& node) override; + void VisitPropertyDeclaration(ast::PropertyDeclaration& node) override; + + void VisitBlockStatement(ast::BlockStatement& node) override; + void VisitIfStatement(ast::IfStatement& node) override; + void VisitForInStatement(ast::ForInStatement& node) override; + void VisitForToStatement(ast::ForToStatement& node) override; + void VisitWhileStatement(ast::WhileStatement& node) override; + void VisitRepeatStatement(ast::RepeatStatement& node) override; + void VisitCaseStatement(ast::CaseStatement& node) override; + void VisitTryStatement(ast::TryStatement& node) override; + + // Uses statement handling + void VisitUsesStatement(ast::UsesStatement& node) override; + + // ===== Visit expressions (collect references) ===== + void VisitIdentifier(ast::Identifier& node) override; + void VisitCallExpression(ast::CallExpression& node) override; + void VisitAttributeExpression(ast::AttributeExpression& node) override; + void VisitAssignmentExpression(ast::AssignmentExpression& node) override; + + // Other expressions + void VisitLiteral(ast::Literal& node) override; + void VisitBinaryExpression(ast::BinaryExpression& node) override; + void VisitTernaryExpression(ast::TernaryExpression& node) override; + void VisitSubscriptExpression(ast::SubscriptExpression& node) override; + void VisitArrayExpression(ast::ArrayExpression& node) override; + void VisitAnonymousFunctionExpression(ast::AnonymousFunctionExpression& node) override; + + // Unary expressions + void VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) override; + void VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) override; + void VisitPrefixIncrementExpression(ast::PrefixIncrementExpression& node) override; + void VisitPrefixDecrementExpression(ast::PrefixDecrementExpression& node) override; + void VisitPostfixIncrementExpression(ast::PostfixIncrementExpression& node) override; + void VisitPostfixDecrementExpression(ast::PostfixDecrementExpression& node) override; + void VisitLogicalNotExpression(ast::LogicalNotExpression& node) override; + void VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) override; + void VisitDerivativeExpression(ast::DerivativeExpression& node) override; + void VisitMatrixTransposeExpression(ast::MatrixTransposeExpression& node) override; + void VisitExprOperatorExpression(ast::ExprOperatorExpression& node) override; + void VisitFunctionPointerExpression(ast::FunctionPointerExpression& node) override; + + // Other expressions + void VisitNewExpression(ast::NewExpression& node) override; + void VisitEchoExpression(ast::EchoExpression& node) override; + void VisitRaiseExpression(ast::RaiseExpression& node) override; + void VisitInheritedExpression(ast::InheritedExpression& node) override; + void VisitParenthesizedExpression(ast::ParenthesizedExpression& node) override; + + void VisitExpressionStatement(ast::ExpressionStatement& node) override; + void VisitBreakStatement(ast::BreakStatement& node) override {} + void VisitContinueStatement(ast::ContinueStatement& node) override {} + void VisitReturnStatement(ast::ReturnStatement& node) override; + void VisitTSSQLExpression(ast::TSSQLExpression& node) override {} + void VisitColumnReference(ast::ColumnReference& node) override; + void VisitUnpackPattern(ast::UnpackPattern& node) override; + void VisitMatrixIterationStatement(ast::MatrixIterationStatement& node) override; + + // Compiler directives + void VisitCompilerDirective(ast::CompilerDirective& node) override; + void VisitConditionalDirective(ast::ConditionalDirective& node) override; + void VisitConditionalBlock(ast::ConditionalBlock& node) override; + void VisitTSLXBlock(ast::TSLXBlock& node) override {} + + void VisitParameter(ast::Parameter& node) override {} + + private: + const symbol::SymbolTable& symbol_table_; + SemanticModel& semantic_model_; + + // 当前上下文 + std::optional current_function_id_; + std::optional current_class_id_; + + // ===== Helper methods ===== + + /** + * 查找标识符对应的符号 + */ + std::optional ResolveIdentifier( + const std::string& name, + const ast::Location& location); + + /** + * 记录符号引用 + */ + void TrackReference(symbol::SymbolId symbol_id, + const ast::Location& location, + bool is_write = false); + + /** + * 记录函数调用 + */ + void TrackCall(symbol::SymbolId callee, + const ast::Location& location); + + /** + * 推断表达式类型 + */ + std::shared_ptr InferExpressionType(ast::Expression& expr); + + /** + * 遍历语句列表 + */ + void VisitStatements(const std::vector& statements); + + /** + * 访问表达式 + */ + void VisitExpression(ast::Expression& expr); + + /** + * 处理左值(赋值目标) + */ + void ProcessLValue(const ast::LValue& lvalue); + }; + +} // namespace lsp::language::semantic diff --git a/lsp-server/src/language/semantic/graph/call.cpp b/lsp-server/src/language/semantic/graph/call.cpp new file mode 100644 index 0000000..083b9c9 --- /dev/null +++ b/lsp-server/src/language/semantic/graph/call.cpp @@ -0,0 +1,59 @@ +#include "call.hpp" + +#include + +namespace lsp::language::semantic::graph +{ + + void Call::OnSymbolRemoved(SymbolId id) + { + callers_map_.erase(id); + callees_map_.erase(id); + + for (auto& [_, calls] : callers_map_) + { + calls.erase(std::remove_if(calls.begin(), calls.end(), + [id](const semantic::Call& call) { + return call.caller == id; + }), + calls.end()); + } + + for (auto& [_, calls] : callees_map_) + { + calls.erase(std::remove_if(calls.begin(), calls.end(), + [id](const semantic::Call& call) { + return call.callee == id; + }), + calls.end()); + } + } + + void Call::Clear() + { + callers_map_.clear(); + callees_map_.clear(); + } + + void Call::AddCall(SymbolId caller, SymbolId callee, const ast::Location& location) + { + semantic::Call call{caller, callee, location}; + callers_map_[callee].push_back(call); + callees_map_[caller].push_back(call); + } + + const std::vector& Call::callers(SymbolId id) const + { + static const std::vector kEmpty; + auto it = callers_map_.find(id); + return it != callers_map_.end() ? it->second : kEmpty; + } + + const std::vector& Call::callees(SymbolId id) const + { + static const std::vector kEmpty; + auto it = callees_map_.find(id); + return it != callees_map_.end() ? it->second : kEmpty; + } + +} // namespace lsp::language::semantic::graph diff --git a/lsp-server/src/language/semantic/graph/call.hpp b/lsp-server/src/language/semantic/graph/call.hpp new file mode 100644 index 0000000..057ca5c --- /dev/null +++ b/lsp-server/src/language/semantic/graph/call.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + +#include "../interface.hpp" +#include "../types.hpp" + +namespace lsp::language::semantic::graph +{ + + using symbol::SymbolId; + + /** + * Call - 调用图 + * + * 管理函数/方法的调用关系 + */ + class Call : public ISemanticGraph + { + public: + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; + + /** + * 添加调用关系 + */ + void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location); + + /** + * 获取调用指定符号的所有调用者 + */ + const std::vector& callers(SymbolId id) const; + + /** + * 获取指定符号调用的所有被调用者 + */ + const std::vector& callees(SymbolId id) const; + + private: + std::unordered_map> callers_map_; + std::unordered_map> callees_map_; + }; + +} // namespace lsp::language::semantic::graph diff --git a/lsp-server/src/language/semantic/graph/inheritance.cpp b/lsp-server/src/language/semantic/graph/inheritance.cpp new file mode 100644 index 0000000..dfb6752 --- /dev/null +++ b/lsp-server/src/language/semantic/graph/inheritance.cpp @@ -0,0 +1,70 @@ +#include "inheritance.hpp" + +#include + +namespace lsp::language::semantic::graph +{ + + void Inheritance::OnSymbolRemoved(SymbolId id) + { + base_classes_.erase(id); + derived_classes_.erase(id); + + for (auto& [_, bases] : base_classes_) + { + bases.erase(std::remove(bases.begin(), bases.end(), id), bases.end()); + } + + for (auto& [_, derived] : derived_classes_) + { + derived.erase(std::remove(derived.begin(), derived.end(), id), + derived.end()); + } + } + + void Inheritance::Clear() + { + base_classes_.clear(); + derived_classes_.clear(); + } + + void Inheritance::AddInheritance(SymbolId derived, SymbolId base) + { + base_classes_[derived].push_back(base); + derived_classes_[base].push_back(derived); + } + + const std::vector& Inheritance::base_classes(SymbolId id) const + { + static const std::vector kEmpty; + auto it = base_classes_.find(id); + return it != base_classes_.end() ? it->second : kEmpty; + } + + const std::vector& Inheritance::derived_classes(SymbolId id) const + { + static const std::vector kEmpty; + auto it = derived_classes_.find(id); + return it != derived_classes_.end() ? it->second : kEmpty; + } + + bool Inheritance::IsSubclassOf(SymbolId derived, SymbolId base) const + { + auto it = base_classes_.find(derived); + if (it == base_classes_.end()) + { + return false; + } + + for (SymbolId parent : it->second) + { + if (parent == base || IsSubclassOf(parent, base)) + { + return true; + } + } + + return false; + } + +} // namespace lsp::language::semantic::graph diff --git a/lsp-server/src/language/semantic/graph/inheritance.hpp b/lsp-server/src/language/semantic/graph/inheritance.hpp new file mode 100644 index 0000000..c050e99 --- /dev/null +++ b/lsp-server/src/language/semantic/graph/inheritance.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include +#include + +#include "../interface.hpp" +#include "../../symbol/types.hpp" + +namespace lsp::language::semantic::graph +{ + + using symbol::SymbolId; + + /** + * Inheritance - 继承图 + * + * 管理类的继承关系 + */ + class Inheritance : public ISemanticGraph + { + public: + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; + + /** + * 添加继承关系 + */ + void AddInheritance(SymbolId derived, SymbolId base); + + /** + * 获取类的所有基类 + */ + const std::vector& base_classes(SymbolId id) const; + + /** + * 获取类的所有派生类 + */ + const std::vector& derived_classes(SymbolId id) const; + + /** + * 检查 derived 是否是 base 的子类(直接或间接) + */ + bool IsSubclassOf(SymbolId derived, SymbolId base) const; + + private: + std::unordered_map> base_classes_; + std::unordered_map> derived_classes_; + }; + +} // namespace lsp::language::semantic::graph diff --git a/lsp-server/src/language/semantic/graph/reference.cpp b/lsp-server/src/language/semantic/graph/reference.cpp new file mode 100644 index 0000000..d4e2f2a --- /dev/null +++ b/lsp-server/src/language/semantic/graph/reference.cpp @@ -0,0 +1,55 @@ +#include "reference.hpp" + +#include + +namespace lsp::language::semantic::graph +{ + + void Reference::OnSymbolRemoved(SymbolId id) + { + references_.erase(id); + + for (auto& [_, refs] : references_) + { + refs.erase(std::remove_if(refs.begin(), refs.end(), + [id](const semantic::Reference& ref) { + return ref.symbol_id == id; + }), + refs.end()); + } + } + + void Reference::Clear() { references_.clear(); } + + void Reference::AddReference(SymbolId symbol_id, const ast::Location& location, + bool is_definition, bool is_write) + { + references_[symbol_id].push_back( + {location, symbol_id, is_definition, is_write}); + } + + const std::vector& Reference::references(SymbolId id) const + { + static const std::vector kEmpty; + auto it = references_.find(id); + return it != references_.end() ? it->second : kEmpty; + } + + std::optional Reference::FindDefinitionLocation( + SymbolId id) const + { + auto it = references_.find(id); + if (it != references_.end()) + { + for (const auto& ref : it->second) + { + if (ref.is_definition) + { + return ref.location; + } + } + } + return std::nullopt; + } + +} // namespace lsp::language::semantic::graph diff --git a/lsp-server/src/language/semantic/graph/reference.hpp b/lsp-server/src/language/semantic/graph/reference.hpp new file mode 100644 index 0000000..44820f5 --- /dev/null +++ b/lsp-server/src/language/semantic/graph/reference.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include + +#include "../interface.hpp" +#include "../types.hpp" + +namespace lsp::language::semantic::graph +{ + + using symbol::SymbolId; + + /** + * Reference - 引用图 + * + * 管理符号的引用关系,记录每个符号在哪些位置被引用 + */ + class Reference : public ISemanticGraph + { + public: + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; + + /** + * 添加符号引用 + */ + void AddReference(SymbolId symbol_id, const ast::Location& location, + bool is_definition = false, bool is_write = false); + + /** + * 获取符号的所有引用 + */ + const std::vector& references(SymbolId id) const; + + /** + * 查找符号的定义位置 + */ + std::optional FindDefinitionLocation(SymbolId id) const; + + private: + std::unordered_map> references_; + }; + +} // namespace lsp::language::semantic::graph diff --git a/lsp-server/src/language/semantic/interface.hpp b/lsp-server/src/language/semantic/interface.hpp new file mode 100644 index 0000000..d76fb8d --- /dev/null +++ b/lsp-server/src/language/semantic/interface.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include "../symbol/types.hpp" + +namespace lsp::language::semantic +{ + + /** + * ISemanticGraph - 语义图接口 + * + * 所有语义关系图(引用图、继承图、调用图等)都应实现此接口 + * 用于统一管理图的生命周期和符号删除时的清理 + */ + class ISemanticGraph + { + public: + virtual ~ISemanticGraph() = default; + + /** + * 当符号被删除时调用,清理相关的语义信息 + * @param id 被删除的符号ID + */ + virtual void OnSymbolRemoved(symbol::SymbolId id) = 0; + + /** + * 清空所有语义信息 + */ + virtual void Clear() = 0; + }; + +} // namespace lsp::language::semantic diff --git a/lsp-server/src/language/semantic/name_resolver.cpp b/lsp-server/src/language/semantic/name_resolver.cpp new file mode 100644 index 0000000..8d7c34d --- /dev/null +++ b/lsp-server/src/language/semantic/name_resolver.cpp @@ -0,0 +1,443 @@ +#include "./name_resolver.hpp" + +#include + +namespace lsp::language::semantic +{ + + NameResolutionResult NameResolver::ResolveName( + const std::string& name, + symbol::ScopeId scope_id, + bool search_parent) const + { + if (!search_parent) + { + // 仅在当前作用域查找 + auto symbols = symbol_table_.FindSymbolsByName(name); + if (symbols.empty()) + { + return NameResolutionResult::NotFound(); + } + if (symbols.size() == 1) + { + return NameResolutionResult::Success(symbols[0]); + } + return NameResolutionResult::Ambiguous(std::move(symbols)); + } + + // 作用域链查找 + auto candidates = SearchScopeChain(name, scope_id); + if (candidates.empty()) + { + return NameResolutionResult::NotFound(); + } + if (candidates.size() == 1) + { + return NameResolutionResult::Success(candidates[0]); + } + return NameResolutionResult::Ambiguous(std::move(candidates)); + } + + NameResolutionResult NameResolver::ResolveNameAtLocation( + const std::string& name, + const ast::Location& location) const + { + // 根据位置找到对应的作用域 + auto scope_id = symbol_table_.scopes().FindScopeAt(location); + if (!scope_id) + { + // 如果找不到作用域,使用全局作用域 + scope_id = 1; // 假设全局作用域ID为1 + } + + return ResolveName(name, *scope_id, true); + } + + NameResolutionResult NameResolver::ResolveMemberAccess( + symbol::SymbolId object_symbol_id, + const std::string& member_name) const + { + // 获取对象的类型 + auto object_type = type_system_.GetSymbolType(object_symbol_id); + if (!object_type || object_type->kind() != TypeKind::kClass) + { + return NameResolutionResult::NotFound(); + } + + const auto* class_type = object_type->As(); + return ResolveClassMember(class_type->class_id(), member_name, false); + } + + NameResolutionResult NameResolver::ResolveClassMember( + symbol::SymbolId class_id, + const std::string& member_name, + bool static_only) const + { + const auto* class_symbol = symbol_table_.definition(class_id); + if (!class_symbol || !class_symbol->Is()) + { + return NameResolutionResult::NotFound(); + } + + const auto* class_data = class_symbol->As(); + std::vector candidates; + + // 在类的成员中查找 + for (auto member_id : class_data->members) + { + const auto* member = symbol_table_.definition(member_id); + if (!member) + continue; + + if (member->name() != member_name) + continue; + + // 如果仅查找静态成员 + if (static_only) + { + if (member->Is()) + { + const auto* method = member->As(); + if (!method->is_static) + continue; + } + else if (member->Is()) + { + const auto* field = member->As(); + if (!field->is_static) + continue; + } + } + + candidates.push_back(member_id); + } + + // TODO: 在基类中查找(需要继承图支持) + + if (candidates.empty()) + { + return NameResolutionResult::NotFound(); + } + if (candidates.size() == 1) + { + return NameResolutionResult::Success(candidates[0]); + } + return NameResolutionResult::Ambiguous(std::move(candidates)); + } + + NameResolutionResult NameResolver::ResolveFunctionCall( + const std::string& function_name, + const std::vector>& arg_types, + symbol::ScopeId scope_id) const + { + // 查找所有同名函数 + auto candidates_ids = SearchScopeChain(function_name, scope_id); + if (candidates_ids.empty()) + { + return NameResolutionResult::NotFound(); + } + + // 如果只有一个候选,直接返回 + if (candidates_ids.size() == 1) + { + return NameResolutionResult::Success(candidates_ids[0]); + } + + // 计算每个候选的匹配得分 + std::vector candidates; + for (auto candidate_id : candidates_ids) + { + auto candidate = CalculateOverloadScore(candidate_id, arg_types); + if (candidate.match_score >= 0) + { + candidates.push_back(candidate); + } + } + + return SelectBestOverload(candidates); + } + + NameResolutionResult NameResolver::ResolveMethodCall( + symbol::SymbolId object_symbol_id, + const std::string& method_name, + const std::vector>& arg_types) const + { + // 先解析方法名称,获取所有候选 + auto resolution = ResolveMemberAccess(object_symbol_id, method_name); + if (!resolution.IsResolved()) + { + return resolution; + } + + // 如果没有歧义,直接返回 + if (!resolution.is_ambiguous) + { + return resolution; + } + + // 执行重载解析 + std::vector candidates; + for (auto candidate_id : resolution.candidates) + { + auto candidate = CalculateOverloadScore(candidate_id, arg_types); + if (candidate.match_score >= 0) + { + candidates.push_back(candidate); + } + } + + return SelectBestOverload(candidates); + } + + NameResolutionResult NameResolver::ResolveQualifiedName( + const std::vector& qualified_name, + symbol::ScopeId scope_id) const + { + if (qualified_name.empty()) + { + return NameResolutionResult::NotFound(); + } + + // 从第一个名称开始解析 + auto current_result = ResolveName(qualified_name[0], scope_id, true); + if (!current_result.IsResolved()) + { + return current_result; + } + + symbol::SymbolId current_id = current_result.symbol_id; + + // 依次解析后续的限定名称 + for (size_t i = 1; i < qualified_name.size(); ++i) + { + const auto* current_symbol = symbol_table_.definition(current_id); + if (!current_symbol) + { + return NameResolutionResult::NotFound(); + } + + // 根据当前符号类型决定如何查找下一个名称 + if (current_symbol->Is()) + { + // 在类中查找成员 + current_result = ResolveClassMember(current_id, qualified_name[i], true); + } + else if (current_symbol->Is()) + { + // 在 Unit 中查找符号(TODO: 需要 Unit 作用域支持) + return NameResolutionResult::NotFound(); + } + else + { + // 其他类型不支持限定名称 + return NameResolutionResult::NotFound(); + } + + if (!current_result.IsResolved()) + { + return current_result; + } + + current_id = current_result.symbol_id; + } + + return NameResolutionResult::Success(current_id); + } + + std::optional NameResolver::GetSymbolScope( + symbol::SymbolId symbol_id) const + { + // TODO: 实现符号到作用域的映射 + // 目前返回 nullopt + return std::nullopt; + } + + bool NameResolver::IsSymbolVisibleInScope( + symbol::SymbolId symbol_id, + symbol::ScopeId scope_id) const + { + // TODO: 实现可见性检查 + // 考虑访问修饰符、作用域层次等 + return true; + } + + // ===== Private Methods ===== + + std::vector NameResolver::SearchScopeChain( + const std::string& name, + symbol::ScopeId start_scope) const + { + std::vector results; + + symbol::ScopeId current_scope = start_scope; + while (current_scope != symbol::kInvalidScopeId) + { + // 在当前作用域中查找 + auto scope_symbols = symbol_table_.scopes().FindSymbols(current_scope, name); + results.insert(results.end(), scope_symbols.begin(), scope_symbols.end()); + + // 如果找到了符号,停止向上查找 + if (!results.empty()) + { + break; + } + + // 移到父作用域 + auto parent = symbol_table_.scopes().GetParent(current_scope); + if (!parent) + { + break; + } + current_scope = *parent; + } + + return results; + } + + NameResolutionResult NameResolver::SelectBestOverload( + const std::vector& candidates) const + { + if (candidates.empty()) + { + return NameResolutionResult::NotFound(); + } + + // 按匹配得分排序 + auto sorted = candidates; + std::sort(sorted.begin(), sorted.end()); + + // 检查是否有唯一的最佳匹配 + if (sorted.size() > 1 && sorted[0].match_score == sorted[1].match_score) + { + // 存在歧义 + std::vector ambiguous_ids; + for (const auto& candidate : sorted) + { + if (candidate.match_score == sorted[0].match_score) + { + ambiguous_ids.push_back(candidate.symbol_id); + } + } + return NameResolutionResult::Ambiguous(std::move(ambiguous_ids)); + } + + return NameResolutionResult::Success(sorted[0].symbol_id); + } + + OverloadCandidate NameResolver::CalculateOverloadScore( + symbol::SymbolId candidate_id, + const std::vector>& arg_types) const + { + OverloadCandidate result; + result.symbol_id = candidate_id; + + // 获取候选的参数类型 + auto param_types = GetParameterTypes(candidate_id); + + // 参数数量不匹配 + if (param_types.size() != arg_types.size()) + { + result.match_score = -1; + return result; + } + + // 计算每个参数的转换代价 + int total_score = 0; + for (size_t i = 0; i < arg_types.size(); ++i) + { + auto compat = type_system_.CheckCompatibility(*arg_types[i], *param_types[i]); + if (!compat.is_compatible) + { + result.match_score = -1; + return result; + } + + result.arg_conversions.push_back(compat); + + // 计算得分:完全匹配得分最高 + if (compat.conversion_cost == 0) + { + total_score += 100; + } + else + { + // 转换代价越低,得分越高 + total_score += std::max(0, 50 - compat.conversion_cost * 10); + } + + // 需要显式转换的降低得分 + if (compat.requires_cast) + { + total_score -= 20; + } + } + + result.match_score = total_score; + return result; + } + + std::vector> NameResolver::GetParameterTypes( + symbol::SymbolId symbol_id) const + { + const auto* symbol = symbol_table_.definition(symbol_id); + if (!symbol) + { + return {}; + } + + std::vector> param_types; + + if (symbol->Is()) + { + const auto* func = symbol->As(); + for (const auto& param : func->parameters) + { + if (param.type) + { + auto type = type_system_.GetTypeByName(*param.type); + param_types.push_back(type); + } + else + { + param_types.push_back(type_system_.GetUnknownType()); + } + } + } + else if (symbol->Is()) + { + const auto* method = symbol->As(); + for (const auto& param : method->parameters) + { + if (param.type) + { + auto type = type_system_.GetTypeByName(*param.type); + param_types.push_back(type); + } + else + { + param_types.push_back(type_system_.GetUnknownType()); + } + } + } + + return param_types; + } + + std::optional NameResolver::GetOwnerClassId( + symbol::SymbolId symbol_id) const + { + // TODO: 实现符号到所属类的映射 + // 需要在符号表中维护这个关系 + return std::nullopt; + } + + bool NameResolver::CheckMemberAccessibility( + const symbol::Symbol& member, + symbol::ScopeId access_scope) const + { + // TODO: 实现访问权限检查 + // 检查 public/private/protected 访问权限 + return true; + } + +} // namespace lsp::language::semantic diff --git a/lsp-server/src/language/semantic/name_resolver.hpp b/lsp-server/src/language/semantic/name_resolver.hpp new file mode 100644 index 0000000..9da333a --- /dev/null +++ b/lsp-server/src/language/semantic/name_resolver.hpp @@ -0,0 +1,226 @@ +#pragma once + +#include +#include +#include + +#include "../symbol/table.hpp" +#include "./type_system.hpp" + +namespace lsp::language::semantic +{ + + /** + * NameResolutionResult - 名称解析结果 + */ + struct NameResolutionResult + { + symbol::SymbolId symbol_id = symbol::kInvalidSymbolId; + bool is_ambiguous = false; // 是否存在歧义(多个候选) + std::vector candidates; // 所有候选符号 + + bool IsResolved() const + { + return symbol_id != symbol::kInvalidSymbolId; + } + + static NameResolutionResult Success(symbol::SymbolId id) + { + return {id, false, {id}}; + } + + static NameResolutionResult Ambiguous(std::vector symbols) + { + return { + symbols.empty() ? symbol::kInvalidSymbolId : symbols[0], + true, + std::move(symbols)}; + } + + static NameResolutionResult NotFound() + { + return {symbol::kInvalidSymbolId, false, {}}; + } + }; + + /** + * OverloadCandidate - 重载候选 + * 用于函数/方法重载解析 + */ + struct OverloadCandidate + { + symbol::SymbolId symbol_id; + int match_score = 0; // 匹配得分(越高越好) + std::vector arg_conversions; // 每个参数的转换信息 + + bool operator<(const OverloadCandidate& other) const + { + return match_score > other.match_score; // 降序排序 + } + }; + + /** + * NameResolver - 名称解析器 + * + * 职责: + * 1. 根据名称和上下文查找符号 + * 2. 作用域链查找 + * 3. 重载解析 + * 4. 成员访问解析 + */ + class NameResolver + { + public: + explicit NameResolver(const symbol::SymbolTable& symbol_table, + const TypeSystem& type_system) + : symbol_table_(symbol_table), type_system_(type_system) {} + + // ===== 基本名称解析 ===== + + /** + * 在指定作用域中解析简单名称 + * @param name 符号名称 + * @param scope_id 起始作用域ID + * @param search_parent 是否搜索父作用域 + * @return 解析结果 + */ + NameResolutionResult ResolveName( + const std::string& name, + symbol::ScopeId scope_id, + bool search_parent = true) const; + + /** + * 在指定位置解析名称 + * @param name 符号名称 + * @param location 代码位置 + * @return 解析结果 + */ + NameResolutionResult ResolveNameAtLocation( + const std::string& name, + const ast::Location& location) const; + + // ===== 成员访问解析 ===== + + /** + * 解析成员访问 (object.member) + * @param object_symbol_id 对象符号ID + * @param member_name 成员名称 + * @return 解析结果 + */ + NameResolutionResult ResolveMemberAccess( + symbol::SymbolId object_symbol_id, + const std::string& member_name) const; + + /** + * 解析类型成员访问 (Class.StaticMember) + * @param class_id 类符号ID + * @param member_name 成员名称 + * @param static_only 是否仅搜索静态成员 + * @return 解析结果 + */ + NameResolutionResult ResolveClassMember( + symbol::SymbolId class_id, + const std::string& member_name, + bool static_only = false) const; + + // ===== 重载解析 ===== + + /** + * 解析函数调用(处理重载) + * @param function_name 函数名称 + * @param arg_types 参数类型列表 + * @param scope_id 调用所在的作用域 + * @return 最佳匹配的函数符号ID + */ + NameResolutionResult ResolveFunctionCall( + const std::string& function_name, + const std::vector>& arg_types, + symbol::ScopeId scope_id) const; + + /** + * 解析方法调用(处理重载) + * @param object_symbol_id 对象符号ID + * @param method_name 方法名称 + * @param arg_types 参数类型列表 + * @return 最佳匹配的方法符号ID + */ + NameResolutionResult ResolveMethodCall( + symbol::SymbolId object_symbol_id, + const std::string& method_name, + const std::vector>& arg_types) const; + + // ===== 限定名称解析 ===== + + /** + * 解析限定名称 (Unit.Class.Method) + * @param qualified_name 限定名称(用点分隔) + * @param scope_id 起始作用域 + * @return 解析结果 + */ + NameResolutionResult ResolveQualifiedName( + const std::vector& qualified_name, + symbol::ScopeId scope_id) const; + + // ===== 辅助方法 ===== + + /** + * 获取符号的作用域(如果符号本身定义了作用域) + * 例如:类、函数定义了新的作用域 + */ + std::optional GetSymbolScope( + symbol::SymbolId symbol_id) const; + + /** + * 检查符号是否在指定作用域可见 + */ + bool IsSymbolVisibleInScope( + symbol::SymbolId symbol_id, + symbol::ScopeId scope_id) const; + + private: + const symbol::SymbolTable& symbol_table_; + const TypeSystem& type_system_; + + // ===== 内部辅助方法 ===== + + /** + * 在作用域链中查找符号 + */ + std::vector SearchScopeChain( + const std::string& name, + symbol::ScopeId start_scope) const; + + /** + * 选择最佳重载候选 + */ + NameResolutionResult SelectBestOverload( + const std::vector& candidates) const; + + /** + * 计算重载候选的匹配得分 + */ + OverloadCandidate CalculateOverloadScore( + symbol::SymbolId candidate_id, + const std::vector>& arg_types) const; + + /** + * 提取函数/方法的参数类型 + */ + std::vector> GetParameterTypes( + symbol::SymbolId symbol_id) const; + + /** + * 获取符号所属的类ID(如果是成员) + */ + std::optional GetOwnerClassId( + symbol::SymbolId symbol_id) const; + + /** + * 检查成员的访问权限 + */ + bool CheckMemberAccessibility( + const symbol::Symbol& member, + symbol::ScopeId access_scope) const; + }; + +} // namespace lsp::language::semantic diff --git a/lsp-server/src/language/semantic/semantic_model.cpp b/lsp-server/src/language/semantic/semantic_model.cpp new file mode 100644 index 0000000..052c8ec --- /dev/null +++ b/lsp-server/src/language/semantic/semantic_model.cpp @@ -0,0 +1,56 @@ +#include "semantic_model.hpp" + +namespace lsp::language::semantic +{ + + SemanticModel::SemanticModel(const symbol::SymbolTable& symbol_table) + : symbol_table_(symbol_table), + reference_graph_(), + inheritance_graph_(), + call_graph_(), + type_system_(), + name_resolver_(std::make_unique(symbol_table, type_system_)) + { + // 设置类型系统的继承检查器 + type_system_.SetInheritanceChecker( + [this](symbol::SymbolId derived, symbol::SymbolId base) { + return inheritance_graph_.IsSubclassOf(derived, base); + }); + } + + void SemanticModel::Clear() + { + reference_graph_.Clear(); + inheritance_graph_.Clear(); + call_graph_.Clear(); + // Note: TypeSystem 和 NameResolver 的状态由外部管理 + } + + void SemanticModel::OnSymbolRemoved(symbol::SymbolId id) + { + reference_graph_.OnSymbolRemoved(id); + inheritance_graph_.OnSymbolRemoved(id); + call_graph_.OnSymbolRemoved(id); + } + + void SemanticModel::AddReference(symbol::SymbolId symbol_id, + const ast::Location& location, + bool is_definition, + bool is_write) + { + reference_graph_.AddReference(symbol_id, location, is_definition, is_write); + } + + void SemanticModel::AddInheritance(symbol::SymbolId derived, + symbol::SymbolId base) + { + inheritance_graph_.AddInheritance(derived, base); + } + + void SemanticModel::AddCall(symbol::SymbolId caller, symbol::SymbolId callee, + const ast::Location& location) + { + call_graph_.AddCall(caller, callee, location); + } + +} // namespace lsp::language::semantic diff --git a/lsp-server/src/language/semantic/semantic_model.hpp b/lsp-server/src/language/semantic/semantic_model.hpp new file mode 100644 index 0000000..34b6848 --- /dev/null +++ b/lsp-server/src/language/semantic/semantic_model.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include + +#include "../symbol/table.hpp" +#include "./graph/call.hpp" +#include "./graph/inheritance.hpp" +#include "./graph/reference.hpp" +#include "./name_resolver.hpp" +#include "./type_system.hpp" + +namespace lsp::language::semantic +{ + + /** + * SemanticModel - 完整的语义模型 + * + * 职责: + * 1. 管理语义关系图(引用图、继承图、调用图) + * 2. 提供类型系统 + * 3. 提供名称解析服务 + * 4. 协调各个语义组件 + * + * 这是语义分析层的核心类,整合了所有语义分析功能 + */ + class SemanticModel + { + public: + explicit SemanticModel(const symbol::SymbolTable& symbol_table); + + // ===== 生命周期管理 ===== + + /** + * 清空所有语义信息 + */ + void Clear(); + + /** + * 当符号被删除时调用,清理相关的语义信息 + */ + void OnSymbolRemoved(symbol::SymbolId id); + + // ===== 关系图管理 ===== + + /** + * 添加符号引用 + */ + void AddReference( + symbol::SymbolId symbol_id, + const ast::Location& location, + bool is_definition = false, + bool is_write = false); + + /** + * 添加继承关系 + */ + void AddInheritance(symbol::SymbolId derived, symbol::SymbolId base); + + /** + * 添加调用关系 + */ + void AddCall( + symbol::SymbolId caller, + symbol::SymbolId callee, + const ast::Location& location); + + // ===== 访问器 ===== + + graph::Reference& references() { return reference_graph_; } + graph::Inheritance& inheritance() { return inheritance_graph_; } + graph::Call& calls() { return call_graph_; } + + const graph::Reference& references() const { return reference_graph_; } + const graph::Inheritance& inheritance() const { return inheritance_graph_; } + const graph::Call& calls() const { return call_graph_; } + + TypeSystem& type_system() { return type_system_; } + const TypeSystem& type_system() const { return type_system_; } + + NameResolver& name_resolver() { return *name_resolver_; } + const NameResolver& name_resolver() const { return *name_resolver_; } + + // ===== 辅助方法 ===== + + /** + * 获取符号的类型 + */ + std::shared_ptr GetSymbolType(symbol::SymbolId symbol_id) const + { + return type_system_.GetSymbolType(symbol_id); + } + + /** + * 设置符号的类型 + */ + void SetSymbolType(symbol::SymbolId symbol_id, std::shared_ptr type) + { + type_system_.RegisterSymbolType(symbol_id, std::move(type)); + } + + /** + * 检查一个类是否是另一个类的子类 + */ + bool IsSubclassOf(symbol::SymbolId derived, symbol::SymbolId base) const + { + return inheritance_graph_.IsSubclassOf(derived, base); + } + + private: + // 符号表引用(只读) + const symbol::SymbolTable& symbol_table_; + + // 语义关系图 + graph::Reference reference_graph_; + graph::Inheritance inheritance_graph_; + graph::Call call_graph_; + + // 类型系统 + TypeSystem type_system_; + + // 名称解析器 + std::unique_ptr name_resolver_; + }; + +} // namespace lsp::language::semantic diff --git a/lsp-server/src/language/semantic/type_system.cpp b/lsp-server/src/language/semantic/type_system.cpp new file mode 100644 index 0000000..79520b6 --- /dev/null +++ b/lsp-server/src/language/semantic/type_system.cpp @@ -0,0 +1,550 @@ +#include "./type_system.hpp" + +#include + +namespace lsp::language::semantic +{ + + // ===== PrimitiveType Implementation ===== + + std::string PrimitiveType::ToString() const + { + switch (kind_) + { + case PrimitiveTypeKind::kInt: + return "int"; + case PrimitiveTypeKind::kFloat: + return "float"; + case PrimitiveTypeKind::kString: + return "string"; + case PrimitiveTypeKind::kBool: + return "bool"; + case PrimitiveTypeKind::kChar: + return "char"; + default: + return "unknown"; + } + } + + // ===== Type Implementation ===== + + TypeKind Type::kind() const + { + return std::visit( + [](const auto& type_data) -> TypeKind { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + return TypeKind::kPrimitive; + } + else if constexpr (std::is_same_v) + { + return TypeKind::kClass; + } + else if constexpr (std::is_same_v) + { + return TypeKind::kArray; + } + else if constexpr (std::is_same_v) + { + return TypeKind::kFunction; + } + else if constexpr (std::is_same_v) + { + return TypeKind::kOptional; + } + else if constexpr (std::is_same_v) + { + return TypeKind::kVoid; + } + else if constexpr (std::is_same_v) + { + return TypeKind::kUnknown; + } + else if constexpr (std::is_same_v) + { + return TypeKind::kError; + } + return TypeKind::kUnknown; + }, + data_); + } + + std::string Type::ToString() const + { + return std::visit( + [](const auto& type_data) -> std::string { + using T = std::decay_t; + if constexpr (std::is_same_v) + { + return type_data.ToString(); + } + else if constexpr (std::is_same_v) + { + return "class#" + std::to_string(type_data.class_id()); + } + else if constexpr (std::is_same_v) + { + return "array<" + type_data.element_type().ToString() + ">"; + } + else if constexpr (std::is_same_v) + { + std::string result = "function("; + const auto& params = type_data.param_types(); + for (size_t i = 0; i < params.size(); ++i) + { + if (i > 0) + result += ", "; + result += params[i]->ToString(); + } + result += ") -> " + type_data.return_type().ToString(); + return result; + } + else if constexpr (std::is_same_v) + { + return type_data.inner_type().ToString() + "?"; + } + else if constexpr (std::is_same_v) + { + return "void"; + } + else if constexpr (std::is_same_v) + { + return "unknown"; + } + else if constexpr (std::is_same_v) + { + return "error"; + } + return "unknown"; + }, + data_); + } + + bool Type::Equals(const Type& other) const + { + if (kind() != other.kind()) + { + return false; + } + + return std::visit( + [&other](const auto& type_data) -> bool { + using T = std::decay_t; + const auto* other_data = other.As(); + if (!other_data) + { + return false; + } + + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) + { + return type_data == *other_data; + } + else if constexpr (std::is_same_v) + { + return type_data.element_type().Equals(other_data->element_type()); + } + else if constexpr (std::is_same_v) + { + return type_data.inner_type().Equals(other_data->inner_type()); + } + else if constexpr (std::is_same_v) + { + const auto& params1 = type_data.param_types(); + const auto& params2 = other_data->param_types(); + if (params1.size() != params2.size()) + { + return false; + } + for (size_t i = 0; i < params1.size(); ++i) + { + if (!params1[i]->Equals(*params2[i])) + { + return false; + } + } + return type_data.return_type().Equals(other_data->return_type()); + } + return false; + }, + data_); + } + + // ===== TypeSystem Implementation ===== + + TypeSystem::TypeSystem() + { + // 初始化内置类型 + int_type_ = std::make_shared(PrimitiveType(PrimitiveTypeKind::kInt)); + float_type_ = std::make_shared(PrimitiveType(PrimitiveTypeKind::kFloat)); + string_type_ = std::make_shared(PrimitiveType(PrimitiveTypeKind::kString)); + bool_type_ = std::make_shared(PrimitiveType(PrimitiveTypeKind::kBool)); + char_type_ = std::make_shared(PrimitiveType(PrimitiveTypeKind::kChar)); + void_type_ = std::make_shared(VoidType()); + unknown_type_ = std::make_shared(UnknownType()); + error_type_ = std::make_shared(ErrorType()); + + // 注册类型名称 + type_by_name_["int"] = int_type_; + type_by_name_["integer"] = int_type_; + type_by_name_["float"] = float_type_; + type_by_name_["double"] = float_type_; + type_by_name_["string"] = string_type_; + type_by_name_["bool"] = bool_type_; + type_by_name_["boolean"] = bool_type_; + type_by_name_["char"] = char_type_; + type_by_name_["void"] = void_type_; + } + + std::shared_ptr TypeSystem::CreateClassType(symbol::SymbolId class_id) + { + return std::make_shared(ClassType(class_id)); + } + + std::shared_ptr TypeSystem::CreateArrayType( + std::shared_ptr element_type) + { + return std::make_shared(ArrayType(std::move(element_type))); + } + + std::shared_ptr TypeSystem::CreateFunctionType( + std::vector> param_types, + std::shared_ptr return_type) + { + return std::make_shared( + FunctionType(std::move(param_types), std::move(return_type))); + } + + std::shared_ptr TypeSystem::CreateOptionalType( + std::shared_ptr inner_type) + { + return std::make_shared(OptionalType(std::move(inner_type))); + } + + std::shared_ptr TypeSystem::GetTypeByName(const std::string& type_name) + { + auto it = type_by_name_.find(type_name); + if (it != type_by_name_.end()) + { + return it->second; + } + return unknown_type_; + } + + std::shared_ptr TypeSystem::GetSymbolType(symbol::SymbolId symbol_id) + { + auto it = symbol_types_.find(symbol_id); + if (it != symbol_types_.end()) + { + return it->second; + } + return unknown_type_; + } + + void TypeSystem::RegisterSymbolType(symbol::SymbolId symbol_id, + std::shared_ptr type) + { + symbol_types_[symbol_id] = std::move(type); + } + + TypeCompatibility TypeSystem::CheckCompatibility(const Type& from, + const Type& to) const + { + // 相同类型,完全兼容 + if (from.Equals(to)) + { + return TypeCompatibility::Exact(); + } + + // 错误类型与任何类型兼容 + if (from.kind() == TypeKind::kError || to.kind() == TypeKind::kError) + { + return TypeCompatibility::Exact(); + } + + // 未知类型与任何类型兼容(用于推断失败的情况) + if (from.kind() == TypeKind::kUnknown || to.kind() == TypeKind::kUnknown) + { + return TypeCompatibility::Exact(); + } + + // 基本类型之间的兼容性 + if (from.kind() == TypeKind::kPrimitive && to.kind() == TypeKind::kPrimitive) + { + return CheckPrimitiveCompatibility(*from.As(), + *to.As()); + } + + // 类类型之间的兼容性(继承关系) + if (from.kind() == TypeKind::kClass && to.kind() == TypeKind::kClass) + { + return CheckClassCompatibility(*from.As(), + *to.As()); + } + + // 数组类型的兼容性 + if (from.kind() == TypeKind::kArray && to.kind() == TypeKind::kArray) + { + return CheckArrayCompatibility(*from.As(), + *to.As()); + } + + // 函数类型的兼容性 + if (from.kind() == TypeKind::kFunction && to.kind() == TypeKind::kFunction) + { + return CheckFunctionCompatibility(*from.As(), + *to.As()); + } + + // 可选类型:T 可以赋值给 T? + if (to.kind() == TypeKind::kOptional) + { + const auto& inner = to.As()->inner_type(); + return CheckCompatibility(from, inner); + } + + return TypeCompatibility::Incompatible(); + } + + bool TypeSystem::IsAssignable(const Type& from, const Type& to) const + { + return CheckCompatibility(from, to).is_compatible; + } + + bool TypeSystem::RequiresExplicitCast(const Type& from, const Type& to) const + { + auto compat = CheckCompatibility(from, to); + return compat.is_compatible && compat.requires_cast; + } + + TypeCompatibility TypeSystem::CheckPrimitiveCompatibility( + const PrimitiveType& from, + const PrimitiveType& to) const + { + using Kind = PrimitiveTypeKind; + + // int -> float (隐式转换,代价为 1) + if (from.kind() == Kind::kInt && to.kind() == Kind::kFloat) + { + return TypeCompatibility::Implicit(1); + } + + // float -> int (需要显式转换,可能丢失精度) + if (from.kind() == Kind::kFloat && to.kind() == Kind::kInt) + { + return TypeCompatibility::ExplicitCast(2); + } + + // char -> string (隐式转换) + if (from.kind() == Kind::kChar && to.kind() == Kind::kString) + { + return TypeCompatibility::Implicit(1); + } + + // 其他基本类型不兼容 + return TypeCompatibility::Incompatible(); + } + + TypeCompatibility TypeSystem::CheckClassCompatibility( + const ClassType& from, + const ClassType& to) const + { + // 检查继承关系(如果有继承检查器) + if (is_subclass_of_ && is_subclass_of_(from.class_id(), to.class_id())) + { + // 子类可以隐式转换为父类 + return TypeCompatibility::Implicit(1); + } + + return TypeCompatibility::Incompatible(); + } + + TypeCompatibility TypeSystem::CheckArrayCompatibility( + const ArrayType& from, + const ArrayType& to) const + { + // 数组元素类型必须完全匹配(协变/逆变问题) + if (from.element_type().Equals(to.element_type())) + { + return TypeCompatibility::Exact(); + } + + return TypeCompatibility::Incompatible(); + } + + TypeCompatibility TypeSystem::CheckFunctionCompatibility( + const FunctionType& from, + const FunctionType& to) const + { + // 函数类型必须完全匹配(参数和返回值) + const auto& from_params = from.param_types(); + const auto& to_params = to.param_types(); + + if (from_params.size() != to_params.size()) + { + return TypeCompatibility::Incompatible(); + } + + for (size_t i = 0; i < from_params.size(); ++i) + { + if (!from_params[i]->Equals(*to_params[i])) + { + return TypeCompatibility::Incompatible(); + } + } + + if (!from.return_type().Equals(to.return_type())) + { + return TypeCompatibility::Incompatible(); + } + + return TypeCompatibility::Exact(); + } + + std::shared_ptr TypeSystem::InferBinaryExpressionType( + const Type& left, + const Type& right, + const std::string& op) const + { + // 算术运算符 + if (op == "+" || op == "-" || op == "*" || op == "/" || op == "%") + { + // int + int -> int + if (left.kind() == TypeKind::kPrimitive && + right.kind() == TypeKind::kPrimitive) + { + const auto* left_prim = left.As(); + const auto* right_prim = right.As(); + + using Kind = PrimitiveTypeKind; + + // 如果任一操作数是 float,结果是 float + if (left_prim->kind() == Kind::kFloat || + right_prim->kind() == Kind::kFloat) + { + return float_type_; + } + + // 两个 int -> int + if (left_prim->kind() == Kind::kInt && + right_prim->kind() == Kind::kInt) + { + return int_type_; + } + } + + // string + string -> string (字符串连接) + if (op == "+" && left.kind() == TypeKind::kPrimitive && + right.kind() == TypeKind::kPrimitive) + { + const auto* left_prim = left.As(); + const auto* right_prim = right.As(); + + if (left_prim->kind() == PrimitiveTypeKind::kString || + right_prim->kind() == PrimitiveTypeKind::kString) + { + return string_type_; + } + } + } + + // 比较运算符 + if (op == "==" || op == "!=" || op == "<" || op == ">" || op == "<=" || + op == ">=") + { + return bool_type_; + } + + // 逻辑运算符 + if (op == "&&" || op == "||" || op == "and" || op == "or") + { + return bool_type_; + } + + // 位运算符 + if (op == "&" || op == "|" || op == "^" || op == "<<" || op == ">>") + { + return int_type_; + } + + return unknown_type_; + } + + std::shared_ptr TypeSystem::InferUnaryExpressionType( + const Type& operand, + const std::string& op) const + { + // 数值取负 + if (op == "-" || op == "+") + { + if (operand.kind() == TypeKind::kPrimitive) + { + const auto* prim = operand.As(); + if (prim->kind() == PrimitiveTypeKind::kInt || + prim->kind() == PrimitiveTypeKind::kFloat) + { + return std::make_shared(*prim); + } + } + } + + // 逻辑取反 + if (op == "!" || op == "not") + { + return bool_type_; + } + + // 位取反 + if (op == "~") + { + return int_type_; + } + + // 自增/自减 + if (op == "++" || op == "--") + { + return std::make_shared(operand); + } + + return unknown_type_; + } + + std::shared_ptr TypeSystem::InferLiteralType( + const std::string& literal_value) const + { + // 简单的字面量类型推断 + if (literal_value.empty()) + { + return unknown_type_; + } + + // 布尔值 + if (literal_value == "true" || literal_value == "false") + { + return bool_type_; + } + + // 字符串(以引号开头) + if (literal_value[0] == '"' || literal_value[0] == '\'') + { + return string_type_; + } + + // 数字 + bool has_dot = literal_value.find('.') != std::string::npos; + if (has_dot) + { + return float_type_; + } + else + { + return int_type_; + } + } + +} // namespace lsp::language::semantic diff --git a/lsp-server/src/language/semantic/type_system.hpp b/lsp-server/src/language/semantic/type_system.hpp new file mode 100644 index 0000000..09b084b --- /dev/null +++ b/lsp-server/src/language/semantic/type_system.hpp @@ -0,0 +1,406 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "../symbol/types.hpp" + +namespace lsp::language::semantic +{ + + // Forward declarations + class Type; + class TypeSystem; + + // ===== Type Kinds ===== + + enum class TypeKind + { + kPrimitive, // 基本类型:int, float, string, bool + kClass, // 类类型 + kArray, // 数组类型 + kFunction, // 函数类型 + kOptional, // 可选类型(可为 null) + kVoid, // void 类型 + kUnknown, // 未知类型 + kError // 错误类型(类型检查失败) + }; + + // ===== Primitive Type Kinds ===== + + enum class PrimitiveTypeKind + { + kInt, + kFloat, + kString, + kBool, + kChar + }; + + // ===== Type Compatibility ===== + + /** + * 类型兼容性结果 + */ + struct TypeCompatibility + { + bool is_compatible = false; // 是否兼容 + int conversion_cost = -1; // 转换代价(-1 表示不兼容,0 表示无需转换) + bool requires_cast = false; // 是否需要显式转换 + + static TypeCompatibility Exact() + { + return {true, 0, false}; + } + + static TypeCompatibility Implicit(int cost) + { + return {true, cost, false}; + } + + static TypeCompatibility ExplicitCast(int cost) + { + return {true, cost, true}; + } + + static TypeCompatibility Incompatible() + { + return {false, -1, false}; + } + }; + + // ===== Type Representations ===== + + /** + * 基本类型 + */ + class PrimitiveType + { + public: + explicit PrimitiveType(PrimitiveTypeKind kind) : kind_(kind) {} + + PrimitiveTypeKind kind() const { return kind_; } + std::string ToString() const; + + bool operator==(const PrimitiveType& other) const + { + return kind_ == other.kind_; + } + + private: + PrimitiveTypeKind kind_; + }; + + /** + * 类类型 + */ + class ClassType + { + public: + explicit ClassType(symbol::SymbolId class_id) : class_id_(class_id) {} + + symbol::SymbolId class_id() const { return class_id_; } + + bool operator==(const ClassType& other) const + { + return class_id_ == other.class_id_; + } + + private: + symbol::SymbolId class_id_; + }; + + /** + * 数组类型 + */ + class ArrayType + { + public: + explicit ArrayType(std::shared_ptr element_type) + : element_type_(std::move(element_type)) {} + + const Type& element_type() const { return *element_type_; } + std::shared_ptr element_type_ptr() const { return element_type_; } + + private: + std::shared_ptr element_type_; + }; + + /** + * 函数类型 + */ + class FunctionType + { + public: + FunctionType(std::vector> param_types, + std::shared_ptr return_type) + : param_types_(std::move(param_types)), + return_type_(std::move(return_type)) {} + + const std::vector>& param_types() const + { + return param_types_; + } + + const Type& return_type() const { return *return_type_; } + std::shared_ptr return_type_ptr() const { return return_type_; } + + private: + std::vector> param_types_; + std::shared_ptr return_type_; + }; + + /** + * 可选类型(可为 null) + */ + class OptionalType + { + public: + explicit OptionalType(std::shared_ptr inner_type) + : inner_type_(std::move(inner_type)) {} + + const Type& inner_type() const { return *inner_type_; } + std::shared_ptr inner_type_ptr() const { return inner_type_; } + + private: + std::shared_ptr inner_type_; + }; + + /** + * Void 类型 + */ + class VoidType + { + public: + bool operator==(const VoidType&) const { return true; } + }; + + /** + * 未知类型 + */ + class UnknownType + { + public: + bool operator==(const UnknownType&) const { return true; } + }; + + /** + * 错误类型 + */ + class ErrorType + { + public: + explicit ErrorType(std::string message = "") + : message_(std::move(message)) {} + + const std::string& message() const { return message_; } + + bool operator==(const ErrorType&) const { return true; } + + private: + std::string message_; + }; + + // ===== Type Variant ===== + + using TypeData = std::variant< + PrimitiveType, + ClassType, + ArrayType, + FunctionType, + OptionalType, + VoidType, + UnknownType, + ErrorType>; + + // ===== Type Class ===== + + /** + * Type - 类型表示 + * + * 统一的类型表示,支持多种类型 + */ + class Type + { + public: + explicit Type(TypeData data) : data_(std::move(data)) {} + + // Type checking + template + bool Is() const + { + return std::holds_alternative(data_); + } + + template + const T* As() const + { + return std::get_if(&data_); + } + + template + T* As() + { + return std::get_if(&data_); + } + + TypeKind kind() const; + std::string ToString() const; + + // Equality comparison + bool Equals(const Type& other) const; + + const TypeData& data() const { return data_; } + + private: + TypeData data_; + }; + + // ===== Type System ===== + + /** + * TypeSystem - 类型系统 + * + * 负责: + * 1. 管理内置类型 + * 2. 类型兼容性检查 + * 3. 类型推断 + * 4. 类型转换规则 + */ + class TypeSystem + { + public: + TypeSystem(); + + // ===== 获取内置类型 ===== + + std::shared_ptr GetIntType() const { return int_type_; } + std::shared_ptr GetFloatType() const { return float_type_; } + std::shared_ptr GetStringType() const { return string_type_; } + std::shared_ptr GetBoolType() const { return bool_type_; } + std::shared_ptr GetCharType() const { return char_type_; } + std::shared_ptr GetVoidType() const { return void_type_; } + std::shared_ptr GetUnknownType() const { return unknown_type_; } + std::shared_ptr GetErrorType() const { return error_type_; } + + // ===== 创建复合类型 ===== + + std::shared_ptr CreateClassType(symbol::SymbolId class_id); + std::shared_ptr CreateArrayType(std::shared_ptr element_type); + std::shared_ptr CreateFunctionType( + std::vector> param_types, + std::shared_ptr return_type); + std::shared_ptr CreateOptionalType(std::shared_ptr inner_type); + + // ===== 类型查询 ===== + + /** + * 根据类型名称获取类型 + * @param type_name 类型名称(如 "int", "string") + * @return 类型指针,如果未找到返回 UnknownType + */ + std::shared_ptr GetTypeByName(const std::string& type_name); + + /** + * 根据符号ID获取类型 + * @param symbol_id 符号ID + * @return 类型指针 + */ + std::shared_ptr GetSymbolType(symbol::SymbolId symbol_id); + + /** + * 注册符号类型(用于记录变量、函数等的类型) + */ + void RegisterSymbolType(symbol::SymbolId symbol_id, std::shared_ptr type); + + // ===== 类型兼容性检查 ===== + + /** + * 检查两个类型是否兼容(from 能否赋值给 to) + */ + TypeCompatibility CheckCompatibility(const Type& from, const Type& to) const; + + /** + * 检查类型是否可以赋值 + */ + bool IsAssignable(const Type& from, const Type& to) const; + + /** + * 检查是否需要显式转换 + */ + bool RequiresExplicitCast(const Type& from, const Type& to) const; + + // ===== 类型推断 ===== + + /** + * 推断二元表达式的类型 + */ + std::shared_ptr InferBinaryExpressionType( + const Type& left, + const Type& right, + const std::string& op) const; + + /** + * 推断一元表达式的类型 + */ + std::shared_ptr InferUnaryExpressionType( + const Type& operand, + const std::string& op) const; + + /** + * 推断字面量类型 + */ + std::shared_ptr InferLiteralType(const std::string& literal_value) const; + + // ===== 继承关系检查(需要语义模型支持)===== + + void SetInheritanceChecker( + std::function checker) + { + is_subclass_of_ = std::move(checker); + } + + private: + // 内置类型 + std::shared_ptr int_type_; + std::shared_ptr float_type_; + std::shared_ptr string_type_; + std::shared_ptr bool_type_; + std::shared_ptr char_type_; + std::shared_ptr void_type_; + std::shared_ptr unknown_type_; + std::shared_ptr error_type_; + + // 类型名称映射 + std::unordered_map> type_by_name_; + + // 符号类型映射 + std::unordered_map> symbol_types_; + + // 继承关系检查器(由外部注入) + std::function + is_subclass_of_; + + // 辅助方法 + TypeCompatibility CheckPrimitiveCompatibility( + const PrimitiveType& from, + const PrimitiveType& to) const; + + TypeCompatibility CheckClassCompatibility( + const ClassType& from, + const ClassType& to) const; + + TypeCompatibility CheckArrayCompatibility( + const ArrayType& from, + const ArrayType& to) const; + + TypeCompatibility CheckFunctionCompatibility( + const FunctionType& from, + const FunctionType& to) const; + }; + +} // namespace lsp::language::semantic diff --git a/lsp-server/src/language/semantic/types.hpp b/lsp-server/src/language/semantic/types.hpp new file mode 100644 index 0000000..0c611fd --- /dev/null +++ b/lsp-server/src/language/semantic/types.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include "../ast/types.hpp" +#include "../symbol/types.hpp" + +namespace lsp::language::semantic +{ + + // ===== Semantic Relationship Types ===== + + /** + * Reference - 符号引用信息 + * 记录代码中对符号的引用位置和类型 + */ + struct Reference + { + ast::Location location; // 引用位置 + symbol::SymbolId symbol_id; // 被引用的符号ID + bool is_definition; // 是否是定义 + bool is_write; // 是否是写引用 + }; + + /** + * Call - 函数/方法调用关系 + * 记录函数调用关系,用于构建调用图 + */ + struct Call + { + symbol::SymbolId caller; // 调用者符号ID + symbol::SymbolId callee; // 被调用者符号ID + ast::Location call_site; // 调用位置 + }; + + /** + * Inheritance - 继承关系 + * 记录类继承关系信息 + */ + struct Inheritance + { + symbol::SymbolId derived; // 派生类符号ID + symbol::SymbolId base; // 基类符号ID + }; + +} // namespace lsp::language::semantic diff --git a/lsp-server/src/language/symbol/builder.cpp b/lsp-server/src/language/symbol/builder.cpp index 55ad9b4..e88a76a 100644 --- a/lsp-server/src/language/symbol/builder.cpp +++ b/lsp-server/src/language/symbol/builder.cpp @@ -1,759 +1,929 @@ -#include "./builder.hpp" - #include -namespace lsp::language::symbol { -Builder::Builder(SymbolTable& table) - : table_(table), - current_scope_id_(kInvalidScopeId), - in_interface_section_(false) {} +#include "./builder.hpp" -void Builder::Build(ast::ASTNode& root) { root.Accept(*this); } +namespace lsp::language::symbol +{ + Builder::Builder(SymbolTable& table) : table_(table), + current_scope_id_(kInvalidScopeId), + in_interface_section_(false) {} -SymbolId Builder::CreateSymbol(const std::string& name, SymbolKind kind, - const ast::Location& location, - const std::optional& type_hint) { - std::optional visibility; - if (in_interface_section_) visibility = UnitVisibility::kInterface; + void Builder::Build(ast::ASTNode& root) { root.Accept(*this); } + + SymbolId Builder::CreateSymbol(const std::string& name, protocol::SymbolKind kind, const ast::Location& location, const std::optional& type_hint) + { + std::optional visibility; + if (in_interface_section_) + visibility = UnitVisibility::kInterface; + + Symbol symbol = [&]() -> Symbol { + switch (kind) + { + case protocol::SymbolKind::Class: + { + Class cls; + cls.name = name; + cls.selection_range = location; + cls.range = location; + cls.unit_visibility = visibility; + return Symbol(std::move(cls)); + } + case protocol::SymbolKind::Function: + { + Function fn; + fn.name = name; + fn.selection_range = location; + fn.range = location; + fn.declaration_range = location; + fn.return_type = type_hint; + fn.unit_visibility = visibility; + return Symbol(std::move(fn)); + } + case protocol::SymbolKind::Method: + { + Method method; + method.name = name; + method.selection_range = location; + method.range = location; + method.declaration_range = location; + method.return_type = type_hint; + return Symbol(std::move(method)); + } + case protocol::SymbolKind::Property: + { + Property property; + property.name = name; + property.selection_range = location; + property.range = location; + property.type = type_hint; + return Symbol(std::move(property)); + } + case protocol::SymbolKind::Field: + { + Field field; + field.name = name; + field.selection_range = location; + field.range = location; + field.type = type_hint; + return Symbol(std::move(field)); + } + case protocol::SymbolKind::Variable: + { + Variable var; + var.name = name; + var.selection_range = location; + var.range = location; + var.type = type_hint; + var.unit_visibility = visibility; + return Symbol(std::move(var)); + } + case protocol::SymbolKind::Constant: + { + Constant constant; + constant.name = name; + constant.selection_range = location; + constant.range = location; + constant.type = type_hint; + return Symbol(std::move(constant)); + } + default: + { + Variable fallback; + fallback.name = name; + fallback.selection_range = location; + fallback.range = location; + fallback.type = type_hint; + fallback.unit_visibility = visibility; + return Symbol(std::move(fallback)); + } + } + }(); + + SymbolId id = table_.CreateSymbol(std::move(symbol)); + + if (current_scope_id_ != kInvalidScopeId) + { + table_.AddSymbolToScope(current_scope_id_, name, id); + } + + return id; + } + + SymbolId Builder::CreateSymbol(const std::string& name, protocol::SymbolKind kind, const ast::ASTNode& node, const std::optional& type_hint) + { + return CreateSymbol(name, kind, node.span, type_hint); + } + + SymbolId Builder::CreateFunctionSymbol( + const std::string& name, + const ast::Location& location, + const std::vector>& parameters, + const std::optional& return_type) + { + std::optional visibility; + if (in_interface_section_) + visibility = UnitVisibility::kInterface; - Symbol symbol = [&]() -> Symbol { - switch (kind) { - case SymbolKind::Class: { - Class cls; - cls.name = name; - cls.selection_range = location; - cls.range = location; - cls.unit_visibility = visibility; - return Symbol(std::move(cls)); - } - case SymbolKind::Function: { Function fn; fn.name = name; fn.selection_range = location; fn.range = location; fn.declaration_range = location; - fn.return_type = type_hint; + fn.return_type = ExtractTypeName(return_type); fn.unit_visibility = visibility; - return Symbol(std::move(fn)); - } - case SymbolKind::Method: { + fn.parameters = BuildParameters(parameters); + + Symbol symbol(std::move(fn)); + SymbolId id = table_.CreateSymbol(std::move(symbol)); + + if (current_scope_id_ != kInvalidScopeId) + { + table_.AddSymbolToScope(current_scope_id_, name, id); + } + + return id; + } + + SymbolId Builder::CreateMethodSymbol( + const std::string& name, + const ast::Location& location, + const std::vector>& parameters, + const std::optional& return_type) + { Method method; method.name = name; method.selection_range = location; method.range = location; method.declaration_range = location; - method.return_type = type_hint; - return Symbol(std::move(method)); - } - case SymbolKind::Property: { - Property property; - property.name = name; - property.selection_range = location; - property.range = location; - property.type = type_hint; - return Symbol(std::move(property)); - } - case SymbolKind::Field: { - Field field; - field.name = name; - field.selection_range = location; - field.range = location; - field.type = type_hint; - return Symbol(std::move(field)); - } - case SymbolKind::Variable: { - Variable var; - var.name = name; - var.selection_range = location; - var.range = location; - var.type = type_hint; - var.unit_visibility = visibility; - return Symbol(std::move(var)); - } - case SymbolKind::Constant: { - Constant constant; - constant.name = name; - constant.selection_range = location; - constant.range = location; - constant.type = type_hint; - return Symbol(std::move(constant)); - } - default: { - Variable fallback; - fallback.name = name; - fallback.selection_range = location; - fallback.range = location; - fallback.type = type_hint; - fallback.unit_visibility = visibility; - return Symbol(std::move(fallback)); - } - } - }(); + method.return_type = ExtractTypeName(return_type); + method.parameters = BuildParameters(parameters); - SymbolId id = table_.CreateSymbol(std::move(symbol)); + Symbol symbol(std::move(method)); + SymbolId id = table_.CreateSymbol(std::move(symbol)); - if (current_scope_id_ != kInvalidScopeId) { - table_.AddSymbolToScope(current_scope_id_, name, id); - } - - return id; -} - -SymbolId Builder::CreateSymbol(const std::string& name, SymbolKind kind, - const ast::ASTNode& node, - const std::optional& type_hint) { - return CreateSymbol(name, kind, node.span, type_hint); -} - -SymbolId Builder::CreateFunctionSymbol( - const std::string& name, const ast::Location& location, - const std::vector>& parameters, - const std::optional& return_type) { - std::optional visibility; - if (in_interface_section_) visibility = UnitVisibility::kInterface; - - Function fn; - fn.name = name; - fn.selection_range = location; - fn.range = location; - fn.declaration_range = location; - fn.return_type = ExtractTypeName(return_type); - fn.unit_visibility = visibility; - fn.parameters = BuildParameters(parameters); - - Symbol symbol(std::move(fn)); - SymbolId id = table_.CreateSymbol(std::move(symbol)); - - if (current_scope_id_ != kInvalidScopeId) { - table_.AddSymbolToScope(current_scope_id_, name, id); - } - - return id; -} - -SymbolId Builder::CreateMethodSymbol( - const std::string& name, const ast::Location& location, - const std::vector>& parameters, - const std::optional& return_type) { - Method method; - method.name = name; - method.selection_range = location; - method.range = location; - method.declaration_range = location; - method.return_type = ExtractTypeName(return_type); - method.parameters = BuildParameters(parameters); - - Symbol symbol(std::move(method)); - SymbolId id = table_.CreateSymbol(std::move(symbol)); - - if (current_scope_id_ != kInvalidScopeId) { - table_.AddSymbolToScope(current_scope_id_, name, id); - } - - return id; -} - -ScopeId Builder::EnterScopeWithSymbol(ScopeKind kind, SymbolId symbol_id, - const ast::Location& range) { - current_scope_id_ = - table_.CreateScope(kind, range, current_scope_id_, symbol_id); - return current_scope_id_; -} - -ScopeId Builder::EnterScope(ScopeKind kind, const ast::Location& range) { - current_scope_id_ = - table_.CreateScope(kind, range, current_scope_id_, std::nullopt); - return current_scope_id_; -} - -void Builder::ExitScope() { - const auto* scope_info = table_.scopes().scope(current_scope_id_); - if (scope_info && scope_info->parent) { - current_scope_id_ = *scope_info->parent; - } -} - -void Builder::VisitStatements( - const std::vector& statements) { - for (const auto& stmt : statements) { - if (stmt) stmt->Accept(*this); - } -} - -void Builder::VisitExpression(ast::Expression& expr) { expr.Accept(*this); } - -std::optional Builder::ExtractTypeName( - const std::optional& type) const { - if (type) return type->name; - return std::nullopt; -} - -std::vector Builder::BuildParameters( - const std::vector>& parameters) const { - std::vector result; - result.reserve(parameters.size()); - - for (const auto& param : parameters) { - if (!param) continue; - - language::symbol::Parameter p; - p.name = param->name; - if (param->type) p.type = param->type->name; - if (param->default_value) p.default_value = ""; - result.push_back(std::move(p)); - } - - return result; -} - -void Builder::VisitProgram(ast::Program& node) { - current_scope_id_ = table_.CreateScope(ScopeKind::kGlobal, node.span, - std::nullopt, std::nullopt); - VisitStatements(node.statements); -} - -void Builder::VisitUnitDefinition(ast::UnitDefinition& node) { - [[maybe_unused]] auto unit_scope = EnterScope(ScopeKind::kUnit, node.span); - - // Process interface section - in_interface_section_ = true; - VisitStatements(node.interface_statements); - - // Process implementation section - in_interface_section_ = false; - VisitStatements(node.implementation_statements); - - ExitScope(); -} - -void Builder::VisitClassDefinition(ast::ClassDefinition& node) { - auto class_id = CreateSymbol(node.name, SymbolKind::Class, node.location); - - [[maybe_unused]] auto class_scope = - EnterScopeWithSymbol(ScopeKind::kClass, class_id, node.span); - - auto prev_parent = current_parent_symbol_id_; - current_parent_symbol_id_ = class_id; - - for (auto& member : node.members) { - if (member) member->Accept(*this); - } - - current_parent_symbol_id_ = prev_parent; - ExitScope(); -} - -void Builder::VisitFunctionDeclaration(ast::FunctionDeclaration& node) { - auto location = ast::Location{}; - if (!node.name.empty() && !node.parameters.empty()) { - location = node.span; - } - CreateFunctionSymbol(node.name, location, node.parameters, node.return_type); -} - -void Builder::VisitFunctionDefinition(ast::FunctionDefinition& node) { - auto func_id = CreateFunctionSymbol(node.name, node.location, node.parameters, - node.return_type); - - if (node.body) { - [[maybe_unused]] auto func_scope = - EnterScopeWithSymbol(ScopeKind::kFunction, func_id, node.body->span); - - auto prev_function = current_function_id_; - current_function_id_ = func_id; - - for (auto& param : node.parameters) { - if (param) { - CreateSymbol(param->name, SymbolKind::Variable, param->location); - } - } - - VisitStatements(node.body->statements); - - current_function_id_ = prev_function; - ExitScope(); - } -} - -void Builder::VisitMethodDeclaration(ast::MethodDeclaration& node) { - auto method_id = CreateMethodSymbol(node.name, node.location, node.parameters, - node.return_type); - - if (node.body) { - [[maybe_unused]] auto method_scope = - EnterScopeWithSymbol(ScopeKind::kFunction, method_id, node.body->span); - - auto prev_function = current_function_id_; - current_function_id_ = method_id; - - for (auto& param : node.parameters) { - if (param) { - CreateSymbol(param->name, SymbolKind::Variable, param->location); - } - } - - VisitStatements(node.body->statements); - - current_function_id_ = prev_function; - ExitScope(); - } -} - -void Builder::VisitPropertyDeclaration(ast::PropertyDeclaration& node) { - auto type = ExtractTypeName(node.type); - CreateSymbol(node.name, SymbolKind::Property, node.location, type); -} - -void Builder::VisitExternalMethodDefinition( - ast::ExternalMethodDefinition& node) { - std::string method_name = node.name; - size_t dot_pos = method_name.find_last_of('.'); - if (dot_pos != std::string::npos) { - method_name = method_name.substr(dot_pos + 1); - } - - std::optional method_id; - - if (current_parent_symbol_id_) { - const auto* scope_info = table_.scopes().scope(current_scope_id_); - if (scope_info) { - method_id = - table_.scopes().FindSymbolInScope(current_scope_id_, method_name); - } - } - - if (!method_id) { - method_id = CreateMethodSymbol(method_name, node.location, node.parameters, - node.return_type); - } - - if (node.body) { - [[maybe_unused]] auto method_scope = - EnterScopeWithSymbol(ScopeKind::kFunction, *method_id, node.body->span); - - auto prev_function = current_function_id_; - current_function_id_ = *method_id; - - for (auto& param : node.parameters) { - if (param) { - CreateSymbol(param->name, SymbolKind::Variable, param->location); - } - } - - VisitStatements(node.body->statements); - - current_function_id_ = prev_function; - ExitScope(); - } -} - -void Builder::VisitClassMember(ast::ClassMember& node) { - std::visit( - [this](auto& member) { - if (member) member->Accept(*this); - }, - node.member); -} - -void Builder::VisitVarDeclaration(ast::VarDeclaration& node) { - CreateSymbol(node.name, SymbolKind::Variable, node.location, - ExtractTypeName(node.type)); - - if (node.initializer) { - VisitExpression(*node.initializer.value()); - } -} - -void Builder::VisitStaticDeclaration(ast::StaticDeclaration& node) { - CreateSymbol(node.name, SymbolKind::Variable, node.location, - ExtractTypeName(node.type)); - - if (node.initializer) { - VisitExpression(*node.initializer.value()); - } -} - -void Builder::VisitGlobalDeclaration(ast::GlobalDeclaration& node) { - CreateSymbol(node.name, SymbolKind::Variable, node.location, - ExtractTypeName(node.type)); - - if (node.initializer) { - VisitExpression(*node.initializer.value()); - } -} - -void Builder::VisitConstDeclaration(ast::ConstDeclaration& node) { - CreateSymbol(node.name, SymbolKind::Constant, node.location, - ExtractTypeName(node.type)); - - if (node.value) { - VisitExpression(*node.value); - } -} - -void Builder::VisitFieldDeclaration(ast::FieldDeclaration& node) { - CreateSymbol(node.name, SymbolKind::Field, node.location, - ExtractTypeName(node.type)); - - if (node.initializer) { - VisitExpression(*node.initializer.value()); - } -} - -void Builder::VisitUsesStatement(ast::UsesStatement& node) { (void)node; } - -void Builder::VisitIdentifier(ast::Identifier& node) { - auto symbol_id = - table_.scopes().FindSymbolInScopeChain(current_scope_id_, node.name); - if (symbol_id) { - table_.AddReference(*symbol_id, node.location, false, false); - } -} - -void Builder::VisitCallExpression(ast::CallExpression& node) { - if (node.callee) { - if (auto* id = dynamic_cast(node.callee.get())) { - auto symbol_id = - table_.scopes().FindSymbolInScopeChain(current_scope_id_, id->name); - if (symbol_id) { - table_.AddReference(*symbol_id, id->location, false, false); - - if (current_function_id_) { - table_.AddCall(*current_function_id_, *symbol_id, node.span); + if (current_scope_id_ != kInvalidScopeId) + { + table_.AddSymbolToScope(current_scope_id_, name, id); } - } - } else { - node.callee->Accept(*this); + + return id; } - } - for (auto& arg : node.arguments) { - if (arg.value) VisitExpression(*arg.value); - } -} + ScopeId Builder::EnterScopeWithSymbol(ScopeKind kind, SymbolId symbol_id, const ast::Location& range) + { + current_scope_id_ = + table_.CreateScope(kind, range, current_scope_id_, symbol_id); + return current_scope_id_; + } -void Builder::VisitAttributeExpression(ast::AttributeExpression& node) { - if (node.object) { - node.object->Accept(*this); - } + ScopeId Builder::EnterScope(ScopeKind kind, const ast::Location& range) + { + current_scope_id_ = + table_.CreateScope(kind, range, current_scope_id_, std::nullopt); + return current_scope_id_; + } - if (node.attribute) { - node.attribute->Accept(*this); - } -} - -void Builder::VisitAssignmentExpression(ast::AssignmentExpression& node) { - ProcessLValue(node.left, true); - - if (node.right) { - VisitExpression(*node.right); - } -} - -void Builder::ProcessLValue(const ast::LValue& lvalue, bool is_write) { - std::visit( - [this, is_write](auto& val) { - using T = std::decay_t; - if constexpr (std::is_same_v>) { - if (val) { - auto symbol_id = table_.scopes().FindSymbolInScopeChain( - current_scope_id_, val->name); - if (symbol_id) { - table_.AddReference(*symbol_id, val->location, false, is_write); - } - } - } else if constexpr (std::is_same_v>) { - if (val) { - if (val->object) { - val->object->Accept(*this); - } - if (val->attribute) { - val->attribute->Accept(*this); - } - } - } else if constexpr (std::is_same_v>) { - if (val) { - if (val->base) { - val->base->Accept(*this); - } - for (auto& idx : val->indices) { - if (idx.start) idx.start->Accept(*this); - if (idx.end) idx.end->Accept(*this); - if (idx.step) idx.step->Accept(*this); - } - } - } else if constexpr (std::is_same_v< - T, std::unique_ptr>) { - if (val && val->value) { - val->value->Accept(*this); - } + void Builder::ExitScope() + { + const auto* scope_info = table_.scopes().scope(current_scope_id_); + if (scope_info && scope_info->parent) + { + current_scope_id_ = *scope_info->parent; } - }, - lvalue); -} - -void Builder::VisitBlockStatement(ast::BlockStatement& node) { - [[maybe_unused]] auto block_scope = EnterScope(ScopeKind::kBlock, node.span); - VisitStatements(node.statements); - ExitScope(); -} - -void Builder::VisitIfStatement(ast::IfStatement& node) { - for (auto& branch : node.branches) { - if (branch.condition) VisitExpression(*branch.condition); - - if (branch.body) branch.body->Accept(*this); - } - - if (node.else_body && *node.else_body) { - (*node.else_body)->Accept(*this); - } -} - -void Builder::VisitForInStatement(ast::ForInStatement& node) { - [[maybe_unused]] auto for_scope = EnterScope(ScopeKind::kBlock, node.span); - - if (!node.key.empty()) { - CreateSymbol(node.key, SymbolKind::Variable, node.key_location); - } - - if (!node.value.empty()) { - CreateSymbol(node.value, SymbolKind::Variable, node.value_location); - } - - if (node.collection) VisitExpression(*node.collection); - - if (node.body) node.body->Accept(*this); - - ExitScope(); -} - -void Builder::VisitForToStatement(ast::ForToStatement& node) { - [[maybe_unused]] auto for_scope = EnterScope(ScopeKind::kBlock, node.span); - - CreateSymbol(node.counter, SymbolKind::Variable, node.counter_location); - - if (node.start) VisitExpression(*node.start); - if (node.end) VisitExpression(*node.end); - if (node.step) VisitExpression(*node.step); - if (node.body) node.body->Accept(*this); - - ExitScope(); -} - -void Builder::VisitWhileStatement(ast::WhileStatement& node) { - if (node.condition) VisitExpression(*node.condition); - if (node.body) node.body->Accept(*this); -} - -void Builder::VisitRepeatStatement(ast::RepeatStatement& node) { - VisitStatements(node.body); - if (node.condition) VisitExpression(*node.condition); -} - -void Builder::VisitCaseStatement(ast::CaseStatement& node) { - if (node.discriminant) VisitExpression(*node.discriminant); - - for (auto& branch : node.branches) { - for (auto& value : branch.values) { - if (value) VisitExpression(*value); - } - if (branch.body) branch.body->Accept(*this); - } - - if (node.else_body && *node.else_body) { - (*node.else_body)->Accept(*this); - } -} - -void Builder::VisitTryStatement(ast::TryStatement& node) { - if (node.try_body) { - VisitStatements(node.try_body->statements); - } - - if (node.except_body) { - VisitStatements(node.except_body->statements); - } -} - -void Builder::VisitMatrixIterationStatement( - ast::MatrixIterationStatement& node) { - [[maybe_unused]] auto iter_scope = EnterScope(ScopeKind::kBlock, node.span); - - if (node.target) VisitExpression(*node.target); - - if (node.body) VisitStatements(node.body->statements); - - ExitScope(); -} - -void Builder::VisitExpressionStatement(ast::ExpressionStatement& node) { - if (node.expression) VisitExpression(*node.expression); -} - -void Builder::VisitReturnStatement(ast::ReturnStatement& node) { - if (node.value && *node.value) VisitExpression(**node.value); -} - -void Builder::VisitBinaryExpression(ast::BinaryExpression& node) { - if (node.left) VisitExpression(*node.left); - if (node.right) VisitExpression(*node.right); -} - -void Builder::VisitTernaryExpression(ast::TernaryExpression& node) { - if (node.condition) VisitExpression(*node.condition); - if (node.consequence) VisitExpression(*node.consequence); - if (node.alternative) VisitExpression(*node.alternative); -} - -void Builder::VisitSubscriptExpression(ast::SubscriptExpression& node) { - if (node.base) VisitExpression(*node.base); - - for (auto& idx : node.indices) { - if (idx.start) VisitExpression(*idx.start); - if (idx.end) VisitExpression(*idx.end); - if (idx.step) VisitExpression(*idx.step); - } -} - -void Builder::VisitArrayExpression(ast::ArrayExpression& node) { - for (auto& elem : node.elements) { - if (elem.key) VisitExpression(*elem.key.value()); - if (elem.value) VisitExpression(*elem.value); - } -} - -void Builder::VisitAnonymousFunctionExpression( - ast::AnonymousFunctionExpression& node) { - auto func_id = CreateFunctionSymbol("", node.span, node.parameters, - node.return_type); - - if (node.body) { - [[maybe_unused]] auto func_scope = EnterScopeWithSymbol( - ScopeKind::kAnonymousFunction, func_id, node.body->span); - - auto prev_function = current_function_id_; - current_function_id_ = func_id; - - for (auto& param : node.parameters) { - if (param) { - CreateSymbol(param->name, SymbolKind::Variable, param->location); - } } - VisitStatements(node.body->statements); + void Builder::VisitStatements( + const std::vector& statements) + { + for (const auto& stmt : statements) + { + if (stmt) + stmt->Accept(*this); + } + } - current_function_id_ = prev_function; - ExitScope(); - } -} + void Builder::VisitExpression(ast::Expression& expr) { expr.Accept(*this); } -void Builder::VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + std::optional Builder::ExtractTypeName( + const std::optional& type) const + { + if (type) + return type->name; + return std::nullopt; + } -void Builder::VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + std::vector Builder::BuildParameters( + const std::vector>& parameters) const + { + std::vector result; + result.reserve(parameters.size()); -void Builder::VisitPrefixIncrementExpression( - ast::PrefixIncrementExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + for (const auto& param : parameters) + { + if (!param) + continue; -void Builder::VisitPrefixDecrementExpression( - ast::PrefixDecrementExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + language::symbol::Parameter p; + p.name = param->name; + if (param->type) + p.type = param->type->name; + if (param->default_value) + p.default_value = ""; + result.push_back(std::move(p)); + } -void Builder::VisitPostfixIncrementExpression( - ast::PostfixIncrementExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + return result; + } -void Builder::VisitPostfixDecrementExpression( - ast::PostfixDecrementExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + void Builder::VisitProgram(ast::Program& node) + { + current_scope_id_ = table_.CreateScope(ScopeKind::kGlobal, node.span, std::nullopt, std::nullopt); + VisitStatements(node.statements); + } -void Builder::VisitLogicalNotExpression(ast::LogicalNotExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + void Builder::VisitUnitDefinition(ast::UnitDefinition& node) + { + [[maybe_unused]] auto unit_scope = EnterScope(ScopeKind::kUnit, node.span); -void Builder::VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + // Process interface section + in_interface_section_ = true; + VisitStatements(node.interface_statements); -void Builder::VisitDerivativeExpression(ast::DerivativeExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + // Process implementation section + in_interface_section_ = false; + VisitStatements(node.implementation_statements); -void Builder::VisitMatrixTransposeExpression( - ast::MatrixTransposeExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + ExitScope(); + } -void Builder::VisitExprOperatorExpression(ast::ExprOperatorExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + void Builder::VisitClassDefinition(ast::ClassDefinition& node) + { + auto class_id = CreateSymbol(node.name, protocol::SymbolKind::Class, node.location); -void Builder::VisitFunctionPointerExpression( - [[maybe_unused]] ast::FunctionPointerExpression& node) { - if (node.argument) VisitExpression(*node.argument); -} + [[maybe_unused]] auto class_scope = + EnterScopeWithSymbol(ScopeKind::kClass, class_id, node.span); -void Builder::VisitNewExpression(ast::NewExpression& node) { - if (node.target) VisitExpression(*node.target); -} + auto prev_parent = current_parent_symbol_id_; + current_parent_symbol_id_ = class_id; -void Builder::VisitEchoExpression(ast::EchoExpression& node) { - for (auto& expr : node.expressions) { - if (expr) VisitExpression(*expr); - } -} + for (auto& member : node.members) + { + if (member) + member->Accept(*this); + } -void Builder::VisitRaiseExpression(ast::RaiseExpression& node) { - if (node.exception) VisitExpression(*node.exception); -} + current_parent_symbol_id_ = prev_parent; + ExitScope(); + } -void Builder::VisitInheritedExpression(ast::InheritedExpression& node) { - if (node.call && *node.call) { - (*node.call)->Accept(*this); - } -} + void Builder::VisitFunctionDeclaration(ast::FunctionDeclaration& node) + { + auto location = ast::Location{}; + if (!node.name.empty() && !node.parameters.empty()) + { + location = node.span; + } + CreateFunctionSymbol(node.name, location, node.parameters, node.return_type); + } -void Builder::VisitParenthesizedExpression(ast::ParenthesizedExpression& node) { - for (auto& elem : node.elements) { - if (elem.key) VisitExpression(*elem.key.value()); - if (elem.value) VisitExpression(*elem.value); - } -} + void Builder::VisitFunctionDefinition(ast::FunctionDefinition& node) + { + auto func_id = CreateFunctionSymbol(node.name, node.location, node.parameters, node.return_type); -void Builder::VisitColumnReference(ast::ColumnReference& node) { - if (node.value) { - VisitExpression(*node.value); - } -} + if (node.body) + { + [[maybe_unused]] auto func_scope = + EnterScopeWithSymbol(ScopeKind::kFunction, func_id, node.body->span); -void Builder::VisitUnpackPattern([[maybe_unused]] ast::UnpackPattern& node) { - (void)node; -} + auto prev_function = current_function_id_; + current_function_id_ = func_id; -void Builder::VisitCompilerDirective( - [[maybe_unused]] ast::CompilerDirective& node) { - (void)node; -} + for (auto& param : node.parameters) + { + if (param) + { + CreateSymbol(param->name, protocol::SymbolKind::Variable, param->location); + } + } -void Builder::VisitConditionalDirective( - [[maybe_unused]] ast::ConditionalDirective& node) { - (void)node; -} + VisitStatements(node.body->statements); -void Builder::VisitConditionalBlock(ast::ConditionalBlock& node) { - VisitStatements(node.consequence); - VisitStatements(node.alternative); -} + current_function_id_ = prev_function; + ExitScope(); + } + } -} // namespace lsp::language::symbol + void Builder::VisitMethodDeclaration(ast::MethodDeclaration& node) + { + auto method_id = CreateMethodSymbol(node.name, node.location, node.parameters, node.return_type); + + if (node.body) + { + [[maybe_unused]] auto method_scope = + EnterScopeWithSymbol(ScopeKind::kFunction, method_id, node.body->span); + + auto prev_function = current_function_id_; + current_function_id_ = method_id; + + for (auto& param : node.parameters) + { + if (param) + { + CreateSymbol(param->name, protocol::SymbolKind::Variable, param->location); + } + } + + VisitStatements(node.body->statements); + + current_function_id_ = prev_function; + ExitScope(); + } + } + + void Builder::VisitPropertyDeclaration(ast::PropertyDeclaration& node) + { + auto type = ExtractTypeName(node.type); + CreateSymbol(node.name, protocol::SymbolKind::Property, node.location, type); + } + + void Builder::VisitExternalMethodDefinition( + ast::ExternalMethodDefinition& node) + { + std::string method_name = node.name; + size_t dot_pos = method_name.find_last_of('.'); + if (dot_pos != std::string::npos) + { + method_name = method_name.substr(dot_pos + 1); + } + + std::optional method_id; + + if (current_parent_symbol_id_) + { + const auto* scope_info = table_.scopes().scope(current_scope_id_); + if (scope_info) + { + method_id = + table_.scopes().FindSymbolInScope(current_scope_id_, method_name); + } + } + + if (!method_id) + { + method_id = CreateMethodSymbol(method_name, node.location, node.parameters, node.return_type); + } + + if (node.body) + { + [[maybe_unused]] auto method_scope = + EnterScopeWithSymbol(ScopeKind::kFunction, *method_id, node.body->span); + + auto prev_function = current_function_id_; + current_function_id_ = *method_id; + + for (auto& param : node.parameters) + { + if (param) + { + CreateSymbol(param->name, protocol::SymbolKind::Variable, param->location); + } + } + + VisitStatements(node.body->statements); + + current_function_id_ = prev_function; + ExitScope(); + } + } + + void Builder::VisitClassMember(ast::ClassMember& node) + { + std::visit( + [this](auto& member) { + if (member) + member->Accept(*this); + }, + node.member); + } + + void Builder::VisitVarDeclaration(ast::VarDeclaration& node) + { + CreateSymbol(node.name, protocol::SymbolKind::Variable, node.location, ExtractTypeName(node.type)); + + if (node.initializer) + { + VisitExpression(*node.initializer.value()); + } + } + + void Builder::VisitStaticDeclaration(ast::StaticDeclaration& node) + { + CreateSymbol(node.name, protocol::SymbolKind::Variable, node.location, ExtractTypeName(node.type)); + + if (node.initializer) + { + VisitExpression(*node.initializer.value()); + } + } + + void Builder::VisitGlobalDeclaration(ast::GlobalDeclaration& node) + { + CreateSymbol(node.name, protocol::SymbolKind::Variable, node.location, ExtractTypeName(node.type)); + + if (node.initializer) + { + VisitExpression(*node.initializer.value()); + } + } + + void Builder::VisitConstDeclaration(ast::ConstDeclaration& node) + { + CreateSymbol(node.name, protocol::SymbolKind::Constant, node.location, ExtractTypeName(node.type)); + + if (node.value) + { + VisitExpression(*node.value); + } + } + + void Builder::VisitFieldDeclaration(ast::FieldDeclaration& node) + { + CreateSymbol(node.name, protocol::SymbolKind::Field, node.location, ExtractTypeName(node.type)); + + if (node.initializer) + { + VisitExpression(*node.initializer.value()); + } + } + + void Builder::VisitUsesStatement(ast::UsesStatement& node) { (void)node; } + + void Builder::VisitCallExpression(ast::CallExpression& node) + { + if (node.callee) + { + node.callee->Accept(*this); + } + + for (auto& arg : node.arguments) + { + if (arg.value) + VisitExpression(*arg.value); + } + } + + void Builder::VisitAttributeExpression(ast::AttributeExpression& node) + { + if (node.object) + { + node.object->Accept(*this); + } + + if (node.attribute) + { + node.attribute->Accept(*this); + } + } + + void Builder::VisitAssignmentExpression(ast::AssignmentExpression& node) + { + ProcessLValue(node.left, true); + + if (node.right) + { + VisitExpression(*node.right); + } + } + + void Builder::ProcessLValue(const ast::LValue& lvalue, bool is_write) + { + (void)is_write; + std::visit( + [this](auto& val) { + using T = std::decay_t; + if constexpr (std::is_same_v>) + { + (void)val; + } + else if constexpr (std::is_same_v>) + { + if (val) + { + if (val->object) + { + val->object->Accept(*this); + } + if (val->attribute) + { + val->attribute->Accept(*this); + } + } + } + else if constexpr (std::is_same_v>) + { + if (val) + { + if (val->base) + { + val->base->Accept(*this); + } + for (auto& idx : val->indices) + { + if (idx.start) + idx.start->Accept(*this); + if (idx.end) + idx.end->Accept(*this); + if (idx.step) + idx.step->Accept(*this); + } + } + } + else if constexpr (std::is_same_v< + T, + std::unique_ptr>) + { + if (val && val->value) + { + val->value->Accept(*this); + } + } + }, + lvalue); + } + + void Builder::VisitBlockStatement(ast::BlockStatement& node) + { + [[maybe_unused]] auto block_scope = EnterScope(ScopeKind::kBlock, node.span); + VisitStatements(node.statements); + ExitScope(); + } + + void Builder::VisitIfStatement(ast::IfStatement& node) + { + for (auto& branch : node.branches) + { + if (branch.condition) + VisitExpression(*branch.condition); + + if (branch.body) + branch.body->Accept(*this); + } + + if (node.else_body && *node.else_body) + { + (*node.else_body)->Accept(*this); + } + } + + void Builder::VisitForInStatement(ast::ForInStatement& node) + { + [[maybe_unused]] auto for_scope = EnterScope(ScopeKind::kBlock, node.span); + + if (!node.key.empty()) + { + CreateSymbol(node.key, protocol::SymbolKind::Variable, node.key_location); + } + + if (!node.value.empty()) + { + CreateSymbol(node.value, protocol::SymbolKind::Variable, node.value_location); + } + + if (node.collection) + VisitExpression(*node.collection); + + if (node.body) + node.body->Accept(*this); + + ExitScope(); + } + + void Builder::VisitForToStatement(ast::ForToStatement& node) + { + [[maybe_unused]] auto for_scope = EnterScope(ScopeKind::kBlock, node.span); + + CreateSymbol(node.counter, protocol::SymbolKind::Variable, node.counter_location); + + if (node.start) + VisitExpression(*node.start); + if (node.end) + VisitExpression(*node.end); + if (node.step) + VisitExpression(*node.step); + if (node.body) + node.body->Accept(*this); + + ExitScope(); + } + + void Builder::VisitWhileStatement(ast::WhileStatement& node) + { + if (node.condition) + VisitExpression(*node.condition); + if (node.body) + node.body->Accept(*this); + } + + void Builder::VisitRepeatStatement(ast::RepeatStatement& node) + { + VisitStatements(node.body); + if (node.condition) + VisitExpression(*node.condition); + } + + void Builder::VisitCaseStatement(ast::CaseStatement& node) + { + if (node.discriminant) + VisitExpression(*node.discriminant); + + for (auto& branch : node.branches) + { + for (auto& value : branch.values) + { + if (value) + VisitExpression(*value); + } + if (branch.body) + branch.body->Accept(*this); + } + + if (node.else_body && *node.else_body) + { + (*node.else_body)->Accept(*this); + } + } + + void Builder::VisitTryStatement(ast::TryStatement& node) + { + if (node.try_body) + { + VisitStatements(node.try_body->statements); + } + + if (node.except_body) + { + VisitStatements(node.except_body->statements); + } + } + + void Builder::VisitMatrixIterationStatement( + ast::MatrixIterationStatement& node) + { + [[maybe_unused]] auto iter_scope = EnterScope(ScopeKind::kBlock, node.span); + + if (node.target) + VisitExpression(*node.target); + + if (node.body) + VisitStatements(node.body->statements); + + ExitScope(); + } + + void Builder::VisitExpressionStatement(ast::ExpressionStatement& node) + { + if (node.expression) + VisitExpression(*node.expression); + } + + void Builder::VisitBreakStatement([[maybe_unused]] ast::BreakStatement& node) + { + } + + void Builder::VisitContinueStatement([[maybe_unused]] ast::ContinueStatement& node) + { + } + + void Builder::VisitReturnStatement(ast::ReturnStatement& node) + { + if (node.value && *node.value) + VisitExpression(**node.value); + } + + void Builder::VisitTSSQLExpression([[maybe_unused]] ast::TSSQLExpression& node) {} + + void Builder::VisitLiteral([[maybe_unused]] ast::Literal& node) + { + } + + void Builder::VisitBinaryExpression(ast::BinaryExpression& node) + { + if (node.left) + VisitExpression(*node.left); + if (node.right) + VisitExpression(*node.right); + } + + void Builder::VisitTernaryExpression(ast::TernaryExpression& node) + { + if (node.condition) + VisitExpression(*node.condition); + if (node.consequence) + VisitExpression(*node.consequence); + if (node.alternative) + VisitExpression(*node.alternative); + } + + void Builder::VisitSubscriptExpression(ast::SubscriptExpression& node) + { + if (node.base) + VisitExpression(*node.base); + + for (auto& idx : node.indices) + { + if (idx.start) + VisitExpression(*idx.start); + if (idx.end) + VisitExpression(*idx.end); + if (idx.step) + VisitExpression(*idx.step); + } + } + + void Builder::VisitArrayExpression(ast::ArrayExpression& node) + { + for (auto& elem : node.elements) + { + if (elem.key) + VisitExpression(*elem.key.value()); + if (elem.value) + VisitExpression(*elem.value); + } + } + + void Builder::VisitAnonymousFunctionExpression( + ast::AnonymousFunctionExpression& node) + { + auto func_id = CreateFunctionSymbol("", node.span, node.parameters, node.return_type); + + if (node.body) + { + [[maybe_unused]] auto func_scope = EnterScopeWithSymbol( + ScopeKind::kAnonymousFunction, func_id, node.body->span); + + auto prev_function = current_function_id_; + current_function_id_ = func_id; + + for (auto& param : node.parameters) + { + if (param) + { + CreateSymbol(param->name, protocol::SymbolKind::Variable, param->location); + } + } + + VisitStatements(node.body->statements); + + current_function_id_ = prev_function; + ExitScope(); + } + } + + void Builder::VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitPrefixIncrementExpression( + ast::PrefixIncrementExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitPrefixDecrementExpression( + ast::PrefixDecrementExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitPostfixIncrementExpression( + ast::PostfixIncrementExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitPostfixDecrementExpression( + ast::PostfixDecrementExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitLogicalNotExpression(ast::LogicalNotExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitDerivativeExpression(ast::DerivativeExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitMatrixTransposeExpression( + ast::MatrixTransposeExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitExprOperatorExpression(ast::ExprOperatorExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitFunctionPointerExpression([[maybe_unused]] ast::FunctionPointerExpression& node) + { + if (node.argument) + VisitExpression(*node.argument); + } + + void Builder::VisitNewExpression(ast::NewExpression& node) + { + if (node.target) + VisitExpression(*node.target); + } + + void Builder::VisitEchoExpression(ast::EchoExpression& node) + { + for (auto& expr : node.expressions) + { + if (expr) + VisitExpression(*expr); + } + } + + void Builder::VisitRaiseExpression(ast::RaiseExpression& node) + { + if (node.exception) + VisitExpression(*node.exception); + } + + void Builder::VisitInheritedExpression(ast::InheritedExpression& node) + { + if (node.call && *node.call) + { + (*node.call)->Accept(*this); + } + } + + void Builder::VisitParenthesizedExpression(ast::ParenthesizedExpression& node) + { + for (auto& elem : node.elements) + { + if (elem.key) + VisitExpression(*elem.key.value()); + if (elem.value) + VisitExpression(*elem.value); + } + } + + void Builder::VisitColumnReference(ast::ColumnReference& node) + { + if (node.value) + { + VisitExpression(*node.value); + } + } + + void Builder::VisitUnpackPattern([[maybe_unused]] ast::UnpackPattern& node) + { + } + + void Builder::VisitCompilerDirective([[maybe_unused]] ast::CompilerDirective& node) + { + } + + void Builder::VisitConditionalDirective([[maybe_unused]] ast::ConditionalDirective& node) + { + } + + void Builder::VisitConditionalBlock(ast::ConditionalBlock& node) + { + VisitStatements(node.consequence); + VisitStatements(node.alternative); + } + + void Builder::VisitTSLXBlock([[maybe_unused]] ast::TSLXBlock& node) {} + + void Builder::VisitParameter([[maybe_unused]] ast::Parameter& node) {} + +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/builder.hpp b/lsp-server/src/language/symbol/builder.hpp index 74f0753..2eb7faa 100644 --- a/lsp-server/src/language/symbol/builder.hpp +++ b/lsp-server/src/language/symbol/builder.hpp @@ -2,157 +2,146 @@ #include "./table.hpp" -namespace lsp::language::symbol { -class Builder : public ast::ASTVisitor { - public: - explicit Builder(SymbolTable& table); +namespace lsp::language::symbol +{ + class Builder : public ast::ASTVisitor + { + public: + explicit Builder(SymbolTable& table); - void Build(ast::ASTNode& root); + void Build(ast::ASTNode& root); - // Visit statements and declarations - void VisitProgram(ast::Program& node) override; - void VisitUnitDefinition(ast::UnitDefinition& node) override; - void VisitClassDefinition(ast::ClassDefinition& node) override; - void VisitFunctionDefinition(ast::FunctionDefinition& node) override; - void VisitFunctionDeclaration(ast::FunctionDeclaration& node) override; - void VisitMethodDeclaration(ast::MethodDeclaration& node) override; - void VisitPropertyDeclaration(ast::PropertyDeclaration& node) override; - void VisitExternalMethodDefinition( - ast::ExternalMethodDefinition& node) override; - void VisitClassMember(ast::ClassMember& node) override; + // Visit statements and declarations + void VisitProgram(ast::Program& node) override; + void VisitUnitDefinition(ast::UnitDefinition& node) override; + void VisitClassDefinition(ast::ClassDefinition& node) override; + void VisitFunctionDefinition(ast::FunctionDefinition& node) override; + void VisitFunctionDeclaration(ast::FunctionDeclaration& node) override; + void VisitMethodDeclaration(ast::MethodDeclaration& node) override; + void VisitPropertyDeclaration(ast::PropertyDeclaration& node) override; + void VisitExternalMethodDefinition(ast::ExternalMethodDefinition& node) override; + void VisitClassMember(ast::ClassMember& node) override; - // Declarations - void VisitVarDeclaration(ast::VarDeclaration& node) override; - void VisitStaticDeclaration(ast::StaticDeclaration& node) override; - void VisitGlobalDeclaration(ast::GlobalDeclaration& node) override; - void VisitConstDeclaration(ast::ConstDeclaration& node) override; - void VisitFieldDeclaration(ast::FieldDeclaration& node) override; + // Declarations + void VisitVarDeclaration(ast::VarDeclaration& node) override; + void VisitStaticDeclaration(ast::StaticDeclaration& node) override; + void VisitGlobalDeclaration(ast::GlobalDeclaration& node) override; + void VisitConstDeclaration(ast::ConstDeclaration& node) override; + void VisitFieldDeclaration(ast::FieldDeclaration& node) override; - void VisitBlockStatement(ast::BlockStatement& node) override; - void VisitIfStatement(ast::IfStatement& node) override; - void VisitForInStatement(ast::ForInStatement& node) override; - void VisitForToStatement(ast::ForToStatement& node) override; - void VisitWhileStatement(ast::WhileStatement& node) override; - void VisitRepeatStatement(ast::RepeatStatement& node) override; - void VisitCaseStatement(ast::CaseStatement& node) override; - void VisitTryStatement(ast::TryStatement& node) override; + void VisitBlockStatement(ast::BlockStatement& node) override; + void VisitIfStatement(ast::IfStatement& node) override; + void VisitForInStatement(ast::ForInStatement& node) override; + void VisitForToStatement(ast::ForToStatement& node) override; + void VisitWhileStatement(ast::WhileStatement& node) override; + void VisitRepeatStatement(ast::RepeatStatement& node) override; + void VisitCaseStatement(ast::CaseStatement& node) override; + void VisitTryStatement(ast::TryStatement& node) override; - // Uses statement handling - void VisitUsesStatement(ast::UsesStatement& node) override; + // Uses statement handling + void VisitUsesStatement(ast::UsesStatement& node) override; - // Visit expressions (collect references) - void VisitIdentifier(ast::Identifier& node) override; - void VisitCallExpression(ast::CallExpression& node) override; - void VisitAttributeExpression(ast::AttributeExpression& node) override; - void VisitAssignmentExpression(ast::AssignmentExpression& node) override; + // Visit expressions (只遍历,不收集引用) + void VisitIdentifier(ast::Identifier& node) override {} + void VisitCallExpression(ast::CallExpression& node) override; + void VisitAttributeExpression(ast::AttributeExpression& node) override; + void VisitAssignmentExpression(ast::AssignmentExpression& node) override; - // Other node visits - void VisitLiteral([[maybe_unused]] ast::Literal& node) override {} - void VisitBinaryExpression(ast::BinaryExpression& node) override; - void VisitTernaryExpression(ast::TernaryExpression& node) override; - void VisitSubscriptExpression(ast::SubscriptExpression& node) override; - void VisitArrayExpression(ast::ArrayExpression& node) override; - void VisitAnonymousFunctionExpression( - ast::AnonymousFunctionExpression& node) override; + // Other node visits + void VisitLiteral(ast::Literal& node) override; + void VisitBinaryExpression(ast::BinaryExpression& node) override; + void VisitTernaryExpression(ast::TernaryExpression& node) override; + void VisitSubscriptExpression(ast::SubscriptExpression& node) override; + void VisitArrayExpression(ast::ArrayExpression& node) override; + void VisitAnonymousFunctionExpression(ast::AnonymousFunctionExpression& node) override; - // Unary expressions (refined types) - void VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) override; - void VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) override; - void VisitPrefixIncrementExpression( - ast::PrefixIncrementExpression& node) override; - void VisitPrefixDecrementExpression( - ast::PrefixDecrementExpression& node) override; - void VisitPostfixIncrementExpression( - ast::PostfixIncrementExpression& node) override; - void VisitPostfixDecrementExpression( - ast::PostfixDecrementExpression& node) override; - void VisitLogicalNotExpression(ast::LogicalNotExpression& node) override; - void VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) override; - void VisitDerivativeExpression(ast::DerivativeExpression& node) override; - void VisitMatrixTransposeExpression( - ast::MatrixTransposeExpression& node) override; - void VisitExprOperatorExpression(ast::ExprOperatorExpression& node) override; + // Unary expressions + void VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) override; + void VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) override; + void VisitPrefixIncrementExpression(ast::PrefixIncrementExpression& node) override; + void VisitPrefixDecrementExpression(ast::PrefixDecrementExpression& node) override; + void VisitPostfixIncrementExpression(ast::PostfixIncrementExpression& node) override; + void VisitPostfixDecrementExpression(ast::PostfixDecrementExpression& node) override; + void VisitLogicalNotExpression(ast::LogicalNotExpression& node) override; + void VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) override; + void VisitDerivativeExpression(ast::DerivativeExpression& node) override; + void VisitMatrixTransposeExpression(ast::MatrixTransposeExpression& node) override; + void VisitExprOperatorExpression(ast::ExprOperatorExpression& node) override; - void VisitFunctionPointerExpression( - [[maybe_unused]] ast::FunctionPointerExpression& node) override; + void VisitFunctionPointerExpression(ast::FunctionPointerExpression& node) override; - // Other expressions - void VisitNewExpression(ast::NewExpression& node) override; - void VisitEchoExpression(ast::EchoExpression& node) override; - void VisitRaiseExpression(ast::RaiseExpression& node) override; - void VisitInheritedExpression(ast::InheritedExpression& node) override; - void VisitParenthesizedExpression( - ast::ParenthesizedExpression& node) override; + // Other expressions + void VisitNewExpression(ast::NewExpression& node) override; + void VisitEchoExpression(ast::EchoExpression& node) override; + void VisitRaiseExpression(ast::RaiseExpression& node) override; + void VisitInheritedExpression(ast::InheritedExpression& node) override; + void VisitParenthesizedExpression(ast::ParenthesizedExpression& node) override; - void VisitExpressionStatement(ast::ExpressionStatement& node) override; - void VisitBreakStatement( - [[maybe_unused]] ast::BreakStatement& node) override {} - void VisitContinueStatement( - [[maybe_unused]] ast::ContinueStatement& node) override {} - void VisitReturnStatement(ast::ReturnStatement& node) override; - void VisitTSSQLExpression( - [[maybe_unused]] ast::TSSQLExpression& node) override {} - void VisitColumnReference(ast::ColumnReference& node) override; - void VisitUnpackPattern([[maybe_unused]] ast::UnpackPattern& node) override; - void VisitMatrixIterationStatement( - ast::MatrixIterationStatement& node) override; + void VisitExpressionStatement(ast::ExpressionStatement& node) override; + void VisitBreakStatement(ast::BreakStatement& node) override; + void VisitContinueStatement(ast::ContinueStatement& node) override; + void VisitReturnStatement(ast::ReturnStatement& node) override; + void VisitTSSQLExpression(ast::TSSQLExpression& node) override; + void VisitColumnReference(ast::ColumnReference& node) override; + void VisitUnpackPattern(ast::UnpackPattern& node) override; + void VisitMatrixIterationStatement(ast::MatrixIterationStatement& node) override; - // Compiler directives (correct names without Statement suffix) - void VisitCompilerDirective( - [[maybe_unused]] ast::CompilerDirective& node) override; - void VisitConditionalDirective(ast::ConditionalDirective& node) override; - void VisitConditionalBlock(ast::ConditionalBlock& node) override; - void VisitTSLXBlock([[maybe_unused]] ast::TSLXBlock& node) override {} + // Compiler directives + void VisitCompilerDirective(ast::CompilerDirective& node) override; + void VisitConditionalDirective(ast::ConditionalDirective& node) override; + void VisitConditionalBlock(ast::ConditionalBlock& node) override; + void VisitTSLXBlock(ast::TSLXBlock& node) override; - void VisitParameter([[maybe_unused]] ast::Parameter& node) override {} + void VisitParameter(ast::Parameter& node) override; - private: - SymbolId CreateFunctionSymbol( - const std::string& name, const ast::Location& location, - const std::vector>& parameters, - const std::optional& return_type); + private: + SymbolId CreateFunctionSymbol( + const std::string& name, + const ast::Location& location, + const std::vector>& parameters, + const std::optional& return_type); - SymbolId CreateMethodSymbol( - const std::string& name, const ast::Location& location, - const std::vector>& parameters, - const std::optional& return_type); + SymbolId CreateMethodSymbol( + const std::string& name, + const ast::Location& location, + const std::vector>& parameters, + const std::optional& return_type); - SymbolId CreateSymbol( - const std::string& name, SymbolKind kind, const ast::Location& location, - const std::optional& type_hint = std::nullopt); + SymbolId CreateSymbol( + const std::string& name, + protocol::SymbolKind kind, + const ast::Location& location, + const std::optional& type_hint = std::nullopt); - SymbolId CreateSymbol( - const std::string& name, SymbolKind kind, const ast::ASTNode& node, - const std::optional& type_hint = std::nullopt); + SymbolId CreateSymbol( + const std::string& name, + protocol::SymbolKind kind, + const ast::ASTNode& node, + const std::optional& type_hint = std::nullopt); - // Scope management - ScopeId EnterScopeWithSymbol(ScopeKind kind, SymbolId symbol_id, - const ast::Location& range); - ScopeId EnterScope(ScopeKind kind, const ast::Location& range); - void ExitScope(); + // Scope management + ScopeId EnterScopeWithSymbol(ScopeKind kind, SymbolId symbol_id, const ast::Location& range); + ScopeId EnterScope(ScopeKind kind, const ast::Location& range); + void ExitScope(); - // Traversal helpers - void VisitStatements(const std::vector& statements); - void VisitExpression(ast::Expression& expr); + // Traversal helpers + void VisitStatements(const std::vector& statements); + void VisitExpression(ast::Expression& expr); - // LValue processing for assignments - void ProcessLValue(const ast::LValue& lvalue, bool is_write); + // Type and parameter extraction + std::optional ExtractTypeName(const std::optional& type) const; - // Type and parameter extraction - std::optional ExtractTypeName( - const std::optional& type) const; + std::vector BuildParameters(const std::vector>& parameters) const; + void ProcessLValue(const ast::LValue& lvalue, bool is_write); - std::vector BuildParameters( - const std::vector>& parameters) const; + private: + SymbolTable& table_; - private: - SymbolTable& table_; + ScopeId current_scope_id_; + std::optional current_parent_symbol_id_; + std::optional current_function_id_; - ScopeId current_scope_id_; - std::optional current_parent_symbol_id_; - std::optional current_function_id_; + bool in_interface_section_; + }; - bool in_interface_section_; -}; - -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/graph/call.cpp b/lsp-server/src/language/symbol/graph/call.cpp deleted file mode 100644 index 6b13a54..0000000 --- a/lsp-server/src/language/symbol/graph/call.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include "call.hpp" - -#include - -namespace lsp::language::symbol::graph { - -void Call::OnSymbolRemoved(SymbolId id) { - callers_map_.erase(id); - callees_map_.erase(id); - - for (auto& [_, calls] : callers_map_) { - calls.erase(std::remove_if(calls.begin(), calls.end(), - [id](const symbol::Call& call) { - return call.caller == id; - }), - calls.end()); - } - - for (auto& [_, calls] : callees_map_) { - calls.erase(std::remove_if(calls.begin(), calls.end(), - [id](const symbol::Call& call) { - return call.callee == id; - }), - calls.end()); - } -} - -void Call::Clear() { - callers_map_.clear(); - callees_map_.clear(); -} - -void Call::AddCall(SymbolId caller, SymbolId callee, - const ast::Location& location) { - symbol::Call call{caller, callee, location}; - callers_map_[callee].push_back(call); - callees_map_[caller].push_back(call); -} - -const std::vector& Call::callers(SymbolId id) const { - static const std::vector kEmpty; - auto it = callers_map_.find(id); - return it != callers_map_.end() ? it->second : kEmpty; -} - -const std::vector& Call::callees(SymbolId id) const { - static const std::vector kEmpty; - auto it = callees_map_.find(id); - return it != callees_map_.end() ? it->second : kEmpty; -} - -} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/graph/call.hpp b/lsp-server/src/language/symbol/graph/call.hpp deleted file mode 100644 index ffbbe8e..0000000 --- a/lsp-server/src/language/symbol/graph/call.hpp +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - -#include -#include - -#include "../interface.hpp" -#include "../types.hpp" - -namespace lsp::language::symbol::graph { - -class Call : public ISymbolGraph { - public: - void OnSymbolRemoved(SymbolId id) override; - void Clear() override; - - void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location); - - const std::vector& callers(SymbolId id) const; - const std::vector& callees(SymbolId id) const; - - private: - std::unordered_map> callers_map_; - std::unordered_map> callees_map_; -}; - -} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/graph/inheritance.cpp b/lsp-server/src/language/symbol/graph/inheritance.cpp deleted file mode 100644 index 19cfc74..0000000 --- a/lsp-server/src/language/symbol/graph/inheritance.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include "inheritance.hpp" - -#include - -namespace lsp::language::symbol::graph { - -void Inheritance::OnSymbolRemoved(SymbolId id) { - base_classes_.erase(id); - derived_classes_.erase(id); - - for (auto& [_, bases] : base_classes_) { - bases.erase(std::remove(bases.begin(), bases.end(), id), bases.end()); - } - - for (auto& [_, derived] : derived_classes_) { - derived.erase(std::remove(derived.begin(), derived.end(), id), - derived.end()); - } -} - -void Inheritance::Clear() { - base_classes_.clear(); - derived_classes_.clear(); -} - -void Inheritance::AddInheritance(SymbolId derived, SymbolId base) { - base_classes_[derived].push_back(base); - derived_classes_[base].push_back(derived); -} - -const std::vector& Inheritance::base_classes(SymbolId id) const { - static const std::vector kEmpty; - auto it = base_classes_.find(id); - return it != base_classes_.end() ? it->second : kEmpty; -} - -const std::vector& Inheritance::derived_classes(SymbolId id) const { - static const std::vector kEmpty; - auto it = derived_classes_.find(id); - return it != derived_classes_.end() ? it->second : kEmpty; -} - -bool Inheritance::IsSubclassOf(SymbolId derived, SymbolId base) const { - auto it = base_classes_.find(derived); - if (it == base_classes_.end()) { - return false; - } - - for (SymbolId parent : it->second) { - if (parent == base || IsSubclassOf(parent, base)) { - return true; - } - } - - return false; -} - -} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/graph/inheritance.hpp b/lsp-server/src/language/symbol/graph/inheritance.hpp deleted file mode 100644 index 826af2b..0000000 --- a/lsp-server/src/language/symbol/graph/inheritance.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include -#include - -#include "../interface.hpp" -#include "../types.hpp" - -namespace lsp::language::symbol::graph { - -class Inheritance : public ISymbolGraph { - public: - void OnSymbolRemoved(SymbolId id) override; - void Clear() override; - - void AddInheritance(SymbolId derived, SymbolId base); - - const std::vector& base_classes(SymbolId id) const; - const std::vector& derived_classes(SymbolId id) const; - bool IsSubclassOf(SymbolId derived, SymbolId base) const; - - private: - std::unordered_map> base_classes_; - std::unordered_map> derived_classes_; -}; - -} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/graph/reference.cpp b/lsp-server/src/language/symbol/graph/reference.cpp deleted file mode 100644 index d5cc7b8..0000000 --- a/lsp-server/src/language/symbol/graph/reference.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "reference.hpp" - -#include - -namespace lsp::language::symbol::graph { - -void Reference::OnSymbolRemoved(SymbolId id) { - references_.erase(id); - - for (auto& [_, refs] : references_) { - refs.erase(std::remove_if(refs.begin(), refs.end(), - [id](const symbol::Reference& ref) { - return ref.symbol_id == id; - }), - refs.end()); - } -} - -void Reference::Clear() { references_.clear(); } - -void Reference::AddReference(SymbolId symbol_id, const ast::Location& location, - bool is_definition, bool is_write) { - references_[symbol_id].push_back( - {location, symbol_id, is_definition, is_write}); -} - -const std::vector& Reference::references(SymbolId id) const { - static const std::vector kEmpty; - auto it = references_.find(id); - return it != references_.end() ? it->second : kEmpty; -} - -std::optional Reference::FindDefinitionLocation( - SymbolId id) const { - auto it = references_.find(id); - if (it != references_.end()) { - for (const auto& ref : it->second) { - if (ref.is_definition) { - return ref.location; - } - } - } - return std::nullopt; -} - -} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/graph/reference.hpp b/lsp-server/src/language/symbol/graph/reference.hpp deleted file mode 100644 index 79b0371..0000000 --- a/lsp-server/src/language/symbol/graph/reference.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include -#include -#include - -#include "../interface.hpp" -#include "../types.hpp" - -namespace lsp::language::symbol::graph { - -class Reference : public ISymbolGraph { - public: - void OnSymbolRemoved(SymbolId id) override; - void Clear() override; - - void AddReference(SymbolId symbol_id, const ast::Location& location, - bool is_definition = false, bool is_write = false); - - const std::vector& references(SymbolId id) const; - std::optional FindDefinitionLocation(SymbolId id) const; - - private: - std::unordered_map> references_; -}; - -} // namespace lsp::language::symbol::graph diff --git a/lsp-server/src/language/symbol/index/dispatcher.hpp b/lsp-server/src/language/symbol/index/dispatcher.hpp new file mode 100644 index 0000000..7e1c8ba --- /dev/null +++ b/lsp-server/src/language/symbol/index/dispatcher.hpp @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include + +#include "../interface.hpp" + +namespace lsp::language::symbol +{ + + /** + * IndexDispatcher - 索引分发器 + * + * 职责: + * - 维护所有索引的注册表 + * - 当符号表变化时,分发事件给所有索引 + * - 使用观察者模式保持索引同步 + */ + class IndexDispatcher + { + public: + /** + * 注册索引 + */ + void RegisterIndex(std::shared_ptr index) + { + indexes_.push_back(std::move(index)); + } + + /** + * 通知所有索引:符号已添加 + */ + void NotifySymbolAdded(const Symbol& symbol) const + { + for (const auto& index : indexes_) + { + index->OnSymbolAdded(symbol); + } + } + + /** + * 通知所有索引:符号已删除 + */ + void NotifySymbolRemoved(SymbolId id) const + { + for (const auto& index : indexes_) + { + index->OnSymbolRemoved(id); + } + } + + /** + * 清空所有索引 + */ + void Clear() const + { + for (const auto& index : indexes_) + { + index->Clear(); + } + } + + private: + std::vector> indexes_; + }; + +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/index/location.cpp b/lsp-server/src/language/symbol/index/location.cpp index 061d9a4..a8cedbb 100644 --- a/lsp-server/src/language/symbol/index/location.cpp +++ b/lsp-server/src/language/symbol/index/location.cpp @@ -2,67 +2,78 @@ #include -namespace lsp::language::symbol::index { +namespace lsp::language::symbol::index +{ -void Location::OnSymbolAdded(const Symbol& symbol) { - const auto& loc = symbol.selection_range(); - entries_.push_back({loc.start_offset, loc.end_offset, symbol.id()}); - needs_sort_ = true; -} - -void Location::OnSymbolRemoved(SymbolId id) { - entries_.erase( - std::remove_if(entries_.begin(), entries_.end(), - [id](const Entry& e) { return e.symbol_id == id; }), - entries_.end()); -} - -void Location::Clear() { - entries_.clear(); - needs_sort_ = false; -} - -std::optional Location::FindSymbolAt( - const ast::Location& location) const { - EnsureSorted(); - uint32_t pos = location.start_offset; - - auto it = - std::lower_bound(entries_.begin(), entries_.end(), pos, - [](const Entry& e, uint32_t p) { return e.start < p; }); - - if (it != entries_.begin()) { - --it; - } - - std::optional result; - uint32_t min_span = UINT32_MAX; - - for (; it != entries_.end() && it->start <= pos; ++it) { - if (pos >= it->start && pos < it->end) { - uint32_t span = it->end - it->start; - if (span < min_span) { - min_span = span; - result = it->symbol_id; - } + void Location::OnSymbolAdded(const Symbol& symbol) + { + const auto& loc = symbol.selection_range(); + entries_.push_back({ loc.start_offset, loc.end_offset, symbol.id() }); + needs_sort_ = true; } - } - return result; -} + void Location::OnSymbolRemoved(SymbolId id) + { + entries_.erase( + std::remove_if(entries_.begin(), entries_.end(), [id](const Entry& e) { return e.symbol_id == id; }), + entries_.end()); + } -bool Location::Entry::operator<(const Entry& other) const { - if (start != other.start) { - return start < other.start; - } - return end > other.end; -} + void Location::Clear() + { + entries_.clear(); + needs_sort_ = false; + } -void Location::EnsureSorted() const { - if (needs_sort_) { - std::sort(entries_.begin(), entries_.end()); - needs_sort_ = false; - } -} + std::optional Location::FindSymbolAt( + const ast::Location& location) const + { + EnsureSorted(); + uint32_t pos = location.start_offset; -} // namespace lsp::language::symbol::index + auto it = + std::lower_bound(entries_.begin(), entries_.end(), pos, [](const Entry& e, uint32_t p) { return e.start < p; }); + + if (it != entries_.begin()) + { + --it; + } + + std::optional result; + uint32_t min_span = UINT32_MAX; + + for (; it != entries_.end() && it->start <= pos; ++it) + { + if (pos >= it->start && pos < it->end) + { + uint32_t span = it->end - it->start; + if (span < min_span) + { + min_span = span; + result = it->symbol_id; + } + } + } + + return result; + } + + bool Location::Entry::operator<(const Entry& other) const + { + if (start != other.start) + { + return start < other.start; + } + return end > other.end; + } + + void Location::EnsureSorted() const + { + if (needs_sort_) + { + std::sort(entries_.begin(), entries_.end()); + needs_sort_ = false; + } + } + +} // namespace lsp::language::symbol::index diff --git a/lsp-server/src/language/symbol/index/location.hpp b/lsp-server/src/language/symbol/index/location.hpp index 9f5cfe9..27a37fb 100644 --- a/lsp-server/src/language/symbol/index/location.hpp +++ b/lsp-server/src/language/symbol/index/location.hpp @@ -6,29 +6,32 @@ #include "../interface.hpp" #include "../types.hpp" -namespace lsp::language::symbol::index { +namespace lsp::language::symbol::index +{ -class Location : public ISymbolIndex { - public: - void OnSymbolAdded(const Symbol& symbol) override; - void OnSymbolRemoved(SymbolId id) override; - void Clear() override; + class Location : public ISymbolIndex + { + public: + void OnSymbolAdded(const Symbol& symbol) override; + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; - std::optional FindSymbolAt(const ast::Location& location) const; + std::optional FindSymbolAt(const ast::Location& location) const; - private: - struct Entry { - uint32_t start; - uint32_t end; - SymbolId symbol_id; + private: + struct Entry + { + uint32_t start; + uint32_t end; + SymbolId symbol_id; - bool operator<(const Entry& other) const; - }; + bool operator<(const Entry& other) const; + }; - void EnsureSorted() const; + void EnsureSorted() const; - mutable std::vector entries_; - mutable bool needs_sort_ = false; -}; + mutable std::vector entries_; + mutable bool needs_sort_ = false; + }; -} // namespace lsp::language::symbol::index +} // namespace lsp::language::symbol::index diff --git a/lsp-server/src/language/symbol/index/scope.cpp b/lsp-server/src/language/symbol/index/scope.cpp index c71cc80..b1c3268 100644 --- a/lsp-server/src/language/symbol/index/scope.cpp +++ b/lsp-server/src/language/symbol/index/scope.cpp @@ -2,92 +2,110 @@ #include -namespace lsp::language::symbol::index { +namespace lsp::language::symbol::index +{ -void Scope::OnSymbolAdded(const Symbol&) {} + void Scope::OnSymbolAdded(const Symbol&) {} -void Scope::OnSymbolRemoved(SymbolId id) { - for (auto& [_, scope] : scopes_) { - for (auto it = scope.symbols.begin(); it != scope.symbols.end();) { - if (it->second == id) { - it = scope.symbols.erase(it); - } else { - ++it; - } - } - } -} - -void Scope::Clear() { - scopes_.clear(); - next_scope_id_ = 1; - global_scope_ = kInvalidScopeId; -} - -ScopeId Scope::CreateScope(ScopeKind kind, const ast::Location& range, - std::optional parent, - std::optional owner) { - ScopeId id = next_scope_id_++; - scopes_[id] = {id, kind, range, parent, owner, {}}; - - if (kind == ScopeKind::kGlobal) { - global_scope_ = id; - } - - return id; -} - -void Scope::AddSymbol(ScopeId scope_id, const std::string& name, - SymbolId symbol_id) { - auto it = scopes_.find(scope_id); - if (it != scopes_.end()) { - it->second.symbols[ToLower(name)] = symbol_id; - } -} - -std::optional Scope::FindSymbolInScope( - ScopeId scope_id, const std::string& name) const { - auto it = scopes_.find(scope_id); - if (it == scopes_.end()) { - return std::nullopt; - } - - auto sym_it = it->second.symbols.find(ToLower(name)); - return sym_it != it->second.symbols.end() ? std::optional(sym_it->second) - : std::nullopt; -} - -std::optional Scope::FindSymbolInScopeChain( - ScopeId scope_id, const std::string& name) const { - std::optional current = scope_id; - - while (current) { - if (auto result = FindSymbolInScope(*current, name)) { - return result; + void Scope::OnSymbolRemoved(SymbolId id) + { + for (auto& [_, scope] : scopes_) + { + for (auto it = scope.symbols.begin(); it != scope.symbols.end();) + { + if (it->second == id) + { + it = scope.symbols.erase(it); + } + else + { + ++it; + } + } + } } - auto it = scopes_.find(*current); - current = it != scopes_.end() ? it->second.parent : std::nullopt; - } + void Scope::Clear() + { + scopes_.clear(); + next_scope_id_ = 1; + global_scope_ = kInvalidScopeId; + } - return std::nullopt; -} + ScopeId Scope::CreateScope(ScopeKind kind, const ast::Location& range, std::optional parent, std::optional owner) + { + ScopeId id = next_scope_id_++; + scopes_[id] = { id, kind, range, parent, owner, {} }; -const symbol::Scope* Scope::scope(ScopeId id) const { - auto it = scopes_.find(id); - return it != scopes_.end() ? &it->second : nullptr; -} + if (kind == ScopeKind::kGlobal) + { + global_scope_ = id; + } -ScopeId Scope::global_scope() const { return global_scope_; } + return id; + } -const std::unordered_map& Scope::all_scopes() const { - return scopes_; -} + void Scope::AddSymbol(ScopeId scope_id, const std::string& name, SymbolId symbol_id) + { + auto it = scopes_.find(scope_id); + if (it != scopes_.end()) + { + it->second.symbols[ToLower(name)] = symbol_id; + } + } -std::string Scope::ToLower(const std::string& s) { - std::string result = s; - std::transform(result.begin(), result.end(), result.begin(), ::tolower); - return result; -} + std::optional Scope::FindSymbolInScope( + ScopeId scope_id, + const std::string& name) const + { + auto it = scopes_.find(scope_id); + if (it == scopes_.end()) + { + return std::nullopt; + } -} // namespace lsp::language::symbol::index + auto sym_it = it->second.symbols.find(ToLower(name)); + return sym_it != it->second.symbols.end() ? std::optional(sym_it->second) : std::nullopt; + } + + std::optional Scope::FindSymbolInScopeChain( + ScopeId scope_id, + const std::string& name) const + { + std::optional current = scope_id; + + while (current) + { + if (auto result = FindSymbolInScope(*current, name)) + { + return result; + } + + auto it = scopes_.find(*current); + current = it != scopes_.end() ? it->second.parent : std::nullopt; + } + + return std::nullopt; + } + + const symbol::Scope* Scope::scope(ScopeId id) const + { + auto it = scopes_.find(id); + return it != scopes_.end() ? &it->second : nullptr; + } + + ScopeId Scope::global_scope() const { return global_scope_; } + + const std::unordered_map& Scope::all_scopes() const + { + return scopes_; + } + + std::string Scope::ToLower(const std::string& s) + { + std::string result = s; + std::transform(result.begin(), result.end(), result.begin(), ::tolower); + return result; + } + +} // namespace lsp::language::symbol::index diff --git a/lsp-server/src/language/symbol/index/scope.hpp b/lsp-server/src/language/symbol/index/scope.hpp index cbdae52..eed0e69 100644 --- a/lsp-server/src/language/symbol/index/scope.hpp +++ b/lsp-server/src/language/symbol/index/scope.hpp @@ -7,46 +7,48 @@ #include "../interface.hpp" #include "../types.hpp" -namespace lsp::language::symbol { +namespace lsp::language::symbol +{ -struct Scope { - ScopeId id; - ScopeKind kind; - ast::Location range; - std::optional parent; - std::optional owner; - std::unordered_map symbols; -}; + struct Scope + { + ScopeId id; + ScopeKind kind; + ast::Location range; + std::optional parent; + std::optional owner; + std::unordered_map symbols; + }; -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol -namespace lsp::language::symbol::index { +namespace lsp::language::symbol::index +{ -class Scope : public ISymbolIndex { - public: - void OnSymbolAdded(const Symbol&) override; - void OnSymbolRemoved(SymbolId id) override; - void Clear() override; + class Scope : public ISymbolIndex + { + public: + void OnSymbolAdded(const Symbol&) override; + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; - ScopeId CreateScope(ScopeKind kind, const ast::Location& range, - std::optional parent = std::nullopt, - std::optional owner = std::nullopt); - void AddSymbol(ScopeId scope_id, const std::string& name, SymbolId symbol_id); - std::optional FindSymbolInScope(ScopeId scope_id, - const std::string& name) const; - std::optional FindSymbolInScopeChain(ScopeId scope_id, - const std::string& name) const; + ScopeId CreateScope(ScopeKind kind, const ast::Location& range, std::optional parent = std::nullopt, std::optional owner = std::nullopt); + void AddSymbol(ScopeId scope_id, const std::string& name, SymbolId symbol_id); + std::optional FindSymbolInScope(ScopeId scope_id, + const std::string& name) const; + std::optional FindSymbolInScopeChain(ScopeId scope_id, + const std::string& name) const; - const symbol::Scope* scope(ScopeId id) const; - ScopeId global_scope() const; - const std::unordered_map& all_scopes() const; + const symbol::Scope* scope(ScopeId id) const; + ScopeId global_scope() const; + const std::unordered_map& all_scopes() const; - private: - static std::string ToLower(const std::string& s); + private: + static std::string ToLower(const std::string& s); - ScopeId next_scope_id_ = 1; - ScopeId global_scope_ = kInvalidScopeId; - std::unordered_map scopes_; -}; + ScopeId next_scope_id_ = 1; + ScopeId global_scope_ = kInvalidScopeId; + std::unordered_map scopes_; + }; -} // namespace lsp::language::symbol::index +} // namespace lsp::language::symbol::index diff --git a/lsp-server/src/language/symbol/interface.hpp b/lsp-server/src/language/symbol/interface.hpp index 5554f1f..29ef135 100644 --- a/lsp-server/src/language/symbol/interface.hpp +++ b/lsp-server/src/language/symbol/interface.hpp @@ -2,21 +2,36 @@ #include "./types.hpp" -namespace lsp::language::symbol { +namespace lsp::language::symbol +{ -class ISymbolIndex { - public: - virtual ~ISymbolIndex() = default; - virtual void OnSymbolAdded(const Symbol& symbol) = 0; - virtual void OnSymbolRemoved(SymbolId id) = 0; - virtual void Clear() = 0; -}; + /** + * ISymbolIndex - 符号索引接口 + * + * 符号表的索引(如位置索引、作用域索引等)应实现此接口 + * 用于观察符号表的变化并更新索引 + */ + class ISymbolIndex + { + public: + virtual ~ISymbolIndex() = default; -class ISymbolGraph { - public: - virtual ~ISymbolGraph() = default; - virtual void OnSymbolRemoved(SymbolId id) = 0; - virtual void Clear() = 0; -}; + /** + * 当符号被添加时调用 + * @param symbol 被添加的符号 + */ + virtual void OnSymbolAdded(const Symbol& symbol) = 0; -} // namespace lsp::language::symbol + /** + * 当符号被删除时调用 + * @param id 被删除的符号ID + */ + virtual void OnSymbolRemoved(SymbolId id) = 0; + + /** + * 清空所有索引数据 + */ + virtual void Clear() = 0; + }; + +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/store.cpp b/lsp-server/src/language/symbol/store.cpp index 13d7ee0..8391c04 100644 --- a/lsp-server/src/language/symbol/store.cpp +++ b/lsp-server/src/language/symbol/store.cpp @@ -1,59 +1,69 @@ -#include "store.hpp" - #include -namespace lsp::language::symbol { +#include "./store.hpp" -SymbolId SymbolStore::Add(Symbol def) { - SymbolId id = next_id_++; - std::visit([id](auto& s) { s.id = id; }, def.mutable_data()); +namespace lsp::language::symbol +{ - auto [it, _] = definitions_.emplace(id, std::move(def)); - const auto& stored = it->second; - by_name_[stored.name()].push_back(id); - return id; -} + SymbolId SymbolStore::Add(Symbol def) + { + SymbolId id = next_id_++; + std::visit([id](auto& s) { s.id = id; }, def.mutable_data()); -bool SymbolStore::Remove(SymbolId id) { - auto it = definitions_.find(id); - if (it == definitions_.end()) { - return false; - } + auto [it, _] = definitions_.emplace(id, std::move(def)); + const auto& stored = it->second; + by_name_[stored.name()].push_back(id); + return id; + } - const std::string& name = it->second.name(); - auto& ids = by_name_[name]; - ids.erase(std::remove(ids.begin(), ids.end(), id), ids.end()); - if (ids.empty()) { - by_name_.erase(name); - } + bool SymbolStore::Remove(SymbolId id) + { + auto it = definitions_.find(id); + if (it == definitions_.end()) + { + return false; + } - definitions_.erase(it); - return true; -} + const std::string& name = it->second.name(); + auto& ids = by_name_[name]; + ids.erase(std::remove(ids.begin(), ids.end(), id), ids.end()); + if (ids.empty()) + { + by_name_.erase(name); + } -void SymbolStore::Clear() { - definitions_.clear(); - by_name_.clear(); - next_id_ = 1; -} + definitions_.erase(it); + return true; + } -const Symbol* SymbolStore::Get(SymbolId id) const { - auto it = definitions_.find(id); - return it != definitions_.end() ? &it->second : nullptr; -} + void SymbolStore::Clear() + { + definitions_.clear(); + by_name_.clear(); + next_id_ = 1; + } -std::vector> SymbolStore::GetAll() const { - std::vector> result; - result.reserve(definitions_.size()); - for (const auto& [_, def] : definitions_) { - result.push_back(std::cref(def)); - } - return result; -} + const Symbol* SymbolStore::Get(SymbolId id) const + { + auto it = definitions_.find(id); + return it != definitions_.end() ? &it->second : nullptr; + } -std::vector SymbolStore::FindByName(const std::string& name) const { - auto it = by_name_.find(name); - return it != by_name_.end() ? it->second : std::vector(); -} + std::vector> SymbolStore::GetAll() const + { + std::vector> result; + result.reserve(definitions_.size()); + for (const auto& [_, def] : definitions_) + { + result.push_back(std::cref(def)); + } + return result; + } -} // namespace lsp::language::symbol + std::vector SymbolStore::FindByName(const std::string& name) const + { + auto it = by_name_.find(name); + return it != by_name_.end() ? it->second : std::vector(); + } + +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/store.hpp b/lsp-server/src/language/symbol/store.hpp index d023cc2..8fbbb47 100644 --- a/lsp-server/src/language/symbol/store.hpp +++ b/lsp-server/src/language/symbol/store.hpp @@ -7,22 +7,24 @@ #include "./types.hpp" -namespace lsp::language::symbol { +namespace lsp::language::symbol +{ -class SymbolStore { - public: - SymbolId Add(Symbol def); - bool Remove(SymbolId id); - void Clear(); + class SymbolStore + { + public: + SymbolId Add(Symbol def); + bool Remove(SymbolId id); + void Clear(); - const Symbol* Get(SymbolId id) const; - std::vector> GetAll() const; - std::vector FindByName(const std::string& name) const; + const Symbol* Get(SymbolId id) const; + std::vector> GetAll() const; + std::vector FindByName(const std::string& name) const; - private: - SymbolId next_id_ = 1; - std::unordered_map definitions_; - std::unordered_map> by_name_; -}; + private: + SymbolId next_id_ = 1; + std::unordered_map definitions_; + std::unordered_map> by_name_; + }; -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/table.cpp b/lsp-server/src/language/symbol/table.cpp index e5fe8e1..23a0be4 100644 --- a/lsp-server/src/language/symbol/table.cpp +++ b/lsp-server/src/language/symbol/table.cpp @@ -1,104 +1,100 @@ -#include "table.hpp" +#include "./table.hpp" -#include +namespace lsp::language::symbol +{ -namespace lsp::language::symbol { + SymbolTable::SymbolTable() + { + location_index_ = std::make_shared(); + scope_index_ = std::make_shared(); -SymbolId SymbolTable::CreateSymbol(Symbol symbol) { - auto def = Symbol(std::move(symbol)); - auto id = store_.Add(def); + index_dispatcher_.RegisterIndex(location_index_); + index_dispatcher_.RegisterIndex(scope_index_); + } - location_index_.OnSymbolAdded(def); - scope_index_.OnSymbolAdded(def); + SymbolId SymbolTable::CreateSymbol(Symbol symbol) + { + auto id = store_.Add(std::move(symbol)); - return id; -} + if (const Symbol* stored = store_.Get(id)) + { + index_dispatcher_.NotifySymbolAdded(*stored); + } -bool SymbolTable::RemoveSymbol(SymbolId id) { - location_index_.OnSymbolRemoved(id); - scope_index_.OnSymbolRemoved(id); - reference_graph_.OnSymbolRemoved(id); - inheritance_graph_.OnSymbolRemoved(id); - call_graph_.OnSymbolRemoved(id); + return id; + } - return store_.Remove(id); -} + bool SymbolTable::RemoveSymbol(SymbolId id) + { + if (!store_.Remove(id)) + { + return false; + } -void SymbolTable::Clear() { - store_.Clear(); - location_index_.Clear(); - scope_index_.Clear(); - reference_graph_.Clear(); - inheritance_graph_.Clear(); - call_graph_.Clear(); -} + index_dispatcher_.NotifySymbolRemoved(id); + return true; + } -std::vector SymbolTable::FindSymbolsByName( - const std::string& name) const { - return store_.FindByName(name); -} + void SymbolTable::Clear() + { + store_.Clear(); + index_dispatcher_.Clear(); + } -std::optional SymbolTable::FindSymbolAt( - const ast::Location& location) const { - return location_index_.FindSymbolAt(location); -} + std::vector SymbolTable::FindSymbolsByName( + const std::string& name) const + { + return store_.FindByName(name); + } -const Symbol* SymbolTable::definition(SymbolId id) const { - return store_.Get(id); -} + std::optional SymbolTable::FindSymbolAt( + const ast::Location& location) const + { + return location_index_->FindSymbolAt(location); + } -std::vector> SymbolTable::all_definitions() - const { - return store_.GetAll(); -} + const Symbol* SymbolTable::definition(SymbolId id) const + { + return store_.Get(id); + } -index::Location& SymbolTable::locations() { return location_index_; } -index::Scope& SymbolTable::scopes() { return scope_index_; } + std::vector> SymbolTable::all_definitions() + const + { + return store_.GetAll(); + } -const index::Location& SymbolTable::locations() const { - return location_index_; -} + void SymbolTable::RegisterIndex(std::shared_ptr index) + { + if (!index) + { + return; + } + for (const auto& symbol_ref : store_.GetAll()) + { + index->OnSymbolAdded(symbol_ref.get()); + } + index_dispatcher_.RegisterIndex(std::move(index)); + } -const index::Scope& SymbolTable::scopes() const { return scope_index_; } + index::Location& SymbolTable::locations() { return *location_index_; } + index::Scope& SymbolTable::scopes() { return *scope_index_; } -graph::Reference& SymbolTable::references() { return reference_graph_; } -graph::Inheritance& SymbolTable::inheritance() { return inheritance_graph_; } -graph::Call& SymbolTable::calls() { return call_graph_; } + const index::Location& SymbolTable::locations() const + { + return *location_index_; + } -const graph::Reference& SymbolTable::references() const { - return reference_graph_; -} + const index::Scope& SymbolTable::scopes() const { return *scope_index_; } -const graph::Inheritance& SymbolTable::inheritance() const { - return inheritance_graph_; -} + ScopeId SymbolTable::CreateScope(ScopeKind kind, const ast::Location& range, std::optional parent, std::optional owner) + { + return scope_index_->CreateScope(kind, range, parent, owner); + } -const graph::Call& SymbolTable::calls() const { return call_graph_; } + void SymbolTable::AddSymbolToScope(ScopeId scope_id, const std::string& name, SymbolId symbol_id) + { + scope_index_->AddSymbol(scope_id, name, symbol_id); + } -ScopeId SymbolTable::CreateScope(ScopeKind kind, const ast::Location& range, - std::optional parent, - std::optional owner) { - return scope_index_.CreateScope(kind, range, parent, owner); -} - -void SymbolTable::AddSymbolToScope(ScopeId scope_id, const std::string& name, - SymbolId symbol_id) { - scope_index_.AddSymbol(scope_id, name, symbol_id); -} - -void SymbolTable::AddReference(SymbolId symbol_id, - const ast::Location& location, - bool is_definition, bool is_write) { - reference_graph_.AddReference(symbol_id, location, is_definition, is_write); -} - -void SymbolTable::AddInheritance(SymbolId derived, SymbolId base) { - inheritance_graph_.AddInheritance(derived, base); -} - -void SymbolTable::AddCall(SymbolId caller, SymbolId callee, - const ast::Location& location) { - call_graph_.AddCall(caller, callee, location); -} - -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/table.hpp b/lsp-server/src/language/symbol/table.hpp index 7df6c64..0ea572e 100644 --- a/lsp-server/src/language/symbol/table.hpp +++ b/lsp-server/src/language/symbol/table.hpp @@ -1,63 +1,49 @@ #pragma once #include +#include #include #include -#include "./graph/call.hpp" -#include "./graph/inheritance.hpp" -#include "./graph/reference.hpp" #include "./index/location.hpp" +#include "./index/dispatcher.hpp" #include "./index/scope.hpp" #include "./store.hpp" -namespace lsp::language::symbol { +namespace lsp::language::symbol +{ -class SymbolTable { - public: - SymbolTable() = default; + class SymbolTable + { + public: + SymbolTable(); - SymbolId CreateSymbol(Symbol symbol); - bool RemoveSymbol(SymbolId id); - void Clear(); + SymbolId CreateSymbol(Symbol symbol); + bool RemoveSymbol(SymbolId id); + void Clear(); - std::vector FindSymbolsByName(const std::string& name) const; - std::optional FindSymbolAt(const ast::Location& location) const; + std::vector FindSymbolsByName(const std::string& name) const; + std::optional FindSymbolAt(const ast::Location& location) const; - const Symbol* definition(SymbolId id) const; - std::vector> all_definitions() const; + const Symbol* definition(SymbolId id) const; + std::vector> all_definitions() const; - index::Location& locations(); - index::Scope& scopes(); + void RegisterIndex(std::shared_ptr index); - const index::Location& locations() const; - const index::Scope& scopes() const; + index::Location& locations(); + index::Scope& scopes(); - graph::Reference& references(); - graph::Inheritance& inheritance(); - graph::Call& calls(); + const index::Location& locations() const; + const index::Scope& scopes() const; - const graph::Reference& references() const; - const graph::Inheritance& inheritance() const; - const graph::Call& calls() const; + ScopeId CreateScope(ScopeKind kind, const ast::Location& range, std::optional parent = std::nullopt, std::optional owner = std::nullopt); + void AddSymbolToScope(ScopeId scope_id, const std::string& name, SymbolId symbol_id); - ScopeId CreateScope(ScopeKind kind, const ast::Location& range, - std::optional parent = std::nullopt, - std::optional owner = std::nullopt); - void AddSymbolToScope(ScopeId scope_id, const std::string& name, - SymbolId symbol_id); - void AddReference(SymbolId symbol_id, const ast::Location& location, - bool is_definition = false, bool is_write = false); - void AddInheritance(SymbolId derived, SymbolId base); - void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location); + private: + SymbolStore store_; + IndexDispatcher index_dispatcher_; + std::shared_ptr location_index_; + std::shared_ptr scope_index_; + }; - private: - SymbolStore store_; - index::Location location_index_; - index::Scope scope_index_; - graph::Reference reference_graph_; - graph::Inheritance inheritance_graph_; - graph::Call call_graph_; -}; - -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol diff --git a/lsp-server/src/language/symbol/types.hpp b/lsp-server/src/language/symbol/types.hpp index d5ffa18..e6aeda3 100644 --- a/lsp-server/src/language/symbol/types.hpp +++ b/lsp-server/src/language/symbol/types.hpp @@ -8,243 +8,259 @@ #include "../../protocol/protocol.hpp" #include "../ast/types.hpp" -namespace lsp::language::symbol { +namespace lsp::language::symbol +{ -// ===== Basic Types ===== + // ===== Basic Types ===== -using SymbolId = uint64_t; -constexpr SymbolId kInvalidSymbolId = 0; + using SymbolId = uint64_t; + constexpr SymbolId kInvalidSymbolId = 0; -using ScopeId = uint64_t; -constexpr ScopeId kInvalidScopeId = 0; + using ScopeId = uint64_t; + constexpr ScopeId kInvalidScopeId = 0; -using SymbolKind = protocol::SymbolKind; + enum class ScopeKind + { + kGlobal, + kUnit, + kClass, + kFunction, + kAnonymousFunction, + kBlock + }; -enum class ScopeKind { - kGlobal, - kUnit, - kClass, - kFunction, - kAnonymousFunction, - kBlock -}; + enum class VariableScope + { + kAutomatic, + kStatic, + kGlobal, + kParameter, + kField + }; -enum class VariableScope { kAutomatic, kStatic, kGlobal, kParameter, kField }; + enum class UnitVisibility + { + kInterface, + kImplementation + }; -enum class UnitVisibility { kInterface, kImplementation }; + // ===== Parameters ===== -// ===== Parameters ===== + struct Parameter + { + std::string name; + std::optional type; + std::optional default_value; + }; -struct Parameter { - std::string name; - std::optional type; - std::optional default_value; -}; + struct UnitImport + { + std::string unit_name; + ast::Location location; + }; -struct UnitImport { - std::string unit_name; - ast::Location location; -}; + // ===== Symbol Types (all fields directly expanded) ===== -// ===== Symbol Types (all fields directly expanded) ===== + struct Function + { + static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Function; -struct Function { - static constexpr SymbolKind kind = SymbolKind::Function; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; // Name identifier location (for cursor) + ast::Location range; // Full range (including modifiers, body) - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; // Name identifier location (for cursor) - ast::Location range; // Full range (including modifiers, body) + ast::Location declaration_range; + std::optional implementation_range; + std::vector parameters; + std::optional return_type; - ast::Location declaration_range; - std::optional implementation_range; - std::vector parameters; - std::optional return_type; + std::vector imports; + std::optional unit_visibility; - std::vector imports; - std::optional unit_visibility; + bool HasImplementation() const { return implementation_range.has_value(); } + }; - bool HasImplementation() const { return implementation_range.has_value(); } -}; + struct Class + { + static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Class; -struct Class { - static constexpr SymbolKind kind = SymbolKind::Class; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + std::optional unit_visibility; - std::optional unit_visibility; + std::vector base_classes; + std::vector members; - std::vector base_classes; - std::vector members; + std::vector imports; + }; - std::vector imports; -}; + struct Method + { + static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Method; -struct Method { - static constexpr SymbolKind kind = SymbolKind::Method; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + // Location information + ast::Location declaration_range; + std::optional implementation_range; - // Location information - ast::Location declaration_range; - std::optional implementation_range; + // Method-specific + ast::MethodKind method_kind = ast::MethodKind::kOrdinary; + ast::AccessModifier access = ast::AccessModifier::kPublic; + std::optional method_modifier = + ast::MethodModifier::kNone; + bool is_static = false; + std::vector parameters; + std::optional return_type; - // Method-specific - ast::MethodKind method_kind = ast::MethodKind::kOrdinary; - ast::AccessModifier access = ast::AccessModifier::kPublic; - std::optional method_modifier = - ast::MethodModifier::kNone; - bool is_static = false; - std::vector parameters; - std::optional return_type; + std::vector imports; - std::vector imports; + bool HasImplementation() const { return implementation_range.has_value(); } + }; - bool HasImplementation() const { return implementation_range.has_value(); } -}; + struct Property + { + static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Property; -struct Property { - static constexpr SymbolKind kind = SymbolKind::Property; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + // Property-specific + ast::AccessModifier access = ast::AccessModifier::kPublic; + std::optional type; + std::optional getter; + std::optional setter; + }; - // Property-specific - ast::AccessModifier access = ast::AccessModifier::kPublic; - std::optional type; - std::optional getter; - std::optional setter; -}; + struct Field + { + static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Field; -struct Field { - static constexpr SymbolKind kind = SymbolKind::Field; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + ast::AccessModifier access = ast::AccessModifier::kPublic; + std::optional reference_modifier; + std::optional type; + bool is_static = false; + }; - ast::AccessModifier access = ast::AccessModifier::kPublic; - std::optional reference_modifier; - std::optional type; - bool is_static = false; -}; + struct Variable + { + static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Variable; -struct Variable { - static constexpr SymbolKind kind = SymbolKind::Variable; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + std::optional type; + std::optional reference_modifier; + VariableScope storage = VariableScope::kAutomatic; + std::optional unit_visibility; + bool has_initializer = false; + }; - std::optional type; - std::optional reference_modifier; - VariableScope storage = VariableScope::kAutomatic; - std::optional unit_visibility; - bool has_initializer = false; -}; + struct Constant + { + static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Constant; -struct Constant { - static constexpr SymbolKind kind = SymbolKind::Constant; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + std::optional type; + std::string value; + }; - std::optional type; - std::string value; -}; + struct Unit + { + static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Namespace; -struct Unit { - static constexpr SymbolKind kind = SymbolKind::Namespace; + SymbolId id = kInvalidSymbolId; + std::string name; + ast::Location selection_range; + ast::Location range; - SymbolId id = kInvalidSymbolId; - std::string name; - ast::Location selection_range; - ast::Location range; + std::vector interface_imports; + std::vector implementation_imports; + }; - std::vector interface_imports; - std::vector implementation_imports; -}; + // ===== Symbol Data Variant ===== -struct Reference { - ast::Location location; - SymbolId symbol_id; - bool is_definition; - bool is_write; -}; + using SymbolData = std::variant; -struct Call { - SymbolId caller; - SymbolId callee; - ast::Location call_site; -}; + // ===== Symbol ===== -// ===== Symbol Data Variant ===== + class Symbol + { + public: + explicit Symbol(SymbolData data) : data_(std::move(data)) {} -using SymbolData = std::variant; + // Type checking and conversion + template + bool Is() const + { + return std::holds_alternative(data_); + } -// ===== Symbol ===== + template + const T* As() const + { + return std::get_if(&data_); + } -class Symbol { - public: - explicit Symbol(SymbolData data) : data_(std::move(data)) {} + template + T* As() + { + return std::get_if(&data_); + } - // Type checking and conversion - template - bool Is() const { - return std::holds_alternative(data_); - } + // Accessors (snake_case per Google style) + const SymbolData& data() const { return data_; } + SymbolData& mutable_data() { return data_; } - template - const T* As() const { - return std::get_if(&data_); - } + // Common accessors (all symbol types have these) + SymbolId id() const + { + return std::visit([](const auto& s) { return s.id; }, data_); + } - template - T* As() { - return std::get_if(&data_); - } + const std::string& name() const + { + return std::visit([](const auto& s) -> const auto& { return s.name; }, + data_); + } - // Accessors (snake_case per Google style) - const SymbolData& data() const { return data_; } - SymbolData& mutable_data() { return data_; } + ast::Location selection_range() const + { + return std::visit([](const auto& s) { return s.selection_range; }, data_); + } - // Common accessors (all symbol types have these) - SymbolId id() const { - return std::visit([](const auto& s) { return s.id; }, data_); - } + ast::Location range() const + { + return std::visit([](const auto& s) { return s.range; }, data_); + } - const std::string& name() const { - return std::visit([](const auto& s) -> const auto& { return s.name; }, - data_); - } + protocol::SymbolKind kind() const + { + return std::visit([](const auto& s) { return s.kind; }, data_); + } - ast::Location selection_range() const { - return std::visit([](const auto& s) { return s.selection_range; }, data_); - } + private: + SymbolData data_; + }; - ast::Location range() const { - return std::visit([](const auto& s) { return s.range; }, data_); - } - - SymbolKind kind() const { - return std::visit([](const auto& s) { return s.kind; }, data_); - } - - private: - SymbolData data_; -}; - -} // namespace lsp::language::symbol +} // namespace lsp::language::symbol diff --git a/lsp-server/test/test_symbol/CMakeLists.txt b/lsp-server/test/test_symbol/CMakeLists.txt index eedef97..3a86dae 100644 --- a/lsp-server/test/test_symbol/CMakeLists.txt +++ b/lsp-server/test/test_symbol/CMakeLists.txt @@ -27,9 +27,6 @@ set(SOURCES ../../src/language/symbol/table.cpp ../../src/language/symbol/index/location.cpp ../../src/language/symbol/index/scope.cpp - ../../src/language/symbol/graph/call.cpp - ../../src/language/symbol/graph/inheritance.cpp - ../../src/language/symbol/graph/reference.cpp ../../src/tree-sitter/scanner.c ../../src/tree-sitter/parser.c) diff --git a/lsp-server/test/test_symbol/debug_printer.cpp b/lsp-server/test/test_symbol/debug_printer.cpp index 493abd6..576684c 100644 --- a/lsp-server/test/test_symbol/debug_printer.cpp +++ b/lsp-server/test/test_symbol/debug_printer.cpp @@ -87,7 +87,6 @@ namespace lsp::language::symbol::debug PrintOptions opts; opts.show_details = true; opts.show_children = true; - opts.show_references = true; return opts; } @@ -107,27 +106,11 @@ namespace lsp::language::symbol::debug symbol_counts.clear(); scope_counts.clear(); - symbols_with_refs = 0; - total_references = 0; - max_references = 0; - most_referenced = kInvalidSymbolId; for (const auto& ref : all_defs) { const auto& sym = ref.get(); symbol_counts[sym.kind()]++; - - const auto& refs = table.references().references(sym.id()); - if (!refs.empty()) - { - symbols_with_refs++; - total_references += refs.size(); - if (refs.size() > max_references) - { - max_references = refs.size(); - most_referenced = sym.id(); - } - } } const auto& scopes = table.scopes().all_scopes(); @@ -153,15 +136,6 @@ namespace lsp::language::symbol::debug os << color(Color::Bold) << "Overview:" << color(Color::Reset) << "\n"; os << " Total Symbols: " << color(Color::Cyan) << total_symbols << color(Color::Reset) << "\n"; os << " Total Scopes: " << color(Color::Cyan) << total_scopes << color(Color::Reset) << "\n"; - os << " Total References: " << color(Color::Cyan) << total_references << color(Color::Reset) << "\n"; - os << " Symbols w/ Refs: " << color(Color::Green) << symbols_with_refs << color(Color::Reset) << "\n"; - - if (most_referenced != kInvalidSymbolId) - { - os << " Most Referenced: " << color(Color::Yellow) << "ID=" << most_referenced - << " (" << max_references << " refs)" << color(Color::Reset) << "\n"; - } - os << "\n"; os << color(Color::Bold) << "Symbol Distribution:" << color(Color::Reset) << "\n"; @@ -507,12 +481,6 @@ namespace lsp::language::symbol::debug } } - const auto& refs = table_.references().references(symbol.id()); - if (options_.show_references && !refs.empty()) - { - os << " " << Dim("[refs: " + std::to_string(refs.size()) + "]"); - } - os << "\n"; }, symbol.data()); @@ -713,135 +681,6 @@ namespace lsp::language::symbol::debug PrintScopeTree(global, os, 0); } - void DebugPrinter::PrintReferences(SymbolId id, std::ostream& os) - { - const auto& refs = table_.references().references(id); - if (refs.empty()) - return; - - PrintSubHeader("References for symbol " + std::to_string(id), os); - for (const auto& ref : refs) - { - os << " - " << FormatLocation(ref.location); - if (ref.is_definition) - os << " (def)"; - if (ref.is_write) - os << " (write)"; - os << "\n"; - } - } - - void DebugPrinter::PrintInheritance(SymbolId class_id, std::ostream& os) - { - const auto& bases = table_.inheritance().base_classes(class_id); - const auto& derived = table_.inheritance().derived_classes(class_id); - - PrintSubHeader("Inheritance for class " + std::to_string(class_id), os); - if (bases.empty() && derived.empty()) - { - os << " (no inheritance info)\n"; - return; - } - - if (!bases.empty()) - { - os << " Base classes: "; - for (size_t i = 0; i < bases.size(); ++i) - { - os << bases[i]; - if (i + 1 < bases.size()) - os << ", "; - } - os << "\n"; - } - if (!derived.empty()) - { - os << " Derived classes: "; - for (size_t i = 0; i < derived.size(); ++i) - { - os << derived[i]; - if (i + 1 < derived.size()) - os << ", "; - } - os << "\n"; - } - } - - void DebugPrinter::PrintCallGraph(SymbolId function_id, std::ostream& os) - { - const auto& incoming = table_.calls().callers(function_id); - const auto& outgoing = table_.calls().callees(function_id); - PrintSubHeader("Call graph for " + std::to_string(function_id), os); - if (incoming.empty() && outgoing.empty()) - { - os << " (no call info)\n"; - return; - } - - if (!incoming.empty()) - { - os << " Called by: "; - for (size_t i = 0; i < incoming.size(); ++i) - { - os << incoming[i].caller; - if (i + 1 < incoming.size()) - os << ", "; - } - os << "\n"; - } - - if (!outgoing.empty()) - { - os << " Calls: "; - for (size_t i = 0; i < outgoing.size(); ++i) - { - os << outgoing[i].callee; - if (i + 1 < outgoing.size()) - os << ", "; - } - os << "\n"; - } - } - - void DebugPrinter::PrintAllReferences(std::ostream& os) - { - PrintHeader("References", os); - auto all_defs = table_.all_definitions(); - for (const auto& ref : all_defs) - { - const auto& sym = ref.get(); - PrintReferences(sym.id(), os); - } - } - - void DebugPrinter::PrintAllInheritance(std::ostream& os) - { - PrintHeader("Inheritance Graph", os); - auto all_defs = table_.all_definitions(); - for (const auto& ref : all_defs) - { - const auto& sym = ref.get(); - if (sym.kind() == SymbolKind::Class) - { - PrintInheritance(sym.id(), os); - } - } - } - - void DebugPrinter::PrintAllCalls(std::ostream& os) - { - PrintHeader("Call Graph", os); - auto all_defs = table_.all_definitions(); - for (const auto& ref : all_defs) - { - const auto& sym = ref.get(); - if (sym.kind() == SymbolKind::Function || sym.kind() == SymbolKind::Method) - { - PrintCallGraph(sym.id(), os); - } - } - } - void DebugPrinter::FindAndPrint(const std::string& name, std::ostream& os) { PrintHeader("Search: " + name, os); @@ -855,7 +694,6 @@ namespace lsp::language::symbol::debug for (auto id : matches) { PrintSymbol(id, os); - PrintReferences(id, os); } } @@ -875,7 +713,6 @@ namespace lsp::language::symbol::debug PrintHeader("Overview", os); os << "Symbols: " << stats_.total_symbols << "\n"; os << "Scopes: " << stats_.total_scopes << "\n"; - os << "Refs: " << stats_.total_references << "\n"; } void DebugPrinter::PrintStatistics(std::ostream& os) @@ -888,13 +725,6 @@ namespace lsp::language::symbol::debug PrintOverview(os); PrintSymbolList(os); PrintScopeHierarchy(os); - - if (options_.show_references) - { - PrintAllReferences(os); - } - PrintAllInheritance(os); - PrintAllCalls(os); PrintStatistics(os); } diff --git a/lsp-server/test/test_symbol/debug_printer.hpp b/lsp-server/test/test_symbol/debug_printer.hpp index 141c2ee..1f03fe7 100644 --- a/lsp-server/test/test_symbol/debug_printer.hpp +++ b/lsp-server/test/test_symbol/debug_printer.hpp @@ -4,10 +4,13 @@ #include #include #include "../../src/language/symbol/table.hpp" +#include "../../src/protocol/protocol.hpp" namespace lsp::language::symbol::debug { + using SymbolKind = protocol::SymbolKind; + // ==================== 打印选项 ==================== struct PrintOptions @@ -16,7 +19,6 @@ namespace lsp::language::symbol::debug bool show_location = true; // 显示位置信息 bool show_details = true; // 显示详细信息 bool show_children = true; // 显示子符号 - bool show_references = false; // 显示引用列表 bool compact_mode = false; // 紧凑模式 int indent_size = 2; // 缩进大小 int max_depth = -1; // 最大深度 (-1 = 无限制) @@ -33,15 +35,10 @@ namespace lsp::language::symbol::debug { size_t total_symbols = 0; size_t total_scopes = 0; - size_t total_references = 0; std::unordered_map symbol_counts; std::unordered_map scope_counts; - size_t symbols_with_refs = 0; - size_t max_references = 0; - SymbolId most_referenced = kInvalidSymbolId; - void Compute(const SymbolTable& table); void Print(std::ostream& os, bool use_color = true) const; }; @@ -69,14 +66,6 @@ namespace lsp::language::symbol::debug void PrintScopeTree(ScopeId id, std::ostream& os = std::cout, int depth = 0); void PrintScopeHierarchy(std::ostream& os = std::cout); - // ===== 关系打印 ===== - void PrintReferences(SymbolId id, std::ostream& os = std::cout); - void PrintInheritance(SymbolId class_id, std::ostream& os = std::cout); - void PrintCallGraph(SymbolId function_id, std::ostream& os = std::cout); - void PrintAllReferences(std::ostream& os = std::cout); - void PrintAllInheritance(std::ostream& os = std::cout); - void PrintAllCalls(std::ostream& os = std::cout); - // ===== 搜索和查询 ===== void FindAndPrint(const std::string& name, std::ostream& os = std::cout); void FindAtLocation(const ast::Location& loc, std::ostream& os = std::cout); diff --git a/lsp-server/test/test_symbol/test.cpp b/lsp-server/test/test_symbol/test.cpp index 4bd7a91..7a57f7d 100644 --- a/lsp-server/test/test_symbol/test.cpp +++ b/lsp-server/test/test_symbol/test.cpp @@ -115,9 +115,6 @@ struct Options bool print_all = true; bool print_definitions = false; bool print_scopes = false; - bool print_references = false; - bool print_inheritance = false; - bool print_calls = false; bool compact_mode = false; bool statistics_only = false; std::string search_symbol; @@ -135,9 +132,6 @@ void PrintUsage(const char* program_name) std::cout << " -o, --output Write output to file instead of stdout\n"; std::cout << " -d, --definitions Print only symbol definitions\n"; std::cout << " -s, --scopes Print only scope hierarchy\n"; - std::cout << " -r, --references Print only references\n"; - std::cout << " -i, --inheritance Print only inheritance graph\n"; - std::cout << " -c, --calls Print only call graph\n"; std::cout << " -C, --compact Use compact output format\n"; std::cout << " -S, --stats Print statistics only\n"; std::cout << " -O, --overview Print overview only\n"; @@ -193,21 +187,6 @@ bool ParseArguments(int argc, char* argv[], Options& options) options.print_scopes = true; any_specific_print = true; } - else if (arg == "-r" || arg == "--references") - { - options.print_references = true; - any_specific_print = true; - } - else if (arg == "-i" || arg == "--inheritance") - { - options.print_inheritance = true; - any_specific_print = true; - } - else if (arg == "-c" || arg == "--calls") - { - options.print_calls = true; - any_specific_print = true; - } else if (arg == "-C" || arg == "--compact") { options.compact_mode = true; @@ -550,15 +529,6 @@ private: EnterScope(symbol::ScopeKind::kClass, node.span, class_id); - for (const auto& parent : node.parent_classes) - { - auto it = name_index_.find(ToLower(parent.name)); - if (it != name_index_.end()) - { - table_.AddInheritance(class_id, it->second); - } - } - for (auto& member : node.members) { if (member) @@ -1013,8 +983,6 @@ void AnalyzeFile(const Options& options) if (options.no_color || !options.output_file.empty()) print_opts.use_color = false; - print_opts.show_references = options.print_references || options.verbose; - symbol::debug::DebugPrinter printer(table, print_opts); if (!options.search_symbol.empty()) @@ -1053,24 +1021,6 @@ void AnalyzeFile(const Options& options) printed_anything = true; } - if (options.print_references) - { - printer.PrintAllReferences(*out); - printed_anything = true; - } - - if (options.print_inheritance) - { - printer.PrintAllInheritance(*out); - printed_anything = true; - } - - if (options.print_calls) - { - printer.PrintAllCalls(*out); - printed_anything = true; - } - if (printed_anything) { *out << "\n";