diff --git a/lsp-server/CMakeLists.txt b/lsp-server/CMakeLists.txt index 1537e96..a4ae009 100644 --- a/lsp-server/CMakeLists.txt +++ b/lsp-server/CMakeLists.txt @@ -77,9 +77,9 @@ set(SOURCES src/main.cpp src/utils/args_parser.cpp src/language/tsl_keywords.cpp - src/lsp/dispacther.cpp - src/lsp/server.cpp - src/lsp/request_scheduler.cpp + src/core/dispacther.cpp + src/core/server.cpp + src/scheduler/request_scheduler.cpp src/provider/base/provider_registry.cpp src/provider/base/provider_interface.cpp src/provider/initialize/initialize_provider.cpp @@ -88,6 +88,8 @@ set(SOURCES src/provider/text_document/did_change_provider.cpp src/provider/text_document/completion_provider.cpp src/provider/shutdown/shutdown_provider.cpp + src/provider/exit/exit_provider.cpp + src/provider/cancel_request/cancel_request_provider.cpp src/provider/trace/set_trace_provider.cpp) add_executable(${PROJECT_NAME} ${SOURCES}) diff --git a/lsp-server/src/core/dispacther.cpp b/lsp-server/src/core/dispacther.cpp new file mode 100644 index 0000000..e193b0f --- /dev/null +++ b/lsp-server/src/core/dispacther.cpp @@ -0,0 +1,193 @@ +#include +#include +#include "./dispacther.hpp" +#include "../protocol/transform/facade.hpp" + +namespace lsp::core +{ + void RequestDispatcher::SetRequestScheduler(scheduler::RequestScheduler* scheduler) + { + scheduler_ = scheduler; + spdlog::debug("RequestScheduler set in dispatcher"); + } + + void RequestDispatcher::RegisterRequestProvider(std::shared_ptr provider) + { + std::unique_lock lock(providers_mutex_); + std::string method = provider->GetMethod(); + + providers_[method] = provider; + spdlog::info("Registered request provider '{}' for method: {}", provider->GetProviderName(), method); + } + + void RequestDispatcher::RegisterNotificationProvider(std::shared_ptr provider) + { + std::unique_lock lock(notification_providers_mutex_); + std::string method = provider->GetMethod(); + + notification_providers_[method] = provider; + spdlog::info("Registered notification provider '{}' for method: {}", provider->GetProviderName(), method); + } + + void RequestDispatcher::RegisterLifecycleCallback(LifecycleCallback callback) + { + std::lock_guard lock(callbacks_mutex_); + lifecycle_callbacks_.push_back(std::move(callback)); + spdlog::debug("Registered lifecycle callback"); + } + + std::string RequestDispatcher::Dispatch(const protocol::RequestMessage& request) + { + providers::ProviderContext context( + scheduler_, + [this](ServerLifecycleEvent event) + { + NotifyAllLifecycleListeners(event); + }); + + std::shared_lock lock(providers_mutex_); + auto it = providers_.find(request.method); + if (it != providers_.end()) + { + auto provider = it->second; + lock.unlock(); + try + { + spdlog::debug("Dispatching request '{}' to provider '{}'", request.method, provider->GetProviderName()); + return provider->ProvideResponse(request, context); + } + catch (const std::exception& e) + { + spdlog::error("Provider error for method {}: {}", request.method, e.what()); + return BuildErrorResponseMessage(request, protocol::ErrorCode::kInternalError, e.what()); + } + } + return HandleUnknownRequest(request); + } + + void RequestDispatcher::Dispatch(const protocol::NotificationMessage& notification) + { + // 创建 provider context + providers::ProviderContext context( + scheduler_, + [this](ServerLifecycleEvent event) + { + NotifyAllLifecycleListeners(event); + }); + + std::shared_lock lock(notification_providers_mutex_); + + // 先尝试精确匹配 + auto it = notification_providers_.find(notification.method); + if (it != notification_providers_.end()) + { + auto provider = it->second; + lock.unlock(); + + try + { + spdlog::debug("Dispatching notification '{}' to provider '{}'", notification.method, provider->GetProviderName()); + provider->HandleNotification(notification, context); + return; + } + catch (const std::exception& e) + { + spdlog::error("Notification provider '{}' threw exception for method '{}': {}", provider->GetProviderName(), notification.method, e.what()); + return; + } + } + HandleUnknownNotification(notification); + } + + bool RequestDispatcher::SupportsRequest(const std::string& method) const + { + std::shared_lock lock(providers_mutex_); + return providers_.find(method) != providers_.end(); + } + + std::vector RequestDispatcher::GetSupportedRequests() const + { + std::shared_lock lock(providers_mutex_); + std::vector methods; + methods.reserve(providers_.size()); + for (const auto& [method, _] : providers_) + methods.push_back(method); + return methods; + } + + std::vector RequestDispatcher::GetSupportedNotifications() const + { + std::shared_lock lock(notification_providers_mutex_); + std::vector methods; + methods.reserve(notification_providers_.size()); + for (const auto& [method, provider] : notification_providers_) + methods.push_back(method); + return methods; + } + + std::vector RequestDispatcher::GetAllSupportedMethods() const + { + auto requests = GetSupportedRequests(); + auto notifications = GetSupportedNotifications(); + + // 合并两个列表 + requests.insert(requests.end(), notifications.begin(), notifications.end()); + return requests; + } + + std::string RequestDispatcher::BuildErrorResponseMessage(const protocol::RequestMessage& request, protocol::ErrorCode code, const std::string& message) + { + return providers::IRequestProvider::BuildErrorResponseMessage(request, code, message); + } + + void RequestDispatcher::NotifyAllLifecycleListeners(ServerLifecycleEvent event) + { + std::lock_guard lock(callbacks_mutex_); + + std::string event_name; + switch (event) + { + case ServerLifecycleEvent::kInitializing: + event_name = "Initializing"; + break; + case ServerLifecycleEvent::kInitialized: + event_name = "Initialized"; + break; + case ServerLifecycleEvent::kInitializeFailed: + event_name = "InitializeFailed"; + break; + case ServerLifecycleEvent::kShuttingDown: + event_name = "ShuttingDown"; + break; + case ServerLifecycleEvent::kShutdown: + event_name = "Shutdown"; + break; + } + + spdlog::info("Lifecycle event: {}", event_name); + + for (const auto& callback : lifecycle_callbacks_) + { + try + { + callback(event); + } + catch (const std::exception& e) + { + spdlog::error("Lifecycle callback error: {}", e.what()); + } + } + } + + std::string RequestDispatcher::HandleUnknownRequest(const protocol::RequestMessage& request) + { + return BuildErrorResponseMessage(request, protocol::ErrorCode::kMethodNotFound, "Method not found: " + request.method); + } + + void RequestDispatcher::HandleUnknownNotification(const protocol::NotificationMessage& notification) + { + spdlog::debug("No handler found for notification: {}", notification.method); + // 通知没有响应,所以只记录日志 + } + +} diff --git a/lsp-server/src/core/dispacther.hpp b/lsp-server/src/core/dispacther.hpp new file mode 100644 index 0000000..c56f903 --- /dev/null +++ b/lsp-server/src/core/dispacther.hpp @@ -0,0 +1,56 @@ +#pragma once +#include +#include +#include +#include "../protocol/protocol.hpp" +#include "../provider/base/provider_interface.hpp" + +namespace lsp::core +{ + + using ServerLifecycleEvent = providers::ServerLifecycleEvent; + using LifecycleCallback = providers::LifecycleCallback; + + class RequestDispatcher + { + public: + RequestDispatcher() = default; + ~RequestDispatcher() = default; + + void SetRequestScheduler(scheduler::RequestScheduler* scheduler); + + void RegisterRequestProvider(std::shared_ptr provider); + void RegisterNotificationProvider(std::shared_ptr provider); + + void RegisterLifecycleCallback(LifecycleCallback callback); + + std::string Dispatch(const protocol::RequestMessage& request); + void Dispatch(const protocol::NotificationMessage& notification); + + bool SupportsRequest(const std::string& method) const; + bool SupportsNotification(const std::string& method) const; + std::vector GetSupportedRequests() const; + std::vector GetSupportedNotifications() const; + std::vector GetAllSupportedMethods() const; + + std::string BuildErrorResponseMessage(const protocol::RequestMessage& request, protocol::ErrorCode code, const std::string& message); + + private: + void NotifyAllLifecycleListeners(ServerLifecycleEvent event); + std::string HandleUnknownRequest(const protocol::RequestMessage& request); + void HandleUnknownNotification(const protocol::NotificationMessage& notification); + + private: + mutable std::shared_mutex providers_mutex_; + std::unordered_map> providers_; + + mutable std::shared_mutex notification_providers_mutex_; + std::unordered_map> notification_providers_; + + std::mutex callbacks_mutex_; + std::vector lifecycle_callbacks_; + + scheduler::RequestScheduler* scheduler_ = nullptr; + }; + +} diff --git a/lsp-server/src/lsp/server.cpp b/lsp-server/src/core/server.cpp similarity index 57% rename from lsp-server/src/lsp/server.cpp rename to lsp-server/src/core/server.cpp index 14acae7..84cf7d6 100644 --- a/lsp-server/src/lsp/server.cpp +++ b/lsp-server/src/core/server.cpp @@ -1,376 +1,303 @@ -#include "request_scheduler.hpp" -#include -#include -#include -#ifdef _WIN32 -#include -#include -#endif -#include "../provider/base/provider_registry.hpp" -#include "../protocol/transform/facade.hpp" -#include "./server.hpp" - -namespace lsp -{ - LspServer::LspServer(size_t concurrency) : - scheduler_(4) - { - spdlog::info("Initializing LSP server with {} worker threads", concurrency); - - dispatcher_.RegisterLifecycleCallback( - [this](providers::ServerLifecycleEvent event) { - OnLifecycleEvent(event); - }); - providers::RegisterAllProviders(dispatcher_); - - scheduler_.SetResponseCallback([this](const std::string& response) { - SendResponse(response); - }); - - spdlog::debug("LSP server initialized with {} providers.", dispatcher_.GetSupportedMethods().size()); - } - - LspServer::~LspServer() - { - is_shutting_down_ = true; - spdlog::info("LSP server shutting down..."); - } - - void LspServer::Run() - { - spdlog::info("LSP server starting main loop..."); - - // 设置二进制模式 -#ifdef _WIN32 - _setmode(_fileno(stdout), _O_BINARY); - _setmode(_fileno(stdin), _O_BINARY); -#endif - - while (!is_shutting_down_) - { - try - { - std::optional message = ReadMessage(); - if (!message) - { - spdlog::debug("No message received, continuing..."); - continue; - } - - HandleMessage(*message); - } - catch (const std::exception& e) - { - spdlog::error("Error in main loop: {}", e.what()); - } - } - spdlog::info("LSP server main loop ended"); - } - - std::optional LspServer::ReadMessage() - { - std::string line; - size_t content_length = 0; - - // 读取 LSP Header - while (std::getline(std::cin, line)) - { - // 去掉尾部 \r - if (!line.empty() && line.back() == '\r') - { - line.pop_back(); - } - - if (line.empty()) - { - break; // 空行表示 header 结束 - } - - if (line.find("Content-Length:") == 0) - { - std::string length_str = line.substr(15); // 跳过 "Content-Length:" - size_t start = length_str.find_first_not_of(" "); - if (start != std::string::npos) - { - length_str = length_str.substr(start); - try - { - content_length = std::stoul(length_str); - spdlog::trace("Content-Length: {}", content_length); - } - catch (const std::exception& e) - { - spdlog::error("Failed to parse Content-Length: {}", e.what()); - return std::nullopt; - } - } - } - } - - if (content_length == 0) - { - spdlog::debug("No Content-Length found in header"); - return std::nullopt; - } - - // 读取内容体 - std::string body(content_length, '\0'); - std::cin.read(&body[0], content_length); - - if (std::cin.gcount() != static_cast(content_length)) - { - spdlog::error("Read incomplete message body, expected: {}, got: {}", content_length, std::cin.gcount()); - return std::nullopt; - } - - spdlog::trace("Received message: {}", body); - return body; - } - - void LspServer::HandleMessage(const std::string& raw_message) - { - try - { - // 解析 JSON 判断消息类型 - glz::json_t doc; - auto error = glz::read_json(doc, raw_message); - if (error) - { - spdlog::error("Failed to parse message as JSON: {}", glz::format_error(error, raw_message)); - return; - } - - auto& obj = doc.get(); - bool has_method = obj.contains("method"); - - if (has_method) - { - bool has_id = obj.contains("id"); - - if (has_id) - { - // RequestMessage - protocol::RequestMessage request; - error = glz::read_json(request, raw_message); - if (error) - { - spdlog::error("Failed to parse request: {}", glz::format_error(error, raw_message)); - return; - } - HandleRequest(request); - } - else - { - // NotificationMessage - protocol::NotificationMessage notification; - error = glz::read_json(notification, raw_message); - if (error) - { - spdlog::error("Failed to parse notification: {}", glz::format_error(error, raw_message)); - return; - } - HandleNotification(notification); - } - } - else if (obj.contains("id") && (obj.contains("result") || obj.contains("error"))) - { - // ResponseMessage - protocol::ResponseMessage response; - error = glz::read_json(response, raw_message); - if (error) - { - spdlog::error("Failed to parse response: {}", - glz::format_error(error, raw_message)); - return; - } - HandleResponse(response); - } - else - { - spdlog::error("Unknown message type"); - } - } - catch (const std::exception& e) - { - spdlog::error("Failed to handle message: {}", e.what()); - } - } - - void LspServer::HandleRequest(const protocol::RequestMessage& request) - { - std::string request_id = transform::debug::GetIdString(request.id); - spdlog::debug("Processing request - id: {}, method: {}", request_id, request.method); - - // 检查是否可以处理请求 - if (!CanProcessRequest(request.method)) - { - protocol::ErrorCode error_code; - std::string message; - - if (!is_initialized_) - { - error_code = protocol::ErrorCode::kServerNotInitialized; - message = "Server not initialized"; - } - else if (is_shutting_down_) - { - error_code = protocol::ErrorCode::kInvalidRequest; - message = "Server is shutting down, only 'exit' is allowed"; - } - else - { - error_code = protocol::ErrorCode::kInternalError; - message = "Request not allowed in current state"; - } - - SendResponse(dispatcher_.BuildErrorResponseMessage(request, error_code, message)); - return; - } - - // 决定同步还是异步处理 - if (RequiresSyncProcessing(request.method)) - { - SendResponse(dispatcher_.Dispatch(request)); - } - else - { - // 异步处理 - scheduler_.Submit(request_id, [this, request]() -> std::optional { - if (is_shutting_down_) - { - spdlog::debug("Skipping request {} due to shutdown", request.method); - return std::nullopt; - } - - try - { - return dispatcher_.Dispatch(request); - } - catch (const std::exception& e) - { - spdlog::error("Request processing failed: {}", e.what()); - return dispatcher_.BuildErrorResponseMessage( request, protocol::ErrorCode::kInternalError, e.what()); - } - }); - } - spdlog::debug("Processing request method: {}", request.method); - } - - void LspServer::HandleNotification(const protocol::NotificationMessage& notification) - { - spdlog::debug("Processing notification - method: {}", notification.method); - - // 处理 $/ 开头的通知 - if (notification.method.starts_with("$/")) - { - if (notification.method == "$/cancelRequest") - { - HandleCancelRequest(notification); - } - else - { - spdlog::debug("Ignoring protocol-specific notification: {}", notification.method); - } - return; - } - - // 处理标准通知 - if (notification.method == "initialized") - { - spdlog::info("Client acknowledged initialization"); - } - else if (notification.method == "exit") - { - spdlog::info("Exit notification received"); - is_shutting_down_ = true; - // 给一点时间让任务完成 - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - std::exit(0); - } - } - - void LspServer::HandleResponse(const protocol::ResponseMessage& response) - { - std::string id_str = transform::debug::GetIdString(response.id); - spdlog::debug("Received response - id: {}", id_str); - } - - void LspServer::OnLifecycleEvent(ServerLifecycleEvent event) - { - switch (event) - { - case ServerLifecycleEvent::kInitializing: - spdlog::info("Server initializing..."); - break; - - case ServerLifecycleEvent::kInitialized: - is_initialized_ = true; - spdlog::info("Server initialized successfully"); - break; - - case ServerLifecycleEvent::kInitializeFailed: - is_initialized_ = false; - spdlog::error("Server initialization failed"); - break; - - case ServerLifecycleEvent::kShuttingDown: - is_shutting_down_ = true; - spdlog::info("Server entering shutdown state"); - break; - - case ServerLifecycleEvent::kShutdown: - is_shutting_down_ = true; - spdlog::info("Server shutdown complete"); - break; - } - } - - bool LspServer::RequiresSyncProcessing(const std::string& method) const - { - static const std::unordered_set sync_methods = { - "initialize", // 必须同步完成 - "shutdown" // 必须同步完成 - }; - - return sync_methods.count(method) > 0; - } - - bool LspServer::CanProcessRequest(const std::string& method) const - { - // 未初始化状态 - if (!is_initialized_) - { - return method == "initialize" || method == "exit"; - } - - // 关闭中状态 - if (is_shutting_down_) - { - return method == "exit"; - } - - // 正常状态 - 接受所有请求 - return true; - } - - void LspServer::HandleCancelRequest(const protocol::NotificationMessage& notification) - { - spdlog::info("Handle cancel request - method: {}", notification.method); - } - - void LspServer::SendResponse(const std::string& response) - { - std::lock_guard lock(output_mutex_); - - size_t byte_length = response.length(); - std::string header = "Content-Length: " + std::to_string(byte_length) + "\r\n\r\n"; - - // 发送 header 和 body - std::cout.write(header.c_str(), header.length()); - std::cout.write(response.c_str(), response.length()); - std::cout.flush(); - - spdlog::trace("Response sent - length: {}", byte_length); - spdlog::trace("Response sent - body: {}", response); - } -} +#include +#include +#include +#ifdef _WIN32 +#include +#include +#endif +#include "../provider/base/provider_registry.hpp" +#include "../protocol/transform/facade.hpp" +#include "../scheduler/request_scheduler.hpp" +#include "./server.hpp" + +namespace lsp::core +{ + LspServer::LspServer(size_t concurrency) : + scheduler_(4) + { + spdlog::info("Initializing LSP server with {} worker threads", concurrency); + + dispatcher_.SetRequestScheduler(&scheduler_); + + dispatcher_.RegisterLifecycleCallback( + [this](providers::ServerLifecycleEvent event) { + OnLifecycleEvent(event); + }); + providers::RegisterAllProviders(dispatcher_); + + scheduler_.SetResponseCallback([this](const std::string& response) { + SendResponse(response); + }); + + spdlog::debug("LSP server initialized with {} providers.", dispatcher_.GetAllSupportedMethods().size()); + } + + LspServer::~LspServer() + { + is_shutting_down_ = true; + spdlog::info("LSP server shutting down..."); + } + + void LspServer::Run() + { + spdlog::info("LSP server starting main loop..."); + + // 设置二进制模式 +#ifdef _WIN32 + _setmode(_fileno(stdout), _O_BINARY); + _setmode(_fileno(stdin), _O_BINARY); +#endif + + while (!is_shutting_down_) + { + try + { + std::optional message = ReadMessage(); + if (!message) + { + if (std::cin.eof()) + { + spdlog::info("End of input stream, exiting main loop"); + break; // EOF + } + spdlog::debug("No message received, continuing..."); + continue; + } + + HandleMessage(*message); + } + catch (const std::exception& e) + { + spdlog::error("Error in main loop: {}", e.what()); + } + } + spdlog::info("LSP server main loop ended"); + } + + std::optional LspServer::ReadMessage() + { + std::string line; + size_t content_length = 0; + + // 读取 LSP Header + while (std::getline(std::cin, line)) + { + // 去掉尾部 \r + if (!line.empty() && line.back() == '\r') + { + line.pop_back(); + } + + if (line.empty()) + { + break; // 空行表示 header 结束 + } + + if (line.find("Content-Length:") == 0) + { + std::string length_str = line.substr(15); // 跳过 "Content-Length:" + size_t start = length_str.find_first_not_of(" "); + if (start != std::string::npos) + { + length_str = length_str.substr(start); + try + { + content_length = std::stoul(length_str); + spdlog::trace("Content-Length: {}", content_length); + } + catch (const std::exception& e) + { + spdlog::error("Failed to parse Content-Length: {}", e.what()); + return std::nullopt; + } + } + } + } + + if (content_length == 0) + { + spdlog::debug("No Content-Length found in header"); + return std::nullopt; + } + + // 读取内容体 + std::string body(content_length, '\0'); + std::cin.read(&body[0], content_length); + + if (std::cin.gcount() != static_cast(content_length)) + { + spdlog::error("Read incomplete message body, expected: {}, got: {}", content_length, std::cin.gcount()); + return std::nullopt; + } + + spdlog::trace("Received message: {}", body); + return body; + } + + void LspServer::HandleMessage(const std::string& raw_message) + { + if (auto request = transform::Deserialize(raw_message)) + HandleRequest(*request); + else if (auto notification = transform::Deserialize(raw_message)) + HandleNotification(*notification); + else if (auto response = transform::Deserialize(raw_message)) + HandleResponse(*response); + else + spdlog::error("Failed to deserialize message as any LSP message type"); + } + + void LspServer::HandleRequest(const protocol::RequestMessage& request) + { + std::string request_id = transform::debug::GetIdString(request.id); + spdlog::debug("Processing request - id: {}, method: {}", request_id, request.method); + + // 检查是否可以处理请求 + if (!CanProcessRequest(request.method)) + { + SendStateError(request); + } + else + { + // 决定同步还是异步处理 + if (RequiresSyncProcessing(request.method)) + { + SendResponse(dispatcher_.Dispatch(request)); + } + else + { + // 异步处理 + scheduler_.Submit(request_id, [this, request]() -> std::optional { + if (is_shutting_down_) + { + spdlog::debug("Skipping request {} due to shutdown", request.method); + return std::nullopt; + } + + try + { + return dispatcher_.Dispatch(request); + } + catch (const std::exception& e) + { + spdlog::error("Request processing failed: {}", e.what()); + return dispatcher_.BuildErrorResponseMessage( request, protocol::ErrorCode::kInternalError, e.what()); + } + }); + } + spdlog::debug("Processing request method: {}", request.method); + } + + } + + void LspServer::HandleNotification(const protocol::NotificationMessage& notification) + { + spdlog::debug("Processing notification - method: {}", notification.method); + + try + { + dispatcher_.Dispatch(notification); + } + catch (const std::exception& e) + { + spdlog::error("Notification processing failed for '{}': {}", notification.method, e.what()); + } + } + + void LspServer::HandleResponse(const protocol::ResponseMessage& response) + { + std::string id_str = transform::debug::GetIdString(response.id); + spdlog::debug("Received response - id: {}", id_str); + } + + void LspServer::OnLifecycleEvent(ServerLifecycleEvent event) + { + switch (event) + { + case ServerLifecycleEvent::kInitializing: + spdlog::info("Server initializing..."); + break; + + case ServerLifecycleEvent::kInitialized: + is_initialized_ = true; + spdlog::info("Server initialized successfully"); + break; + + case ServerLifecycleEvent::kInitializeFailed: + is_initialized_ = false; + spdlog::error("Server initialization failed"); + break; + + case ServerLifecycleEvent::kShuttingDown: + is_shutting_down_ = true; + spdlog::info("Server entering shutdown state"); + break; + + case ServerLifecycleEvent::kShutdown: + is_shutting_down_ = true; + spdlog::info("Server shutdown complete"); + break; + } + } + + bool LspServer::RequiresSyncProcessing(const std::string& method) const + { + static const std::unordered_set sync_methods = { + "initialize", // 必须同步完成 + "shutdown" // 必须同步完成 + }; + + return sync_methods.count(method) > 0; + } + + bool LspServer::CanProcessRequest(const std::string& method) const + { + // 未初始化状态 + if (!is_initialized_) + return method == "initialize" || method == "exit"; + + // 关闭中状态 + if (is_shutting_down_) + return method == "exit"; + + // 正常状态 - 接受所有请求 + return true; + } + + void LspServer::HandleCancelRequest(const protocol::NotificationMessage& notification) + { + spdlog::info("Handle cancel request - method: {}", notification.method); + } + + void LspServer::SendResponse(const std::string& response) + { + std::lock_guard lock(output_mutex_); + + size_t byte_length = response.length(); + std::string header = "Content-Length: " + std::to_string(byte_length) + "\r\n\r\n"; + + // 发送 header 和 body + std::cout.write(header.c_str(), header.length()); + std::cout.write(response.c_str(), response.length()); + std::cout.flush(); + + spdlog::trace("Response sent - length: {}", byte_length); + spdlog::trace("Response sent - body: {}", response); + } + + void LspServer::SendError(const protocol::RequestMessage& request, protocol::ErrorCode code, const std::string& message) + { + spdlog::warn("Sending error response - method: {}, code: {}, message: {}", request.method, static_cast(code), message); + std::string error_response = dispatcher_.BuildErrorResponseMessage(request, code, message); + SendResponse(error_response); + } + + void LspServer::SendStateError(const protocol::RequestMessage& request) + { + if (!is_initialized_) + SendError(request, protocol::ErrorCode::kServerNotInitialized, "Server not initialized"); + else if (is_shutting_down_) + SendError(request, protocol::ErrorCode::kInvalidRequest, "Server is shutting down, only 'exit' is allowed"); + else + SendError(request, protocol::ErrorCode::kInternalError, "Request not allowed in current state"); + } +} diff --git a/lsp-server/src/lsp/server.hpp b/lsp-server/src/core/server.hpp similarity index 82% rename from lsp-server/src/lsp/server.hpp rename to lsp-server/src/core/server.hpp index 33752cb..53478bc 100644 --- a/lsp-server/src/lsp/server.hpp +++ b/lsp-server/src/core/server.hpp @@ -1,53 +1,57 @@ -#pragma once -#include -#include -#include -#include "./dispacther.hpp" -#include "./request_scheduler.hpp" -#include "../provider/base/provider_registry.hpp" - -namespace lsp -{ - class LspServer - { - public: - LspServer(size_t concurrency = std::thread::hardware_concurrency()); - ~LspServer(); - void Run(); - - private: - // 读取LSP消息 - std::optional ReadMessage(); - - // 处理LSP请求 - 返回序列化的响应或空字符串(对于通知) - void HandleMessage(const std::string& raw_message); - - // 发送LSP响应 - void SendResponse(const std::string& response); - - // 处理不同类型的消息 - void HandleRequest(const protocol::RequestMessage& request); - void HandleNotification(const protocol::NotificationMessage& notification); - void HandleResponse(const protocol::ResponseMessage& response); - - // 生命周期事件处理 - void OnLifecycleEvent(providers::ServerLifecycleEvent event); - - // 判断是否需要同步处理 - bool RequiresSyncProcessing(const std::string& method) const; - - // 检查是否可以处理请求 - bool CanProcessRequest(const std::string& method) const; - - // 处理取消请求 - void HandleCancelRequest(const protocol::NotificationMessage& notification); - - private: - RequestDispatcher dispatcher_; - RequestScheduler scheduler_; - - std::atomic is_initialized_ = false; - std::atomic is_shutting_down_ = false; - std::mutex output_mutex_; - }; -} +#pragma once +#include +#include +#include +#include "./dispacther.hpp" +#include "../scheduler/request_scheduler.hpp" +#include "../provider/base/provider_registry.hpp" + +namespace lsp::core +{ + class LspServer + { + public: + LspServer(size_t concurrency = std::thread::hardware_concurrency()); + ~LspServer(); + void Run(); + + private: + // 读取LSP消息 + std::optional ReadMessage(); + + // 处理LSP请求 - 返回序列化的响应或空字符串(对于通知) + void HandleMessage(const std::string& raw_message); + + // 发送LSP响应 + void SendResponse(const std::string& response); + + // 处理不同类型的消息 + void HandleRequest(const protocol::RequestMessage& request); + void HandleNotification(const protocol::NotificationMessage& notification); + void HandleResponse(const protocol::ResponseMessage& response); + + // 生命周期事件处理 + void OnLifecycleEvent(providers::ServerLifecycleEvent event); + + // 判断是否需要同步处理 + bool RequiresSyncProcessing(const std::string& method) const; + + // 检查是否可以处理请求 + bool CanProcessRequest(const std::string& method) const; + + // 处理取消请求 + void HandleCancelRequest(const protocol::NotificationMessage& notification); + + private: + void SendError(const protocol::RequestMessage& request, protocol::ErrorCode code, const std::string& message); + void SendStateError(const protocol::RequestMessage& request); + + private: + RequestDispatcher dispatcher_; + scheduler::RequestScheduler scheduler_; + + std::atomic is_initialized_ = false; + std::atomic is_shutting_down_ = false; + std::mutex output_mutex_; + }; +} diff --git a/lsp-server/src/lsp/dispacther.cpp b/lsp-server/src/lsp/dispacther.cpp deleted file mode 100644 index c1065a9..0000000 --- a/lsp-server/src/lsp/dispacther.cpp +++ /dev/null @@ -1,124 +0,0 @@ -#include -#include -#include "./dispacther.hpp" -#include "../protocol/transform/facade.hpp" - -namespace lsp -{ - void RequestDispatcher::RegisterProvider(std::shared_ptr provider) - { - std::unique_lock lock(providers_mutex_); - std::string method = provider->GetMethod(); - - // 如果是生命周期感知的 Provider,设置回调 - if (auto lifecycle_aware = std::dynamic_pointer_cast(provider)) - { - lifecycle_aware->SetLifecycleCallback( - [this](ServerLifecycleEvent event) { - NotifyAllLifecycleListeners(event); - }); - spdlog::debug("Registered lifecycle-aware provider for method: {}", method); - } - else - { - spdlog::debug("Registered standard provider for method: {}", method); - } - providers_[method] = provider; - spdlog::debug("Registered provider for method: {}", method); - } - - void RequestDispatcher::RegisterLifecycleCallback(LifecycleCallback callback) - { - std::lock_guard lock(callbacks_mutex_); - lifecycle_callbacks_.push_back(std::move(callback)); - spdlog::debug("Registered lifecycle callback, total callbacks: {}", lifecycle_callbacks_.size()); - } - - std::string RequestDispatcher::Dispatch(const protocol::RequestMessage& request) - { - std::shared_lock lock(providers_mutex_); - auto it = providers_.find(request.method); - if (it != providers_.end()) - { - auto provider = it->second; - lock.unlock(); - try - { - return provider->ProvideResponse(request); - } - catch (const std::exception& e) - { - spdlog::error("Provider error for method {}: {}", request.method, e.what()); - return BuildErrorResponseMessage(request, protocol::ErrorCode::kInternalError, e.what()); - } - } - return HandleUnknownMethod(request); - } - - bool RequestDispatcher::SupportsMethod(const std::string& method) const - { - std::shared_lock lock(providers_mutex_); - return providers_.find(method) != providers_.end(); - } - - std::vector RequestDispatcher::GetSupportedMethods() const - { - std::shared_lock lock(providers_mutex_); - std::vector methods; - methods.reserve(providers_.size()); - for (const auto& [method, _] : providers_) - { - methods.push_back(method); - } - return methods; - } - - void RequestDispatcher::NotifyAllLifecycleListeners(ServerLifecycleEvent event) - { - std::lock_guard lock(callbacks_mutex_); - - std::string event_name; - switch (event) - { - case ServerLifecycleEvent::kInitializing: - event_name = "Initializing"; - break; - case ServerLifecycleEvent::kInitialized: - event_name = "Initialized"; - break; - case ServerLifecycleEvent::kInitializeFailed: - event_name = "InitializeFailed"; - break; - case ServerLifecycleEvent::kShuttingDown: - event_name = "ShuttingDown"; - break; - case ServerLifecycleEvent::kShutdown: - event_name = "Shutdown"; - break; - } - - spdlog::info("Lifecycle event: {}", event_name); - - for (const auto& callback : lifecycle_callbacks_) - { - try - { - callback(event); - } - catch (const std::exception& e) - { - spdlog::error("Lifecycle callback error: {}", e.what()); - } - } - } - - std::string RequestDispatcher::HandleUnknownMethod(const protocol::RequestMessage& request) - { - return BuildErrorResponseMessage(request, protocol::ErrorCode::kMethodNotFound, "Method not found: " + request.method); - } - - std::string RequestDispatcher::BuildErrorResponseMessage(const protocol::RequestMessage& request, protocol::ErrorCode code, const std::string& message) - { - return providers::ILspProvider::BuildErrorResponseMessage(request, code, message); - } -} diff --git a/lsp-server/src/lsp/dispacther.hpp b/lsp-server/src/lsp/dispacther.hpp deleted file mode 100644 index 68045ac..0000000 --- a/lsp-server/src/lsp/dispacther.hpp +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once -#include -#include -#include -#include "../protocol/protocol.hpp" -#include "../provider/base/provider_interface.hpp" - -namespace lsp -{ - - using ServerLifecycleEvent = providers::ServerLifecycleEvent; - using LifecycleCallback = providers::LifecycleCallback; - - class RequestDispatcher - { - public: - RequestDispatcher() = default; - - void RegisterProvider(std::shared_ptr provider); - void RegisterLifecycleCallback(LifecycleCallback callback); - std::string Dispatch(const protocol::RequestMessage& request); - bool SupportsMethod(const std::string& method) const; - std::vector GetSupportedMethods() const; - - std::string BuildErrorResponseMessage(const protocol::RequestMessage& request, protocol::ErrorCode code, const std::string& message); - - private: - void NotifyAllLifecycleListeners(ServerLifecycleEvent event); - std::string HandleUnknownMethod(const protocol::RequestMessage& request); - - private: - mutable std::shared_mutex providers_mutex_; - std::unordered_map> providers_; - - std::mutex callbacks_mutex_; - std::vector lifecycle_callbacks_; - }; -} diff --git a/lsp-server/src/main.cpp b/lsp-server/src/main.cpp index 52ea267..3095b32 100644 --- a/lsp-server/src/main.cpp +++ b/lsp-server/src/main.cpp @@ -1,11 +1,9 @@ #include #include -#include -#include #include #include #include -#include "./lsp/server.hpp" +#include "./core/server.hpp" #include "./utils/args_parser.hpp" int main(int argc, char* argv[]) @@ -23,7 +21,7 @@ int main(int argc, char* argv[]) try { spdlog::info("TSL-LSP server starting..."); - lsp::LspServer server(config.thread_count); + lsp::core::LspServer server(config.thread_count); server.Run(); } catch (const std::exception& e) diff --git a/lsp-server/src/protocol/transform/facade.hpp b/lsp-server/src/protocol/transform/facade.hpp index cd14e6a..ceda7f7 100644 --- a/lsp-server/src/protocol/transform/facade.hpp +++ b/lsp-server/src/protocol/transform/facade.hpp @@ -5,6 +5,12 @@ namespace lsp::transform { // ===== 全局便利函数 ===== + template + std::optional Deserialize(const std::string& json); + + template + std::optional Serialize(const T& obj); + // 基本类型 template protocol::LSPAny LSPAny(const T& obj); diff --git a/lsp-server/src/protocol/transform/facade.inl b/lsp-server/src/protocol/transform/facade.inl index 21180f0..891724d 100644 --- a/lsp-server/src/protocol/transform/facade.inl +++ b/lsp-server/src/protocol/transform/facade.inl @@ -3,6 +3,28 @@ namespace lsp::transform { + template + inline std::optional Deserialize(const std::string& json) + { + T obj; + auto ce = glz::read_json(obj, json); + if (ce) + return std::nullopt; + else + return obj; + } + + template + std::optional Serialize(const T& obj) + { + std::string json; + auto ce = glz::write_json(obj, json); + if (ce) + return std::nullopt; + else + return json; + } + template inline protocol::LSPAny LSPAny(const T& obj) { diff --git a/lsp-server/src/provider/base/provider_interface.cpp b/lsp-server/src/provider/base/provider_interface.cpp index 68e205f..ace5a38 100644 --- a/lsp-server/src/provider/base/provider_interface.cpp +++ b/lsp-server/src/provider/base/provider_interface.cpp @@ -1,10 +1,9 @@ -#include #include "./provider_interface.hpp" namespace lsp::providers { - std::string ILspProvider::BuildErrorResponseMessage(const protocol::RequestMessage& request, protocol::ErrorCode code, const std::string& message) + std::string IRequestProvider::BuildErrorResponseMessage(const protocol::RequestMessage& request, protocol::ErrorCode code, const std::string& message) { protocol::ResponseMessage response; response.id = request.id; @@ -23,18 +22,4 @@ namespace lsp::providers return json; } - void ILifecycleAwareProvider::SetLifecycleCallback(LifecycleCallback callback) - { - lifecycle_callback_ = std::move(callback); - } - - void ILifecycleAwareProvider::NotifyLifecycleEvent(ServerLifecycleEvent event) - { - if (lifecycle_callback_) - { - lifecycle_callback_(event); - spdlog::debug("Provider {} triggered event: {}", GetProviderName(), static_cast(event)); - } - } - } diff --git a/lsp-server/src/provider/base/provider_interface.hpp b/lsp-server/src/provider/base/provider_interface.hpp index 932b4b4..455e821 100644 --- a/lsp-server/src/provider/base/provider_interface.hpp +++ b/lsp-server/src/provider/base/provider_interface.hpp @@ -1,6 +1,8 @@ #pragma once #include +#include #include "../../protocol/protocol.hpp" +#include "../../scheduler/request_scheduler.hpp" namespace lsp::providers { @@ -15,37 +17,57 @@ namespace lsp::providers using LifecycleCallback = std::function; + class ProviderContext; + // LSP请求提供者接口基类 - class ILspProvider + class IProvider { public: - virtual ~ILspProvider() = default; + virtual ~IProvider() = default; - // 处理LSP请求 - virtual std::string ProvideResponse(const protocol::RequestMessage& request) = 0; // 获取支持的LSP方法名 virtual std::string GetMethod() const = 0; // 获取提供者名称(用于日志和调试) virtual std::string GetProviderName() const = 0; + }; + // LSP 请求处理器接口 + class IRequestProvider : public IProvider + { + public: + virtual ~IRequestProvider() = default; + + // 处理LSP请求 + virtual std::string ProvideResponse(const protocol::RequestMessage& request, ProviderContext& context) = 0; static std::string BuildErrorResponseMessage(const protocol::RequestMessage& request, protocol::ErrorCode code, const std::string& message); }; - // 生命周期感知的 Provider 接口 - class ILifecycleAwareProvider : public ILspProvider + // LSP 通知处理器接口 + class INotificationProvider : public IProvider { public: - virtual ~ILifecycleAwareProvider() = default; + virtual ~INotificationProvider() = default; - // 设置生命周期回调 - void SetLifecycleCallback(LifecycleCallback callback); - - protected: - // 触发生命周期事件 - void NotifyLifecycleEvent(ServerLifecycleEvent event); - - private: - LifecycleCallback lifecycle_callback_; + // 处理LSP通知 + virtual void HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) = 0; }; + class ProviderContext + { + public: + ProviderContext(scheduler::RequestScheduler* scheduler, LifecycleCallback lifecycle_callback) : + scheduler_(scheduler), lifecycle_callback_(lifecycle_callback) {} + + scheduler::RequestScheduler* GetScheduler() const { return scheduler_; } + + void TriggerLifecycleEvent(ServerLifecycleEvent event) const + { + if (lifecycle_callback_) + lifecycle_callback_(event); + } + + private: + scheduler::RequestScheduler* scheduler_; + LifecycleCallback lifecycle_callback_; + }; } diff --git a/lsp-server/src/provider/base/provider_registry.cpp b/lsp-server/src/provider/base/provider_registry.cpp index 1f68bd2..bf1b887 100644 --- a/lsp-server/src/provider/base/provider_registry.cpp +++ b/lsp-server/src/provider/base/provider_registry.cpp @@ -7,11 +7,13 @@ #include "../text_document/completion_provider.hpp" #include "../trace/set_trace_provider.hpp" #include "../shutdown/shutdown_provider.hpp" +#include "../cancel_request/cancel_request_provider.hpp" +#include "../exit/exit_provider.hpp" namespace lsp::providers { - void RegisterAllProviders(RequestDispatcher& dispatcher) + void RegisterAllProviders(core::RequestDispatcher& dispatcher) { spdlog::info("Registering LSP providers..."); @@ -20,10 +22,12 @@ namespace lsp::providers RegisterProvider(dispatcher); RegisterProvider(dispatcher); RegisterProvider(dispatcher); - RegisterProvider(dispatcher); + RegisterProvider(dispatcher); RegisterProvider(dispatcher); + RegisterProvider(dispatcher); + RegisterProvider(dispatcher); - spdlog::info("Successfully registered {} LSP providers", dispatcher.GetSupportedMethods().size()); + spdlog::info("Successfully registered {} LSP providers", dispatcher.GetAllSupportedMethods().size()); } } diff --git a/lsp-server/src/provider/base/provider_registry.hpp b/lsp-server/src/provider/base/provider_registry.hpp index ea8f2f7..54d52de 100644 --- a/lsp-server/src/provider/base/provider_registry.hpp +++ b/lsp-server/src/provider/base/provider_registry.hpp @@ -1,24 +1,32 @@ #pragma once #include -#include "../../lsp/dispacther.hpp" +#include +#include "../../core/dispacther.hpp" #include "./provider_interface.hpp" namespace lsp::providers { - // 模板函数:注册provider + // 注册请求处理器的模板函数 template - void RegisterProvider(RequestDispatcher& dispatcher) + typename std::enable_if_t> + RegisterProvider(core::RequestDispatcher& dispatcher) { - static_assert(std::is_base_of_v, - "Provider must inherit from ILspProvider"); - auto provider = std::make_shared(); - dispatcher.RegisterProvider(provider); + dispatcher.RegisterRequestProvider(provider); + spdlog::info("Registered request provider '{}' for method: {}", provider->GetProviderName(), provider->GetMethod()); + } - spdlog::info("Registering {} for method: {}", provider->GetProviderName(), provider->GetMethod()); + // 注册通知处理器的模板函数 + template + typename std::enable_if_t> + RegisterProvider(core::RequestDispatcher& dispatcher) + { + auto provider = std::make_shared(); + dispatcher.RegisterNotificationProvider(provider); + spdlog::info("Registered notification provider '{}' for method: {}", provider->GetProviderName(), provider->GetMethod()); } // 批量注册provider - void RegisterAllProviders(RequestDispatcher& dispatcher); + void RegisterAllProviders(core::RequestDispatcher& dispatcher); } diff --git a/lsp-server/src/provider/cancel_request/cancel_request_provider.cpp b/lsp-server/src/provider/cancel_request/cancel_request_provider.cpp new file mode 100644 index 0000000..73c0db4 --- /dev/null +++ b/lsp-server/src/provider/cancel_request/cancel_request_provider.cpp @@ -0,0 +1,34 @@ +#include "./cancel_request_provider.hpp" + +namespace lsp::providers::cancel_request +{ + std::string CancelRequestProvider::GetMethod() const + { + return "$/cancelRequest"; + } + + std::string CancelRequestProvider::GetProviderName() const + { + return "CancelRequestProvider"; + } + + void CancelRequestProvider::HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) + { + try + { + auto params = transform::As(notification.params.value()); + std::string id_to_cancel = transform::debug::GetIdString(params.id); + spdlog::info("Processing cancel request for ID: {}", id_to_cancel); + + if (auto* scheduler = context.GetScheduler()) + { + bool cancelled = scheduler->Cancel(id_to_cancel); + spdlog::info("Cancel request {} result: {}", id_to_cancel, cancelled ? "success" : "not found"); + } + } + catch (const std::exception& e) + { + spdlog::error("Error handling cancel request: {}", e.what()); + } + } +} diff --git a/lsp-server/src/provider/cancel_request/cancel_request_provider.hpp b/lsp-server/src/provider/cancel_request/cancel_request_provider.hpp new file mode 100644 index 0000000..4da1cd8 --- /dev/null +++ b/lsp-server/src/provider/cancel_request/cancel_request_provider.hpp @@ -0,0 +1,16 @@ +#pragma once +#include +#include "../base/provider_interface.hpp" +#include "../../scheduler/request_scheduler.hpp" +#include "../../protocol/transform/facade.hpp" + +namespace lsp::providers::cancel_request +{ + class CancelRequestProvider : public INotificationProvider + { + public: + std::string GetMethod() const override; + std::string GetProviderName() const override; + void HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) override; + }; +} diff --git a/lsp-server/src/provider/exit/exit_provider.cpp b/lsp-server/src/provider/exit/exit_provider.cpp new file mode 100644 index 0000000..5762a6b --- /dev/null +++ b/lsp-server/src/provider/exit/exit_provider.cpp @@ -0,0 +1,33 @@ +#include "./exit_provider.hpp" + +namespace lsp::providers::exit +{ + + std::string ExitProvider::GetMethod() const + { + return "exit"; + } + + std::string ExitProvider::GetProviderName() const + { + return "ExitProvider"; + } + + void ExitProvider::HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) + { + // static_cast(context); + spdlog::debug("ExitProvider: Providing response for method {}", notification.method); + spdlog::info("Exit notification received"); + + // 触发生命周期事件 + context.TriggerLifecycleEvent(ServerLifecycleEvent::kShuttingDown); + + // 给一些时间完成清理 + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + context.TriggerLifecycleEvent(ServerLifecycleEvent::kShutdown); + + std::exit(0); + } + +} diff --git a/lsp-server/src/provider/exit/exit_provider.hpp b/lsp-server/src/provider/exit/exit_provider.hpp new file mode 100644 index 0000000..5c63cf4 --- /dev/null +++ b/lsp-server/src/provider/exit/exit_provider.hpp @@ -0,0 +1,13 @@ +#pragma once +#include "../base/provider_interface.hpp" + +namespace lsp::providers::exit +{ + class ExitProvider : public INotificationProvider + { + public: + std::string GetMethod() const override; + std::string GetProviderName() const override; + void HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) override; + }; +} diff --git a/lsp-server/src/provider/initialize/initialize_provider.cpp b/lsp-server/src/provider/initialize/initialize_provider.cpp index 2797455..1886d30 100644 --- a/lsp-server/src/provider/initialize/initialize_provider.cpp +++ b/lsp-server/src/provider/initialize/initialize_provider.cpp @@ -4,7 +4,17 @@ namespace lsp::providers::initialize { - std::string InitializeProvider::ProvideResponse(const protocol::RequestMessage& request) + std::string InitializeProvider::GetMethod() const + { + return "initialize"; + } + + std::string InitializeProvider::GetProviderName() const + { + return "InitializeProvider"; + } + + std::string InitializeProvider::ProvideResponse(const protocol::RequestMessage& request, ProviderContext& context) { spdlog::debug("InitializeProvider: Providing response for method {}", request.method); protocol::ResponseMessage response; @@ -14,23 +24,13 @@ namespace lsp::providers::initialize auto ec = glz::write_json(response, json); if (ec) { - NotifyLifecycleEvent(ServerLifecycleEvent::kInitializeFailed); + context.TriggerLifecycleEvent(ServerLifecycleEvent::kInitializeFailed); return BuildErrorResponseMessage(request, protocol::ErrorCode::kInternalError, "Internal error"); } - NotifyLifecycleEvent(ServerLifecycleEvent::kInitialized); + context.TriggerLifecycleEvent(ServerLifecycleEvent::kInitialized); return json; } - std::string InitializeProvider::GetMethod() const - { - return "initialize"; - } - - std::string InitializeProvider::GetProviderName() const - { - return "InitializeProvider"; - } - protocol::InitializeResult InitializeProvider::BuildInitializeResult() { protocol::InitializeResult result; diff --git a/lsp-server/src/provider/initialize/initialize_provider.hpp b/lsp-server/src/provider/initialize/initialize_provider.hpp index 75280aa..49eb721 100644 --- a/lsp-server/src/provider/initialize/initialize_provider.hpp +++ b/lsp-server/src/provider/initialize/initialize_provider.hpp @@ -4,11 +4,11 @@ namespace lsp::providers::initialize { using namespace lsp; - class InitializeProvider : public ILifecycleAwareProvider + class InitializeProvider : public IRequestProvider { public: InitializeProvider() = default; - std::string ProvideResponse(const protocol::RequestMessage& request) override; + std::string ProvideResponse(const protocol::RequestMessage& request, ProviderContext& context) override; std::string GetMethod() const override; std::string GetProviderName() const override; diff --git a/lsp-server/src/provider/initialized/initialized_provider.cpp b/lsp-server/src/provider/initialized/initialized_provider.cpp index 248447d..adb054c 100644 --- a/lsp-server/src/provider/initialized/initialized_provider.cpp +++ b/lsp-server/src/provider/initialized/initialized_provider.cpp @@ -4,15 +4,6 @@ namespace lsp::providers::initialized { - std::string InitializedProvider::ProvideResponse(const protocol::RequestMessage& request) - { - spdlog::debug("InitializeProvider: Providing response for method {}", request.method); - std::string json; - glz::obj empty_obj{}; // glaze的对象类型 - auto ec = glz::write_json(empty_obj, json); - return ec ? BuildErrorResponseMessage(request, protocol::ErrorCode::kInternalError, "Internal error") : json; - } - std::string InitializedProvider::GetMethod() const { return "initialized"; @@ -23,4 +14,10 @@ namespace lsp::providers::initialized return "InitializedProvider"; } + void InitializedProvider::HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) + { + static_cast(context); // 如果不需要上下文,可以忽略 + spdlog::debug("InitializeProvider: Providing response for method {}", notification.method); + } + } diff --git a/lsp-server/src/provider/initialized/initialized_provider.hpp b/lsp-server/src/provider/initialized/initialized_provider.hpp index 8741697..92a2091 100644 --- a/lsp-server/src/provider/initialized/initialized_provider.hpp +++ b/lsp-server/src/provider/initialized/initialized_provider.hpp @@ -3,12 +3,12 @@ namespace lsp::providers::initialized { - class InitializedProvider : public ILifecycleAwareProvider + class InitializedProvider : public INotificationProvider { public: InitializedProvider() = default; - std::string ProvideResponse(const protocol::RequestMessage& request) override; std::string GetMethod() const override; std::string GetProviderName() const override; + void HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) override; }; } diff --git a/lsp-server/src/provider/shutdown/shutdown_provider.cpp b/lsp-server/src/provider/shutdown/shutdown_provider.cpp index 674be96..347323e 100644 --- a/lsp-server/src/provider/shutdown/shutdown_provider.cpp +++ b/lsp-server/src/provider/shutdown/shutdown_provider.cpp @@ -4,14 +4,24 @@ namespace lsp::providers::shutdown { - std::string ShutdownProvider::ProvideResponse(const protocol::RequestMessage& request) + std::string ShutdownProvider::GetMethod() const + { + return "shutdown"; + } + + std::string ShutdownProvider::GetProviderName() const + { + return "ShutdownProvider"; + } + + std::string ShutdownProvider::ProvideResponse(const protocol::RequestMessage& request, ProviderContext& context) { spdlog::debug("ShutdownProvider: Providing response for method {}", request.method); try { // 触发关闭事件 - NotifyLifecycleEvent(ServerLifecycleEvent::kShuttingDown); + context.TriggerLifecycleEvent(ServerLifecycleEvent::kShuttingDown); // 构建响应 - shutdown 返回 null protocol::ResponseMessage response; @@ -34,13 +44,4 @@ namespace lsp::providers::shutdown } } - std::string ShutdownProvider::GetMethod() const - { - return "shutdown"; - } - - std::string ShutdownProvider::GetProviderName() const - { - return "ShutdownProvider"; - } } diff --git a/lsp-server/src/provider/shutdown/shutdown_provider.hpp b/lsp-server/src/provider/shutdown/shutdown_provider.hpp index a057e7c..fc7bb09 100644 --- a/lsp-server/src/provider/shutdown/shutdown_provider.hpp +++ b/lsp-server/src/provider/shutdown/shutdown_provider.hpp @@ -3,13 +3,13 @@ namespace lsp::providers::shutdown { - class ShutdownProvider : public ILifecycleAwareProvider + class ShutdownProvider : public IRequestProvider { public: ShutdownProvider() = default; - - std::string ProvideResponse(const protocol::RequestMessage& request) override; + std::string GetMethod() const override; std::string GetProviderName() const override; + std::string ProvideResponse(const protocol::RequestMessage& request, ProviderContext& context) override; }; } diff --git a/lsp-server/src/provider/text_document/completion_provider.cpp b/lsp-server/src/provider/text_document/completion_provider.cpp index a07b57f..f10f400 100644 --- a/lsp-server/src/provider/text_document/completion_provider.cpp +++ b/lsp-server/src/provider/text_document/completion_provider.cpp @@ -4,13 +4,25 @@ namespace lsp::providers::text_document { - std::string CompletionProvider::ProvideResponse(const protocol::RequestMessage& request) + std::string CompletionProvider::GetMethod() const { + return "textDocument/completion"; + } + + std::string CompletionProvider::GetProviderName() const + { + return "CompletionProvider"; + } + + std::string CompletionProvider::ProvideResponse(const protocol::RequestMessage& request, ProviderContext& context) + { + static_cast(context); spdlog::debug("CompletionProvider: Providing response for method {}", request.method); try { // 验证请求是否包含参数 - if (!request.params.has_value()) { + if (!request.params.has_value()) + { spdlog::warn("{}: Missing params in request", GetProviderName()); return BuildErrorResponseMessage(request, protocol::ErrorCode::kInvalidParams, "Missing params"); } @@ -42,15 +54,6 @@ namespace lsp::providers::text_document } } - std::string CompletionProvider::GetMethod() const - { - return "textDocument/completion"; - } - - std::string CompletionProvider::GetProviderName() const - { - return "CompletionProvider"; - } protocol::CompletionList CompletionProvider::BuildCompletionResponse(const protocol::CompletionParams& params) { diff --git a/lsp-server/src/provider/text_document/completion_provider.hpp b/lsp-server/src/provider/text_document/completion_provider.hpp index 1755a44..2b4c7db 100644 --- a/lsp-server/src/provider/text_document/completion_provider.hpp +++ b/lsp-server/src/provider/text_document/completion_provider.hpp @@ -7,14 +7,14 @@ namespace lsp::providers::text_document { - class CompletionProvider : public ILspProvider + class CompletionProvider : public IRequestProvider { public: CompletionProvider() = default; - std::string ProvideResponse(const protocol::RequestMessage& request) override; std::string GetMethod() const override; std::string GetProviderName() const override; + std::string ProvideResponse(const protocol::RequestMessage& request, ProviderContext& context) override; private: // 构建完整的补全响应 diff --git a/lsp-server/src/provider/text_document/completion_resolver_provider.cpp b/lsp-server/src/provider/text_document/completion_resolver_provider.cpp new file mode 100644 index 0000000..e69de29 diff --git a/lsp-server/src/provider/text_document/completion_resolver_provider.hpp b/lsp-server/src/provider/text_document/completion_resolver_provider.hpp new file mode 100644 index 0000000..e69de29 diff --git a/lsp-server/src/provider/text_document/did_change_provider.cpp b/lsp-server/src/provider/text_document/did_change_provider.cpp index 90c9ae0..7466226 100644 --- a/lsp-server/src/provider/text_document/did_change_provider.cpp +++ b/lsp-server/src/provider/text_document/did_change_provider.cpp @@ -1,19 +1,9 @@ #include #include "./did_change_provider.hpp" -#include "./did_open_provider.hpp" namespace lsp::providers::text_document { - std::string DidChangeProvider::ProvideResponse(const protocol::RequestMessage& request) - { - spdlog::debug("DidChangeProvider: Providing response for method {}", request.method); - std::string json; - glz::obj empty_obj{}; // glaze的对象类型 - auto ec = glz::write_json(empty_obj, json); - return ec ? BuildErrorResponseMessage(request, protocol::ErrorCode::kInternalError, "Internal error") : json; - } - std::string DidChangeProvider::GetMethod() const { return "textDocument/didChange"; @@ -24,4 +14,10 @@ namespace lsp::providers::text_document return "DidChangeProvider"; } + void DidChangeProvider::HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) + { + static_cast(context); + spdlog::debug("DidChangeProvider: Providing response for method {}", notification.method); + } + } diff --git a/lsp-server/src/provider/text_document/did_change_provider.hpp b/lsp-server/src/provider/text_document/did_change_provider.hpp index 84b034f..013096e 100644 --- a/lsp-server/src/provider/text_document/did_change_provider.hpp +++ b/lsp-server/src/provider/text_document/did_change_provider.hpp @@ -3,12 +3,12 @@ namespace lsp::providers::text_document { - class DidChangeProvider : public ILspProvider + class DidChangeProvider : public INotificationProvider { public: DidChangeProvider() = default; - std::string ProvideResponse(const protocol::RequestMessage& request) override; std::string GetMethod() const override; std::string GetProviderName() const override; + void HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) override; }; } diff --git a/lsp-server/src/provider/text_document/did_open_provider.cpp b/lsp-server/src/provider/text_document/did_open_provider.cpp index ff38bfe..52ba50b 100644 --- a/lsp-server/src/provider/text_document/did_open_provider.cpp +++ b/lsp-server/src/provider/text_document/did_open_provider.cpp @@ -3,15 +3,6 @@ namespace lsp::providers::text_document { - std::string DidOpenProvider::ProvideResponse(const protocol::RequestMessage& request) - { - spdlog::debug("DidOpenProvider: Providing response for method {}", request.method); - std::string json; - glz::obj empty_obj{}; // glaze的对象类型 - auto ec = glz::write_json(empty_obj, json); - return ec ? BuildErrorResponseMessage(request, protocol::ErrorCode::kInternalError, "Internal error") : json; - } - std::string DidOpenProvider::GetMethod() const { return "textDocument/didOpen"; @@ -22,4 +13,10 @@ namespace lsp::providers::text_document return "DidOpenProvider"; } + void DidOpenProvider::HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) + { + static_cast(context); + spdlog::debug("DidOpenProvider: Providing response for method {}", notification.method); + } + } diff --git a/lsp-server/src/provider/text_document/did_open_provider.hpp b/lsp-server/src/provider/text_document/did_open_provider.hpp index e482ef1..0801d92 100644 --- a/lsp-server/src/provider/text_document/did_open_provider.hpp +++ b/lsp-server/src/provider/text_document/did_open_provider.hpp @@ -4,12 +4,12 @@ namespace lsp::providers::text_document { - class DidOpenProvider : public ILspProvider + class DidOpenProvider : public INotificationProvider { public: DidOpenProvider() = default; - std::string ProvideResponse(const protocol::RequestMessage& request) override; std::string GetMethod() const override; std::string GetProviderName() const override; + void HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) override; }; } diff --git a/lsp-server/src/provider/trace/set_trace_provider.cpp b/lsp-server/src/provider/trace/set_trace_provider.cpp index 7ed8ff1..11faafe 100644 --- a/lsp-server/src/provider/trace/set_trace_provider.cpp +++ b/lsp-server/src/provider/trace/set_trace_provider.cpp @@ -1,18 +1,8 @@ #include #include "./set_trace_provider.hpp" -namespace lsp::providers::trace +namespace lsp::providers::set_trace { - - std::string SetTraceProvider::ProvideResponse(const protocol::RequestMessage& request) - { - spdlog::debug("SetTraceProvider: Providing response for method {}", request.method); - std::string json; - glz::obj empty_obj{}; // glaze的对象类型 - auto ec = glz::write_json(empty_obj, json); - return ec ? BuildErrorResponseMessage(request, protocol::ErrorCode::kInternalError, "Internal error") : json; - } - std::string SetTraceProvider::GetMethod() const { return "$/setTrace"; @@ -23,4 +13,10 @@ namespace lsp::providers::trace return "SetTraceProvider"; } + void SetTraceProvider::HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) + { + static_cast(context); + spdlog::debug("SetTraceProvider: Providing response for method {}", notification.method); + } + } diff --git a/lsp-server/src/provider/trace/set_trace_provider.hpp b/lsp-server/src/provider/trace/set_trace_provider.hpp index ba3c217..38a6e1b 100644 --- a/lsp-server/src/provider/trace/set_trace_provider.hpp +++ b/lsp-server/src/provider/trace/set_trace_provider.hpp @@ -1,14 +1,14 @@ #pragma once #include "../base/provider_interface.hpp" -namespace lsp::providers::trace +namespace lsp::providers::set_trace { - class SetTraceProvider : public ILspProvider + class SetTraceProvider : public INotificationProvider { public: SetTraceProvider() = default; - std::string ProvideResponse(const protocol::RequestMessage& request) override; std::string GetMethod() const override; std::string GetProviderName() const override; + void HandleNotification(const protocol::NotificationMessage& notification, ProviderContext& context) override; }; } diff --git a/lsp-server/src/lsp/request_scheduler.cpp b/lsp-server/src/scheduler/request_scheduler.cpp similarity index 95% rename from lsp-server/src/lsp/request_scheduler.cpp rename to lsp-server/src/scheduler/request_scheduler.cpp index 8e7ef49..7bec05b 100644 --- a/lsp-server/src/lsp/request_scheduler.cpp +++ b/lsp-server/src/scheduler/request_scheduler.cpp @@ -1,101 +1,101 @@ -#include -#include "./request_scheduler.hpp" - -namespace lsp -{ - RequestScheduler::RequestScheduler(size_t concurrency) : - executor_(concurrency) - { - spdlog::info("RequestScheduler initialized with {} threads", concurrency); - } - - RequestScheduler::~RequestScheduler() - { - WaitAll(); - } - - void RequestScheduler::Submit(const std::string& request_id, TaskFunc task) - { - auto context = std::make_shared(); - - { - std::lock_guard lock(mutex_); - - // 取消旧任务 - auto it = running_tasks_.find(request_id); - if (it != running_tasks_.end()) - { - it->second->cancelled.store(true); - } - - running_tasks_[request_id] = context; - } - - executor_.async([this, request_id, task = std::move(task), context]() { - try - { - if (context->cancelled.load()) - { - spdlog::debug("Task {} was cancelled", request_id); - return; - } - - auto result = task(); - - if (!context->cancelled.load() && result) - { - SendResponse(*result); - } - } - catch (const std::exception& e) - { - spdlog::error("Task {} failed: {}", request_id, e.what()); - } - - // 清理 - { - std::lock_guard lock(mutex_); - running_tasks_.erase(request_id); - } - }); - } - - bool RequestScheduler::Cancel(const std::string& request_id) - { - std::lock_guard lock(mutex_); - - auto it = running_tasks_.find(request_id); - if (it != running_tasks_.end()) - { - it->second->cancelled.store(true); - return true; - } - - return false; - } - - void RequestScheduler::SetResponseCallback(ResponseCallback callback) - { - std::lock_guard lock(mutex_); - response_callback_ = std::move(callback); - } - - void RequestScheduler::WaitAll() - { - executor_.wait_for_all(); - } - - void RequestScheduler::SendResponse(const std::string& response) - { - ResponseCallback callback; - { - std::lock_guard lock(mutex_); - callback = response_callback_; - } - - if (callback) - callback(response); - else - spdlog::error("No response callback set!"); - } -} +#include +#include "./request_scheduler.hpp" + +namespace lsp::scheduler +{ + RequestScheduler::RequestScheduler(size_t concurrency) : + executor_(concurrency) + { + spdlog::info("RequestScheduler initialized with {} threads", concurrency); + } + + RequestScheduler::~RequestScheduler() + { + WaitAll(); + } + + void RequestScheduler::Submit(const std::string& request_id, TaskFunc task) + { + auto context = std::make_shared(); + + { + std::lock_guard lock(mutex_); + + // 取消旧任务 + auto it = running_tasks_.find(request_id); + if (it != running_tasks_.end()) + { + it->second->cancelled.store(true); + } + + running_tasks_[request_id] = context; + } + + executor_.async([this, request_id, task = std::move(task), context]() { + try + { + if (context->cancelled.load()) + { + spdlog::debug("Task {} was cancelled", request_id); + return; + } + + auto result = task(); + + if (!context->cancelled.load() && result) + { + SendResponse(*result); + } + } + catch (const std::exception& e) + { + spdlog::error("Task {} failed: {}", request_id, e.what()); + } + + // 清理 + { + std::lock_guard lock(mutex_); + running_tasks_.erase(request_id); + } + }); + } + + bool RequestScheduler::Cancel(const std::string& request_id) + { + std::lock_guard lock(mutex_); + + auto it = running_tasks_.find(request_id); + if (it != running_tasks_.end()) + { + it->second->cancelled.store(true); + return true; + } + + return false; + } + + void RequestScheduler::SetResponseCallback(ResponseCallback callback) + { + std::lock_guard lock(mutex_); + response_callback_ = std::move(callback); + } + + void RequestScheduler::WaitAll() + { + executor_.wait_for_all(); + } + + void RequestScheduler::SendResponse(const std::string& response) + { + ResponseCallback callback; + { + std::lock_guard lock(mutex_); + callback = response_callback_; + } + + if (callback) + callback(response); + else + spdlog::error("No response callback set!"); + } +} diff --git a/lsp-server/src/lsp/request_scheduler.hpp b/lsp-server/src/scheduler/request_scheduler.hpp similarity index 95% rename from lsp-server/src/lsp/request_scheduler.hpp rename to lsp-server/src/scheduler/request_scheduler.hpp index 4cac7ac..7f3b1cc 100644 --- a/lsp-server/src/lsp/request_scheduler.hpp +++ b/lsp-server/src/scheduler/request_scheduler.hpp @@ -1,41 +1,41 @@ -// request_scheduler.hpp -#pragma once -#include -#include -#include -#include -#include -#include -#include - -namespace lsp -{ - class RequestScheduler - { - public: - using TaskFunc = std::function()>; - using ResponseCallback = std::function; - - explicit RequestScheduler(size_t concurrency = std::thread::hardware_concurrency()); - ~RequestScheduler(); - - void Submit(const std::string& request_id, TaskFunc task); - bool Cancel(const std::string& request_id); - void SetResponseCallback(ResponseCallback callback); - void WaitAll(); - - private: - struct TaskContext - { - std::atomic cancelled{false}; - }; - - void SendResponse(const std::string& response); - - private: - tf::Executor executor_; - mutable std::mutex mutex_; - std::unordered_map> running_tasks_; - ResponseCallback response_callback_; - }; -} +// request_scheduler.hpp +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace lsp::scheduler +{ + class RequestScheduler + { + public: + using TaskFunc = std::function()>; + using ResponseCallback = std::function; + + explicit RequestScheduler(size_t concurrency = std::thread::hardware_concurrency()); + ~RequestScheduler(); + + void Submit(const std::string& request_id, TaskFunc task); + bool Cancel(const std::string& request_id); + void SetResponseCallback(ResponseCallback callback); + void WaitAll(); + + private: + struct TaskContext + { + std::atomic cancelled{false}; + }; + + void SendResponse(const std::string& response); + + private: + tf::Executor executor_; + mutable std::mutex mutex_; + std::unordered_map> running_tasks_; + ResponseCallback response_callback_; + }; +} diff --git a/lsp-server/src/services/document.cpp b/lsp-server/src/services/document.cpp new file mode 100644 index 0000000..f93747d --- /dev/null +++ b/lsp-server/src/services/document.cpp @@ -0,0 +1,821 @@ +#include +#include "./document.hpp" + +namespace lsp::services +{ + // ===== Document 实现 ===== + + Document::Document(const protocol::TextDocumentItem& item) : + item_(item), last_modified_time_(std::chrono::system_clock::now()) + { + UpdateInternalState(); + spdlog::trace("Created document: {} (version {}, {} bytes)", item_.uri, item_.version, item_.text.length()); + } + + void Document::SetContent(int32_t newVersion, const std::string& newText) + { + item_.version = newVersion; + item_.text = newText; + is_dirty_ = true; + last_modified_time_ = std::chrono::system_clock::now(); + UpdateInternalState(); + + spdlog::trace("Document {} updated to version {} ({} bytes)", item_.uri, item_.version, item_.text.length()); + } + + void Document::ApplyContentChange(protocol::integer version, const std::vector& changes) + { + // 应用所有变更 + for (const auto& change : changes) + { + ApplyContentChange(change); + } + + item_.version = version; + is_dirty_ = true; + last_modified_time_ = std::chrono::system_clock::now(); + UpdateInternalState(); + + spdlog::trace("Document {} updated to version {} with {} changes", item_.uri, item_.version, changes.size()); + } + + void Document::ApplyContentChange(const protocol::TextDocumentContentChangeEvent& change) + { + // 增量更新 + size_t startOffset = PositionToOffset(change.range.start); + size_t endOffset = PositionToOffset(change.range.end); + + // 替换指定范围 + item_.text = item_.text.substr(0, startOffset) + change.text + item_.text.substr(endOffset); + } + + size_t Document::PositionToOffset(const protocol::Position& position) const + { + if (position.line >= lines_.size()) + { + return item_.text.length(); + } + + size_t offset = line_offsets_[position.line]; + + // 根据编码计算字符偏移 + if (encoding_ == protocol::PositionEncodingKindLiterals::UTF8) + { + // 直接使用字节偏移 + offset += std::min(static_cast(position.character), lines_[position.line].length()); + } + else + { + // UTF-16 或 UTF-32 + offset += CharacterToByteOffset(lines_[position.line], position.character); + } + + return offset; + } + + protocol::Position Document::OffsetToPosition(size_t offset) const + { + protocol::Position pos; + pos.line = 0; + pos.character = 0; + + // 二分查找行号 + auto it = std::upper_bound(line_offsets_.begin(), line_offsets_.end(), offset); + if (it != line_offsets_.begin()) + { + --it; + pos.line = static_cast(std::distance(line_offsets_.begin(), it)); + size_t lineOffset = *it; + size_t byteOffset = offset - lineOffset; + + // 根据编码计算字符位置 + if (encoding_ == protocol::PositionEncodingKindLiterals::UTF8) + { + pos.character = static_cast(byteOffset); + } + else + { + pos.character = ByteOffsetToCharacter(lines_[pos.line], byteOffset); + } + } + + return pos; + } + + std::string Document::GetTextInRange(const protocol::Range& range) const + { + size_t start = PositionToOffset(range.start); + size_t end = PositionToOffset(range.end); + + if (start >= item_.text.length()) + { + return ""; + } + + end = std::min(end, item_.text.length()); + return item_.text.substr(start, end - start); + } + + std::optional Document::GetCharAt(const protocol::Position& position) const + { + if (position.line >= lines_.size()) + { + return std::nullopt; + } + + const std::string& line = lines_[position.line]; + size_t byteOffset = CharacterToByteOffset(line, position.character); + + if (byteOffset >= line.length()) + { + return std::nullopt; + } + + return line[byteOffset]; + } + + std::string Document::GetLine(size_t lineNumber) const + { + if (lineNumber < lines_.size()) + { + return lines_[lineNumber]; + } + return ""; + } + + std::string Document::GetLineAt(const protocol::Position& position) const + { + return GetLine(position.line); + } + + std::string Document::GetWordAt(const protocol::Position& position) const + { + if (position.line >= lines_.size()) + { + return ""; + } + + const std::string& line = lines_[position.line]; + size_t bytePos = CharacterToByteOffset(line, position.character); + + // 找到单词边界 + size_t start = bytePos; + while (start > 0 && IsWordChar(line[start - 1])) + { + --start; + } + + size_t end = bytePos; + while (end < line.length() && IsWordChar(line[end])) + { + ++end; + } + + return line.substr(start, end - start); + } + + protocol::Range Document::GetWordRangeAt(const protocol::Position& position) const + { + if (position.line >= lines_.size()) + { + return protocol::Range{ position, position }; + } + + const std::string& line = lines_[position.line]; + size_t bytePos = CharacterToByteOffset(line, position.character); + + // 找到单词边界 + size_t start = bytePos; + while (start > 0 && IsWordChar(line[start - 1])) + { + --start; + } + + size_t end = bytePos; + while (end < line.length() && IsWordChar(line[end])) + { + ++end; + } + + protocol::Range range; + range.start.line = position.line; + range.start.character = ByteOffsetToCharacter(line, start); + range.end.line = position.line; + range.end.character = ByteOffsetToCharacter(line, end); + + return range; + } + + void Document::UpdateInternalState() + { + UpdateLines(); + UpdateLineOffsets(); + } + + void Document::UpdateLines() + { + lines_.clear(); + + size_t start = 0; + for (size_t i = 0; i < item_.text.length(); ++i) + { + if (item_.text[i] == '\n') + { + lines_.push_back(item_.text.substr(start, i - start)); + start = i + 1; + } + else if (item_.text[i] == '\r') + { + if (i + 1 < item_.text.length() && item_.text[i + 1] == '\n') + { + lines_.push_back(item_.text.substr(start, i - start)); + start = i + 2; + ++i; // Skip \n + } + else + { + lines_.push_back(item_.text.substr(start, i - start)); + start = i + 1; + } + } + } + + // 添加最后一行 + if (start <= item_.text.length()) + { + lines_.push_back(item_.text.substr(start)); + } + } + + void Document::UpdateLineOffsets() + { + line_offsets_.clear(); + line_offsets_.reserve(lines_.size() + 1); + + size_t offset = 0; + line_offsets_.push_back(0); + + for (size_t i = 0; i < item_.text.length(); ++i) + { + if (item_.text[i] == '\n') + { + line_offsets_.push_back(i + 1); + } + else if (item_.text[i] == '\r') + { + if (i + 1 < item_.text.length() && item_.text[i + 1] == '\n') + { + line_offsets_.push_back(i + 2); + ++i; + } + else + { + line_offsets_.push_back(i + 1); + } + } + } + } + + bool Document::IsWordChar(char c) const + { + return std::isalnum(static_cast(c)) || c == '_' || c == '$'; + } + + size_t Document::CharacterToByteOffset(const std::string& line, int32_t character) const + { + if (encoding_ == protocol::PositionEncodingKindLiterals::UTF8) + { + return std::min(static_cast(character), line.length()); + } + + // UTF-16 编码:需要正确计算 + size_t byteOffset = 0; + int32_t charCount = 0; + + while (byteOffset < line.length() && charCount < character) + { + unsigned char c = line[byteOffset]; + + if (encoding_ == protocol::PositionEncodingKindLiterals::UTF16) + { + // UTF-16: 计算代码单元 + if ((c & 0x80) == 0) + { + // ASCII + byteOffset += 1; + charCount += 1; + } + else if ((c & 0xE0) == 0xC0) + { + // 2字节UTF-8 -> 1个UTF-16单元 + byteOffset += 2; + charCount += 1; + } + else if ((c & 0xF0) == 0xE0) + { + // 3字节UTF-8 -> 1个UTF-16单元 + byteOffset += 3; + charCount += 1; + } + else if ((c & 0xF8) == 0xF0) + { + // 4字节UTF-8 -> 2个UTF-16单元(代理对) + byteOffset += 4; + charCount += 2; + } + else + { + byteOffset += 1; // 错误情况 + } + } + else // UTF32 + { + // UTF-32: 每个Unicode代码点算一个 + if ((c & 0x80) == 0) + { + byteOffset += 1; + } + else if ((c & 0xE0) == 0xC0) + { + byteOffset += 2; + } + else if ((c & 0xF0) == 0xE0) + { + byteOffset += 3; + } + else if ((c & 0xF8) == 0xF0) + { + byteOffset += 4; + } + else + { + byteOffset += 1; + } + charCount += 1; + } + } + + return byteOffset; + } + + int32_t Document::ByteOffsetToCharacter(const std::string& line, size_t byteOffset) const + { + if (encoding_ == protocol::PositionEncodingKindLiterals::UTF8) + { + return static_cast(byteOffset); + } + + int32_t charCount = 0; + size_t pos = 0; + + while (pos < byteOffset && pos < line.length()) + { + unsigned char c = line[pos]; + + if (encoding_ == protocol::PositionEncodingKindLiterals::UTF16) + { + if ((c & 0x80) == 0) + { + pos += 1; + charCount += 1; + } + else if ((c & 0xE0) == 0xC0) + { + pos += 2; + charCount += 1; + } + else if ((c & 0xF0) == 0xE0) + { + pos += 3; + charCount += 1; + } + else if ((c & 0xF8) == 0xF0) + { + pos += 4; + charCount += 2; // 代理对 + } + else + { + pos += 1; + charCount += 1; + } + } + else // UTF32 + { + if ((c & 0x80) == 0) + { + pos += 1; + } + else if ((c & 0xE0) == 0xC0) + { + pos += 2; + } + else if ((c & 0xF0) == 0xE0) + { + pos += 3; + } + else if ((c & 0xF8) == 0xF0) + { + pos += 4; + } + else + { + pos += 1; + } + charCount += 1; + } + } + + return charCount; + } + + // ===== DocumentManager 实现 ===== + + void DocumentManager::DidOpenTextDocument(const protocol::DidOpenTextDocumentParams& params) + { + std::unique_lock lock(mutex_); + + // 检查文档大小 + if (config_.max_document_size > 0 && + params.textDocument.text.length() > config_.max_document_size) + { + spdlog::error("Document {} exceeds maximum size ({} > {})", + params.textDocument.uri, + params.textDocument.text.length(), + config_.max_document_size); + return; + } + + // 创建新文档 + auto doc = std::make_shared(params.textDocument); + doc->SetEncoding(config_.default_encoding); + + documents_[params.textDocument.uri] = doc; + + spdlog::info("Opened document: {} (version {}, {} bytes, language: {})", + params.textDocument.uri, + params.textDocument.version, + params.textDocument.text.length(), + params.textDocument.languageId); + } + + void DocumentManager::DidChangeTextDocument(const protocol::DidChangeTextDocumentParams& params) + { + std::unique_lock lock(mutex_); + + auto it = documents_.find(params.textDocument.uri); + if (it == documents_.end()) + { + spdlog::error("Attempt to change non-existent document: {}", + params.textDocument.uri); + return; + } + + auto& doc = it->second; + + // 版本检查 + if (params.textDocument.version) + { + protocol::integer expectedVersion = params.textDocument.version; + if (expectedVersion <= doc->GetVersion()) + { + spdlog::warn("Ignoring stale change for {}: version {} <= current {}", + params.textDocument.uri, + expectedVersion, + doc->GetVersion()); + return; + } + } + + // 应用变更 + if (params.contentChanges.empty()) + { + spdlog::warn("Empty content changes for document: {}", params.textDocument.uri); + return; + } + + // 检查是全文还是增量 + if (params.contentChanges.size() == 1) + { + // 全文更新 + // doc->SetContent(params.textDocument.version(doc->GetVersion() + 1), params.contentChanges[0].text); + } + else + { + // 增量更新 + // doc->ApplyContentChanges( params.textDocument.version(doc->GetVersion() + 1), params.contentChanges); + } + + spdlog::debug("Changed document: {} to version {} ({} changes)", params.textDocument.uri, doc->GetVersion(), params.contentChanges.size()); + } + + void DocumentManager::DidCloseTextDocument(const protocol::DidCloseTextDocumentParams& params) + { + std::unique_lock lock(mutex_); + + auto it = documents_.find(params.textDocument.uri); + if (it == documents_.end()) + { + spdlog::warn("Attempt to close non-existent document: {}", + params.textDocument.uri); + return; + } + + documents_.erase(it); + spdlog::info("Closed document: {}", params.textDocument.uri); + } + + void DocumentManager::DidSaveTextDocument(const protocol::DidSaveTextDocumentParams& params) + { + std::shared_lock lock(mutex_); + + auto it = documents_.find(params.textDocument.uri); + if (it == documents_.end()) + { + spdlog::warn("Attempt to save non-existent document: {}", + params.textDocument.uri); + return; + } + + it->second->SetDirty(false); + + // 如果保存通知包含文本,可以验证同步状态 + if (params.text.has_value()) + { + if (params.text.value() != it->second->GetText()) + { + spdlog::error("Document content mismatch on save for: {}", params.textDocument.uri); + } + } + + spdlog::info("Saved document: {}", params.textDocument.uri); + } + + std::shared_ptr DocumentManager::GetDocument(const std::string& uri) const + { + std::shared_lock lock(mutex_); + + auto it = documents_.find(uri); + if (it != documents_.end()) + { + return it->second; + } + + return nullptr; + } + + std::vector DocumentManager::GetAllUris() const + { + std::shared_lock lock(mutex_); + + std::vector uris; + uris.reserve(documents_.size()); + + for (const auto& [uri, doc] : documents_) + { + uris.push_back(uri); + } + + return uris; + } + + std::vector> DocumentManager::GetAllDocuments() const + { + std::shared_lock lock(mutex_); + + std::vector> docs; + docs.reserve(documents_.size()); + + for (const auto& [uri, doc] : documents_) + { + docs.push_back(doc); + } + + return docs; + } + + std::vector> DocumentManager::GetDocumentsByLanguage( + const std::string& languageId) const + { + std::shared_lock lock(mutex_); + + std::vector> docs; + + for (const auto& [uri, doc] : documents_) + { + if (doc->GetLanguageId() == languageId) + { + docs.push_back(doc); + } + } + + return docs; + } + + bool DocumentManager::HasDocument(const std::string& uri) const + { + std::shared_lock lock(mutex_); + return documents_.find(uri) != documents_.end(); + } + + size_t DocumentManager::GetDocumentCount() const + { + std::shared_lock lock(mutex_); + return documents_.size(); + } + + std::vector DocumentManager::GetDirtyDocuments() const + { + std::shared_lock lock(mutex_); + + std::vector dirtyUris; + + for (const auto& [uri, doc] : documents_) + { + if (doc->IsDirty()) + { + dirtyUris.push_back(uri); + } + } + + return dirtyUris; + } + + std::string DocumentManager::ResolveUri(const std::string& uri) const + { + // 如果已经是绝对URI,直接返回 + if (utils::IsFileUri(uri)) + { + return uri; + } + + // 尝试相对于工作区文件夹解析 + for (const auto& folder : workspace_folders_) + { + std::string folderPath = utils::UriToPath(folder.uri); + std::string resolvedPath = folderPath + "/" + uri; + + // 检查文件是否存在(这里简化处理) + return utils::PathToUri(resolvedPath); + } + + return uri; + } + + // ===== 工具函数实现 ===== + + namespace utils + { + std::string NormalizeUri(const std::string& uri) + { + std::string normalized = uri; + + // 确保使用正斜杠 + std::replace(normalized.begin(), normalized.end(), '\\', '/'); + + // 移除重复的斜杠 + auto newEnd = std::unique(normalized.begin(), normalized.end(), [](char a, char b) { return a == '/' && b == '/'; }); + normalized.erase(newEnd, normalized.end()); + + return normalized; + } + + std::string UriToPath(const std::string& uri) + { + if (uri.substr(0, 7) == "file://") + { + std::string path = uri.substr(7); + +// Windows路径处理 +#ifdef _WIN32 + if (path.length() >= 3 && path[0] == '/' && + std::isalpha(path[1]) && path[2] == ':') + { + path = path.substr(1); + } +#endif + + return path; + } + + return uri; + } + + std::string PathToUri(const std::string& path) + { + std::string uri = "file://"; + +#ifdef _WIN32 + // Windows路径 + if (path.length() >= 2 && std::isalpha(path[0]) && path[1] == ':') + { + uri += "/"; + } +#endif + + uri += path; + return NormalizeUri(uri); + } + + bool IsFileUri(const std::string& uri) + { + return uri.substr(0, 7) == "file://"; + } + + protocol::TextEdit CreateReplace(const protocol::Range& range, const std::string& newText) + { + protocol::TextEdit edit; + edit.range = range; + edit.newText = newText; + return edit; + } + + protocol::TextEdit CreateInsert(const protocol::Position& position, const std::string& text) + { + return CreateReplace(protocol::Range{ position, position }, text); + } + + protocol::TextEdit CreateDelete(const protocol::Range& range) + { + return CreateReplace(range, ""); + } + + bool IsPositionInRange(const protocol::Position& position, const protocol::Range& range) + { + if (position.line < range.start.line || position.line > range.end.line) + { + return false; + } + + if (position.line == range.start.line && position.character < range.start.character) + { + return false; + } + + if (position.line == range.end.line && position.character >= range.end.character) + { + return false; + } + + return true; + } + + bool IsRangeOverlapping(const protocol::Range& a, const protocol::Range& b) + { + return !(a.end.line < b.start.line || + (a.end.line == b.start.line && a.end.character <= b.start.character) || + b.end.line < a.start.line || + (b.end.line == a.start.line && b.end.character <= a.start.character)); + } + + protocol::Range ExtendRange(const protocol::Range& range, int32_t lines) + { + protocol::Range extended = range; + extended.start.line = std::max(static_cast(0), static_cast(extended.start.line - lines)); + extended.end.line += lines; + return extended; + } + + std::string ApplyTextEdits(const std::string& text, + const std::vector& edits) + { + if (edits.empty()) + { + return text; + } + + // 排序编辑(从后向前,避免偏移问题) + std::vector sortedEdits = edits; + std::sort(sortedEdits.begin(), sortedEdits.end(), [](const protocol::TextEdit& a, const protocol::TextEdit& b) { + if (a.range.start.line != b.range.start.line) + { + return a.range.start.line > b.range.start.line; + } + return a.range.start.character > b.range.start.character; + }); + + // 创建临时文档来应用编辑 + protocol::TextDocumentItem tempItem; + tempItem.uri = "temp://"; + tempItem.languageId = ""; + tempItem.version = 0; + tempItem.text = text; + + Document tempDoc(tempItem); + std::string result = text; + + for (const auto& edit : sortedEdits) + { + size_t start = tempDoc.PositionToOffset(edit.range.start); + size_t end = tempDoc.PositionToOffset(edit.range.end); + + result = result.substr(0, start) + edit.newText + result.substr(end); + + // 更新临时文档 + tempDoc.SetContent(0, result); + } + + return result; + } + } +} diff --git a/lsp-server/src/services/document.hpp b/lsp-server/src/services/document.hpp new file mode 100644 index 0000000..54f6e04 --- /dev/null +++ b/lsp-server/src/services/document.hpp @@ -0,0 +1,238 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include "../protocol/protocol.hpp" + +namespace lsp::services +{ + class Document + { + public: + Document(const protocol::TextDocumentItem& item); + + const protocol::DocumentUri& GetUri() const { return item_.uri; } + const protocol::string& GetLanguageId() const { return item_.languageId; } + protocol::integer GetVersion() const { return item_.version; } + const protocol::string& GetText() const { return item_.text; } + const protocol::TextDocumentItem& GetItem() const { return item_; } + + protocol::TextDocumentIdentifier GetIdentifier() const + { + return protocol::TextDocumentIdentifier{item_.uri}; + } + + protocol::VersionedTextDocumentIdentifier GetVersionedIdentifier() const + { + protocol::VersionedTextDocumentIdentifier id; + id.uri = item_.uri; + id.version = item_.version; + return id; + } + + // ===== 位置和范围操作 ===== + size_t PositionToOffset(const protocol::Position& position) const; + protocol::Position OffsetToPosition(size_t offset) const; + std::string GetTextInRange(const protocol::Range& range) const; + + // 内容更新 + void SetContent(protocol::integer version, const protocol::string& new_text); + void ApplyContentChange(protocol::integer version, const std::vector& changes); + + // 获取指定位置的字符 + std::optional GetCharAt(const protocol::Position& position) const; + + // ===== 行操作 ===== + const std::vector& GetLines() const { return lines_; } + size_t GetLineCount() const { return lines_.size(); } + std::string GetLine(size_t lineNumber) const; + std::string GetLineAt(const protocol::Position& position) const; + + // ===== 单词和符号操作 ===== + std::string GetWordAt(const protocol::Position& position) const; + protocol::Range GetWordRangeAt(const protocol::Position& position) const; + + // ===== 实用方法 ===== + // 创建一个Location + protocol::Location CreateLocation(const protocol::Range& range) const + { + return protocol::Location{item_.uri, range}; + } + + // 创建一个TextDocumentPositionParams + protocol::TextDocumentPositionParams CreatePositionParams(const protocol::Position& position) const + { + protocol::TextDocumentPositionParams params; + params.textDocument = GetIdentifier(); + params.position = position; + return params; + } + + // ===== 元数据 ===== + // 文档是否被修改(相对于上次保存) + bool IsDirty() const { return is_dirty_; } + void SetDirty(bool dirty) { is_dirty_ = dirty; } + + // 最后修改时间 + std::chrono::system_clock::time_point GetLastModified() const { return last_modified_time_; } + + void SetEncoding(protocol::PositionEncodingKind encoding) { encoding_ = encoding; } + protocol::PositionEncodingKind GetEncoding() const { return encoding_; } + + private: + // 更新内部缓存 + void UpdateInternalState(); + void UpdateLines(); + void UpdateLineOffsets(); + + // 辅助方法 + bool IsWordChar(char c) const; + size_t CharacterToByteOffset(const std::string& line, std::int32_t character) const; + std::int32_t ByteOffsetToCharacter(const std::string& line, size_t byteOffset) const; + + // 应用单个内容变更 + void ApplyContentChange(const protocol::TextDocumentContentChangeEvent& change); + + private: + protocol::TextDocumentItem item_; + + // 缓存行的信息 + std::vector lines_; + std::vector line_offsets_; + + bool is_dirty_ = false; + std::chrono::system_clock::time_point last_modified_time_; + protocol::PositionEncodingKind encoding_ = protocol::PositionEncodingKindLiterals::UTF16; + }; + + /** + * 文档管理器 - 使用protocol类型作为接口 + */ + class DocumentManager + { + public: + DocumentManager() = default; + ~DocumentManager() = default; + + // 禁止拷贝 + DocumentManager(const DocumentManager&) = delete; + DocumentManager& operator=(const DocumentManager&) = delete; + + // ===== 文档生命周期管理 - 直接使用protocol类型 ===== + + // 处理 textDocument/didOpen + void DidOpenTextDocument(const protocol::DidOpenTextDocumentParams& params); + + // 处理 textDocument/didChange + void DidChangeTextDocument(const protocol::DidChangeTextDocumentParams& params); + + // 处理 textDocument/didClose + void DidCloseTextDocument(const protocol::DidCloseTextDocumentParams& params); + + // 处理 textDocument/didSave + void DidSaveTextDocument(const protocol::DidSaveTextDocumentParams& params); + + // ===== 文档访问 - 支持多种查询方式 ===== + + // 通过URI获取 + std::shared_ptr GetDocument(const std::string& uri) const; + + // 通过标识符获取 + std::shared_ptr GetDocument(const protocol::TextDocumentIdentifier& identifier) const + { + return GetDocument(identifier.uri); + } + + // 通过版本化标识符获取 + std::shared_ptr GetDocument(const protocol::VersionedTextDocumentIdentifier& identifier) const + { + auto doc = GetDocument(identifier.uri); + if (doc && identifier.version && doc->GetVersion() != identifier.version) + { + spdlog::warn("Version mismatch for {}: expected {}, got {}", identifier.uri, identifier.version, doc->GetVersion()); + } + return doc; + } + + // 通过TextDocumentPositionParams获取 + std::shared_ptr GetDocument(const protocol::TextDocumentPositionParams& params) const + { + return GetDocument(params.textDocument); + } + + // ===== 批量操作 ===== + std::vector GetAllUris() const; + std::vector> GetAllDocuments() const; + std::vector> GetDocumentsByLanguage(const std::string& languageId) const; + + // ===== 查询 ===== + bool HasDocument(const std::string& uri) const; + bool IsDocumentOpen(const protocol::TextDocumentIdentifier& identifier) const + { + return HasDocument(identifier.uri); + } + + size_t GetDocumentCount() const; + + // ===== 诊断支持 ===== + // 获取需要诊断的文档(已修改的) + std::vector GetDirtyDocuments() const; + + // ===== 工作区支持 ===== + // 设置工作区文件夹(用于相对路径解析) + void SetWorkspaceFolders(const std::vector& folders) + { + workspace_folders_ = folders; + } + + const std::vector& GetWorkspaceFolders() const + { + return workspace_folders_; + } + + // 解析相对URI + std::string ResolveUri(const std::string& uri) const; + + // ===== 配置 ===== + struct Configuration { + size_t max_document_size = 10 * 1024 * 1024; // 10MB + bool validate_utf8 = true; + protocol::PositionEncodingKind default_encoding = protocol::PositionEncodingKindLiterals::UTF16; + }; + + void SetConfiguration(const Configuration& config) { config_ = config; } + const Configuration& GetConfiguration() const { return config_; } + + private: + mutable std::shared_mutex mutex_; + std::unordered_map> documents_; + std::vector workspace_folders_; + Configuration config_; + }; + + // ===== 工具函数 ===== + namespace utils + { + // URI处理 + std::string NormalizeUri(const std::string& uri); + std::string UriToPath(const std::string& uri); + std::string PathToUri(const std::string& path); + bool IsFileUri(const std::string& uri); + + // 创建TextEdit + protocol::TextEdit CreateReplace(const protocol::Range& range, const std::string& newText); + protocol::TextEdit CreateInsert(const protocol::Position& position, const std::string& text); + protocol::TextEdit CreateDelete(const protocol::Range& range); + + // 范围操作 + bool IsPositionInRange(const protocol::Position& position, const protocol::Range& range); + bool IsRangeOverlapping(const protocol::Range& a, const protocol::Range& b); + protocol::Range ExtendRange(const protocol::Range& range, int32_t lines); + + // 应用编辑 + std::string ApplyTextEdits(const std::string& text, const std::vector& edits); + } +} diff --git a/lsp-server/src/utils/args_parser.cpp b/lsp-server/src/utils/args_parser.cpp index 8014fbf..5490221 100644 --- a/lsp-server/src/utils/args_parser.cpp +++ b/lsp-server/src/utils/args_parser.cpp @@ -103,7 +103,7 @@ namespace lsp::utils << "Usage: " << program_name << " [options]\n" << "\n" << "Options:\n" - << " --help, -h Show this help message and exit\n" + << " --help Show this help message and exit\n" << "\n" << "Logging options:\n" << " --log=LEVEL Set log level (trace, debug, info, warn, error, off)\n"