tsl-devkit/lsp-server/src/provider/completion_item/resolve.cppm

372 lines
13 KiB
C++

module;
export module lsp.provider.completion_item.resolve;
import spdlog;
import std;
import lsp.protocol;
import lsp.codec.facade;
import lsp.provider.base.interface;
import lsp.language.symbol;
import lsp.language.ast;
import lsp.utils.string;
namespace transform = lsp::codec;
export namespace lsp::provider::completion_item
{
class Resolve : public AutoRegisterProvider<Resolve, IRequestProvider>
{
public:
static constexpr std::string_view kMethod = "completionItem/resolve";
static constexpr std::string_view kProviderName = "CompletionItemResolve";
Resolve() = default;
std::string ProvideResponse(const protocol::RequestMessage& request, ExecutionContext& execution_context) override;
};
}
namespace lsp::provider::completion_item
{
namespace
{
std::optional<std::string> GetStringField(const protocol::LSPObject& obj, const std::string& key)
{
auto it = obj.find(key);
if (it == obj.end() || !it->second.Is<protocol::string>())
{
return std::nullopt;
}
const auto& s = it->second.Get<protocol::string>();
return s;
}
std::optional<bool> GetBoolField(const protocol::LSPObject& obj, const std::string& key)
{
auto it = obj.find(key);
if (it == obj.end() || !it->second.Is<protocol::boolean>())
{
return std::nullopt;
}
return it->second.Get<protocol::boolean>();
}
std::string GetModuleName(const language::symbol::SymbolTable& table)
{
for (const auto& wrapper : table.all_definitions())
{
const auto& symbol = wrapper.get();
if (symbol.kind() == protocol::SymbolKind::Module)
{
return symbol.name();
}
}
return "";
}
std::optional<const language::symbol::Symbol*> FindClassSymbol(
const language::symbol::SymbolTable& table,
const std::string& class_name)
{
auto ids = table.FindSymbolsByName(class_name);
for (auto id : ids)
{
const auto* symbol = table.definition(id);
if (symbol && symbol->kind() == protocol::SymbolKind::Class)
{
return symbol;
}
}
return std::nullopt;
}
std::optional<language::symbol::ScopeId> FindScopeOwnedBy(
const language::symbol::SymbolTable& table,
language::symbol::ScopeKind kind,
language::symbol::SymbolId owner_id)
{
for (const auto& [scope_id, scope] : table.scopes().all_scopes())
{
if (scope.kind == kind && scope.owner && *scope.owner == owner_id)
{
return scope_id;
}
}
return std::nullopt;
}
std::vector<const language::symbol::Method*> CollectConstructors(
const language::symbol::SymbolTable& table,
language::symbol::SymbolId class_id)
{
std::vector<const language::symbol::Method*> result;
auto scope_id = FindScopeOwnedBy(table, language::symbol::ScopeKind::kClass, class_id);
if (!scope_id)
{
return result;
}
const auto* scope = table.scopes().scope(*scope_id);
if (!scope)
{
return result;
}
for (const auto& [_, ids] : scope->symbols)
{
for (auto id : ids)
{
const auto* member = table.definition(id);
if (!member || member->kind() != protocol::SymbolKind::Method)
{
continue;
}
const auto* method = member->As<language::symbol::Method>();
if (!method || method->method_kind != language::ast::MethodKind::kConstructor)
{
continue;
}
result.push_back(method);
}
}
return result;
}
const language::symbol::Method* PickBestConstructor(const std::vector<const language::symbol::Method*>& ctors)
{
const language::symbol::Method* best = nullptr;
std::size_t best_required = std::numeric_limits<std::size_t>::max();
std::size_t best_total = std::numeric_limits<std::size_t>::max();
for (const auto* ctor : ctors)
{
if (!ctor)
{
continue;
}
std::size_t required = 0;
for (const auto& p : ctor->parameters)
{
if (!p.default_value.has_value())
{
++required;
}
}
if (required < best_required || (required == best_required && ctor->parameters.size() < best_total))
{
best = ctor;
best_required = required;
best_total = ctor->parameters.size();
}
}
return best;
}
std::string BuildSignature(const std::vector<language::symbol::Parameter>& params, const std::optional<std::string>& return_type)
{
std::string detail = "(";
for (std::size_t i = 0; i < params.size(); ++i)
{
if (i > 0)
detail += ", ";
detail += params[i].name;
if (params[i].type && !params[i].type->empty())
detail += ": " + *params[i].type;
}
detail += ")";
if (return_type && !return_type->empty())
detail += ": " + *return_type;
return detail;
}
std::string BuildNewSnippet(const std::string& class_name, const language::symbol::Method* ctor)
{
std::string snippet = class_name;
snippet += "(";
if (ctor && !ctor->parameters.empty())
{
for (std::size_t i = 0; i < ctor->parameters.size(); ++i)
{
if (i > 0)
{
snippet += ", ";
}
const auto& p = ctor->parameters[i];
snippet += "${" + std::to_string(i + 1) + ":" + p.name + "}";
}
}
snippet += ")";
snippet += "$0";
return snippet;
}
std::string BuildCreateObjectSnippet(const std::string& class_name,
const language::symbol::Method* ctor,
bool has_open_quote,
char quote_char)
{
std::string snippet;
if (!has_open_quote)
{
snippet.push_back(quote_char);
}
snippet += class_name;
snippet.push_back(quote_char);
if (ctor && !ctor->parameters.empty())
{
for (std::size_t i = 0; i < ctor->parameters.size(); ++i)
{
snippet += ", ";
const auto& p = ctor->parameters[i];
snippet += "${" + std::to_string(i + 1) + ":" + p.name + "}";
}
}
snippet += "$0";
return snippet;
}
}
std::string Resolve::ProvideResponse(const protocol::RequestMessage& request,
ExecutionContext& execution_context)
{
if (!request.params.has_value())
{
spdlog::warn("{}: Missing params in request", GetProviderName());
return BuildErrorResponseMessage(request,
protocol::ErrorCodes::InvalidParams,
"Missing params");
}
protocol::CompletionItem item = transform::FromLSPAny.template operator()<protocol::CompletionItem>(request.params.value());
if (item.data && item.data->Is<protocol::LSPObject>())
{
const auto& obj = item.data->Get<protocol::LSPObject>();
auto ctx = GetStringField(obj, "ctx");
auto class_name = GetStringField(obj, "class");
auto unit_name = GetStringField(obj, "unit");
auto uri = GetStringField(obj, "uri");
if (ctx && class_name && !class_name->empty() && uri)
{
auto& hub = execution_context.GetManagerHub();
const language::symbol::SymbolTable* editing_table = hub.symbols().GetSymbolTable(*uri);
auto workspace_tables = hub.symbols().GetWorkspaceSymbolTables();
auto system_tables = hub.symbols().GetSystemSymbolTables();
auto try_find = [&](const language::symbol::SymbolTable& table) -> std::optional<const language::symbol::SymbolTable*> {
if (unit_name && !unit_name->empty())
{
auto module = GetModuleName(table);
if (module.empty() || !utils::IEquals(module, *unit_name))
{
return std::nullopt;
}
}
if (FindClassSymbol(table, *class_name))
{
return &table;
}
return std::nullopt;
};
const language::symbol::SymbolTable* table_for_class = nullptr;
if (editing_table)
{
if (auto t = try_find(*editing_table))
table_for_class = *t;
}
if (!table_for_class)
{
for (const auto* t : workspace_tables)
{
if (!t)
continue;
if (auto found = try_find(*t))
{
table_for_class = *found;
break;
}
}
}
if (!table_for_class)
{
for (const auto* t : system_tables)
{
if (!t)
continue;
if (auto found = try_find(*t))
{
table_for_class = *found;
break;
}
}
}
const language::symbol::Method* best_ctor = nullptr;
if (table_for_class)
{
if (auto cls_sym = FindClassSymbol(*table_for_class, *class_name))
{
auto ctors = CollectConstructors(*table_for_class, (*cls_sym)->id());
best_ctor = PickBestConstructor(ctors);
if (!item.labelDetails)
{
item.labelDetails = protocol::CompletionItemLabelDetails{};
}
item.labelDetails->detail = best_ctor ? BuildSignature(best_ctor->parameters, best_ctor->return_type) : "";
}
}
if (*ctx == "new")
{
item.insertText = BuildNewSnippet(*class_name, best_ctor);
item.insertTextFormat = protocol::InsertTextFormat::Snippet;
item.kind = protocol::CompletionItemKind::Constructor;
}
else if (*ctx == "createobject")
{
bool has_open_quote = GetBoolField(obj, "has_open_quote").value_or(false);
char quote_char = '"';
if (auto quote = GetStringField(obj, "quote"); quote && !quote->empty())
{
quote_char = (*quote)[0];
}
item.insertText = BuildCreateObjectSnippet(*class_name, best_ctor, has_open_quote, quote_char);
item.insertTextFormat = protocol::InsertTextFormat::Snippet;
item.kind = protocol::CompletionItemKind::Constructor;
}
}
}
protocol::ResponseMessage response;
response.id = request.id;
response.result = transform::ToLSPAny(item);
if (auto json = transform::Serialize(response))
return *json;
return BuildErrorResponseMessage(request,
protocol::ErrorCodes::InternalError,
"Failed to serialize completion resolve response");
}
}