新增语义模块

♻️ 重构符号表,职责更清晰单一

🐛 同步修复`test_symbol`
This commit is contained in:
csh 2025-11-18 23:11:40 +08:00
parent 3274af67d5
commit 4c2e242920
40 changed files with 4635 additions and 1892 deletions

View File

@ -49,11 +49,15 @@ set(SOURCES
language/symbol/builder.cpp
language/symbol/index/location.cpp
language/symbol/index/scope.cpp
language/symbol/graph/call.cpp
language/symbol/graph/inheritance.cpp
language/symbol/graph/reference.cpp
language/symbol/store.cpp
language/symbol/table.cpp
language/semantic/graph/call.cpp
language/semantic/graph/inheritance.cpp
language/semantic/graph/reference.cpp
language/semantic/semantic_model.cpp
language/semantic/type_system.cpp
language/semantic/name_resolver.cpp
language/semantic/analyzer.cpp
language/keyword/repo.cpp
provider/base/bootstrap.cpp
provider/base/interface.cpp

View File

@ -0,0 +1,555 @@
#include "./analyzer.hpp"
namespace lsp::language::semantic
{
Analyzer::Analyzer(const symbol::SymbolTable& symbol_table,
SemanticModel& semantic_model)
: symbol_table_(symbol_table),
semantic_model_(semantic_model),
current_function_id_(std::nullopt),
current_class_id_(std::nullopt)
{
}
void Analyzer::Analyze(ast::ASTNode& root)
{
root.Accept(*this);
}
// ===== Statements =====
void Analyzer::VisitProgram(ast::Program& node)
{
VisitStatements(node.statements);
}
void Analyzer::VisitUnitDefinition(ast::UnitDefinition& node)
{
// TODO: 处理 Unit 导入关系
VisitStatements(node.interface_section.statements);
VisitStatements(node.implementation_section.statements);
}
void Analyzer::VisitClassDefinition(ast::ClassDefinition& node)
{
// 查找类符号
auto class_id = ResolveIdentifier(node.name, node.location);
if (!class_id)
{
return;
}
// 设置当前类上下文
auto prev_class = current_class_id_;
current_class_id_ = class_id;
// 处理继承关系
for (const auto& base_class : node.base_classes)
{
auto base_id = ResolveIdentifier(base_class, node.location);
if (base_id)
{
semantic_model_.AddInheritance(*class_id, *base_id);
}
}
// 访问类成员
for (const auto& member : node.members)
{
member->Accept(*this);
}
// 恢复上下文
current_class_id_ = prev_class;
}
void Analyzer::VisitFunctionDefinition(ast::FunctionDefinition& node)
{
auto func_id = ResolveIdentifier(node.name, node.location);
if (!func_id)
{
return;
}
auto prev_func = current_function_id_;
current_function_id_ = func_id;
// 访问函数体
if (node.body)
{
node.body->Accept(*this);
}
current_function_id_ = prev_func;
}
void Analyzer::VisitFunctionDeclaration(ast::FunctionDeclaration& node)
{
// 函数声明不需要额外的语义分析
}
void Analyzer::VisitMethodDeclaration(ast::MethodDeclaration& node)
{
auto method_id = ResolveIdentifier(node.name, node.location);
if (!method_id)
{
return;
}
auto prev_func = current_function_id_;
current_function_id_ = method_id;
if (node.body)
{
node.body->Accept(*this);
}
current_function_id_ = prev_func;
}
void Analyzer::VisitExternalMethodDefinition(ast::ExternalMethodDefinition& node)
{
auto method_id = ResolveIdentifier(node.name, node.location);
if (!method_id)
{
return;
}
auto prev_func = current_function_id_;
current_function_id_ = method_id;
if (node.body)
{
node.body->Accept(*this);
}
current_function_id_ = prev_func;
}
// ===== Declarations =====
void Analyzer::VisitVarDeclaration(ast::VarDeclaration& node)
{
if (node.initializer)
{
VisitExpression(*node.initializer);
}
}
void Analyzer::VisitStaticDeclaration(ast::StaticDeclaration& node)
{
if (node.initializer)
{
VisitExpression(*node.initializer);
}
}
void Analyzer::VisitGlobalDeclaration(ast::GlobalDeclaration& node)
{
if (node.initializer)
{
VisitExpression(*node.initializer);
}
}
void Analyzer::VisitConstDeclaration(ast::ConstDeclaration& node)
{
VisitExpression(*node.value);
}
void Analyzer::VisitFieldDeclaration(ast::FieldDeclaration& node)
{
// 字段声明不需要额外的语义分析
}
void Analyzer::VisitClassMember(ast::ClassMember& node)
{
node.member->Accept(*this);
}
void Analyzer::VisitPropertyDeclaration(ast::PropertyDeclaration& node)
{
// TODO: 处理 property getter/setter
}
// ===== Control Flow =====
void Analyzer::VisitBlockStatement(ast::BlockStatement& node)
{
VisitStatements(node.statements);
}
void Analyzer::VisitIfStatement(ast::IfStatement& node)
{
VisitExpression(*node.condition);
node.then_branch->Accept(*this);
if (node.else_branch)
{
node.else_branch->Accept(*this);
}
}
void Analyzer::VisitForInStatement(ast::ForInStatement& node)
{
VisitExpression(*node.iterable);
node.body->Accept(*this);
}
void Analyzer::VisitForToStatement(ast::ForToStatement& node)
{
VisitExpression(*node.start);
VisitExpression(*node.end);
if (node.step)
{
VisitExpression(*node.step);
}
node.body->Accept(*this);
}
void Analyzer::VisitWhileStatement(ast::WhileStatement& node)
{
VisitExpression(*node.condition);
node.body->Accept(*this);
}
void Analyzer::VisitRepeatStatement(ast::RepeatStatement& node)
{
VisitStatements(node.statements);
VisitExpression(*node.condition);
}
void Analyzer::VisitCaseStatement(ast::CaseStatement& node)
{
VisitExpression(*node.expression);
for (const auto& branch : node.branches)
{
for (const auto& value : branch.values)
{
VisitExpression(*value);
}
branch.statement->Accept(*this);
}
if (node.else_branch)
{
node.else_branch->Accept(*this);
}
}
void Analyzer::VisitTryStatement(ast::TryStatement& node)
{
node.try_block->Accept(*this);
if (node.except_block)
{
node.except_block->Accept(*this);
}
if (node.finally_block)
{
node.finally_block->Accept(*this);
}
}
void Analyzer::VisitUsesStatement(ast::UsesStatement& node)
{
// TODO: 处理 Unit 引用
}
// ===== Expressions =====
void Analyzer::VisitIdentifier(ast::Identifier& node)
{
auto symbol_id = ResolveIdentifier(node.name, node.location);
if (symbol_id)
{
TrackReference(*symbol_id, node.location, false);
}
}
void Analyzer::VisitCallExpression(ast::CallExpression& node)
{
// 访问被调用的表达式
VisitExpression(*node.callee);
// 访问参数
for (const auto& arg : node.arguments)
{
VisitExpression(*arg);
}
// 如果 callee 是标识符,记录调用关系
if (auto* ident = dynamic_cast<ast::Identifier*>(node.callee.get()))
{
auto callee_id = ResolveIdentifier(ident->name, ident->location);
if (callee_id)
{
TrackCall(*callee_id, node.location);
}
}
}
void Analyzer::VisitAttributeExpression(ast::AttributeExpression& node)
{
VisitExpression(*node.object);
// TODO: 解析成员访问
}
void Analyzer::VisitAssignmentExpression(ast::AssignmentExpression& node)
{
// 处理左值(写引用)
ProcessLValue(node.target);
// 处理右值
VisitExpression(*node.value);
}
void Analyzer::VisitLiteral(ast::Literal& node)
{
// 字面量不需要处理
}
void Analyzer::VisitBinaryExpression(ast::BinaryExpression& node)
{
VisitExpression(*node.left);
VisitExpression(*node.right);
}
void Analyzer::VisitTernaryExpression(ast::TernaryExpression& node)
{
VisitExpression(*node.condition);
VisitExpression(*node.true_expr);
VisitExpression(*node.false_expr);
}
void Analyzer::VisitSubscriptExpression(ast::SubscriptExpression& node)
{
VisitExpression(*node.object);
for (const auto& index : node.indices)
{
VisitExpression(*index);
}
}
void Analyzer::VisitArrayExpression(ast::ArrayExpression& node)
{
for (const auto& element : node.elements)
{
VisitExpression(*element);
}
}
void Analyzer::VisitAnonymousFunctionExpression(ast::AnonymousFunctionExpression& node)
{
if (node.body)
{
node.body->Accept(*this);
}
}
// Unary expressions
void Analyzer::VisitUnaryPlusExpression(ast::UnaryPlusExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitUnaryMinusExpression(ast::UnaryMinusExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitPrefixIncrementExpression(ast::PrefixIncrementExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitPrefixDecrementExpression(ast::PrefixDecrementExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitPostfixIncrementExpression(ast::PostfixIncrementExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitPostfixDecrementExpression(ast::PostfixDecrementExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitLogicalNotExpression(ast::LogicalNotExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitBitwiseNotExpression(ast::BitwiseNotExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitDerivativeExpression(ast::DerivativeExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitMatrixTransposeExpression(ast::MatrixTransposeExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitExprOperatorExpression(ast::ExprOperatorExpression& node)
{
VisitExpression(*node.operand);
}
void Analyzer::VisitFunctionPointerExpression(ast::FunctionPointerExpression& node)
{
// TODO: 处理函数指针
}
void Analyzer::VisitNewExpression(ast::NewExpression& node)
{
// TODO: 处理对象创建
}
void Analyzer::VisitEchoExpression(ast::EchoExpression& node)
{
VisitExpression(*node.expression);
}
void Analyzer::VisitRaiseExpression(ast::RaiseExpression& node)
{
if (node.exception)
{
VisitExpression(*node.exception);
}
}
void Analyzer::VisitInheritedExpression(ast::InheritedExpression& node)
{
// TODO: 处理 inherited 调用
}
void Analyzer::VisitParenthesizedExpression(ast::ParenthesizedExpression& node)
{
VisitExpression(*node.expression);
}
void Analyzer::VisitExpressionStatement(ast::ExpressionStatement& node)
{
VisitExpression(*node.expression);
}
void Analyzer::VisitReturnStatement(ast::ReturnStatement& node)
{
if (node.value)
{
VisitExpression(*node.value);
}
}
void Analyzer::VisitColumnReference(ast::ColumnReference& node)
{
// TODO: 处理列引用
}
void Analyzer::VisitUnpackPattern(ast::UnpackPattern& node)
{
// TODO: 处理解包模式
}
void Analyzer::VisitMatrixIterationStatement(ast::MatrixIterationStatement& node)
{
VisitExpression(*node.matrix);
node.body->Accept(*this);
}
void Analyzer::VisitCompilerDirective(ast::CompilerDirective& node)
{
// 编译器指令不需要语义分析
}
void Analyzer::VisitConditionalDirective(ast::ConditionalDirective& node)
{
for (const auto& block : node.blocks)
{
block->Accept(*this);
}
}
void Analyzer::VisitConditionalBlock(ast::ConditionalBlock& node)
{
VisitStatements(node.statements);
}
// ===== Helper Methods =====
std::optional<symbol::SymbolId> Analyzer::ResolveIdentifier(
const std::string& name,
const ast::Location& location)
{
auto result = semantic_model_.name_resolver().ResolveNameAtLocation(name, location);
return result.IsResolved() ? std::optional(result.symbol_id) : std::nullopt;
}
void Analyzer::TrackReference(symbol::SymbolId symbol_id,
const ast::Location& location,
bool is_write)
{
semantic_model_.AddReference(symbol_id, location, false, is_write);
}
void Analyzer::TrackCall(symbol::SymbolId callee,
const ast::Location& location)
{
if (current_function_id_)
{
semantic_model_.AddCall(*current_function_id_, callee, location);
}
}
std::shared_ptr<Type> Analyzer::InferExpressionType(ast::Expression& expr)
{
// TODO: 实现完整的类型推断
return semantic_model_.type_system().GetUnknownType();
}
void Analyzer::VisitStatements(const std::vector<ast::StatementPtr>& statements)
{
for (const auto& stmt : statements)
{
stmt->Accept(*this);
}
}
void Analyzer::VisitExpression(ast::Expression& expr)
{
expr.Accept(*this);
}
void Analyzer::ProcessLValue(const ast::LValue& lvalue)
{
std::visit(
[this](const auto& value) {
using T = std::decay_t<decltype(value)>;
if constexpr (std::is_same_v<T, std::string>)
{
// 简单标识符
auto symbol_id = ResolveIdentifier(value, ast::Location());
if (symbol_id)
{
TrackReference(*symbol_id, ast::Location(), true);
}
}
else if constexpr (std::is_same_v<T, std::unique_ptr<ast::Expression>>)
{
// 复杂左值(如 a.b, a[i]
if (value)
{
VisitExpression(*value);
}
}
},
lvalue);
}
} // namespace lsp::language::semantic

View File

@ -0,0 +1,162 @@
#pragma once
#include "../ast/types.hpp"
#include "../symbol/table.hpp"
#include "./semantic_model.hpp"
namespace lsp::language::semantic
{
/**
* SemanticAnalyzer -
*
*
* 1. AST
* 2.
* 3.
* 4.
* 5.
*
* SemanticAnalyzer
*/
class Analyzer : public ast::ASTVisitor
{
public:
explicit Analyzer(const symbol::SymbolTable& symbol_table, SemanticModel& semantic_model);
void Analyze(ast::ASTNode& root);
// ===== Visit statements and declarations =====
void VisitProgram(ast::Program& node) override;
void VisitUnitDefinition(ast::UnitDefinition& node) override;
void VisitClassDefinition(ast::ClassDefinition& node) override;
void VisitFunctionDefinition(ast::FunctionDefinition& node) override;
void VisitFunctionDeclaration(ast::FunctionDeclaration& node) override;
void VisitMethodDeclaration(ast::MethodDeclaration& node) override;
void VisitExternalMethodDefinition(ast::ExternalMethodDefinition& node) override;
// Declarations
void VisitVarDeclaration(ast::VarDeclaration& node) override;
void VisitStaticDeclaration(ast::StaticDeclaration& node) override;
void VisitGlobalDeclaration(ast::GlobalDeclaration& node) override;
void VisitConstDeclaration(ast::ConstDeclaration& node) override;
void VisitFieldDeclaration(ast::FieldDeclaration& node) override;
void VisitClassMember(ast::ClassMember& node) override;
void VisitPropertyDeclaration(ast::PropertyDeclaration& node) override;
void VisitBlockStatement(ast::BlockStatement& node) override;
void VisitIfStatement(ast::IfStatement& node) override;
void VisitForInStatement(ast::ForInStatement& node) override;
void VisitForToStatement(ast::ForToStatement& node) override;
void VisitWhileStatement(ast::WhileStatement& node) override;
void VisitRepeatStatement(ast::RepeatStatement& node) override;
void VisitCaseStatement(ast::CaseStatement& node) override;
void VisitTryStatement(ast::TryStatement& node) override;
// Uses statement handling
void VisitUsesStatement(ast::UsesStatement& node) override;
// ===== Visit expressions (collect references) =====
void VisitIdentifier(ast::Identifier& node) override;
void VisitCallExpression(ast::CallExpression& node) override;
void VisitAttributeExpression(ast::AttributeExpression& node) override;
void VisitAssignmentExpression(ast::AssignmentExpression& node) override;
// Other expressions
void VisitLiteral(ast::Literal& node) override;
void VisitBinaryExpression(ast::BinaryExpression& node) override;
void VisitTernaryExpression(ast::TernaryExpression& node) override;
void VisitSubscriptExpression(ast::SubscriptExpression& node) override;
void VisitArrayExpression(ast::ArrayExpression& node) override;
void VisitAnonymousFunctionExpression(ast::AnonymousFunctionExpression& node) override;
// Unary expressions
void VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) override;
void VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) override;
void VisitPrefixIncrementExpression(ast::PrefixIncrementExpression& node) override;
void VisitPrefixDecrementExpression(ast::PrefixDecrementExpression& node) override;
void VisitPostfixIncrementExpression(ast::PostfixIncrementExpression& node) override;
void VisitPostfixDecrementExpression(ast::PostfixDecrementExpression& node) override;
void VisitLogicalNotExpression(ast::LogicalNotExpression& node) override;
void VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) override;
void VisitDerivativeExpression(ast::DerivativeExpression& node) override;
void VisitMatrixTransposeExpression(ast::MatrixTransposeExpression& node) override;
void VisitExprOperatorExpression(ast::ExprOperatorExpression& node) override;
void VisitFunctionPointerExpression(ast::FunctionPointerExpression& node) override;
// Other expressions
void VisitNewExpression(ast::NewExpression& node) override;
void VisitEchoExpression(ast::EchoExpression& node) override;
void VisitRaiseExpression(ast::RaiseExpression& node) override;
void VisitInheritedExpression(ast::InheritedExpression& node) override;
void VisitParenthesizedExpression(ast::ParenthesizedExpression& node) override;
void VisitExpressionStatement(ast::ExpressionStatement& node) override;
void VisitBreakStatement(ast::BreakStatement& node) override {}
void VisitContinueStatement(ast::ContinueStatement& node) override {}
void VisitReturnStatement(ast::ReturnStatement& node) override;
void VisitTSSQLExpression(ast::TSSQLExpression& node) override {}
void VisitColumnReference(ast::ColumnReference& node) override;
void VisitUnpackPattern(ast::UnpackPattern& node) override;
void VisitMatrixIterationStatement(ast::MatrixIterationStatement& node) override;
// Compiler directives
void VisitCompilerDirective(ast::CompilerDirective& node) override;
void VisitConditionalDirective(ast::ConditionalDirective& node) override;
void VisitConditionalBlock(ast::ConditionalBlock& node) override;
void VisitTSLXBlock(ast::TSLXBlock& node) override {}
void VisitParameter(ast::Parameter& node) override {}
private:
const symbol::SymbolTable& symbol_table_;
SemanticModel& semantic_model_;
// 当前上下文
std::optional<symbol::SymbolId> current_function_id_;
std::optional<symbol::SymbolId> current_class_id_;
// ===== Helper methods =====
/**
*
*/
std::optional<symbol::SymbolId> ResolveIdentifier(
const std::string& name,
const ast::Location& location);
/**
*
*/
void TrackReference(symbol::SymbolId symbol_id,
const ast::Location& location,
bool is_write = false);
/**
*
*/
void TrackCall(symbol::SymbolId callee,
const ast::Location& location);
/**
*
*/
std::shared_ptr<Type> InferExpressionType(ast::Expression& expr);
/**
*
*/
void VisitStatements(const std::vector<ast::StatementPtr>& statements);
/**
* 访
*/
void VisitExpression(ast::Expression& expr);
/**
*
*/
void ProcessLValue(const ast::LValue& lvalue);
};
} // namespace lsp::language::semantic

View File

@ -0,0 +1,59 @@
#include "call.hpp"
#include <algorithm>
namespace lsp::language::semantic::graph
{
void Call::OnSymbolRemoved(SymbolId id)
{
callers_map_.erase(id);
callees_map_.erase(id);
for (auto& [_, calls] : callers_map_)
{
calls.erase(std::remove_if(calls.begin(), calls.end(),
[id](const semantic::Call& call) {
return call.caller == id;
}),
calls.end());
}
for (auto& [_, calls] : callees_map_)
{
calls.erase(std::remove_if(calls.begin(), calls.end(),
[id](const semantic::Call& call) {
return call.callee == id;
}),
calls.end());
}
}
void Call::Clear()
{
callers_map_.clear();
callees_map_.clear();
}
void Call::AddCall(SymbolId caller, SymbolId callee, const ast::Location& location)
{
semantic::Call call{caller, callee, location};
callers_map_[callee].push_back(call);
callees_map_[caller].push_back(call);
}
const std::vector<semantic::Call>& Call::callers(SymbolId id) const
{
static const std::vector<semantic::Call> kEmpty;
auto it = callers_map_.find(id);
return it != callers_map_.end() ? it->second : kEmpty;
}
const std::vector<semantic::Call>& Call::callees(SymbolId id) const
{
static const std::vector<semantic::Call> kEmpty;
auto it = callees_map_.find(id);
return it != callees_map_.end() ? it->second : kEmpty;
}
} // namespace lsp::language::semantic::graph

View File

@ -0,0 +1,45 @@
#pragma once
#include <unordered_map>
#include <vector>
#include "../interface.hpp"
#include "../types.hpp"
namespace lsp::language::semantic::graph
{
using symbol::SymbolId;
/**
* Call -
*
* /
*/
class Call : public ISemanticGraph
{
public:
void OnSymbolRemoved(SymbolId id) override;
void Clear() override;
/**
*
*/
void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location);
/**
*
*/
const std::vector<semantic::Call>& callers(SymbolId id) const;
/**
*
*/
const std::vector<semantic::Call>& callees(SymbolId id) const;
private:
std::unordered_map<SymbolId, std::vector<semantic::Call>> callers_map_;
std::unordered_map<SymbolId, std::vector<semantic::Call>> callees_map_;
};
} // namespace lsp::language::semantic::graph

View File

@ -0,0 +1,70 @@
#include "inheritance.hpp"
#include <algorithm>
namespace lsp::language::semantic::graph
{
void Inheritance::OnSymbolRemoved(SymbolId id)
{
base_classes_.erase(id);
derived_classes_.erase(id);
for (auto& [_, bases] : base_classes_)
{
bases.erase(std::remove(bases.begin(), bases.end(), id), bases.end());
}
for (auto& [_, derived] : derived_classes_)
{
derived.erase(std::remove(derived.begin(), derived.end(), id),
derived.end());
}
}
void Inheritance::Clear()
{
base_classes_.clear();
derived_classes_.clear();
}
void Inheritance::AddInheritance(SymbolId derived, SymbolId base)
{
base_classes_[derived].push_back(base);
derived_classes_[base].push_back(derived);
}
const std::vector<SymbolId>& Inheritance::base_classes(SymbolId id) const
{
static const std::vector<SymbolId> kEmpty;
auto it = base_classes_.find(id);
return it != base_classes_.end() ? it->second : kEmpty;
}
const std::vector<SymbolId>& Inheritance::derived_classes(SymbolId id) const
{
static const std::vector<SymbolId> kEmpty;
auto it = derived_classes_.find(id);
return it != derived_classes_.end() ? it->second : kEmpty;
}
bool Inheritance::IsSubclassOf(SymbolId derived, SymbolId base) const
{
auto it = base_classes_.find(derived);
if (it == base_classes_.end())
{
return false;
}
for (SymbolId parent : it->second)
{
if (parent == base || IsSubclassOf(parent, base))
{
return true;
}
}
return false;
}
} // namespace lsp::language::semantic::graph

View File

@ -0,0 +1,50 @@
#pragma once
#include <unordered_map>
#include <vector>
#include "../interface.hpp"
#include "../../symbol/types.hpp"
namespace lsp::language::semantic::graph
{
using symbol::SymbolId;
/**
* Inheritance -
*
*
*/
class Inheritance : public ISemanticGraph
{
public:
void OnSymbolRemoved(SymbolId id) override;
void Clear() override;
/**
*
*/
void AddInheritance(SymbolId derived, SymbolId base);
/**
*
*/
const std::vector<SymbolId>& base_classes(SymbolId id) const;
/**
*
*/
const std::vector<SymbolId>& derived_classes(SymbolId id) const;
/**
* derived base
*/
bool IsSubclassOf(SymbolId derived, SymbolId base) const;
private:
std::unordered_map<SymbolId, std::vector<SymbolId>> base_classes_;
std::unordered_map<SymbolId, std::vector<SymbolId>> derived_classes_;
};
} // namespace lsp::language::semantic::graph

View File

@ -0,0 +1,55 @@
#include "reference.hpp"
#include <algorithm>
namespace lsp::language::semantic::graph
{
void Reference::OnSymbolRemoved(SymbolId id)
{
references_.erase(id);
for (auto& [_, refs] : references_)
{
refs.erase(std::remove_if(refs.begin(), refs.end(),
[id](const semantic::Reference& ref) {
return ref.symbol_id == id;
}),
refs.end());
}
}
void Reference::Clear() { references_.clear(); }
void Reference::AddReference(SymbolId symbol_id, const ast::Location& location,
bool is_definition, bool is_write)
{
references_[symbol_id].push_back(
{location, symbol_id, is_definition, is_write});
}
const std::vector<semantic::Reference>& Reference::references(SymbolId id) const
{
static const std::vector<semantic::Reference> kEmpty;
auto it = references_.find(id);
return it != references_.end() ? it->second : kEmpty;
}
std::optional<ast::Location> Reference::FindDefinitionLocation(
SymbolId id) const
{
auto it = references_.find(id);
if (it != references_.end())
{
for (const auto& ref : it->second)
{
if (ref.is_definition)
{
return ref.location;
}
}
}
return std::nullopt;
}
} // namespace lsp::language::semantic::graph

View File

@ -0,0 +1,46 @@
#pragma once
#include <optional>
#include <unordered_map>
#include <vector>
#include "../interface.hpp"
#include "../types.hpp"
namespace lsp::language::semantic::graph
{
using symbol::SymbolId;
/**
* Reference -
*
*
*/
class Reference : public ISemanticGraph
{
public:
void OnSymbolRemoved(SymbolId id) override;
void Clear() override;
/**
*
*/
void AddReference(SymbolId symbol_id, const ast::Location& location,
bool is_definition = false, bool is_write = false);
/**
*
*/
const std::vector<semantic::Reference>& references(SymbolId id) const;
/**
*
*/
std::optional<ast::Location> FindDefinitionLocation(SymbolId id) const;
private:
std::unordered_map<SymbolId, std::vector<semantic::Reference>> references_;
};
} // namespace lsp::language::semantic::graph

View File

@ -0,0 +1,31 @@
#pragma once
#include "../symbol/types.hpp"
namespace lsp::language::semantic
{
/**
* ISemanticGraph -
*
*
*
*/
class ISemanticGraph
{
public:
virtual ~ISemanticGraph() = default;
/**
*
* @param id ID
*/
virtual void OnSymbolRemoved(symbol::SymbolId id) = 0;
/**
*
*/
virtual void Clear() = 0;
};
} // namespace lsp::language::semantic

View File

@ -0,0 +1,443 @@
#include "./name_resolver.hpp"
#include <algorithm>
namespace lsp::language::semantic
{
NameResolutionResult NameResolver::ResolveName(
const std::string& name,
symbol::ScopeId scope_id,
bool search_parent) const
{
if (!search_parent)
{
// 仅在当前作用域查找
auto symbols = symbol_table_.FindSymbolsByName(name);
if (symbols.empty())
{
return NameResolutionResult::NotFound();
}
if (symbols.size() == 1)
{
return NameResolutionResult::Success(symbols[0]);
}
return NameResolutionResult::Ambiguous(std::move(symbols));
}
// 作用域链查找
auto candidates = SearchScopeChain(name, scope_id);
if (candidates.empty())
{
return NameResolutionResult::NotFound();
}
if (candidates.size() == 1)
{
return NameResolutionResult::Success(candidates[0]);
}
return NameResolutionResult::Ambiguous(std::move(candidates));
}
NameResolutionResult NameResolver::ResolveNameAtLocation(
const std::string& name,
const ast::Location& location) const
{
// 根据位置找到对应的作用域
auto scope_id = symbol_table_.scopes().FindScopeAt(location);
if (!scope_id)
{
// 如果找不到作用域,使用全局作用域
scope_id = 1; // 假设全局作用域ID为1
}
return ResolveName(name, *scope_id, true);
}
NameResolutionResult NameResolver::ResolveMemberAccess(
symbol::SymbolId object_symbol_id,
const std::string& member_name) const
{
// 获取对象的类型
auto object_type = type_system_.GetSymbolType(object_symbol_id);
if (!object_type || object_type->kind() != TypeKind::kClass)
{
return NameResolutionResult::NotFound();
}
const auto* class_type = object_type->As<ClassType>();
return ResolveClassMember(class_type->class_id(), member_name, false);
}
NameResolutionResult NameResolver::ResolveClassMember(
symbol::SymbolId class_id,
const std::string& member_name,
bool static_only) const
{
const auto* class_symbol = symbol_table_.definition(class_id);
if (!class_symbol || !class_symbol->Is<symbol::Class>())
{
return NameResolutionResult::NotFound();
}
const auto* class_data = class_symbol->As<symbol::Class>();
std::vector<symbol::SymbolId> candidates;
// 在类的成员中查找
for (auto member_id : class_data->members)
{
const auto* member = symbol_table_.definition(member_id);
if (!member)
continue;
if (member->name() != member_name)
continue;
// 如果仅查找静态成员
if (static_only)
{
if (member->Is<symbol::Method>())
{
const auto* method = member->As<symbol::Method>();
if (!method->is_static)
continue;
}
else if (member->Is<symbol::Field>())
{
const auto* field = member->As<symbol::Field>();
if (!field->is_static)
continue;
}
}
candidates.push_back(member_id);
}
// TODO: 在基类中查找(需要继承图支持)
if (candidates.empty())
{
return NameResolutionResult::NotFound();
}
if (candidates.size() == 1)
{
return NameResolutionResult::Success(candidates[0]);
}
return NameResolutionResult::Ambiguous(std::move(candidates));
}
NameResolutionResult NameResolver::ResolveFunctionCall(
const std::string& function_name,
const std::vector<std::shared_ptr<Type>>& arg_types,
symbol::ScopeId scope_id) const
{
// 查找所有同名函数
auto candidates_ids = SearchScopeChain(function_name, scope_id);
if (candidates_ids.empty())
{
return NameResolutionResult::NotFound();
}
// 如果只有一个候选,直接返回
if (candidates_ids.size() == 1)
{
return NameResolutionResult::Success(candidates_ids[0]);
}
// 计算每个候选的匹配得分
std::vector<OverloadCandidate> candidates;
for (auto candidate_id : candidates_ids)
{
auto candidate = CalculateOverloadScore(candidate_id, arg_types);
if (candidate.match_score >= 0)
{
candidates.push_back(candidate);
}
}
return SelectBestOverload(candidates);
}
NameResolutionResult NameResolver::ResolveMethodCall(
symbol::SymbolId object_symbol_id,
const std::string& method_name,
const std::vector<std::shared_ptr<Type>>& arg_types) const
{
// 先解析方法名称,获取所有候选
auto resolution = ResolveMemberAccess(object_symbol_id, method_name);
if (!resolution.IsResolved())
{
return resolution;
}
// 如果没有歧义,直接返回
if (!resolution.is_ambiguous)
{
return resolution;
}
// 执行重载解析
std::vector<OverloadCandidate> candidates;
for (auto candidate_id : resolution.candidates)
{
auto candidate = CalculateOverloadScore(candidate_id, arg_types);
if (candidate.match_score >= 0)
{
candidates.push_back(candidate);
}
}
return SelectBestOverload(candidates);
}
NameResolutionResult NameResolver::ResolveQualifiedName(
const std::vector<std::string>& qualified_name,
symbol::ScopeId scope_id) const
{
if (qualified_name.empty())
{
return NameResolutionResult::NotFound();
}
// 从第一个名称开始解析
auto current_result = ResolveName(qualified_name[0], scope_id, true);
if (!current_result.IsResolved())
{
return current_result;
}
symbol::SymbolId current_id = current_result.symbol_id;
// 依次解析后续的限定名称
for (size_t i = 1; i < qualified_name.size(); ++i)
{
const auto* current_symbol = symbol_table_.definition(current_id);
if (!current_symbol)
{
return NameResolutionResult::NotFound();
}
// 根据当前符号类型决定如何查找下一个名称
if (current_symbol->Is<symbol::Class>())
{
// 在类中查找成员
current_result = ResolveClassMember(current_id, qualified_name[i], true);
}
else if (current_symbol->Is<symbol::Unit>())
{
// 在 Unit 中查找符号TODO: 需要 Unit 作用域支持)
return NameResolutionResult::NotFound();
}
else
{
// 其他类型不支持限定名称
return NameResolutionResult::NotFound();
}
if (!current_result.IsResolved())
{
return current_result;
}
current_id = current_result.symbol_id;
}
return NameResolutionResult::Success(current_id);
}
std::optional<symbol::ScopeId> NameResolver::GetSymbolScope(
symbol::SymbolId symbol_id) const
{
// TODO: 实现符号到作用域的映射
// 目前返回 nullopt
return std::nullopt;
}
bool NameResolver::IsSymbolVisibleInScope(
symbol::SymbolId symbol_id,
symbol::ScopeId scope_id) const
{
// TODO: 实现可见性检查
// 考虑访问修饰符、作用域层次等
return true;
}
// ===== Private Methods =====
std::vector<symbol::SymbolId> NameResolver::SearchScopeChain(
const std::string& name,
symbol::ScopeId start_scope) const
{
std::vector<symbol::SymbolId> results;
symbol::ScopeId current_scope = start_scope;
while (current_scope != symbol::kInvalidScopeId)
{
// 在当前作用域中查找
auto scope_symbols = symbol_table_.scopes().FindSymbols(current_scope, name);
results.insert(results.end(), scope_symbols.begin(), scope_symbols.end());
// 如果找到了符号,停止向上查找
if (!results.empty())
{
break;
}
// 移到父作用域
auto parent = symbol_table_.scopes().GetParent(current_scope);
if (!parent)
{
break;
}
current_scope = *parent;
}
return results;
}
NameResolutionResult NameResolver::SelectBestOverload(
const std::vector<OverloadCandidate>& candidates) const
{
if (candidates.empty())
{
return NameResolutionResult::NotFound();
}
// 按匹配得分排序
auto sorted = candidates;
std::sort(sorted.begin(), sorted.end());
// 检查是否有唯一的最佳匹配
if (sorted.size() > 1 && sorted[0].match_score == sorted[1].match_score)
{
// 存在歧义
std::vector<symbol::SymbolId> ambiguous_ids;
for (const auto& candidate : sorted)
{
if (candidate.match_score == sorted[0].match_score)
{
ambiguous_ids.push_back(candidate.symbol_id);
}
}
return NameResolutionResult::Ambiguous(std::move(ambiguous_ids));
}
return NameResolutionResult::Success(sorted[0].symbol_id);
}
OverloadCandidate NameResolver::CalculateOverloadScore(
symbol::SymbolId candidate_id,
const std::vector<std::shared_ptr<Type>>& arg_types) const
{
OverloadCandidate result;
result.symbol_id = candidate_id;
// 获取候选的参数类型
auto param_types = GetParameterTypes(candidate_id);
// 参数数量不匹配
if (param_types.size() != arg_types.size())
{
result.match_score = -1;
return result;
}
// 计算每个参数的转换代价
int total_score = 0;
for (size_t i = 0; i < arg_types.size(); ++i)
{
auto compat = type_system_.CheckCompatibility(*arg_types[i], *param_types[i]);
if (!compat.is_compatible)
{
result.match_score = -1;
return result;
}
result.arg_conversions.push_back(compat);
// 计算得分:完全匹配得分最高
if (compat.conversion_cost == 0)
{
total_score += 100;
}
else
{
// 转换代价越低,得分越高
total_score += std::max(0, 50 - compat.conversion_cost * 10);
}
// 需要显式转换的降低得分
if (compat.requires_cast)
{
total_score -= 20;
}
}
result.match_score = total_score;
return result;
}
std::vector<std::shared_ptr<Type>> NameResolver::GetParameterTypes(
symbol::SymbolId symbol_id) const
{
const auto* symbol = symbol_table_.definition(symbol_id);
if (!symbol)
{
return {};
}
std::vector<std::shared_ptr<Type>> param_types;
if (symbol->Is<symbol::Function>())
{
const auto* func = symbol->As<symbol::Function>();
for (const auto& param : func->parameters)
{
if (param.type)
{
auto type = type_system_.GetTypeByName(*param.type);
param_types.push_back(type);
}
else
{
param_types.push_back(type_system_.GetUnknownType());
}
}
}
else if (symbol->Is<symbol::Method>())
{
const auto* method = symbol->As<symbol::Method>();
for (const auto& param : method->parameters)
{
if (param.type)
{
auto type = type_system_.GetTypeByName(*param.type);
param_types.push_back(type);
}
else
{
param_types.push_back(type_system_.GetUnknownType());
}
}
}
return param_types;
}
std::optional<symbol::SymbolId> NameResolver::GetOwnerClassId(
symbol::SymbolId symbol_id) const
{
// TODO: 实现符号到所属类的映射
// 需要在符号表中维护这个关系
return std::nullopt;
}
bool NameResolver::CheckMemberAccessibility(
const symbol::Symbol& member,
symbol::ScopeId access_scope) const
{
// TODO: 实现访问权限检查
// 检查 public/private/protected 访问权限
return true;
}
} // namespace lsp::language::semantic

View File

@ -0,0 +1,226 @@
#pragma once
#include <optional>
#include <string>
#include <vector>
#include "../symbol/table.hpp"
#include "./type_system.hpp"
namespace lsp::language::semantic
{
/**
* NameResolutionResult -
*/
struct NameResolutionResult
{
symbol::SymbolId symbol_id = symbol::kInvalidSymbolId;
bool is_ambiguous = false; // 是否存在歧义(多个候选)
std::vector<symbol::SymbolId> candidates; // 所有候选符号
bool IsResolved() const
{
return symbol_id != symbol::kInvalidSymbolId;
}
static NameResolutionResult Success(symbol::SymbolId id)
{
return {id, false, {id}};
}
static NameResolutionResult Ambiguous(std::vector<symbol::SymbolId> symbols)
{
return {
symbols.empty() ? symbol::kInvalidSymbolId : symbols[0],
true,
std::move(symbols)};
}
static NameResolutionResult NotFound()
{
return {symbol::kInvalidSymbolId, false, {}};
}
};
/**
* OverloadCandidate -
* /
*/
struct OverloadCandidate
{
symbol::SymbolId symbol_id;
int match_score = 0; // 匹配得分(越高越好)
std::vector<TypeCompatibility> arg_conversions; // 每个参数的转换信息
bool operator<(const OverloadCandidate& other) const
{
return match_score > other.match_score; // 降序排序
}
};
/**
* NameResolver -
*
*
* 1.
* 2.
* 3.
* 4. 访
*/
class NameResolver
{
public:
explicit NameResolver(const symbol::SymbolTable& symbol_table,
const TypeSystem& type_system)
: symbol_table_(symbol_table), type_system_(type_system) {}
// ===== 基本名称解析 =====
/**
*
* @param name
* @param scope_id ID
* @param search_parent
* @return
*/
NameResolutionResult ResolveName(
const std::string& name,
symbol::ScopeId scope_id,
bool search_parent = true) const;
/**
*
* @param name
* @param location
* @return
*/
NameResolutionResult ResolveNameAtLocation(
const std::string& name,
const ast::Location& location) const;
// ===== 成员访问解析 =====
/**
* 访 (object.member)
* @param object_symbol_id ID
* @param member_name
* @return
*/
NameResolutionResult ResolveMemberAccess(
symbol::SymbolId object_symbol_id,
const std::string& member_name) const;
/**
* 访 (Class.StaticMember)
* @param class_id ID
* @param member_name
* @param static_only
* @return
*/
NameResolutionResult ResolveClassMember(
symbol::SymbolId class_id,
const std::string& member_name,
bool static_only = false) const;
// ===== 重载解析 =====
/**
*
* @param function_name
* @param arg_types
* @param scope_id
* @return ID
*/
NameResolutionResult ResolveFunctionCall(
const std::string& function_name,
const std::vector<std::shared_ptr<Type>>& arg_types,
symbol::ScopeId scope_id) const;
/**
*
* @param object_symbol_id ID
* @param method_name
* @param arg_types
* @return ID
*/
NameResolutionResult ResolveMethodCall(
symbol::SymbolId object_symbol_id,
const std::string& method_name,
const std::vector<std::shared_ptr<Type>>& arg_types) const;
// ===== 限定名称解析 =====
/**
* (Unit.Class.Method)
* @param qualified_name
* @param scope_id
* @return
*/
NameResolutionResult ResolveQualifiedName(
const std::vector<std::string>& qualified_name,
symbol::ScopeId scope_id) const;
// ===== 辅助方法 =====
/**
*
*
*/
std::optional<symbol::ScopeId> GetSymbolScope(
symbol::SymbolId symbol_id) const;
/**
*
*/
bool IsSymbolVisibleInScope(
symbol::SymbolId symbol_id,
symbol::ScopeId scope_id) const;
private:
const symbol::SymbolTable& symbol_table_;
const TypeSystem& type_system_;
// ===== 内部辅助方法 =====
/**
*
*/
std::vector<symbol::SymbolId> SearchScopeChain(
const std::string& name,
symbol::ScopeId start_scope) const;
/**
*
*/
NameResolutionResult SelectBestOverload(
const std::vector<OverloadCandidate>& candidates) const;
/**
*
*/
OverloadCandidate CalculateOverloadScore(
symbol::SymbolId candidate_id,
const std::vector<std::shared_ptr<Type>>& arg_types) const;
/**
* /
*/
std::vector<std::shared_ptr<Type>> GetParameterTypes(
symbol::SymbolId symbol_id) const;
/**
* ID
*/
std::optional<symbol::SymbolId> GetOwnerClassId(
symbol::SymbolId symbol_id) const;
/**
* 访
*/
bool CheckMemberAccessibility(
const symbol::Symbol& member,
symbol::ScopeId access_scope) const;
};
} // namespace lsp::language::semantic

View File

@ -0,0 +1,56 @@
#include "semantic_model.hpp"
namespace lsp::language::semantic
{
SemanticModel::SemanticModel(const symbol::SymbolTable& symbol_table)
: symbol_table_(symbol_table),
reference_graph_(),
inheritance_graph_(),
call_graph_(),
type_system_(),
name_resolver_(std::make_unique<NameResolver>(symbol_table, type_system_))
{
// 设置类型系统的继承检查器
type_system_.SetInheritanceChecker(
[this](symbol::SymbolId derived, symbol::SymbolId base) {
return inheritance_graph_.IsSubclassOf(derived, base);
});
}
void SemanticModel::Clear()
{
reference_graph_.Clear();
inheritance_graph_.Clear();
call_graph_.Clear();
// Note: TypeSystem 和 NameResolver 的状态由外部管理
}
void SemanticModel::OnSymbolRemoved(symbol::SymbolId id)
{
reference_graph_.OnSymbolRemoved(id);
inheritance_graph_.OnSymbolRemoved(id);
call_graph_.OnSymbolRemoved(id);
}
void SemanticModel::AddReference(symbol::SymbolId symbol_id,
const ast::Location& location,
bool is_definition,
bool is_write)
{
reference_graph_.AddReference(symbol_id, location, is_definition, is_write);
}
void SemanticModel::AddInheritance(symbol::SymbolId derived,
symbol::SymbolId base)
{
inheritance_graph_.AddInheritance(derived, base);
}
void SemanticModel::AddCall(symbol::SymbolId caller, symbol::SymbolId callee,
const ast::Location& location)
{
call_graph_.AddCall(caller, callee, location);
}
} // namespace lsp::language::semantic

View File

@ -0,0 +1,125 @@
#pragma once
#include <memory>
#include "../symbol/table.hpp"
#include "./graph/call.hpp"
#include "./graph/inheritance.hpp"
#include "./graph/reference.hpp"
#include "./name_resolver.hpp"
#include "./type_system.hpp"
namespace lsp::language::semantic
{
/**
* SemanticModel -
*
*
* 1.
* 2.
* 3.
* 4.
*
*
*/
class SemanticModel
{
public:
explicit SemanticModel(const symbol::SymbolTable& symbol_table);
// ===== 生命周期管理 =====
/**
*
*/
void Clear();
/**
*
*/
void OnSymbolRemoved(symbol::SymbolId id);
// ===== 关系图管理 =====
/**
*
*/
void AddReference(
symbol::SymbolId symbol_id,
const ast::Location& location,
bool is_definition = false,
bool is_write = false);
/**
*
*/
void AddInheritance(symbol::SymbolId derived, symbol::SymbolId base);
/**
*
*/
void AddCall(
symbol::SymbolId caller,
symbol::SymbolId callee,
const ast::Location& location);
// ===== 访问器 =====
graph::Reference& references() { return reference_graph_; }
graph::Inheritance& inheritance() { return inheritance_graph_; }
graph::Call& calls() { return call_graph_; }
const graph::Reference& references() const { return reference_graph_; }
const graph::Inheritance& inheritance() const { return inheritance_graph_; }
const graph::Call& calls() const { return call_graph_; }
TypeSystem& type_system() { return type_system_; }
const TypeSystem& type_system() const { return type_system_; }
NameResolver& name_resolver() { return *name_resolver_; }
const NameResolver& name_resolver() const { return *name_resolver_; }
// ===== 辅助方法 =====
/**
*
*/
std::shared_ptr<Type> GetSymbolType(symbol::SymbolId symbol_id) const
{
return type_system_.GetSymbolType(symbol_id);
}
/**
*
*/
void SetSymbolType(symbol::SymbolId symbol_id, std::shared_ptr<Type> type)
{
type_system_.RegisterSymbolType(symbol_id, std::move(type));
}
/**
*
*/
bool IsSubclassOf(symbol::SymbolId derived, symbol::SymbolId base) const
{
return inheritance_graph_.IsSubclassOf(derived, base);
}
private:
// 符号表引用(只读)
const symbol::SymbolTable& symbol_table_;
// 语义关系图
graph::Reference reference_graph_;
graph::Inheritance inheritance_graph_;
graph::Call call_graph_;
// 类型系统
TypeSystem type_system_;
// 名称解析器
std::unique_ptr<NameResolver> name_resolver_;
};
} // namespace lsp::language::semantic

View File

@ -0,0 +1,550 @@
#include "./type_system.hpp"
#include <sstream>
namespace lsp::language::semantic
{
// ===== PrimitiveType Implementation =====
std::string PrimitiveType::ToString() const
{
switch (kind_)
{
case PrimitiveTypeKind::kInt:
return "int";
case PrimitiveTypeKind::kFloat:
return "float";
case PrimitiveTypeKind::kString:
return "string";
case PrimitiveTypeKind::kBool:
return "bool";
case PrimitiveTypeKind::kChar:
return "char";
default:
return "unknown";
}
}
// ===== Type Implementation =====
TypeKind Type::kind() const
{
return std::visit(
[](const auto& type_data) -> TypeKind {
using T = std::decay_t<decltype(type_data)>;
if constexpr (std::is_same_v<T, PrimitiveType>)
{
return TypeKind::kPrimitive;
}
else if constexpr (std::is_same_v<T, ClassType>)
{
return TypeKind::kClass;
}
else if constexpr (std::is_same_v<T, ArrayType>)
{
return TypeKind::kArray;
}
else if constexpr (std::is_same_v<T, FunctionType>)
{
return TypeKind::kFunction;
}
else if constexpr (std::is_same_v<T, OptionalType>)
{
return TypeKind::kOptional;
}
else if constexpr (std::is_same_v<T, VoidType>)
{
return TypeKind::kVoid;
}
else if constexpr (std::is_same_v<T, UnknownType>)
{
return TypeKind::kUnknown;
}
else if constexpr (std::is_same_v<T, ErrorType>)
{
return TypeKind::kError;
}
return TypeKind::kUnknown;
},
data_);
}
std::string Type::ToString() const
{
return std::visit(
[](const auto& type_data) -> std::string {
using T = std::decay_t<decltype(type_data)>;
if constexpr (std::is_same_v<T, PrimitiveType>)
{
return type_data.ToString();
}
else if constexpr (std::is_same_v<T, ClassType>)
{
return "class#" + std::to_string(type_data.class_id());
}
else if constexpr (std::is_same_v<T, ArrayType>)
{
return "array<" + type_data.element_type().ToString() + ">";
}
else if constexpr (std::is_same_v<T, FunctionType>)
{
std::string result = "function(";
const auto& params = type_data.param_types();
for (size_t i = 0; i < params.size(); ++i)
{
if (i > 0)
result += ", ";
result += params[i]->ToString();
}
result += ") -> " + type_data.return_type().ToString();
return result;
}
else if constexpr (std::is_same_v<T, OptionalType>)
{
return type_data.inner_type().ToString() + "?";
}
else if constexpr (std::is_same_v<T, VoidType>)
{
return "void";
}
else if constexpr (std::is_same_v<T, UnknownType>)
{
return "unknown";
}
else if constexpr (std::is_same_v<T, ErrorType>)
{
return "error";
}
return "unknown";
},
data_);
}
bool Type::Equals(const Type& other) const
{
if (kind() != other.kind())
{
return false;
}
return std::visit(
[&other](const auto& type_data) -> bool {
using T = std::decay_t<decltype(type_data)>;
const auto* other_data = other.As<T>();
if (!other_data)
{
return false;
}
if constexpr (std::is_same_v<T, PrimitiveType> ||
std::is_same_v<T, ClassType> ||
std::is_same_v<T, VoidType> ||
std::is_same_v<T, UnknownType> ||
std::is_same_v<T, ErrorType>)
{
return type_data == *other_data;
}
else if constexpr (std::is_same_v<T, ArrayType>)
{
return type_data.element_type().Equals(other_data->element_type());
}
else if constexpr (std::is_same_v<T, OptionalType>)
{
return type_data.inner_type().Equals(other_data->inner_type());
}
else if constexpr (std::is_same_v<T, FunctionType>)
{
const auto& params1 = type_data.param_types();
const auto& params2 = other_data->param_types();
if (params1.size() != params2.size())
{
return false;
}
for (size_t i = 0; i < params1.size(); ++i)
{
if (!params1[i]->Equals(*params2[i]))
{
return false;
}
}
return type_data.return_type().Equals(other_data->return_type());
}
return false;
},
data_);
}
// ===== TypeSystem Implementation =====
TypeSystem::TypeSystem()
{
// 初始化内置类型
int_type_ = std::make_shared<Type>(PrimitiveType(PrimitiveTypeKind::kInt));
float_type_ = std::make_shared<Type>(PrimitiveType(PrimitiveTypeKind::kFloat));
string_type_ = std::make_shared<Type>(PrimitiveType(PrimitiveTypeKind::kString));
bool_type_ = std::make_shared<Type>(PrimitiveType(PrimitiveTypeKind::kBool));
char_type_ = std::make_shared<Type>(PrimitiveType(PrimitiveTypeKind::kChar));
void_type_ = std::make_shared<Type>(VoidType());
unknown_type_ = std::make_shared<Type>(UnknownType());
error_type_ = std::make_shared<Type>(ErrorType());
// 注册类型名称
type_by_name_["int"] = int_type_;
type_by_name_["integer"] = int_type_;
type_by_name_["float"] = float_type_;
type_by_name_["double"] = float_type_;
type_by_name_["string"] = string_type_;
type_by_name_["bool"] = bool_type_;
type_by_name_["boolean"] = bool_type_;
type_by_name_["char"] = char_type_;
type_by_name_["void"] = void_type_;
}
std::shared_ptr<Type> TypeSystem::CreateClassType(symbol::SymbolId class_id)
{
return std::make_shared<Type>(ClassType(class_id));
}
std::shared_ptr<Type> TypeSystem::CreateArrayType(
std::shared_ptr<Type> element_type)
{
return std::make_shared<Type>(ArrayType(std::move(element_type)));
}
std::shared_ptr<Type> TypeSystem::CreateFunctionType(
std::vector<std::shared_ptr<Type>> param_types,
std::shared_ptr<Type> return_type)
{
return std::make_shared<Type>(
FunctionType(std::move(param_types), std::move(return_type)));
}
std::shared_ptr<Type> TypeSystem::CreateOptionalType(
std::shared_ptr<Type> inner_type)
{
return std::make_shared<Type>(OptionalType(std::move(inner_type)));
}
std::shared_ptr<Type> TypeSystem::GetTypeByName(const std::string& type_name)
{
auto it = type_by_name_.find(type_name);
if (it != type_by_name_.end())
{
return it->second;
}
return unknown_type_;
}
std::shared_ptr<Type> TypeSystem::GetSymbolType(symbol::SymbolId symbol_id)
{
auto it = symbol_types_.find(symbol_id);
if (it != symbol_types_.end())
{
return it->second;
}
return unknown_type_;
}
void TypeSystem::RegisterSymbolType(symbol::SymbolId symbol_id,
std::shared_ptr<Type> type)
{
symbol_types_[symbol_id] = std::move(type);
}
TypeCompatibility TypeSystem::CheckCompatibility(const Type& from,
const Type& to) const
{
// 相同类型,完全兼容
if (from.Equals(to))
{
return TypeCompatibility::Exact();
}
// 错误类型与任何类型兼容
if (from.kind() == TypeKind::kError || to.kind() == TypeKind::kError)
{
return TypeCompatibility::Exact();
}
// 未知类型与任何类型兼容(用于推断失败的情况)
if (from.kind() == TypeKind::kUnknown || to.kind() == TypeKind::kUnknown)
{
return TypeCompatibility::Exact();
}
// 基本类型之间的兼容性
if (from.kind() == TypeKind::kPrimitive && to.kind() == TypeKind::kPrimitive)
{
return CheckPrimitiveCompatibility(*from.As<PrimitiveType>(),
*to.As<PrimitiveType>());
}
// 类类型之间的兼容性(继承关系)
if (from.kind() == TypeKind::kClass && to.kind() == TypeKind::kClass)
{
return CheckClassCompatibility(*from.As<ClassType>(),
*to.As<ClassType>());
}
// 数组类型的兼容性
if (from.kind() == TypeKind::kArray && to.kind() == TypeKind::kArray)
{
return CheckArrayCompatibility(*from.As<ArrayType>(),
*to.As<ArrayType>());
}
// 函数类型的兼容性
if (from.kind() == TypeKind::kFunction && to.kind() == TypeKind::kFunction)
{
return CheckFunctionCompatibility(*from.As<FunctionType>(),
*to.As<FunctionType>());
}
// 可选类型T 可以赋值给 T?
if (to.kind() == TypeKind::kOptional)
{
const auto& inner = to.As<OptionalType>()->inner_type();
return CheckCompatibility(from, inner);
}
return TypeCompatibility::Incompatible();
}
bool TypeSystem::IsAssignable(const Type& from, const Type& to) const
{
return CheckCompatibility(from, to).is_compatible;
}
bool TypeSystem::RequiresExplicitCast(const Type& from, const Type& to) const
{
auto compat = CheckCompatibility(from, to);
return compat.is_compatible && compat.requires_cast;
}
TypeCompatibility TypeSystem::CheckPrimitiveCompatibility(
const PrimitiveType& from,
const PrimitiveType& to) const
{
using Kind = PrimitiveTypeKind;
// int -> float (隐式转换,代价为 1)
if (from.kind() == Kind::kInt && to.kind() == Kind::kFloat)
{
return TypeCompatibility::Implicit(1);
}
// float -> int (需要显式转换,可能丢失精度)
if (from.kind() == Kind::kFloat && to.kind() == Kind::kInt)
{
return TypeCompatibility::ExplicitCast(2);
}
// char -> string (隐式转换)
if (from.kind() == Kind::kChar && to.kind() == Kind::kString)
{
return TypeCompatibility::Implicit(1);
}
// 其他基本类型不兼容
return TypeCompatibility::Incompatible();
}
TypeCompatibility TypeSystem::CheckClassCompatibility(
const ClassType& from,
const ClassType& to) const
{
// 检查继承关系(如果有继承检查器)
if (is_subclass_of_ && is_subclass_of_(from.class_id(), to.class_id()))
{
// 子类可以隐式转换为父类
return TypeCompatibility::Implicit(1);
}
return TypeCompatibility::Incompatible();
}
TypeCompatibility TypeSystem::CheckArrayCompatibility(
const ArrayType& from,
const ArrayType& to) const
{
// 数组元素类型必须完全匹配(协变/逆变问题)
if (from.element_type().Equals(to.element_type()))
{
return TypeCompatibility::Exact();
}
return TypeCompatibility::Incompatible();
}
TypeCompatibility TypeSystem::CheckFunctionCompatibility(
const FunctionType& from,
const FunctionType& to) const
{
// 函数类型必须完全匹配(参数和返回值)
const auto& from_params = from.param_types();
const auto& to_params = to.param_types();
if (from_params.size() != to_params.size())
{
return TypeCompatibility::Incompatible();
}
for (size_t i = 0; i < from_params.size(); ++i)
{
if (!from_params[i]->Equals(*to_params[i]))
{
return TypeCompatibility::Incompatible();
}
}
if (!from.return_type().Equals(to.return_type()))
{
return TypeCompatibility::Incompatible();
}
return TypeCompatibility::Exact();
}
std::shared_ptr<Type> TypeSystem::InferBinaryExpressionType(
const Type& left,
const Type& right,
const std::string& op) const
{
// 算术运算符
if (op == "+" || op == "-" || op == "*" || op == "/" || op == "%")
{
// int + int -> int
if (left.kind() == TypeKind::kPrimitive &&
right.kind() == TypeKind::kPrimitive)
{
const auto* left_prim = left.As<PrimitiveType>();
const auto* right_prim = right.As<PrimitiveType>();
using Kind = PrimitiveTypeKind;
// 如果任一操作数是 float结果是 float
if (left_prim->kind() == Kind::kFloat ||
right_prim->kind() == Kind::kFloat)
{
return float_type_;
}
// 两个 int -> int
if (left_prim->kind() == Kind::kInt &&
right_prim->kind() == Kind::kInt)
{
return int_type_;
}
}
// string + string -> string (字符串连接)
if (op == "+" && left.kind() == TypeKind::kPrimitive &&
right.kind() == TypeKind::kPrimitive)
{
const auto* left_prim = left.As<PrimitiveType>();
const auto* right_prim = right.As<PrimitiveType>();
if (left_prim->kind() == PrimitiveTypeKind::kString ||
right_prim->kind() == PrimitiveTypeKind::kString)
{
return string_type_;
}
}
}
// 比较运算符
if (op == "==" || op == "!=" || op == "<" || op == ">" || op == "<=" ||
op == ">=")
{
return bool_type_;
}
// 逻辑运算符
if (op == "&&" || op == "||" || op == "and" || op == "or")
{
return bool_type_;
}
// 位运算符
if (op == "&" || op == "|" || op == "^" || op == "<<" || op == ">>")
{
return int_type_;
}
return unknown_type_;
}
std::shared_ptr<Type> TypeSystem::InferUnaryExpressionType(
const Type& operand,
const std::string& op) const
{
// 数值取负
if (op == "-" || op == "+")
{
if (operand.kind() == TypeKind::kPrimitive)
{
const auto* prim = operand.As<PrimitiveType>();
if (prim->kind() == PrimitiveTypeKind::kInt ||
prim->kind() == PrimitiveTypeKind::kFloat)
{
return std::make_shared<Type>(*prim);
}
}
}
// 逻辑取反
if (op == "!" || op == "not")
{
return bool_type_;
}
// 位取反
if (op == "~")
{
return int_type_;
}
// 自增/自减
if (op == "++" || op == "--")
{
return std::make_shared<Type>(operand);
}
return unknown_type_;
}
std::shared_ptr<Type> TypeSystem::InferLiteralType(
const std::string& literal_value) const
{
// 简单的字面量类型推断
if (literal_value.empty())
{
return unknown_type_;
}
// 布尔值
if (literal_value == "true" || literal_value == "false")
{
return bool_type_;
}
// 字符串(以引号开头)
if (literal_value[0] == '"' || literal_value[0] == '\'')
{
return string_type_;
}
// 数字
bool has_dot = literal_value.find('.') != std::string::npos;
if (has_dot)
{
return float_type_;
}
else
{
return int_type_;
}
}
} // namespace lsp::language::semantic

View File

@ -0,0 +1,406 @@
#pragma once
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <variant>
#include <vector>
#include "../symbol/types.hpp"
namespace lsp::language::semantic
{
// Forward declarations
class Type;
class TypeSystem;
// ===== Type Kinds =====
enum class TypeKind
{
kPrimitive, // 基本类型int, float, string, bool
kClass, // 类类型
kArray, // 数组类型
kFunction, // 函数类型
kOptional, // 可选类型(可为 null
kVoid, // void 类型
kUnknown, // 未知类型
kError // 错误类型(类型检查失败)
};
// ===== Primitive Type Kinds =====
enum class PrimitiveTypeKind
{
kInt,
kFloat,
kString,
kBool,
kChar
};
// ===== Type Compatibility =====
/**
*
*/
struct TypeCompatibility
{
bool is_compatible = false; // 是否兼容
int conversion_cost = -1; // 转换代价(-1 表示不兼容0 表示无需转换)
bool requires_cast = false; // 是否需要显式转换
static TypeCompatibility Exact()
{
return {true, 0, false};
}
static TypeCompatibility Implicit(int cost)
{
return {true, cost, false};
}
static TypeCompatibility ExplicitCast(int cost)
{
return {true, cost, true};
}
static TypeCompatibility Incompatible()
{
return {false, -1, false};
}
};
// ===== Type Representations =====
/**
*
*/
class PrimitiveType
{
public:
explicit PrimitiveType(PrimitiveTypeKind kind) : kind_(kind) {}
PrimitiveTypeKind kind() const { return kind_; }
std::string ToString() const;
bool operator==(const PrimitiveType& other) const
{
return kind_ == other.kind_;
}
private:
PrimitiveTypeKind kind_;
};
/**
*
*/
class ClassType
{
public:
explicit ClassType(symbol::SymbolId class_id) : class_id_(class_id) {}
symbol::SymbolId class_id() const { return class_id_; }
bool operator==(const ClassType& other) const
{
return class_id_ == other.class_id_;
}
private:
symbol::SymbolId class_id_;
};
/**
*
*/
class ArrayType
{
public:
explicit ArrayType(std::shared_ptr<Type> element_type)
: element_type_(std::move(element_type)) {}
const Type& element_type() const { return *element_type_; }
std::shared_ptr<Type> element_type_ptr() const { return element_type_; }
private:
std::shared_ptr<Type> element_type_;
};
/**
*
*/
class FunctionType
{
public:
FunctionType(std::vector<std::shared_ptr<Type>> param_types,
std::shared_ptr<Type> return_type)
: param_types_(std::move(param_types)),
return_type_(std::move(return_type)) {}
const std::vector<std::shared_ptr<Type>>& param_types() const
{
return param_types_;
}
const Type& return_type() const { return *return_type_; }
std::shared_ptr<Type> return_type_ptr() const { return return_type_; }
private:
std::vector<std::shared_ptr<Type>> param_types_;
std::shared_ptr<Type> return_type_;
};
/**
* null
*/
class OptionalType
{
public:
explicit OptionalType(std::shared_ptr<Type> inner_type)
: inner_type_(std::move(inner_type)) {}
const Type& inner_type() const { return *inner_type_; }
std::shared_ptr<Type> inner_type_ptr() const { return inner_type_; }
private:
std::shared_ptr<Type> inner_type_;
};
/**
* Void
*/
class VoidType
{
public:
bool operator==(const VoidType&) const { return true; }
};
/**
*
*/
class UnknownType
{
public:
bool operator==(const UnknownType&) const { return true; }
};
/**
*
*/
class ErrorType
{
public:
explicit ErrorType(std::string message = "")
: message_(std::move(message)) {}
const std::string& message() const { return message_; }
bool operator==(const ErrorType&) const { return true; }
private:
std::string message_;
};
// ===== Type Variant =====
using TypeData = std::variant<
PrimitiveType,
ClassType,
ArrayType,
FunctionType,
OptionalType,
VoidType,
UnknownType,
ErrorType>;
// ===== Type Class =====
/**
* Type -
*
*
*/
class Type
{
public:
explicit Type(TypeData data) : data_(std::move(data)) {}
// Type checking
template <typename T>
bool Is() const
{
return std::holds_alternative<T>(data_);
}
template <typename T>
const T* As() const
{
return std::get_if<T>(&data_);
}
template <typename T>
T* As()
{
return std::get_if<T>(&data_);
}
TypeKind kind() const;
std::string ToString() const;
// Equality comparison
bool Equals(const Type& other) const;
const TypeData& data() const { return data_; }
private:
TypeData data_;
};
// ===== Type System =====
/**
* TypeSystem -
*
*
* 1.
* 2.
* 3.
* 4.
*/
class TypeSystem
{
public:
TypeSystem();
// ===== 获取内置类型 =====
std::shared_ptr<Type> GetIntType() const { return int_type_; }
std::shared_ptr<Type> GetFloatType() const { return float_type_; }
std::shared_ptr<Type> GetStringType() const { return string_type_; }
std::shared_ptr<Type> GetBoolType() const { return bool_type_; }
std::shared_ptr<Type> GetCharType() const { return char_type_; }
std::shared_ptr<Type> GetVoidType() const { return void_type_; }
std::shared_ptr<Type> GetUnknownType() const { return unknown_type_; }
std::shared_ptr<Type> GetErrorType() const { return error_type_; }
// ===== 创建复合类型 =====
std::shared_ptr<Type> CreateClassType(symbol::SymbolId class_id);
std::shared_ptr<Type> CreateArrayType(std::shared_ptr<Type> element_type);
std::shared_ptr<Type> CreateFunctionType(
std::vector<std::shared_ptr<Type>> param_types,
std::shared_ptr<Type> return_type);
std::shared_ptr<Type> CreateOptionalType(std::shared_ptr<Type> inner_type);
// ===== 类型查询 =====
/**
*
* @param type_name "int", "string"
* @return UnknownType
*/
std::shared_ptr<Type> GetTypeByName(const std::string& type_name);
/**
* ID获取类型
* @param symbol_id ID
* @return
*/
std::shared_ptr<Type> GetSymbolType(symbol::SymbolId symbol_id);
/**
*
*/
void RegisterSymbolType(symbol::SymbolId symbol_id, std::shared_ptr<Type> type);
// ===== 类型兼容性检查 =====
/**
* from to
*/
TypeCompatibility CheckCompatibility(const Type& from, const Type& to) const;
/**
*
*/
bool IsAssignable(const Type& from, const Type& to) const;
/**
*
*/
bool RequiresExplicitCast(const Type& from, const Type& to) const;
// ===== 类型推断 =====
/**
*
*/
std::shared_ptr<Type> InferBinaryExpressionType(
const Type& left,
const Type& right,
const std::string& op) const;
/**
*
*/
std::shared_ptr<Type> InferUnaryExpressionType(
const Type& operand,
const std::string& op) const;
/**
*
*/
std::shared_ptr<Type> InferLiteralType(const std::string& literal_value) const;
// ===== 继承关系检查(需要语义模型支持)=====
void SetInheritanceChecker(
std::function<bool(symbol::SymbolId, symbol::SymbolId)> checker)
{
is_subclass_of_ = std::move(checker);
}
private:
// 内置类型
std::shared_ptr<Type> int_type_;
std::shared_ptr<Type> float_type_;
std::shared_ptr<Type> string_type_;
std::shared_ptr<Type> bool_type_;
std::shared_ptr<Type> char_type_;
std::shared_ptr<Type> void_type_;
std::shared_ptr<Type> unknown_type_;
std::shared_ptr<Type> error_type_;
// 类型名称映射
std::unordered_map<std::string, std::shared_ptr<Type>> type_by_name_;
// 符号类型映射
std::unordered_map<symbol::SymbolId, std::shared_ptr<Type>> symbol_types_;
// 继承关系检查器(由外部注入)
std::function<bool(symbol::SymbolId derived, symbol::SymbolId base)>
is_subclass_of_;
// 辅助方法
TypeCompatibility CheckPrimitiveCompatibility(
const PrimitiveType& from,
const PrimitiveType& to) const;
TypeCompatibility CheckClassCompatibility(
const ClassType& from,
const ClassType& to) const;
TypeCompatibility CheckArrayCompatibility(
const ArrayType& from,
const ArrayType& to) const;
TypeCompatibility CheckFunctionCompatibility(
const FunctionType& from,
const FunctionType& to) const;
};
} // namespace lsp::language::semantic

View File

@ -0,0 +1,44 @@
#pragma once
#include "../ast/types.hpp"
#include "../symbol/types.hpp"
namespace lsp::language::semantic
{
// ===== Semantic Relationship Types =====
/**
* Reference -
*
*/
struct Reference
{
ast::Location location; // 引用位置
symbol::SymbolId symbol_id; // 被引用的符号ID
bool is_definition; // 是否是定义
bool is_write; // 是否是写引用
};
/**
* Call - /
*
*/
struct Call
{
symbol::SymbolId caller; // 调用者符号ID
symbol::SymbolId callee; // 被调用者符号ID
ast::Location call_site; // 调用位置
};
/**
* Inheritance -
*
*/
struct Inheritance
{
symbol::SymbolId derived; // 派生类符号ID
symbol::SymbolId base; // 基类符号ID
};
} // namespace lsp::language::semantic

File diff suppressed because it is too large Load Diff

View File

@ -2,157 +2,146 @@
#include "./table.hpp"
namespace lsp::language::symbol {
class Builder : public ast::ASTVisitor {
public:
explicit Builder(SymbolTable& table);
namespace lsp::language::symbol
{
class Builder : public ast::ASTVisitor
{
public:
explicit Builder(SymbolTable& table);
void Build(ast::ASTNode& root);
void Build(ast::ASTNode& root);
// Visit statements and declarations
void VisitProgram(ast::Program& node) override;
void VisitUnitDefinition(ast::UnitDefinition& node) override;
void VisitClassDefinition(ast::ClassDefinition& node) override;
void VisitFunctionDefinition(ast::FunctionDefinition& node) override;
void VisitFunctionDeclaration(ast::FunctionDeclaration& node) override;
void VisitMethodDeclaration(ast::MethodDeclaration& node) override;
void VisitPropertyDeclaration(ast::PropertyDeclaration& node) override;
void VisitExternalMethodDefinition(
ast::ExternalMethodDefinition& node) override;
void VisitClassMember(ast::ClassMember& node) override;
// Visit statements and declarations
void VisitProgram(ast::Program& node) override;
void VisitUnitDefinition(ast::UnitDefinition& node) override;
void VisitClassDefinition(ast::ClassDefinition& node) override;
void VisitFunctionDefinition(ast::FunctionDefinition& node) override;
void VisitFunctionDeclaration(ast::FunctionDeclaration& node) override;
void VisitMethodDeclaration(ast::MethodDeclaration& node) override;
void VisitPropertyDeclaration(ast::PropertyDeclaration& node) override;
void VisitExternalMethodDefinition(ast::ExternalMethodDefinition& node) override;
void VisitClassMember(ast::ClassMember& node) override;
// Declarations
void VisitVarDeclaration(ast::VarDeclaration& node) override;
void VisitStaticDeclaration(ast::StaticDeclaration& node) override;
void VisitGlobalDeclaration(ast::GlobalDeclaration& node) override;
void VisitConstDeclaration(ast::ConstDeclaration& node) override;
void VisitFieldDeclaration(ast::FieldDeclaration& node) override;
// Declarations
void VisitVarDeclaration(ast::VarDeclaration& node) override;
void VisitStaticDeclaration(ast::StaticDeclaration& node) override;
void VisitGlobalDeclaration(ast::GlobalDeclaration& node) override;
void VisitConstDeclaration(ast::ConstDeclaration& node) override;
void VisitFieldDeclaration(ast::FieldDeclaration& node) override;
void VisitBlockStatement(ast::BlockStatement& node) override;
void VisitIfStatement(ast::IfStatement& node) override;
void VisitForInStatement(ast::ForInStatement& node) override;
void VisitForToStatement(ast::ForToStatement& node) override;
void VisitWhileStatement(ast::WhileStatement& node) override;
void VisitRepeatStatement(ast::RepeatStatement& node) override;
void VisitCaseStatement(ast::CaseStatement& node) override;
void VisitTryStatement(ast::TryStatement& node) override;
void VisitBlockStatement(ast::BlockStatement& node) override;
void VisitIfStatement(ast::IfStatement& node) override;
void VisitForInStatement(ast::ForInStatement& node) override;
void VisitForToStatement(ast::ForToStatement& node) override;
void VisitWhileStatement(ast::WhileStatement& node) override;
void VisitRepeatStatement(ast::RepeatStatement& node) override;
void VisitCaseStatement(ast::CaseStatement& node) override;
void VisitTryStatement(ast::TryStatement& node) override;
// Uses statement handling
void VisitUsesStatement(ast::UsesStatement& node) override;
// Uses statement handling
void VisitUsesStatement(ast::UsesStatement& node) override;
// Visit expressions (collect references)
void VisitIdentifier(ast::Identifier& node) override;
void VisitCallExpression(ast::CallExpression& node) override;
void VisitAttributeExpression(ast::AttributeExpression& node) override;
void VisitAssignmentExpression(ast::AssignmentExpression& node) override;
// Visit expressions (只遍历,不收集引用)
void VisitIdentifier(ast::Identifier& node) override {}
void VisitCallExpression(ast::CallExpression& node) override;
void VisitAttributeExpression(ast::AttributeExpression& node) override;
void VisitAssignmentExpression(ast::AssignmentExpression& node) override;
// Other node visits
void VisitLiteral([[maybe_unused]] ast::Literal& node) override {}
void VisitBinaryExpression(ast::BinaryExpression& node) override;
void VisitTernaryExpression(ast::TernaryExpression& node) override;
void VisitSubscriptExpression(ast::SubscriptExpression& node) override;
void VisitArrayExpression(ast::ArrayExpression& node) override;
void VisitAnonymousFunctionExpression(
ast::AnonymousFunctionExpression& node) override;
// Other node visits
void VisitLiteral(ast::Literal& node) override;
void VisitBinaryExpression(ast::BinaryExpression& node) override;
void VisitTernaryExpression(ast::TernaryExpression& node) override;
void VisitSubscriptExpression(ast::SubscriptExpression& node) override;
void VisitArrayExpression(ast::ArrayExpression& node) override;
void VisitAnonymousFunctionExpression(ast::AnonymousFunctionExpression& node) override;
// Unary expressions (refined types)
void VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) override;
void VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) override;
void VisitPrefixIncrementExpression(
ast::PrefixIncrementExpression& node) override;
void VisitPrefixDecrementExpression(
ast::PrefixDecrementExpression& node) override;
void VisitPostfixIncrementExpression(
ast::PostfixIncrementExpression& node) override;
void VisitPostfixDecrementExpression(
ast::PostfixDecrementExpression& node) override;
void VisitLogicalNotExpression(ast::LogicalNotExpression& node) override;
void VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) override;
void VisitDerivativeExpression(ast::DerivativeExpression& node) override;
void VisitMatrixTransposeExpression(
ast::MatrixTransposeExpression& node) override;
void VisitExprOperatorExpression(ast::ExprOperatorExpression& node) override;
// Unary expressions
void VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) override;
void VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) override;
void VisitPrefixIncrementExpression(ast::PrefixIncrementExpression& node) override;
void VisitPrefixDecrementExpression(ast::PrefixDecrementExpression& node) override;
void VisitPostfixIncrementExpression(ast::PostfixIncrementExpression& node) override;
void VisitPostfixDecrementExpression(ast::PostfixDecrementExpression& node) override;
void VisitLogicalNotExpression(ast::LogicalNotExpression& node) override;
void VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) override;
void VisitDerivativeExpression(ast::DerivativeExpression& node) override;
void VisitMatrixTransposeExpression(ast::MatrixTransposeExpression& node) override;
void VisitExprOperatorExpression(ast::ExprOperatorExpression& node) override;
void VisitFunctionPointerExpression(
[[maybe_unused]] ast::FunctionPointerExpression& node) override;
void VisitFunctionPointerExpression(ast::FunctionPointerExpression& node) override;
// Other expressions
void VisitNewExpression(ast::NewExpression& node) override;
void VisitEchoExpression(ast::EchoExpression& node) override;
void VisitRaiseExpression(ast::RaiseExpression& node) override;
void VisitInheritedExpression(ast::InheritedExpression& node) override;
void VisitParenthesizedExpression(
ast::ParenthesizedExpression& node) override;
// Other expressions
void VisitNewExpression(ast::NewExpression& node) override;
void VisitEchoExpression(ast::EchoExpression& node) override;
void VisitRaiseExpression(ast::RaiseExpression& node) override;
void VisitInheritedExpression(ast::InheritedExpression& node) override;
void VisitParenthesizedExpression(ast::ParenthesizedExpression& node) override;
void VisitExpressionStatement(ast::ExpressionStatement& node) override;
void VisitBreakStatement(
[[maybe_unused]] ast::BreakStatement& node) override {}
void VisitContinueStatement(
[[maybe_unused]] ast::ContinueStatement& node) override {}
void VisitReturnStatement(ast::ReturnStatement& node) override;
void VisitTSSQLExpression(
[[maybe_unused]] ast::TSSQLExpression& node) override {}
void VisitColumnReference(ast::ColumnReference& node) override;
void VisitUnpackPattern([[maybe_unused]] ast::UnpackPattern& node) override;
void VisitMatrixIterationStatement(
ast::MatrixIterationStatement& node) override;
void VisitExpressionStatement(ast::ExpressionStatement& node) override;
void VisitBreakStatement(ast::BreakStatement& node) override;
void VisitContinueStatement(ast::ContinueStatement& node) override;
void VisitReturnStatement(ast::ReturnStatement& node) override;
void VisitTSSQLExpression(ast::TSSQLExpression& node) override;
void VisitColumnReference(ast::ColumnReference& node) override;
void VisitUnpackPattern(ast::UnpackPattern& node) override;
void VisitMatrixIterationStatement(ast::MatrixIterationStatement& node) override;
// Compiler directives (correct names without Statement suffix)
void VisitCompilerDirective(
[[maybe_unused]] ast::CompilerDirective& node) override;
void VisitConditionalDirective(ast::ConditionalDirective& node) override;
void VisitConditionalBlock(ast::ConditionalBlock& node) override;
void VisitTSLXBlock([[maybe_unused]] ast::TSLXBlock& node) override {}
// Compiler directives
void VisitCompilerDirective(ast::CompilerDirective& node) override;
void VisitConditionalDirective(ast::ConditionalDirective& node) override;
void VisitConditionalBlock(ast::ConditionalBlock& node) override;
void VisitTSLXBlock(ast::TSLXBlock& node) override;
void VisitParameter([[maybe_unused]] ast::Parameter& node) override {}
void VisitParameter(ast::Parameter& node) override;
private:
SymbolId CreateFunctionSymbol(
const std::string& name, const ast::Location& location,
const std::vector<std::unique_ptr<ast::Parameter>>& parameters,
const std::optional<ast::TypeAnnotation>& return_type);
private:
SymbolId CreateFunctionSymbol(
const std::string& name,
const ast::Location& location,
const std::vector<std::unique_ptr<ast::Parameter>>& parameters,
const std::optional<ast::TypeAnnotation>& return_type);
SymbolId CreateMethodSymbol(
const std::string& name, const ast::Location& location,
const std::vector<std::unique_ptr<ast::Parameter>>& parameters,
const std::optional<ast::TypeAnnotation>& return_type);
SymbolId CreateMethodSymbol(
const std::string& name,
const ast::Location& location,
const std::vector<std::unique_ptr<ast::Parameter>>& parameters,
const std::optional<ast::TypeAnnotation>& return_type);
SymbolId CreateSymbol(
const std::string& name, SymbolKind kind, const ast::Location& location,
const std::optional<std::string>& type_hint = std::nullopt);
SymbolId CreateSymbol(
const std::string& name,
protocol::SymbolKind kind,
const ast::Location& location,
const std::optional<std::string>& type_hint = std::nullopt);
SymbolId CreateSymbol(
const std::string& name, SymbolKind kind, const ast::ASTNode& node,
const std::optional<std::string>& type_hint = std::nullopt);
SymbolId CreateSymbol(
const std::string& name,
protocol::SymbolKind kind,
const ast::ASTNode& node,
const std::optional<std::string>& type_hint = std::nullopt);
// Scope management
ScopeId EnterScopeWithSymbol(ScopeKind kind, SymbolId symbol_id,
const ast::Location& range);
ScopeId EnterScope(ScopeKind kind, const ast::Location& range);
void ExitScope();
// Scope management
ScopeId EnterScopeWithSymbol(ScopeKind kind, SymbolId symbol_id, const ast::Location& range);
ScopeId EnterScope(ScopeKind kind, const ast::Location& range);
void ExitScope();
// Traversal helpers
void VisitStatements(const std::vector<ast::StatementPtr>& statements);
void VisitExpression(ast::Expression& expr);
// Traversal helpers
void VisitStatements(const std::vector<ast::StatementPtr>& statements);
void VisitExpression(ast::Expression& expr);
// LValue processing for assignments
void ProcessLValue(const ast::LValue& lvalue, bool is_write);
// Type and parameter extraction
std::optional<std::string> ExtractTypeName(const std::optional<ast::TypeAnnotation>& type) const;
// Type and parameter extraction
std::optional<std::string> ExtractTypeName(
const std::optional<ast::TypeAnnotation>& type) const;
std::vector<language::symbol::Parameter> BuildParameters(const std::vector<std::unique_ptr<ast::Parameter>>& parameters) const;
void ProcessLValue(const ast::LValue& lvalue, bool is_write);
std::vector<language::symbol::Parameter> BuildParameters(
const std::vector<std::unique_ptr<ast::Parameter>>& parameters) const;
private:
SymbolTable& table_;
private:
SymbolTable& table_;
ScopeId current_scope_id_;
std::optional<SymbolId> current_parent_symbol_id_;
std::optional<SymbolId> current_function_id_;
ScopeId current_scope_id_;
std::optional<SymbolId> current_parent_symbol_id_;
std::optional<SymbolId> current_function_id_;
bool in_interface_section_;
};
bool in_interface_section_;
};
} // namespace lsp::language::symbol
} // namespace lsp::language::symbol

View File

@ -1,52 +0,0 @@
#include "call.hpp"
#include <algorithm>
namespace lsp::language::symbol::graph {
void Call::OnSymbolRemoved(SymbolId id) {
callers_map_.erase(id);
callees_map_.erase(id);
for (auto& [_, calls] : callers_map_) {
calls.erase(std::remove_if(calls.begin(), calls.end(),
[id](const symbol::Call& call) {
return call.caller == id;
}),
calls.end());
}
for (auto& [_, calls] : callees_map_) {
calls.erase(std::remove_if(calls.begin(), calls.end(),
[id](const symbol::Call& call) {
return call.callee == id;
}),
calls.end());
}
}
void Call::Clear() {
callers_map_.clear();
callees_map_.clear();
}
void Call::AddCall(SymbolId caller, SymbolId callee,
const ast::Location& location) {
symbol::Call call{caller, callee, location};
callers_map_[callee].push_back(call);
callees_map_[caller].push_back(call);
}
const std::vector<symbol::Call>& Call::callers(SymbolId id) const {
static const std::vector<symbol::Call> kEmpty;
auto it = callers_map_.find(id);
return it != callers_map_.end() ? it->second : kEmpty;
}
const std::vector<symbol::Call>& Call::callees(SymbolId id) const {
static const std::vector<symbol::Call> kEmpty;
auto it = callees_map_.find(id);
return it != callees_map_.end() ? it->second : kEmpty;
}
} // namespace lsp::language::symbol::graph

View File

@ -1,26 +0,0 @@
#pragma once
#include <unordered_map>
#include <vector>
#include "../interface.hpp"
#include "../types.hpp"
namespace lsp::language::symbol::graph {
class Call : public ISymbolGraph {
public:
void OnSymbolRemoved(SymbolId id) override;
void Clear() override;
void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location);
const std::vector<symbol::Call>& callers(SymbolId id) const;
const std::vector<symbol::Call>& callees(SymbolId id) const;
private:
std::unordered_map<SymbolId, std::vector<symbol::Call>> callers_map_;
std::unordered_map<SymbolId, std::vector<symbol::Call>> callees_map_;
};
} // namespace lsp::language::symbol::graph

View File

@ -1,58 +0,0 @@
#include "inheritance.hpp"
#include <algorithm>
namespace lsp::language::symbol::graph {
void Inheritance::OnSymbolRemoved(SymbolId id) {
base_classes_.erase(id);
derived_classes_.erase(id);
for (auto& [_, bases] : base_classes_) {
bases.erase(std::remove(bases.begin(), bases.end(), id), bases.end());
}
for (auto& [_, derived] : derived_classes_) {
derived.erase(std::remove(derived.begin(), derived.end(), id),
derived.end());
}
}
void Inheritance::Clear() {
base_classes_.clear();
derived_classes_.clear();
}
void Inheritance::AddInheritance(SymbolId derived, SymbolId base) {
base_classes_[derived].push_back(base);
derived_classes_[base].push_back(derived);
}
const std::vector<SymbolId>& Inheritance::base_classes(SymbolId id) const {
static const std::vector<SymbolId> kEmpty;
auto it = base_classes_.find(id);
return it != base_classes_.end() ? it->second : kEmpty;
}
const std::vector<SymbolId>& Inheritance::derived_classes(SymbolId id) const {
static const std::vector<SymbolId> kEmpty;
auto it = derived_classes_.find(id);
return it != derived_classes_.end() ? it->second : kEmpty;
}
bool Inheritance::IsSubclassOf(SymbolId derived, SymbolId base) const {
auto it = base_classes_.find(derived);
if (it == base_classes_.end()) {
return false;
}
for (SymbolId parent : it->second) {
if (parent == base || IsSubclassOf(parent, base)) {
return true;
}
}
return false;
}
} // namespace lsp::language::symbol::graph

View File

@ -1,27 +0,0 @@
#pragma once
#include <unordered_map>
#include <vector>
#include "../interface.hpp"
#include "../types.hpp"
namespace lsp::language::symbol::graph {
class Inheritance : public ISymbolGraph {
public:
void OnSymbolRemoved(SymbolId id) override;
void Clear() override;
void AddInheritance(SymbolId derived, SymbolId base);
const std::vector<SymbolId>& base_classes(SymbolId id) const;
const std::vector<SymbolId>& derived_classes(SymbolId id) const;
bool IsSubclassOf(SymbolId derived, SymbolId base) const;
private:
std::unordered_map<SymbolId, std::vector<SymbolId>> base_classes_;
std::unordered_map<SymbolId, std::vector<SymbolId>> derived_classes_;
};
} // namespace lsp::language::symbol::graph

View File

@ -1,46 +0,0 @@
#include "reference.hpp"
#include <algorithm>
namespace lsp::language::symbol::graph {
void Reference::OnSymbolRemoved(SymbolId id) {
references_.erase(id);
for (auto& [_, refs] : references_) {
refs.erase(std::remove_if(refs.begin(), refs.end(),
[id](const symbol::Reference& ref) {
return ref.symbol_id == id;
}),
refs.end());
}
}
void Reference::Clear() { references_.clear(); }
void Reference::AddReference(SymbolId symbol_id, const ast::Location& location,
bool is_definition, bool is_write) {
references_[symbol_id].push_back(
{location, symbol_id, is_definition, is_write});
}
const std::vector<symbol::Reference>& Reference::references(SymbolId id) const {
static const std::vector<symbol::Reference> kEmpty;
auto it = references_.find(id);
return it != references_.end() ? it->second : kEmpty;
}
std::optional<ast::Location> Reference::FindDefinitionLocation(
SymbolId id) const {
auto it = references_.find(id);
if (it != references_.end()) {
for (const auto& ref : it->second) {
if (ref.is_definition) {
return ref.location;
}
}
}
return std::nullopt;
}
} // namespace lsp::language::symbol::graph

View File

@ -1,27 +0,0 @@
#pragma once
#include <optional>
#include <unordered_map>
#include <vector>
#include "../interface.hpp"
#include "../types.hpp"
namespace lsp::language::symbol::graph {
class Reference : public ISymbolGraph {
public:
void OnSymbolRemoved(SymbolId id) override;
void Clear() override;
void AddReference(SymbolId symbol_id, const ast::Location& location,
bool is_definition = false, bool is_write = false);
const std::vector<symbol::Reference>& references(SymbolId id) const;
std::optional<ast::Location> FindDefinitionLocation(SymbolId id) const;
private:
std::unordered_map<SymbolId, std::vector<symbol::Reference>> references_;
};
} // namespace lsp::language::symbol::graph

View File

@ -0,0 +1,68 @@
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include "../interface.hpp"
namespace lsp::language::symbol
{
/**
* IndexDispatcher -
*
*
* -
* -
* - 使
*/
class IndexDispatcher
{
public:
/**
*
*/
void RegisterIndex(std::shared_ptr<ISymbolIndex> index)
{
indexes_.push_back(std::move(index));
}
/**
*
*/
void NotifySymbolAdded(const Symbol& symbol) const
{
for (const auto& index : indexes_)
{
index->OnSymbolAdded(symbol);
}
}
/**
*
*/
void NotifySymbolRemoved(SymbolId id) const
{
for (const auto& index : indexes_)
{
index->OnSymbolRemoved(id);
}
}
/**
*
*/
void Clear() const
{
for (const auto& index : indexes_)
{
index->Clear();
}
}
private:
std::vector<std::shared_ptr<ISymbolIndex>> indexes_;
};
} // namespace lsp::language::symbol

View File

@ -2,67 +2,78 @@
#include <algorithm>
namespace lsp::language::symbol::index {
namespace lsp::language::symbol::index
{
void Location::OnSymbolAdded(const Symbol& symbol) {
const auto& loc = symbol.selection_range();
entries_.push_back({loc.start_offset, loc.end_offset, symbol.id()});
needs_sort_ = true;
}
void Location::OnSymbolRemoved(SymbolId id) {
entries_.erase(
std::remove_if(entries_.begin(), entries_.end(),
[id](const Entry& e) { return e.symbol_id == id; }),
entries_.end());
}
void Location::Clear() {
entries_.clear();
needs_sort_ = false;
}
std::optional<SymbolId> Location::FindSymbolAt(
const ast::Location& location) const {
EnsureSorted();
uint32_t pos = location.start_offset;
auto it =
std::lower_bound(entries_.begin(), entries_.end(), pos,
[](const Entry& e, uint32_t p) { return e.start < p; });
if (it != entries_.begin()) {
--it;
}
std::optional<SymbolId> result;
uint32_t min_span = UINT32_MAX;
for (; it != entries_.end() && it->start <= pos; ++it) {
if (pos >= it->start && pos < it->end) {
uint32_t span = it->end - it->start;
if (span < min_span) {
min_span = span;
result = it->symbol_id;
}
void Location::OnSymbolAdded(const Symbol& symbol)
{
const auto& loc = symbol.selection_range();
entries_.push_back({ loc.start_offset, loc.end_offset, symbol.id() });
needs_sort_ = true;
}
}
return result;
}
void Location::OnSymbolRemoved(SymbolId id)
{
entries_.erase(
std::remove_if(entries_.begin(), entries_.end(), [id](const Entry& e) { return e.symbol_id == id; }),
entries_.end());
}
bool Location::Entry::operator<(const Entry& other) const {
if (start != other.start) {
return start < other.start;
}
return end > other.end;
}
void Location::Clear()
{
entries_.clear();
needs_sort_ = false;
}
void Location::EnsureSorted() const {
if (needs_sort_) {
std::sort(entries_.begin(), entries_.end());
needs_sort_ = false;
}
}
std::optional<SymbolId> Location::FindSymbolAt(
const ast::Location& location) const
{
EnsureSorted();
uint32_t pos = location.start_offset;
} // namespace lsp::language::symbol::index
auto it =
std::lower_bound(entries_.begin(), entries_.end(), pos, [](const Entry& e, uint32_t p) { return e.start < p; });
if (it != entries_.begin())
{
--it;
}
std::optional<SymbolId> result;
uint32_t min_span = UINT32_MAX;
for (; it != entries_.end() && it->start <= pos; ++it)
{
if (pos >= it->start && pos < it->end)
{
uint32_t span = it->end - it->start;
if (span < min_span)
{
min_span = span;
result = it->symbol_id;
}
}
}
return result;
}
bool Location::Entry::operator<(const Entry& other) const
{
if (start != other.start)
{
return start < other.start;
}
return end > other.end;
}
void Location::EnsureSorted() const
{
if (needs_sort_)
{
std::sort(entries_.begin(), entries_.end());
needs_sort_ = false;
}
}
} // namespace lsp::language::symbol::index

View File

@ -6,29 +6,32 @@
#include "../interface.hpp"
#include "../types.hpp"
namespace lsp::language::symbol::index {
namespace lsp::language::symbol::index
{
class Location : public ISymbolIndex {
public:
void OnSymbolAdded(const Symbol& symbol) override;
void OnSymbolRemoved(SymbolId id) override;
void Clear() override;
class Location : public ISymbolIndex
{
public:
void OnSymbolAdded(const Symbol& symbol) override;
void OnSymbolRemoved(SymbolId id) override;
void Clear() override;
std::optional<SymbolId> FindSymbolAt(const ast::Location& location) const;
std::optional<SymbolId> FindSymbolAt(const ast::Location& location) const;
private:
struct Entry {
uint32_t start;
uint32_t end;
SymbolId symbol_id;
private:
struct Entry
{
uint32_t start;
uint32_t end;
SymbolId symbol_id;
bool operator<(const Entry& other) const;
};
bool operator<(const Entry& other) const;
};
void EnsureSorted() const;
void EnsureSorted() const;
mutable std::vector<Entry> entries_;
mutable bool needs_sort_ = false;
};
mutable std::vector<Entry> entries_;
mutable bool needs_sort_ = false;
};
} // namespace lsp::language::symbol::index
} // namespace lsp::language::symbol::index

View File

@ -2,92 +2,110 @@
#include <algorithm>
namespace lsp::language::symbol::index {
namespace lsp::language::symbol::index
{
void Scope::OnSymbolAdded(const Symbol&) {}
void Scope::OnSymbolAdded(const Symbol&) {}
void Scope::OnSymbolRemoved(SymbolId id) {
for (auto& [_, scope] : scopes_) {
for (auto it = scope.symbols.begin(); it != scope.symbols.end();) {
if (it->second == id) {
it = scope.symbols.erase(it);
} else {
++it;
}
}
}
}
void Scope::Clear() {
scopes_.clear();
next_scope_id_ = 1;
global_scope_ = kInvalidScopeId;
}
ScopeId Scope::CreateScope(ScopeKind kind, const ast::Location& range,
std::optional<ScopeId> parent,
std::optional<SymbolId> owner) {
ScopeId id = next_scope_id_++;
scopes_[id] = {id, kind, range, parent, owner, {}};
if (kind == ScopeKind::kGlobal) {
global_scope_ = id;
}
return id;
}
void Scope::AddSymbol(ScopeId scope_id, const std::string& name,
SymbolId symbol_id) {
auto it = scopes_.find(scope_id);
if (it != scopes_.end()) {
it->second.symbols[ToLower(name)] = symbol_id;
}
}
std::optional<SymbolId> Scope::FindSymbolInScope(
ScopeId scope_id, const std::string& name) const {
auto it = scopes_.find(scope_id);
if (it == scopes_.end()) {
return std::nullopt;
}
auto sym_it = it->second.symbols.find(ToLower(name));
return sym_it != it->second.symbols.end() ? std::optional(sym_it->second)
: std::nullopt;
}
std::optional<SymbolId> Scope::FindSymbolInScopeChain(
ScopeId scope_id, const std::string& name) const {
std::optional<ScopeId> current = scope_id;
while (current) {
if (auto result = FindSymbolInScope(*current, name)) {
return result;
void Scope::OnSymbolRemoved(SymbolId id)
{
for (auto& [_, scope] : scopes_)
{
for (auto it = scope.symbols.begin(); it != scope.symbols.end();)
{
if (it->second == id)
{
it = scope.symbols.erase(it);
}
else
{
++it;
}
}
}
}
auto it = scopes_.find(*current);
current = it != scopes_.end() ? it->second.parent : std::nullopt;
}
void Scope::Clear()
{
scopes_.clear();
next_scope_id_ = 1;
global_scope_ = kInvalidScopeId;
}
return std::nullopt;
}
ScopeId Scope::CreateScope(ScopeKind kind, const ast::Location& range, std::optional<ScopeId> parent, std::optional<SymbolId> owner)
{
ScopeId id = next_scope_id_++;
scopes_[id] = { id, kind, range, parent, owner, {} };
const symbol::Scope* Scope::scope(ScopeId id) const {
auto it = scopes_.find(id);
return it != scopes_.end() ? &it->second : nullptr;
}
if (kind == ScopeKind::kGlobal)
{
global_scope_ = id;
}
ScopeId Scope::global_scope() const { return global_scope_; }
return id;
}
const std::unordered_map<ScopeId, symbol::Scope>& Scope::all_scopes() const {
return scopes_;
}
void Scope::AddSymbol(ScopeId scope_id, const std::string& name, SymbolId symbol_id)
{
auto it = scopes_.find(scope_id);
if (it != scopes_.end())
{
it->second.symbols[ToLower(name)] = symbol_id;
}
}
std::string Scope::ToLower(const std::string& s) {
std::string result = s;
std::transform(result.begin(), result.end(), result.begin(), ::tolower);
return result;
}
std::optional<SymbolId> Scope::FindSymbolInScope(
ScopeId scope_id,
const std::string& name) const
{
auto it = scopes_.find(scope_id);
if (it == scopes_.end())
{
return std::nullopt;
}
} // namespace lsp::language::symbol::index
auto sym_it = it->second.symbols.find(ToLower(name));
return sym_it != it->second.symbols.end() ? std::optional(sym_it->second) : std::nullopt;
}
std::optional<SymbolId> Scope::FindSymbolInScopeChain(
ScopeId scope_id,
const std::string& name) const
{
std::optional<ScopeId> current = scope_id;
while (current)
{
if (auto result = FindSymbolInScope(*current, name))
{
return result;
}
auto it = scopes_.find(*current);
current = it != scopes_.end() ? it->second.parent : std::nullopt;
}
return std::nullopt;
}
const symbol::Scope* Scope::scope(ScopeId id) const
{
auto it = scopes_.find(id);
return it != scopes_.end() ? &it->second : nullptr;
}
ScopeId Scope::global_scope() const { return global_scope_; }
const std::unordered_map<ScopeId, symbol::Scope>& Scope::all_scopes() const
{
return scopes_;
}
std::string Scope::ToLower(const std::string& s)
{
std::string result = s;
std::transform(result.begin(), result.end(), result.begin(), ::tolower);
return result;
}
} // namespace lsp::language::symbol::index

View File

@ -7,46 +7,48 @@
#include "../interface.hpp"
#include "../types.hpp"
namespace lsp::language::symbol {
namespace lsp::language::symbol
{
struct Scope {
ScopeId id;
ScopeKind kind;
ast::Location range;
std::optional<ScopeId> parent;
std::optional<SymbolId> owner;
std::unordered_map<std::string, SymbolId> symbols;
};
struct Scope
{
ScopeId id;
ScopeKind kind;
ast::Location range;
std::optional<ScopeId> parent;
std::optional<SymbolId> owner;
std::unordered_map<std::string, SymbolId> symbols;
};
} // namespace lsp::language::symbol
} // namespace lsp::language::symbol
namespace lsp::language::symbol::index {
namespace lsp::language::symbol::index
{
class Scope : public ISymbolIndex {
public:
void OnSymbolAdded(const Symbol&) override;
void OnSymbolRemoved(SymbolId id) override;
void Clear() override;
class Scope : public ISymbolIndex
{
public:
void OnSymbolAdded(const Symbol&) override;
void OnSymbolRemoved(SymbolId id) override;
void Clear() override;
ScopeId CreateScope(ScopeKind kind, const ast::Location& range,
std::optional<ScopeId> parent = std::nullopt,
std::optional<SymbolId> owner = std::nullopt);
void AddSymbol(ScopeId scope_id, const std::string& name, SymbolId symbol_id);
std::optional<SymbolId> FindSymbolInScope(ScopeId scope_id,
const std::string& name) const;
std::optional<SymbolId> FindSymbolInScopeChain(ScopeId scope_id,
const std::string& name) const;
ScopeId CreateScope(ScopeKind kind, const ast::Location& range, std::optional<ScopeId> parent = std::nullopt, std::optional<SymbolId> owner = std::nullopt);
void AddSymbol(ScopeId scope_id, const std::string& name, SymbolId symbol_id);
std::optional<SymbolId> FindSymbolInScope(ScopeId scope_id,
const std::string& name) const;
std::optional<SymbolId> FindSymbolInScopeChain(ScopeId scope_id,
const std::string& name) const;
const symbol::Scope* scope(ScopeId id) const;
ScopeId global_scope() const;
const std::unordered_map<ScopeId, symbol::Scope>& all_scopes() const;
const symbol::Scope* scope(ScopeId id) const;
ScopeId global_scope() const;
const std::unordered_map<ScopeId, symbol::Scope>& all_scopes() const;
private:
static std::string ToLower(const std::string& s);
private:
static std::string ToLower(const std::string& s);
ScopeId next_scope_id_ = 1;
ScopeId global_scope_ = kInvalidScopeId;
std::unordered_map<ScopeId, symbol::Scope> scopes_;
};
ScopeId next_scope_id_ = 1;
ScopeId global_scope_ = kInvalidScopeId;
std::unordered_map<ScopeId, symbol::Scope> scopes_;
};
} // namespace lsp::language::symbol::index
} // namespace lsp::language::symbol::index

View File

@ -2,21 +2,36 @@
#include "./types.hpp"
namespace lsp::language::symbol {
namespace lsp::language::symbol
{
class ISymbolIndex {
public:
virtual ~ISymbolIndex() = default;
virtual void OnSymbolAdded(const Symbol& symbol) = 0;
virtual void OnSymbolRemoved(SymbolId id) = 0;
virtual void Clear() = 0;
};
/**
* ISymbolIndex -
*
*
*
*/
class ISymbolIndex
{
public:
virtual ~ISymbolIndex() = default;
class ISymbolGraph {
public:
virtual ~ISymbolGraph() = default;
virtual void OnSymbolRemoved(SymbolId id) = 0;
virtual void Clear() = 0;
};
/**
*
* @param symbol
*/
virtual void OnSymbolAdded(const Symbol& symbol) = 0;
} // namespace lsp::language::symbol
/**
*
* @param id ID
*/
virtual void OnSymbolRemoved(SymbolId id) = 0;
/**
*
*/
virtual void Clear() = 0;
};
} // namespace lsp::language::symbol

View File

@ -1,59 +1,69 @@
#include "store.hpp"
#include <algorithm>
namespace lsp::language::symbol {
#include "./store.hpp"
SymbolId SymbolStore::Add(Symbol def) {
SymbolId id = next_id_++;
std::visit([id](auto& s) { s.id = id; }, def.mutable_data());
namespace lsp::language::symbol
{
auto [it, _] = definitions_.emplace(id, std::move(def));
const auto& stored = it->second;
by_name_[stored.name()].push_back(id);
return id;
}
SymbolId SymbolStore::Add(Symbol def)
{
SymbolId id = next_id_++;
std::visit([id](auto& s) { s.id = id; }, def.mutable_data());
bool SymbolStore::Remove(SymbolId id) {
auto it = definitions_.find(id);
if (it == definitions_.end()) {
return false;
}
auto [it, _] = definitions_.emplace(id, std::move(def));
const auto& stored = it->second;
by_name_[stored.name()].push_back(id);
return id;
}
const std::string& name = it->second.name();
auto& ids = by_name_[name];
ids.erase(std::remove(ids.begin(), ids.end(), id), ids.end());
if (ids.empty()) {
by_name_.erase(name);
}
bool SymbolStore::Remove(SymbolId id)
{
auto it = definitions_.find(id);
if (it == definitions_.end())
{
return false;
}
definitions_.erase(it);
return true;
}
const std::string& name = it->second.name();
auto& ids = by_name_[name];
ids.erase(std::remove(ids.begin(), ids.end(), id), ids.end());
if (ids.empty())
{
by_name_.erase(name);
}
void SymbolStore::Clear() {
definitions_.clear();
by_name_.clear();
next_id_ = 1;
}
definitions_.erase(it);
return true;
}
const Symbol* SymbolStore::Get(SymbolId id) const {
auto it = definitions_.find(id);
return it != definitions_.end() ? &it->second : nullptr;
}
void SymbolStore::Clear()
{
definitions_.clear();
by_name_.clear();
next_id_ = 1;
}
std::vector<std::reference_wrapper<const Symbol>> SymbolStore::GetAll() const {
std::vector<std::reference_wrapper<const Symbol>> result;
result.reserve(definitions_.size());
for (const auto& [_, def] : definitions_) {
result.push_back(std::cref(def));
}
return result;
}
const Symbol* SymbolStore::Get(SymbolId id) const
{
auto it = definitions_.find(id);
return it != definitions_.end() ? &it->second : nullptr;
}
std::vector<SymbolId> SymbolStore::FindByName(const std::string& name) const {
auto it = by_name_.find(name);
return it != by_name_.end() ? it->second : std::vector<SymbolId>();
}
std::vector<std::reference_wrapper<const Symbol>> SymbolStore::GetAll() const
{
std::vector<std::reference_wrapper<const Symbol>> result;
result.reserve(definitions_.size());
for (const auto& [_, def] : definitions_)
{
result.push_back(std::cref(def));
}
return result;
}
} // namespace lsp::language::symbol
std::vector<SymbolId> SymbolStore::FindByName(const std::string& name) const
{
auto it = by_name_.find(name);
return it != by_name_.end() ? it->second : std::vector<SymbolId>();
}
} // namespace lsp::language::symbol

View File

@ -7,22 +7,24 @@
#include "./types.hpp"
namespace lsp::language::symbol {
namespace lsp::language::symbol
{
class SymbolStore {
public:
SymbolId Add(Symbol def);
bool Remove(SymbolId id);
void Clear();
class SymbolStore
{
public:
SymbolId Add(Symbol def);
bool Remove(SymbolId id);
void Clear();
const Symbol* Get(SymbolId id) const;
std::vector<std::reference_wrapper<const Symbol>> GetAll() const;
std::vector<SymbolId> FindByName(const std::string& name) const;
const Symbol* Get(SymbolId id) const;
std::vector<std::reference_wrapper<const Symbol>> GetAll() const;
std::vector<SymbolId> FindByName(const std::string& name) const;
private:
SymbolId next_id_ = 1;
std::unordered_map<SymbolId, Symbol> definitions_;
std::unordered_map<std::string, std::vector<SymbolId>> by_name_;
};
private:
SymbolId next_id_ = 1;
std::unordered_map<SymbolId, Symbol> definitions_;
std::unordered_map<std::string, std::vector<SymbolId>> by_name_;
};
} // namespace lsp::language::symbol
} // namespace lsp::language::symbol

View File

@ -1,104 +1,100 @@
#include "table.hpp"
#include "./table.hpp"
#include <utility>
namespace lsp::language::symbol
{
namespace lsp::language::symbol {
SymbolTable::SymbolTable()
{
location_index_ = std::make_shared<index::Location>();
scope_index_ = std::make_shared<index::Scope>();
SymbolId SymbolTable::CreateSymbol(Symbol symbol) {
auto def = Symbol(std::move(symbol));
auto id = store_.Add(def);
index_dispatcher_.RegisterIndex(location_index_);
index_dispatcher_.RegisterIndex(scope_index_);
}
location_index_.OnSymbolAdded(def);
scope_index_.OnSymbolAdded(def);
SymbolId SymbolTable::CreateSymbol(Symbol symbol)
{
auto id = store_.Add(std::move(symbol));
return id;
}
if (const Symbol* stored = store_.Get(id))
{
index_dispatcher_.NotifySymbolAdded(*stored);
}
bool SymbolTable::RemoveSymbol(SymbolId id) {
location_index_.OnSymbolRemoved(id);
scope_index_.OnSymbolRemoved(id);
reference_graph_.OnSymbolRemoved(id);
inheritance_graph_.OnSymbolRemoved(id);
call_graph_.OnSymbolRemoved(id);
return id;
}
return store_.Remove(id);
}
bool SymbolTable::RemoveSymbol(SymbolId id)
{
if (!store_.Remove(id))
{
return false;
}
void SymbolTable::Clear() {
store_.Clear();
location_index_.Clear();
scope_index_.Clear();
reference_graph_.Clear();
inheritance_graph_.Clear();
call_graph_.Clear();
}
index_dispatcher_.NotifySymbolRemoved(id);
return true;
}
std::vector<SymbolId> SymbolTable::FindSymbolsByName(
const std::string& name) const {
return store_.FindByName(name);
}
void SymbolTable::Clear()
{
store_.Clear();
index_dispatcher_.Clear();
}
std::optional<SymbolId> SymbolTable::FindSymbolAt(
const ast::Location& location) const {
return location_index_.FindSymbolAt(location);
}
std::vector<SymbolId> SymbolTable::FindSymbolsByName(
const std::string& name) const
{
return store_.FindByName(name);
}
const Symbol* SymbolTable::definition(SymbolId id) const {
return store_.Get(id);
}
std::optional<SymbolId> SymbolTable::FindSymbolAt(
const ast::Location& location) const
{
return location_index_->FindSymbolAt(location);
}
std::vector<std::reference_wrapper<const Symbol>> SymbolTable::all_definitions()
const {
return store_.GetAll();
}
const Symbol* SymbolTable::definition(SymbolId id) const
{
return store_.Get(id);
}
index::Location& SymbolTable::locations() { return location_index_; }
index::Scope& SymbolTable::scopes() { return scope_index_; }
std::vector<std::reference_wrapper<const Symbol>> SymbolTable::all_definitions()
const
{
return store_.GetAll();
}
const index::Location& SymbolTable::locations() const {
return location_index_;
}
void SymbolTable::RegisterIndex(std::shared_ptr<ISymbolIndex> index)
{
if (!index)
{
return;
}
for (const auto& symbol_ref : store_.GetAll())
{
index->OnSymbolAdded(symbol_ref.get());
}
index_dispatcher_.RegisterIndex(std::move(index));
}
const index::Scope& SymbolTable::scopes() const { return scope_index_; }
index::Location& SymbolTable::locations() { return *location_index_; }
index::Scope& SymbolTable::scopes() { return *scope_index_; }
graph::Reference& SymbolTable::references() { return reference_graph_; }
graph::Inheritance& SymbolTable::inheritance() { return inheritance_graph_; }
graph::Call& SymbolTable::calls() { return call_graph_; }
const index::Location& SymbolTable::locations() const
{
return *location_index_;
}
const graph::Reference& SymbolTable::references() const {
return reference_graph_;
}
const index::Scope& SymbolTable::scopes() const { return *scope_index_; }
const graph::Inheritance& SymbolTable::inheritance() const {
return inheritance_graph_;
}
ScopeId SymbolTable::CreateScope(ScopeKind kind, const ast::Location& range, std::optional<ScopeId> parent, std::optional<SymbolId> owner)
{
return scope_index_->CreateScope(kind, range, parent, owner);
}
const graph::Call& SymbolTable::calls() const { return call_graph_; }
void SymbolTable::AddSymbolToScope(ScopeId scope_id, const std::string& name, SymbolId symbol_id)
{
scope_index_->AddSymbol(scope_id, name, symbol_id);
}
ScopeId SymbolTable::CreateScope(ScopeKind kind, const ast::Location& range,
std::optional<ScopeId> parent,
std::optional<SymbolId> owner) {
return scope_index_.CreateScope(kind, range, parent, owner);
}
void SymbolTable::AddSymbolToScope(ScopeId scope_id, const std::string& name,
SymbolId symbol_id) {
scope_index_.AddSymbol(scope_id, name, symbol_id);
}
void SymbolTable::AddReference(SymbolId symbol_id,
const ast::Location& location,
bool is_definition, bool is_write) {
reference_graph_.AddReference(symbol_id, location, is_definition, is_write);
}
void SymbolTable::AddInheritance(SymbolId derived, SymbolId base) {
inheritance_graph_.AddInheritance(derived, base);
}
void SymbolTable::AddCall(SymbolId caller, SymbolId callee,
const ast::Location& location) {
call_graph_.AddCall(caller, callee, location);
}
} // namespace lsp::language::symbol
} // namespace lsp::language::symbol

View File

@ -1,63 +1,49 @@
#pragma once
#include <functional>
#include <memory>
#include <optional>
#include <vector>
#include "./graph/call.hpp"
#include "./graph/inheritance.hpp"
#include "./graph/reference.hpp"
#include "./index/location.hpp"
#include "./index/dispatcher.hpp"
#include "./index/scope.hpp"
#include "./store.hpp"
namespace lsp::language::symbol {
namespace lsp::language::symbol
{
class SymbolTable {
public:
SymbolTable() = default;
class SymbolTable
{
public:
SymbolTable();
SymbolId CreateSymbol(Symbol symbol);
bool RemoveSymbol(SymbolId id);
void Clear();
SymbolId CreateSymbol(Symbol symbol);
bool RemoveSymbol(SymbolId id);
void Clear();
std::vector<SymbolId> FindSymbolsByName(const std::string& name) const;
std::optional<SymbolId> FindSymbolAt(const ast::Location& location) const;
std::vector<SymbolId> FindSymbolsByName(const std::string& name) const;
std::optional<SymbolId> FindSymbolAt(const ast::Location& location) const;
const Symbol* definition(SymbolId id) const;
std::vector<std::reference_wrapper<const Symbol>> all_definitions() const;
const Symbol* definition(SymbolId id) const;
std::vector<std::reference_wrapper<const Symbol>> all_definitions() const;
index::Location& locations();
index::Scope& scopes();
void RegisterIndex(std::shared_ptr<ISymbolIndex> index);
const index::Location& locations() const;
const index::Scope& scopes() const;
index::Location& locations();
index::Scope& scopes();
graph::Reference& references();
graph::Inheritance& inheritance();
graph::Call& calls();
const index::Location& locations() const;
const index::Scope& scopes() const;
const graph::Reference& references() const;
const graph::Inheritance& inheritance() const;
const graph::Call& calls() const;
ScopeId CreateScope(ScopeKind kind, const ast::Location& range, std::optional<ScopeId> parent = std::nullopt, std::optional<SymbolId> owner = std::nullopt);
void AddSymbolToScope(ScopeId scope_id, const std::string& name, SymbolId symbol_id);
ScopeId CreateScope(ScopeKind kind, const ast::Location& range,
std::optional<ScopeId> parent = std::nullopt,
std::optional<SymbolId> owner = std::nullopt);
void AddSymbolToScope(ScopeId scope_id, const std::string& name,
SymbolId symbol_id);
void AddReference(SymbolId symbol_id, const ast::Location& location,
bool is_definition = false, bool is_write = false);
void AddInheritance(SymbolId derived, SymbolId base);
void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location);
private:
SymbolStore store_;
IndexDispatcher index_dispatcher_;
std::shared_ptr<index::Location> location_index_;
std::shared_ptr<index::Scope> scope_index_;
};
private:
SymbolStore store_;
index::Location location_index_;
index::Scope scope_index_;
graph::Reference reference_graph_;
graph::Inheritance inheritance_graph_;
graph::Call call_graph_;
};
} // namespace lsp::language::symbol
} // namespace lsp::language::symbol

View File

@ -8,243 +8,259 @@
#include "../../protocol/protocol.hpp"
#include "../ast/types.hpp"
namespace lsp::language::symbol {
namespace lsp::language::symbol
{
// ===== Basic Types =====
// ===== Basic Types =====
using SymbolId = uint64_t;
constexpr SymbolId kInvalidSymbolId = 0;
using SymbolId = uint64_t;
constexpr SymbolId kInvalidSymbolId = 0;
using ScopeId = uint64_t;
constexpr ScopeId kInvalidScopeId = 0;
using ScopeId = uint64_t;
constexpr ScopeId kInvalidScopeId = 0;
using SymbolKind = protocol::SymbolKind;
enum class ScopeKind
{
kGlobal,
kUnit,
kClass,
kFunction,
kAnonymousFunction,
kBlock
};
enum class ScopeKind {
kGlobal,
kUnit,
kClass,
kFunction,
kAnonymousFunction,
kBlock
};
enum class VariableScope
{
kAutomatic,
kStatic,
kGlobal,
kParameter,
kField
};
enum class VariableScope { kAutomatic, kStatic, kGlobal, kParameter, kField };
enum class UnitVisibility
{
kInterface,
kImplementation
};
enum class UnitVisibility { kInterface, kImplementation };
// ===== Parameters =====
// ===== Parameters =====
struct Parameter
{
std::string name;
std::optional<std::string> type;
std::optional<std::string> default_value;
};
struct Parameter {
std::string name;
std::optional<std::string> type;
std::optional<std::string> default_value;
};
struct UnitImport
{
std::string unit_name;
ast::Location location;
};
struct UnitImport {
std::string unit_name;
ast::Location location;
};
// ===== Symbol Types (all fields directly expanded) =====
// ===== Symbol Types (all fields directly expanded) =====
struct Function
{
static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Function;
struct Function {
static constexpr SymbolKind kind = SymbolKind::Function;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range; // Name identifier location (for cursor)
ast::Location range; // Full range (including modifiers, body)
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range; // Name identifier location (for cursor)
ast::Location range; // Full range (including modifiers, body)
ast::Location declaration_range;
std::optional<ast::Location> implementation_range;
std::vector<Parameter> parameters;
std::optional<std::string> return_type;
ast::Location declaration_range;
std::optional<ast::Location> implementation_range;
std::vector<Parameter> parameters;
std::optional<std::string> return_type;
std::vector<UnitImport> imports;
std::optional<UnitVisibility> unit_visibility;
std::vector<UnitImport> imports;
std::optional<UnitVisibility> unit_visibility;
bool HasImplementation() const { return implementation_range.has_value(); }
};
bool HasImplementation() const { return implementation_range.has_value(); }
};
struct Class
{
static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Class;
struct Class {
static constexpr SymbolKind kind = SymbolKind::Class;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
std::optional<UnitVisibility> unit_visibility;
std::optional<UnitVisibility> unit_visibility;
std::vector<SymbolId> base_classes;
std::vector<SymbolId> members;
std::vector<SymbolId> base_classes;
std::vector<SymbolId> members;
std::vector<UnitImport> imports;
};
std::vector<UnitImport> imports;
};
struct Method
{
static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Method;
struct Method {
static constexpr SymbolKind kind = SymbolKind::Method;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
// Location information
ast::Location declaration_range;
std::optional<ast::Location> implementation_range;
// Location information
ast::Location declaration_range;
std::optional<ast::Location> implementation_range;
// Method-specific
ast::MethodKind method_kind = ast::MethodKind::kOrdinary;
ast::AccessModifier access = ast::AccessModifier::kPublic;
std::optional<ast::MethodModifier> method_modifier =
ast::MethodModifier::kNone;
bool is_static = false;
std::vector<Parameter> parameters;
std::optional<std::string> return_type;
// Method-specific
ast::MethodKind method_kind = ast::MethodKind::kOrdinary;
ast::AccessModifier access = ast::AccessModifier::kPublic;
std::optional<ast::MethodModifier> method_modifier =
ast::MethodModifier::kNone;
bool is_static = false;
std::vector<Parameter> parameters;
std::optional<std::string> return_type;
std::vector<UnitImport> imports;
std::vector<UnitImport> imports;
bool HasImplementation() const { return implementation_range.has_value(); }
};
bool HasImplementation() const { return implementation_range.has_value(); }
};
struct Property
{
static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Property;
struct Property {
static constexpr SymbolKind kind = SymbolKind::Property;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
// Property-specific
ast::AccessModifier access = ast::AccessModifier::kPublic;
std::optional<std::string> type;
std::optional<SymbolId> getter;
std::optional<SymbolId> setter;
};
// Property-specific
ast::AccessModifier access = ast::AccessModifier::kPublic;
std::optional<std::string> type;
std::optional<SymbolId> getter;
std::optional<SymbolId> setter;
};
struct Field
{
static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Field;
struct Field {
static constexpr SymbolKind kind = SymbolKind::Field;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
ast::AccessModifier access = ast::AccessModifier::kPublic;
std::optional<ast::ReferenceModifier> reference_modifier;
std::optional<std::string> type;
bool is_static = false;
};
ast::AccessModifier access = ast::AccessModifier::kPublic;
std::optional<ast::ReferenceModifier> reference_modifier;
std::optional<std::string> type;
bool is_static = false;
};
struct Variable
{
static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Variable;
struct Variable {
static constexpr SymbolKind kind = SymbolKind::Variable;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
std::optional<std::string> type;
std::optional<ast::ReferenceModifier> reference_modifier;
VariableScope storage = VariableScope::kAutomatic;
std::optional<UnitVisibility> unit_visibility;
bool has_initializer = false;
};
std::optional<std::string> type;
std::optional<ast::ReferenceModifier> reference_modifier;
VariableScope storage = VariableScope::kAutomatic;
std::optional<UnitVisibility> unit_visibility;
bool has_initializer = false;
};
struct Constant
{
static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Constant;
struct Constant {
static constexpr SymbolKind kind = SymbolKind::Constant;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
std::optional<std::string> type;
std::string value;
};
std::optional<std::string> type;
std::string value;
};
struct Unit
{
static constexpr protocol::SymbolKind kind = protocol::SymbolKind::Namespace;
struct Unit {
static constexpr SymbolKind kind = SymbolKind::Namespace;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
SymbolId id = kInvalidSymbolId;
std::string name;
ast::Location selection_range;
ast::Location range;
std::vector<UnitImport> interface_imports;
std::vector<UnitImport> implementation_imports;
};
std::vector<UnitImport> interface_imports;
std::vector<UnitImport> implementation_imports;
};
// ===== Symbol Data Variant =====
struct Reference {
ast::Location location;
SymbolId symbol_id;
bool is_definition;
bool is_write;
};
using SymbolData = std::variant<Function, Class, Method, Property, Field, Variable, Constant, Unit>;
struct Call {
SymbolId caller;
SymbolId callee;
ast::Location call_site;
};
// ===== Symbol =====
// ===== Symbol Data Variant =====
class Symbol
{
public:
explicit Symbol(SymbolData data) : data_(std::move(data)) {}
using SymbolData = std::variant<Function, Class, Method, Property, Field,
Variable, Constant, Unit>;
// Type checking and conversion
template<typename T>
bool Is() const
{
return std::holds_alternative<T>(data_);
}
// ===== Symbol =====
template<typename T>
const T* As() const
{
return std::get_if<T>(&data_);
}
class Symbol {
public:
explicit Symbol(SymbolData data) : data_(std::move(data)) {}
template<typename T>
T* As()
{
return std::get_if<T>(&data_);
}
// Type checking and conversion
template <typename T>
bool Is() const {
return std::holds_alternative<T>(data_);
}
// Accessors (snake_case per Google style)
const SymbolData& data() const { return data_; }
SymbolData& mutable_data() { return data_; }
template <typename T>
const T* As() const {
return std::get_if<T>(&data_);
}
// Common accessors (all symbol types have these)
SymbolId id() const
{
return std::visit([](const auto& s) { return s.id; }, data_);
}
template <typename T>
T* As() {
return std::get_if<T>(&data_);
}
const std::string& name() const
{
return std::visit([](const auto& s) -> const auto& { return s.name; },
data_);
}
// Accessors (snake_case per Google style)
const SymbolData& data() const { return data_; }
SymbolData& mutable_data() { return data_; }
ast::Location selection_range() const
{
return std::visit([](const auto& s) { return s.selection_range; }, data_);
}
// Common accessors (all symbol types have these)
SymbolId id() const {
return std::visit([](const auto& s) { return s.id; }, data_);
}
ast::Location range() const
{
return std::visit([](const auto& s) { return s.range; }, data_);
}
const std::string& name() const {
return std::visit([](const auto& s) -> const auto& { return s.name; },
data_);
}
protocol::SymbolKind kind() const
{
return std::visit([](const auto& s) { return s.kind; }, data_);
}
ast::Location selection_range() const {
return std::visit([](const auto& s) { return s.selection_range; }, data_);
}
private:
SymbolData data_;
};
ast::Location range() const {
return std::visit([](const auto& s) { return s.range; }, data_);
}
SymbolKind kind() const {
return std::visit([](const auto& s) { return s.kind; }, data_);
}
private:
SymbolData data_;
};
} // namespace lsp::language::symbol
} // namespace lsp::language::symbol

View File

@ -27,9 +27,6 @@ set(SOURCES
../../src/language/symbol/table.cpp
../../src/language/symbol/index/location.cpp
../../src/language/symbol/index/scope.cpp
../../src/language/symbol/graph/call.cpp
../../src/language/symbol/graph/inheritance.cpp
../../src/language/symbol/graph/reference.cpp
../../src/tree-sitter/scanner.c
../../src/tree-sitter/parser.c)

View File

@ -87,7 +87,6 @@ namespace lsp::language::symbol::debug
PrintOptions opts;
opts.show_details = true;
opts.show_children = true;
opts.show_references = true;
return opts;
}
@ -107,27 +106,11 @@ namespace lsp::language::symbol::debug
symbol_counts.clear();
scope_counts.clear();
symbols_with_refs = 0;
total_references = 0;
max_references = 0;
most_referenced = kInvalidSymbolId;
for (const auto& ref : all_defs)
{
const auto& sym = ref.get();
symbol_counts[sym.kind()]++;
const auto& refs = table.references().references(sym.id());
if (!refs.empty())
{
symbols_with_refs++;
total_references += refs.size();
if (refs.size() > max_references)
{
max_references = refs.size();
most_referenced = sym.id();
}
}
}
const auto& scopes = table.scopes().all_scopes();
@ -153,15 +136,6 @@ namespace lsp::language::symbol::debug
os << color(Color::Bold) << "Overview:" << color(Color::Reset) << "\n";
os << " Total Symbols: " << color(Color::Cyan) << total_symbols << color(Color::Reset) << "\n";
os << " Total Scopes: " << color(Color::Cyan) << total_scopes << color(Color::Reset) << "\n";
os << " Total References: " << color(Color::Cyan) << total_references << color(Color::Reset) << "\n";
os << " Symbols w/ Refs: " << color(Color::Green) << symbols_with_refs << color(Color::Reset) << "\n";
if (most_referenced != kInvalidSymbolId)
{
os << " Most Referenced: " << color(Color::Yellow) << "ID=" << most_referenced
<< " (" << max_references << " refs)" << color(Color::Reset) << "\n";
}
os << "\n";
os << color(Color::Bold) << "Symbol Distribution:" << color(Color::Reset) << "\n";
@ -507,12 +481,6 @@ namespace lsp::language::symbol::debug
}
}
const auto& refs = table_.references().references(symbol.id());
if (options_.show_references && !refs.empty())
{
os << " " << Dim("[refs: " + std::to_string(refs.size()) + "]");
}
os << "\n";
},
symbol.data());
@ -713,135 +681,6 @@ namespace lsp::language::symbol::debug
PrintScopeTree(global, os, 0);
}
void DebugPrinter::PrintReferences(SymbolId id, std::ostream& os)
{
const auto& refs = table_.references().references(id);
if (refs.empty())
return;
PrintSubHeader("References for symbol " + std::to_string(id), os);
for (const auto& ref : refs)
{
os << " - " << FormatLocation(ref.location);
if (ref.is_definition)
os << " (def)";
if (ref.is_write)
os << " (write)";
os << "\n";
}
}
void DebugPrinter::PrintInheritance(SymbolId class_id, std::ostream& os)
{
const auto& bases = table_.inheritance().base_classes(class_id);
const auto& derived = table_.inheritance().derived_classes(class_id);
PrintSubHeader("Inheritance for class " + std::to_string(class_id), os);
if (bases.empty() && derived.empty())
{
os << " (no inheritance info)\n";
return;
}
if (!bases.empty())
{
os << " Base classes: ";
for (size_t i = 0; i < bases.size(); ++i)
{
os << bases[i];
if (i + 1 < bases.size())
os << ", ";
}
os << "\n";
}
if (!derived.empty())
{
os << " Derived classes: ";
for (size_t i = 0; i < derived.size(); ++i)
{
os << derived[i];
if (i + 1 < derived.size())
os << ", ";
}
os << "\n";
}
}
void DebugPrinter::PrintCallGraph(SymbolId function_id, std::ostream& os)
{
const auto& incoming = table_.calls().callers(function_id);
const auto& outgoing = table_.calls().callees(function_id);
PrintSubHeader("Call graph for " + std::to_string(function_id), os);
if (incoming.empty() && outgoing.empty())
{
os << " (no call info)\n";
return;
}
if (!incoming.empty())
{
os << " Called by: ";
for (size_t i = 0; i < incoming.size(); ++i)
{
os << incoming[i].caller;
if (i + 1 < incoming.size())
os << ", ";
}
os << "\n";
}
if (!outgoing.empty())
{
os << " Calls: ";
for (size_t i = 0; i < outgoing.size(); ++i)
{
os << outgoing[i].callee;
if (i + 1 < outgoing.size())
os << ", ";
}
os << "\n";
}
}
void DebugPrinter::PrintAllReferences(std::ostream& os)
{
PrintHeader("References", os);
auto all_defs = table_.all_definitions();
for (const auto& ref : all_defs)
{
const auto& sym = ref.get();
PrintReferences(sym.id(), os);
}
}
void DebugPrinter::PrintAllInheritance(std::ostream& os)
{
PrintHeader("Inheritance Graph", os);
auto all_defs = table_.all_definitions();
for (const auto& ref : all_defs)
{
const auto& sym = ref.get();
if (sym.kind() == SymbolKind::Class)
{
PrintInheritance(sym.id(), os);
}
}
}
void DebugPrinter::PrintAllCalls(std::ostream& os)
{
PrintHeader("Call Graph", os);
auto all_defs = table_.all_definitions();
for (const auto& ref : all_defs)
{
const auto& sym = ref.get();
if (sym.kind() == SymbolKind::Function || sym.kind() == SymbolKind::Method)
{
PrintCallGraph(sym.id(), os);
}
}
}
void DebugPrinter::FindAndPrint(const std::string& name, std::ostream& os)
{
PrintHeader("Search: " + name, os);
@ -855,7 +694,6 @@ namespace lsp::language::symbol::debug
for (auto id : matches)
{
PrintSymbol(id, os);
PrintReferences(id, os);
}
}
@ -875,7 +713,6 @@ namespace lsp::language::symbol::debug
PrintHeader("Overview", os);
os << "Symbols: " << stats_.total_symbols << "\n";
os << "Scopes: " << stats_.total_scopes << "\n";
os << "Refs: " << stats_.total_references << "\n";
}
void DebugPrinter::PrintStatistics(std::ostream& os)
@ -888,13 +725,6 @@ namespace lsp::language::symbol::debug
PrintOverview(os);
PrintSymbolList(os);
PrintScopeHierarchy(os);
if (options_.show_references)
{
PrintAllReferences(os);
}
PrintAllInheritance(os);
PrintAllCalls(os);
PrintStatistics(os);
}

View File

@ -4,10 +4,13 @@
#include <string>
#include <unordered_map>
#include "../../src/language/symbol/table.hpp"
#include "../../src/protocol/protocol.hpp"
namespace lsp::language::symbol::debug
{
using SymbolKind = protocol::SymbolKind;
// ==================== 打印选项 ====================
struct PrintOptions
@ -16,7 +19,6 @@ namespace lsp::language::symbol::debug
bool show_location = true; // 显示位置信息
bool show_details = true; // 显示详细信息
bool show_children = true; // 显示子符号
bool show_references = false; // 显示引用列表
bool compact_mode = false; // 紧凑模式
int indent_size = 2; // 缩进大小
int max_depth = -1; // 最大深度 (-1 = 无限制)
@ -33,15 +35,10 @@ namespace lsp::language::symbol::debug
{
size_t total_symbols = 0;
size_t total_scopes = 0;
size_t total_references = 0;
std::unordered_map<SymbolKind, size_t> symbol_counts;
std::unordered_map<ScopeKind, size_t> scope_counts;
size_t symbols_with_refs = 0;
size_t max_references = 0;
SymbolId most_referenced = kInvalidSymbolId;
void Compute(const SymbolTable& table);
void Print(std::ostream& os, bool use_color = true) const;
};
@ -69,14 +66,6 @@ namespace lsp::language::symbol::debug
void PrintScopeTree(ScopeId id, std::ostream& os = std::cout, int depth = 0);
void PrintScopeHierarchy(std::ostream& os = std::cout);
// ===== 关系打印 =====
void PrintReferences(SymbolId id, std::ostream& os = std::cout);
void PrintInheritance(SymbolId class_id, std::ostream& os = std::cout);
void PrintCallGraph(SymbolId function_id, std::ostream& os = std::cout);
void PrintAllReferences(std::ostream& os = std::cout);
void PrintAllInheritance(std::ostream& os = std::cout);
void PrintAllCalls(std::ostream& os = std::cout);
// ===== 搜索和查询 =====
void FindAndPrint(const std::string& name, std::ostream& os = std::cout);
void FindAtLocation(const ast::Location& loc, std::ostream& os = std::cout);

View File

@ -115,9 +115,6 @@ struct Options
bool print_all = true;
bool print_definitions = false;
bool print_scopes = false;
bool print_references = false;
bool print_inheritance = false;
bool print_calls = false;
bool compact_mode = false;
bool statistics_only = false;
std::string search_symbol;
@ -135,9 +132,6 @@ void PrintUsage(const char* program_name)
std::cout << " -o, --output <file> Write output to file instead of stdout\n";
std::cout << " -d, --definitions Print only symbol definitions\n";
std::cout << " -s, --scopes Print only scope hierarchy\n";
std::cout << " -r, --references Print only references\n";
std::cout << " -i, --inheritance Print only inheritance graph\n";
std::cout << " -c, --calls Print only call graph\n";
std::cout << " -C, --compact Use compact output format\n";
std::cout << " -S, --stats Print statistics only\n";
std::cout << " -O, --overview Print overview only\n";
@ -193,21 +187,6 @@ bool ParseArguments(int argc, char* argv[], Options& options)
options.print_scopes = true;
any_specific_print = true;
}
else if (arg == "-r" || arg == "--references")
{
options.print_references = true;
any_specific_print = true;
}
else if (arg == "-i" || arg == "--inheritance")
{
options.print_inheritance = true;
any_specific_print = true;
}
else if (arg == "-c" || arg == "--calls")
{
options.print_calls = true;
any_specific_print = true;
}
else if (arg == "-C" || arg == "--compact")
{
options.compact_mode = true;
@ -550,15 +529,6 @@ private:
EnterScope(symbol::ScopeKind::kClass, node.span, class_id);
for (const auto& parent : node.parent_classes)
{
auto it = name_index_.find(ToLower(parent.name));
if (it != name_index_.end())
{
table_.AddInheritance(class_id, it->second);
}
}
for (auto& member : node.members)
{
if (member)
@ -1013,8 +983,6 @@ void AnalyzeFile(const Options& options)
if (options.no_color || !options.output_file.empty())
print_opts.use_color = false;
print_opts.show_references = options.print_references || options.verbose;
symbol::debug::DebugPrinter printer(table, print_opts);
if (!options.search_symbol.empty())
@ -1053,24 +1021,6 @@ void AnalyzeFile(const Options& options)
printed_anything = true;
}
if (options.print_references)
{
printer.PrintAllReferences(*out);
printed_anything = true;
}
if (options.print_inheritance)
{
printer.PrintAllInheritance(*out);
printed_anything = true;
}
if (options.print_calls)
{
printer.PrintAllCalls(*out);
printed_anything = true;
}
if (printed_anything)
{
*out << "\n";