tsl-devkit/lsp-server/src/language/semantic/name_resolver.cppm

548 lines
16 KiB
C++

module;
export module lsp.language.semantic:name_resolver;
import std;
import :type_system;
import lsp.language.ast;
import lsp.language.symbol;
export namespace lsp::language::semantic
{
struct NameResolutionResult
{
symbol::SymbolId symbol_id = symbol::kInvalidSymbolId;
bool is_ambiguous = false;
std::vector<symbol::SymbolId> candidates;
bool IsResolved() const
{
return symbol_id != symbol::kInvalidSymbolId;
}
static NameResolutionResult Success(symbol::SymbolId id)
{
return { id, false, { id } };
}
static NameResolutionResult Ambiguous(std::vector<symbol::SymbolId> symbols)
{
return {
symbols.empty() ? symbol::kInvalidSymbolId : symbols[0],
true,
std::move(symbols)
};
}
static NameResolutionResult NotFound()
{
return { symbol::kInvalidSymbolId, false, {} };
}
};
struct OverloadCandidate
{
symbol::SymbolId symbol_id;
int match_score = 0;
std::vector<TypeCompatibility> arg_conversions;
bool operator<(const OverloadCandidate& other) const
{
return match_score > other.match_score;
}
};
class NameResolver
{
public:
explicit NameResolver(const symbol::SymbolTable& symbol_table,
const TypeSystem& type_system) : symbol_table_(symbol_table),
type_system_(type_system) {}
NameResolutionResult ResolveName(
const std::string& name,
symbol::ScopeId scope_id,
bool search_parent = true) const;
NameResolutionResult ResolveNameAtLocation(
const std::string& name,
const ast::Location& location) const;
NameResolutionResult ResolveMemberAccess(
symbol::SymbolId object_symbol_id,
const std::string& member_name) const;
NameResolutionResult ResolveClassMember(
symbol::SymbolId class_id,
const std::string& member_name,
bool static_only = false) const;
NameResolutionResult ResolveFunctionCall(
const std::string& function_name,
const std::vector<std::shared_ptr<Type>>& arg_types,
symbol::ScopeId scope_id) const;
NameResolutionResult ResolveMethodCall(
symbol::SymbolId object_symbol_id,
const std::string& method_name,
const std::vector<std::shared_ptr<Type>>& arg_types) const;
NameResolutionResult ResolveQualifiedName(
const std::string& qualifier,
const std::string& name,
symbol::ScopeId scope_id) const;
std::optional<symbol::ScopeId> GetSymbolScope(
[[maybe_unused]] symbol::SymbolId symbol_id) const;
bool IsSymbolVisibleInScope(
[[maybe_unused]] symbol::SymbolId symbol_id,
[[maybe_unused]] symbol::ScopeId scope_id) const;
private:
const symbol::SymbolTable& symbol_table_;
const TypeSystem& type_system_;
std::vector<symbol::SymbolId> SearchScopeChain(
const std::string& name,
symbol::ScopeId start_scope) const;
std::optional<symbol::ScopeId> FindScopeOwnedBy(symbol::SymbolId owner) const;
NameResolutionResult SelectBestOverload(
const std::vector<OverloadCandidate>& candidates) const;
OverloadCandidate CalculateOverloadScore(
symbol::SymbolId candidate_id,
const std::vector<std::shared_ptr<Type>>& arg_types) const;
std::vector<std::shared_ptr<Type>> GetParameterTypes(
symbol::SymbolId symbol_id) const;
std::optional<symbol::SymbolId> GetOwnerClassId(
[[maybe_unused]] symbol::SymbolId symbol_id) const;
bool CheckMemberAccessibility(
[[maybe_unused]] const symbol::Symbol& member,
[[maybe_unused]] symbol::ScopeId access_scope) const;
};
}
namespace lsp::language::semantic
{
NameResolutionResult NameResolver::ResolveName(
const std::string& name,
symbol::ScopeId scope_id,
bool search_parent) const
{
if (!search_parent)
{
auto symbols = symbol_table_.FindSymbolsByName(name);
if (symbols.empty())
{
return NameResolutionResult::NotFound();
}
if (symbols.size() == 1)
{
return NameResolutionResult::Success(symbols[0]);
}
return NameResolutionResult::Ambiguous(std::move(symbols));
}
auto candidates = SearchScopeChain(name, scope_id);
if (candidates.empty())
{
return NameResolutionResult::NotFound();
}
if (candidates.size() == 1)
{
return NameResolutionResult::Success(candidates[0]);
}
return NameResolutionResult::Ambiguous(std::move(candidates));
}
NameResolutionResult NameResolver::ResolveNameAtLocation(
const std::string& name,
const ast::Location& location) const
{
auto scope_id = symbol_table_.scopes().FindScopeAt(location);
if (!scope_id)
{
auto global_scope = symbol_table_.scopes().global_scope();
if (global_scope != symbol::kInvalidScopeId)
{
scope_id = global_scope;
}
}
if (!scope_id)
{
return NameResolutionResult::NotFound();
}
return ResolveName(name, *scope_id, true);
}
NameResolutionResult NameResolver::ResolveMemberAccess(
symbol::SymbolId object_symbol_id,
const std::string& member_name) const
{
auto object_type = type_system_.GetSymbolType(object_symbol_id);
if (!object_type || object_type->kind() != TypeKind::kClass)
{
return NameResolutionResult::NotFound();
}
const auto* class_type = object_type->As<ClassType>();
return ResolveClassMember(class_type->class_id(), member_name, false);
}
NameResolutionResult NameResolver::ResolveClassMember(
symbol::SymbolId class_id,
const std::string& member_name,
bool static_only) const
{
const auto* class_symbol = symbol_table_.definition(class_id);
if (!class_symbol || !class_symbol->Is<symbol::Class>())
{
return NameResolutionResult::NotFound();
}
const auto* class_data = class_symbol->As<symbol::Class>();
std::vector<symbol::SymbolId> candidates;
for (auto member_id : class_data->members)
{
const auto* member = symbol_table_.definition(member_id);
if (!member)
continue;
if (member->name() != member_name)
continue;
if (static_only)
{
if (member->Is<symbol::Method>())
{
const auto* method = member->As<symbol::Method>();
if (!method->is_static)
continue;
}
else if (member->Is<symbol::Field>())
{
const auto* field = member->As<symbol::Field>();
if (!field->is_static)
continue;
}
}
candidates.push_back(member_id);
}
// TODO: 在基类中查找(需要继承图支持)
if (candidates.empty())
{
return NameResolutionResult::NotFound();
}
if (candidates.size() == 1)
{
return NameResolutionResult::Success(candidates[0]);
}
return NameResolutionResult::Ambiguous(std::move(candidates));
}
NameResolutionResult NameResolver::ResolveFunctionCall(
const std::string& function_name,
const std::vector<std::shared_ptr<Type>>& arg_types,
symbol::ScopeId scope_id) const
{
auto candidates_ids = SearchScopeChain(function_name, scope_id);
if (candidates_ids.empty())
{
return NameResolutionResult::NotFound();
}
if (candidates_ids.size() == 1)
{
return NameResolutionResult::Success(candidates_ids[0]);
}
std::vector<OverloadCandidate> candidates;
for (auto candidate_id : candidates_ids)
{
auto candidate = CalculateOverloadScore(candidate_id, arg_types);
if (candidate.match_score >= 0)
{
candidates.push_back(candidate);
}
}
return SelectBestOverload(candidates);
}
NameResolutionResult NameResolver::ResolveMethodCall(
symbol::SymbolId object_symbol_id,
const std::string& method_name,
const std::vector<std::shared_ptr<Type>>& arg_types) const
{
auto resolution = ResolveMemberAccess(object_symbol_id, method_name);
if (!resolution.IsResolved())
{
return resolution;
}
if (!resolution.is_ambiguous)
{
return resolution;
}
std::vector<OverloadCandidate> candidates;
for (auto candidate_id : resolution.candidates)
{
auto candidate = CalculateOverloadScore(candidate_id, arg_types);
if (candidate.match_score >= 0)
{
candidates.push_back(candidate);
}
}
return SelectBestOverload(candidates);
}
NameResolutionResult NameResolver::ResolveQualifiedName(
const std::string& qualifier,
const std::string& name,
symbol::ScopeId scope_id) const
{
auto qualifier_result = ResolveName(qualifier, scope_id, true);
if (!qualifier_result.IsResolved())
{
return qualifier_result;
}
auto qualifier_symbol = symbol_table_.definition(qualifier_result.symbol_id);
if (!qualifier_symbol)
{
return NameResolutionResult::NotFound();
}
if (qualifier_symbol->Is<symbol::Class>())
{
return ResolveClassMember(qualifier_result.symbol_id, name, true);
}
if (qualifier_symbol->Is<symbol::Unit>())
{
auto scope_id = FindScopeOwnedBy(qualifier_result.symbol_id);
if (!scope_id)
{
return NameResolutionResult::NotFound();
}
auto symbol_id = symbol_table_.scopes().FindSymbolInScope(*scope_id, name);
if (!symbol_id)
{
return NameResolutionResult::NotFound();
}
return NameResolutionResult::Success(*symbol_id);
}
return NameResolutionResult::NotFound();
}
std::optional<symbol::ScopeId> NameResolver::GetSymbolScope(
[[maybe_unused]] symbol::SymbolId symbol_id) const
{
// TODO: 实现符号到作用域的映射
return std::nullopt;
}
bool NameResolver::IsSymbolVisibleInScope(
[[maybe_unused]] symbol::SymbolId symbol_id,
[[maybe_unused]] symbol::ScopeId scope_id) const
{
// TODO: 实现可见性检查
return true;
}
std::vector<symbol::SymbolId> NameResolver::SearchScopeChain(
const std::string& name,
symbol::ScopeId start_scope) const
{
std::vector<symbol::SymbolId> results;
symbol::ScopeId current_scope = start_scope;
while (current_scope != symbol::kInvalidScopeId)
{
auto scope_symbols = symbol_table_.scopes().FindSymbols(current_scope, name);
results.insert(results.end(), scope_symbols.begin(), scope_symbols.end());
if (!results.empty())
{
break;
}
auto parent = symbol_table_.scopes().GetParent(current_scope);
if (!parent)
{
break;
}
current_scope = *parent;
}
return results;
}
std::optional<symbol::ScopeId> NameResolver::FindScopeOwnedBy(symbol::SymbolId owner) const
{
const auto& scopes = symbol_table_.scopes().all_scopes();
for (const auto& [id, scope] : scopes)
{
if (scope.owner && *scope.owner == owner)
{
return id;
}
}
return std::nullopt;
}
NameResolutionResult NameResolver::SelectBestOverload(
const std::vector<OverloadCandidate>& candidates) const
{
if (candidates.empty())
{
return NameResolutionResult::NotFound();
}
auto sorted = candidates;
std::sort(sorted.begin(), sorted.end());
if (sorted.size() > 1 && sorted[0].match_score == sorted[1].match_score)
{
std::vector<symbol::SymbolId> ambiguous_ids;
for (const auto& candidate : sorted)
{
if (candidate.match_score == sorted[0].match_score)
{
ambiguous_ids.push_back(candidate.symbol_id);
}
}
return NameResolutionResult::Ambiguous(std::move(ambiguous_ids));
}
return NameResolutionResult::Success(sorted[0].symbol_id);
}
OverloadCandidate NameResolver::CalculateOverloadScore(
symbol::SymbolId candidate_id,
const std::vector<std::shared_ptr<Type>>& arg_types) const
{
OverloadCandidate result;
result.symbol_id = candidate_id;
auto param_types = GetParameterTypes(candidate_id);
if (param_types.size() != arg_types.size())
{
result.match_score = -1;
return result;
}
int total_score = 0;
for (std::size_t i = 0; i < arg_types.size(); ++i)
{
auto compat = type_system_.CheckCompatibility(*arg_types[i], *param_types[i]);
if (!compat.is_compatible)
{
result.match_score = -1;
return result;
}
result.arg_conversions.push_back(compat);
if (compat.conversion_cost == 0)
{
total_score += 100;
}
else
{
total_score += std::max(0, 50 - compat.conversion_cost * 10);
}
if (compat.requires_cast)
{
total_score -= 20;
}
}
result.match_score = total_score;
return result;
}
std::vector<std::shared_ptr<Type>> NameResolver::GetParameterTypes(
symbol::SymbolId symbol_id) const
{
const auto* symbol = symbol_table_.definition(symbol_id);
if (!symbol)
{
return {};
}
std::vector<std::shared_ptr<Type>> param_types;
if (symbol->Is<symbol::Function>())
{
const auto* func = symbol->As<symbol::Function>();
for (const auto& param : func->parameters)
{
if (param.type)
{
auto type = type_system_.GetTypeByName(*param.type);
param_types.push_back(type);
}
else
{
param_types.push_back(type_system_.GetUnknownType());
}
}
}
else if (symbol->Is<symbol::Method>())
{
const auto* method = symbol->As<symbol::Method>();
for (const auto& param : method->parameters)
{
if (param.type)
{
auto type = type_system_.GetTypeByName(*param.type);
param_types.push_back(type);
}
else
{
param_types.push_back(type_system_.GetUnknownType());
}
}
}
return param_types;
}
std::optional<symbol::SymbolId> NameResolver::GetOwnerClassId(
[[maybe_unused]] symbol::SymbolId symbol_id) const
{
// TODO: 实现符号到所属类的映射
return std::nullopt;
}
bool NameResolver::CheckMemberAccessibility(
[[maybe_unused]] const symbol::Symbol& member,
[[maybe_unused]] symbol::ScopeId access_scope) const
{
// TODO: 实现访问权限检查
return true;
}
}