diff --git a/lsp-server/src/CMakeLists.txt b/lsp-server/src/CMakeLists.txt index d570b39..9e326b9 100644 --- a/lsp-server/src/CMakeLists.txt +++ b/lsp-server/src/CMakeLists.txt @@ -60,13 +60,15 @@ set(SOURCES manager/symbol.cppm manager/detail/text_document.cppm manager/manager_hub.cppm + language/semantic/graph/types.cppm language/semantic/graph/call.cppm language/semantic/graph/inheritance.cppm language/semantic/graph/reference.cppm language/semantic/interface.cppm - language/semantic/semantic_model.cppm + language/semantic/type_system.types.cppm language/semantic/type_system.cppm language/semantic/name_resolver.cppm + language/semantic/semantic_model.cppm language/semantic/analyzer.cppm language/semantic/token_collector.cppm tree-sitter/parser.c @@ -132,11 +134,13 @@ target_sources( language/symbol/symbol.cppm language/semantic/interface.cppm language/semantic/semantic.cppm + language/semantic/type_system.types.cppm language/semantic/type_system.cppm language/semantic/name_resolver.cppm language/semantic/semantic_model.cppm language/semantic/analyzer.cppm language/semantic/token_collector.cppm + language/semantic/graph/types.cppm language/semantic/graph/call.cppm language/semantic/graph/inheritance.cppm language/semantic/graph/reference.cppm diff --git a/lsp-server/src/core/server.cppm b/lsp-server/src/core/server.cppm index feeaad5..e18c775 100644 --- a/lsp-server/src/core/server.cppm +++ b/lsp-server/src/core/server.cppm @@ -80,8 +80,8 @@ export namespace lsp::core private: RequestDispatcher dispatcher_; - scheduler::AsyncExecutor async_executor_; manager::ManagerHub manager_hub_; + scheduler::AsyncExecutor async_executor_; std::string interpreter_path_; std::atomic is_initialized_ = false; @@ -92,7 +92,8 @@ export namespace lsp::core namespace lsp::core { - LspServer::LspServer(std::size_t concurrency, std::string interpreter_path) : async_executor_(concurrency), + LspServer::LspServer(std::size_t concurrency, std::string interpreter_path) : manager_hub_(), + async_executor_(concurrency), interpreter_path_(std::move(interpreter_path)) { spdlog::info("Initializing LSP server with {} worker threads", concurrency); diff --git a/lsp-server/src/language/semantic/analyzer.cppm b/lsp-server/src/language/semantic/analyzer.cppm index 5978adb..f82bb30 100644 --- a/lsp-server/src/language/semantic/analyzer.cppm +++ b/lsp-server/src/language/semantic/analyzer.cppm @@ -4,11 +4,142 @@ export module lsp.language.semantic:analyzer; import std; -import :interface; +import :semantic_model; +import :type_system; import lsp.language.ast; import lsp.language.symbol; import lsp.utils.string; +export namespace lsp::language::semantic +{ + class Analyzer : public ast::ASTVisitor + { + public: + explicit Analyzer(symbol::SymbolTable& symbol_table, SemanticModel& semantic_model); + + using ExternalSymbolProvider = std::function(const std::string&)>; + void SetExternalSymbolProvider(ExternalSymbolProvider provider); + + void Analyze(ast::ASTNode& root); + + void VisitProgram(ast::Program& node) override; + void VisitUnitDefinition(ast::UnitDefinition& node) override; + void VisitClassDefinition(ast::ClassDefinition& node) override; + void VisitFunctionDefinition(ast::FunctionDefinition& node) override; + void VisitFunctionDeclaration(ast::FunctionDeclaration& node) override; + void VisitMethodDeclaration(ast::MethodDeclaration& node) override; + void VisitExternalMethodDefinition(ast::ExternalMethodDefinition& node) override; + + void VisitVarDeclaration(ast::VarDeclaration& node) override; + void VisitStaticDeclaration(ast::StaticDeclaration& node) override; + void VisitGlobalDeclaration(ast::GlobalDeclaration& node) override; + void VisitConstDeclaration(ast::ConstDeclaration& node) override; + void VisitFieldDeclaration(ast::FieldDeclaration& node) override; + void VisitClassMember(ast::ClassMember& node) override; + void VisitPropertyDeclaration(ast::PropertyDeclaration& node) override; + + void VisitBlockStatement(ast::BlockStatement& node) override; + void VisitIfStatement(ast::IfStatement& node) override; + void VisitForInStatement(ast::ForInStatement& node) override; + void VisitForToStatement(ast::ForToStatement& node) override; + void VisitWhileStatement(ast::WhileStatement& node) override; + void VisitRepeatStatement(ast::RepeatStatement& node) override; + void VisitCaseStatement(ast::CaseStatement& node) override; + void VisitTryStatement(ast::TryStatement& node) override; + void VisitLabelStatement(ast::LabelStatement& node) override; + void VisitGotoStatement(ast::GotoStatement& node) override; + + void VisitUsesStatement(ast::UsesStatement& node) override; + + void VisitIdentifier(ast::Identifier& node) override; + void VisitCallExpression(ast::CallExpression& node) override; + void VisitAttributeExpression(ast::AttributeExpression& node) override; + void VisitAssignmentExpression(ast::AssignmentExpression& node) override; + + void VisitLiteral(ast::Literal& node) override; + void VisitBinaryExpression(ast::BinaryExpression& node) override; + void VisitTernaryExpression(ast::TernaryExpression& node) override; + void VisitSubscriptExpression(ast::SubscriptExpression& node) override; + void VisitArrayExpression(ast::ArrayExpression& node) override; + void VisitAnonymousFunctionExpression(ast::AnonymousFunctionExpression& node) override; + + void VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) override; + void VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) override; + void VisitPrefixIncrementExpression(ast::PrefixIncrementExpression& node) override; + void VisitPrefixDecrementExpression(ast::PrefixDecrementExpression& node) override; + void VisitPostfixIncrementExpression(ast::PostfixIncrementExpression& node) override; + void VisitPostfixDecrementExpression(ast::PostfixDecrementExpression& node) override; + void VisitLogicalNotExpression(ast::LogicalNotExpression& node) override; + void VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) override; + void VisitDerivativeExpression(ast::DerivativeExpression& node) override; + void VisitMatrixTransposeExpression(ast::MatrixTransposeExpression& node) override; + void VisitExprOperatorExpression(ast::ExprOperatorExpression& node) override; + void VisitFunctionPointerExpression(ast::FunctionPointerExpression& node) override; + + void VisitNewExpression(ast::NewExpression& node) override; + void VisitEchoExpression(ast::EchoExpression& node) override; + void VisitRaiseExpression(ast::RaiseExpression& node) override; + void VisitInheritedExpression(ast::InheritedExpression& node) override; + void VisitRdoExpression(ast::RdoExpression& node) override; + void VisitParenthesizedExpression(ast::ParenthesizedExpression& node) override; + + void VisitExpressionStatement(ast::ExpressionStatement& node) override; + void VisitBreakStatement(ast::BreakStatement& node) override; + void VisitContinueStatement(ast::ContinueStatement& node) override; + void VisitReturnStatement(ast::ReturnStatement& node) override; + void VisitTSSQLExpression(ast::TSSQLExpression& node) override; + void VisitColumnReference(ast::ColumnReference& node) override; + void VisitUnpackPattern(ast::UnpackPattern& node) override; + void VisitMatrixIterationStatement(ast::MatrixIterationStatement& node) override; + + void VisitCompilerDirective(ast::CompilerDirective& node) override; + void VisitConditionalDirective(ast::ConditionalDirective& node) override; + void VisitConditionalBlock(ast::ConditionalBlock& node) override; + void VisitTSLXBlock(ast::TSLXBlock& node) override; + + void VisitParameter(ast::Parameter& node) override; + + private: + void VisitStatements(const std::vector& statements); + void VisitExpression(ast::Expression& expr); + void ProcessLValue(const ast::LValue& lvalue); + std::optional ResolveParentClass(const ast::ClassDefinition::ParentClass& parent); + std::optional ScopeAt(const ast::Location& location) const; + std::optional ResolveByName(const std::string& name); + std::optional FindMethodInClass(symbol::SymbolId class_id, const std::string& method_name) const; + std::optional FindScopeOwnedBy(symbol::SymbolId owner_id) const; + + std::optional ResolveIdentifier(const std::string& name, const ast::Location& location); + std::optional ResolveFromUses(const std::string& name); + void TrackReference(symbol::SymbolId symbol_id, const ast::Location& location, bool is_write = false); + void TrackCall(symbol::SymbolId callee, const ast::Location& location); + std::shared_ptr InferExpressionType(ast::Expression& expr); + std::shared_ptr GetDeclaredTypeForSymbol(symbol::SymbolId symbol_id); + std::optional ResolveClassSymbol(const std::string& name, const ast::Location& location); + std::optional ResolveLValueSymbol(const ast::LValue& lvalue); + void RegisterParameterTypes(symbol::SymbolId function_id, const std::vector>& parameters); + + private: + struct UnitContext + { + std::string unit_name; + std::vector interface_imports; + std::vector implementation_imports; + }; + + symbol::SymbolTable& symbol_table_; + SemanticModel& semantic_model_; + + std::optional current_function_id_; + std::optional current_class_id_; + std::optional current_unit_context_; + std::optional current_unit_section_; + std::vector file_imports_; + ExternalSymbolProvider external_symbol_provider_; + std::unordered_map imported_symbols_; + }; +} + namespace lsp::language::semantic { namespace diff --git a/lsp-server/src/language/semantic/graph/call.cppm b/lsp-server/src/language/semantic/graph/call.cppm index c583dcc..b4635db 100644 --- a/lsp-server/src/language/semantic/graph/call.cppm +++ b/lsp-server/src/language/semantic/graph/call.cppm @@ -4,9 +4,31 @@ export module lsp.language.semantic:graph.call; import std; -import :interface; +import :graph.types; import lsp.language.ast; +export namespace lsp::language::semantic::graph +{ + using symbol::SymbolId; + + class Call : public ISemanticGraph + { + public: + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; + + void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location); + + const std::vector& callers(SymbolId id) const; + + const std::vector& callees(SymbolId id) const; + + private: + std::unordered_map> callers_map_; + std::unordered_map> callees_map_; + }; +} + namespace lsp::language::semantic::graph { diff --git a/lsp-server/src/language/semantic/graph/inheritance.cppm b/lsp-server/src/language/semantic/graph/inheritance.cppm index 1574419..fd3f8d7 100644 --- a/lsp-server/src/language/semantic/graph/inheritance.cppm +++ b/lsp-server/src/language/semantic/graph/inheritance.cppm @@ -4,7 +4,31 @@ export module lsp.language.semantic:graph.inheritance; import std; -import :interface; +import :graph.types; +export namespace lsp::language::semantic::graph +{ + using symbol::SymbolId; + + class Inheritance : public ISemanticGraph + { + public: + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; + + void AddInheritance(SymbolId derived, SymbolId base); + + const std::vector& base_classes(SymbolId id) const; + + const std::vector& derived_classes(SymbolId id) const; + + bool IsSubclassOf(SymbolId derived, SymbolId base) const; + + private: + std::unordered_map> base_classes_; + std::unordered_map> derived_classes_; + }; +} + namespace lsp::language::semantic::graph { diff --git a/lsp-server/src/language/semantic/graph/reference.cppm b/lsp-server/src/language/semantic/graph/reference.cppm index efa6129..f743dc4 100644 --- a/lsp-server/src/language/semantic/graph/reference.cppm +++ b/lsp-server/src/language/semantic/graph/reference.cppm @@ -4,9 +4,30 @@ export module lsp.language.semantic:graph.reference; import std; -import :interface; +import :graph.types; import lsp.language.ast; +export namespace lsp::language::semantic::graph +{ + using symbol::SymbolId; + + class Reference : public ISemanticGraph + { + public: + void OnSymbolRemoved(SymbolId id) override; + void Clear() override; + + void AddReference(SymbolId symbol_id, const ast::Location& location, bool is_definition = false, bool is_write = false); + + const std::vector& references(SymbolId id) const; + + std::optional FindDefinitionLocation(SymbolId id) const; + + private: + std::unordered_map> references_; + }; +} + namespace lsp::language::semantic::graph { diff --git a/lsp-server/src/language/semantic/graph/types.cppm b/lsp-server/src/language/semantic/graph/types.cppm new file mode 100644 index 0000000..b03dc09 --- /dev/null +++ b/lsp-server/src/language/semantic/graph/types.cppm @@ -0,0 +1,47 @@ +module; + +export module lsp.language.semantic:graph.types; + +import std; + +import lsp.language.ast; +import lsp.language.symbol; + +export namespace lsp::language::semantic +{ + struct Reference + { + ast::Location location; + symbol::SymbolId symbol_id; + bool is_definition; + bool is_write; + }; + + struct Call + { + symbol::SymbolId caller; + symbol::SymbolId callee; + ast::Location call_site; + }; + + struct Inheritance + { + symbol::SymbolId derived; + symbol::SymbolId base; + }; + + class ISemanticGraph + { + public: + virtual ~ISemanticGraph() = default; + + virtual void OnSymbolRemoved(symbol::SymbolId id) = 0; + + virtual void Clear() = 0; + }; +} + +export namespace lsp::language::semantic::graph +{ + using symbol::SymbolId; +} diff --git a/lsp-server/src/language/semantic/interface.cppm b/lsp-server/src/language/semantic/interface.cppm index 9547cda..4720d58 100644 --- a/lsp-server/src/language/semantic/interface.cppm +++ b/lsp-server/src/language/semantic/interface.cppm @@ -1,815 +1,5 @@ module; export module lsp.language.semantic:interface; -import tree_sitter; -import std; - -import lsp.language.ast; -import lsp.language.symbol; -import lsp.utils.string; -import lsp.protocol; - -export namespace lsp::language::semantic -{ - struct Reference - { - ast::Location location; - symbol::SymbolId symbol_id; - bool is_definition; - bool is_write; - }; - - struct Call - { - symbol::SymbolId caller; - symbol::SymbolId callee; - ast::Location call_site; - }; - - struct Inheritance - { - symbol::SymbolId derived; - symbol::SymbolId base; - }; - - class ISemanticGraph - { - public: - virtual ~ISemanticGraph() = default; - - virtual void OnSymbolRemoved(symbol::SymbolId id) = 0; - - virtual void Clear() = 0; - }; - - enum class TypeKind - { - kPrimitive, - kClass, - kArray, - kFunction, - kOptional, - kVoid, - kUnknown, - kError - }; - - enum class PrimitiveTypeKind - { - kInt, - kFloat, - kString, - kBool, - kChar - }; - - struct TypeCompatibility - { - bool is_compatible = false; - int conversion_cost = -1; - bool requires_cast = false; - - static TypeCompatibility Exact() - { - return { true, 0, false }; - } - - static TypeCompatibility Implicit(int cost) - { - return { true, cost, false }; - } - - static TypeCompatibility ExplicitCast(int cost) - { - return { true, cost, true }; - } - - static TypeCompatibility Incompatible() - { - return { false, -1, false }; - } - }; - - class Type; - - class PrimitiveType - { - public: - explicit PrimitiveType(PrimitiveTypeKind kind) : kind_(kind) {} - - PrimitiveTypeKind kind() const { return kind_; } - std::string ToString() const; - - bool operator==(const PrimitiveType& other) const - { - return kind_ == other.kind_; - } - - private: - PrimitiveTypeKind kind_; - }; - - class ClassType - { - public: - explicit ClassType(symbol::SymbolId class_id) : class_id_(class_id) {} - - symbol::SymbolId class_id() const { return class_id_; } - - bool operator==(const ClassType& other) const - { - return class_id_ == other.class_id_; - } - - private: - symbol::SymbolId class_id_; - }; - - class ArrayType - { - public: - explicit ArrayType(std::shared_ptr element_type) : element_type_(std::move(element_type)) {} - - const Type& element_type() const { return *element_type_; } - std::shared_ptr element_type_ptr() const { return element_type_; } - - private: - std::shared_ptr element_type_; - }; - - class FunctionType - { - public: - FunctionType(std::vector> param_types, - std::shared_ptr return_type) : param_types_(std::move(param_types)), - return_type_(std::move(return_type)) {} - - const std::vector>& param_types() const - { - return param_types_; - } - - const Type& return_type() const { return *return_type_; } - std::shared_ptr return_type_ptr() const { return return_type_; } - - private: - std::vector> param_types_; - std::shared_ptr return_type_; - }; - - class OptionalType - { - public: - explicit OptionalType(std::shared_ptr inner_type) : inner_type_(std::move(inner_type)) {} - - const Type& inner_type() const { return *inner_type_; } - std::shared_ptr inner_type_ptr() const { return inner_type_; } - - private: - std::shared_ptr inner_type_; - }; - - class VoidType - { - public: - bool operator==(const VoidType&) const { return true; } - }; - - class UnknownType - { - public: - bool operator==(const UnknownType&) const { return true; } - }; - - class ErrorType - { - public: - explicit ErrorType(std::string message = "") : message_(std::move(message)) {} - - const std::string& message() const { return message_; } - - bool operator==(const ErrorType&) const { return true; } - - private: - std::string message_; - }; - - using TypeData = std::variant< - PrimitiveType, - ClassType, - ArrayType, - FunctionType, - OptionalType, - VoidType, - UnknownType, - ErrorType>; - - class Type - { - public: - explicit Type(TypeData data) : data_(std::move(data)) {} - - template - bool Is() const - { - return std::holds_alternative(data_); - } - - template - const T* As() const - { - return std::get_if(&data_); - } - - template - T* As() - { - return std::get_if(&data_); - } - - TypeKind kind() const; - std::string ToString() const; - - bool Equals(const Type& other) const; - - const TypeData& data() const { return data_; } - - private: - TypeData data_; - }; - - class TypeSystem - { - public: - TypeSystem(); - - std::shared_ptr GetIntType() const { return int_type_; } - std::shared_ptr GetFloatType() const { return float_type_; } - std::shared_ptr GetStringType() const { return string_type_; } - std::shared_ptr GetBoolType() const { return bool_type_; } - std::shared_ptr GetCharType() const { return char_type_; } - std::shared_ptr GetVoidType() const { return void_type_; } - std::shared_ptr GetUnknownType() const { return unknown_type_; } - std::shared_ptr GetErrorType() const { return error_type_; } - - std::shared_ptr CreateClassType(symbol::SymbolId class_id); - std::shared_ptr CreateArrayType(std::shared_ptr element_type); - std::shared_ptr CreateFunctionType( - std::vector> param_types, - std::shared_ptr return_type); - std::shared_ptr CreateOptionalType(std::shared_ptr inner_type); - - void RegisterClassType(const std::string& type_name, symbol::SymbolId class_id); - std::shared_ptr GetTypeByName(const std::string& type_name) const; - - std::shared_ptr GetSymbolType(symbol::SymbolId symbol_id) const; - - void RegisterSymbolType(symbol::SymbolId symbol_id, std::shared_ptr type); - - TypeCompatibility CheckCompatibility(const Type& from, const Type& to) const; - - bool IsAssignable(const Type& from, const Type& to) const; - - bool RequiresExplicitCast(const Type& from, const Type& to) const; - - std::shared_ptr InferBinaryExpressionType( - const Type& left, - const Type& right, - const std::string& op) const; - - std::shared_ptr InferUnaryExpressionType( - const Type& operand, - const std::string& op) const; - - std::shared_ptr InferLiteralType(const std::string& literal_value) const; - - void SetInheritanceChecker( - std::function checker) - { - is_subclass_of_ = std::move(checker); - } - - private: - std::shared_ptr int_type_; - std::shared_ptr float_type_; - std::shared_ptr string_type_; - std::shared_ptr bool_type_; - std::shared_ptr char_type_; - std::shared_ptr void_type_; - std::shared_ptr unknown_type_; - std::shared_ptr error_type_; - - std::unordered_map> type_by_name_; - - std::unordered_map> symbol_types_; - - std::function - is_subclass_of_; - - TypeCompatibility CheckPrimitiveCompatibility( - const PrimitiveType& from, - const PrimitiveType& to) const; - - TypeCompatibility CheckClassCompatibility( - const ClassType& from, - const ClassType& to) const; - - TypeCompatibility CheckArrayCompatibility( - const ArrayType& from, - const ArrayType& to) const; - - TypeCompatibility CheckFunctionCompatibility( - const FunctionType& from, - const FunctionType& to) const; - }; - - struct NameResolutionResult - { - symbol::SymbolId symbol_id = symbol::kInvalidSymbolId; - bool is_ambiguous = false; - std::vector candidates; - - bool IsResolved() const - { - return symbol_id != symbol::kInvalidSymbolId; - } - - static NameResolutionResult Success(symbol::SymbolId id) - { - return { id, false, { id } }; - } - - static NameResolutionResult Ambiguous(std::vector symbols) - { - return { - symbols.empty() ? symbol::kInvalidSymbolId : symbols[0], - true, - std::move(symbols) - }; - } - - static NameResolutionResult NotFound() - { - return { symbol::kInvalidSymbolId, false, {} }; - } - }; - - struct OverloadCandidate - { - symbol::SymbolId symbol_id; - int match_score = 0; - std::vector arg_conversions; - - bool operator<(const OverloadCandidate& other) const - { - return match_score > other.match_score; - } - }; - - class NameResolver - { - public: - explicit NameResolver(const symbol::SymbolTable& symbol_table, - const TypeSystem& type_system) : symbol_table_(symbol_table), - type_system_(type_system) {} - - NameResolutionResult ResolveName( - const std::string& name, - symbol::ScopeId scope_id, - bool search_parent = true) const; - - NameResolutionResult ResolveNameAtLocation( - const std::string& name, - const ast::Location& location) const; - - NameResolutionResult ResolveMemberAccess( - symbol::SymbolId object_symbol_id, - const std::string& member_name) const; - - NameResolutionResult ResolveClassMember( - symbol::SymbolId class_id, - const std::string& member_name, - bool static_only = false) const; - - NameResolutionResult ResolveFunctionCall( - const std::string& function_name, - const std::vector>& arg_types, - symbol::ScopeId scope_id) const; - - NameResolutionResult ResolveMethodCall( - symbol::SymbolId object_symbol_id, - const std::string& method_name, - const std::vector>& arg_types) const; - - NameResolutionResult ResolveQualifiedName( - const std::string& qualifier, - const std::string& name, - symbol::ScopeId scope_id) const; - - std::optional GetSymbolScope( - [[maybe_unused]] symbol::SymbolId symbol_id) const; - - bool IsSymbolVisibleInScope( - [[maybe_unused]] symbol::SymbolId symbol_id, - [[maybe_unused]] symbol::ScopeId scope_id) const; - - private: - const symbol::SymbolTable& symbol_table_; - const TypeSystem& type_system_; - - std::vector SearchScopeChain( - const std::string& name, - symbol::ScopeId start_scope) const; - - std::optional FindScopeOwnedBy(symbol::SymbolId owner) const; - - NameResolutionResult SelectBestOverload( - const std::vector& candidates) const; - - OverloadCandidate CalculateOverloadScore( - symbol::SymbolId candidate_id, - const std::vector>& arg_types) const; - - std::vector> GetParameterTypes( - symbol::SymbolId symbol_id) const; - - std::optional GetOwnerClassId( - [[maybe_unused]] symbol::SymbolId symbol_id) const; - - bool CheckMemberAccessibility( - [[maybe_unused]] const symbol::Symbol& member, - [[maybe_unused]] symbol::ScopeId access_scope) const; - }; - - namespace graph - { - using symbol::SymbolId; - - class Reference : public ISemanticGraph - { - public: - void OnSymbolRemoved(SymbolId id) override; - void Clear() override; - - void AddReference(SymbolId symbol_id, const ast::Location& location, bool is_definition = false, bool is_write = false); - - const std::vector& references(SymbolId id) const; - - std::optional FindDefinitionLocation(SymbolId id) const; - - private: - std::unordered_map> references_; - }; - - class Call : public ISemanticGraph - { - public: - void OnSymbolRemoved(SymbolId id) override; - void Clear() override; - - void AddCall(SymbolId caller, SymbolId callee, const ast::Location& location); - - const std::vector& callers(SymbolId id) const; - - const std::vector& callees(SymbolId id) const; - - private: - std::unordered_map> callers_map_; - std::unordered_map> callees_map_; - }; - - class Inheritance : public ISemanticGraph - { - public: - void OnSymbolRemoved(SymbolId id) override; - void Clear() override; - - void AddInheritance(SymbolId derived, SymbolId base); - - const std::vector& base_classes(SymbolId id) const; - - const std::vector& derived_classes(SymbolId id) const; - - bool IsSubclassOf(SymbolId derived, SymbolId base) const; - - private: - std::unordered_map> base_classes_; - std::unordered_map> derived_classes_; - }; - } // namespace graph - - class SemanticModel - { - public: - struct UnitImportSet - { - std::vector interface_imports; - std::vector implementation_imports; - }; - - explicit SemanticModel(const symbol::SymbolTable& symbol_table); - - void Clear(); - void OnSymbolRemoved(symbol::SymbolId id); - void AddReference(symbol::SymbolId symbol_id, const ast::Location& location, bool is_definition = false, bool is_write = false); - void AddInheritance(symbol::SymbolId derived, symbol::SymbolId base); - void AddCall(symbol::SymbolId caller, symbol::SymbolId callee, const ast::Location& location); - - graph::Reference& references() { return reference_graph_; } - graph::Inheritance& inheritance() { return inheritance_graph_; } - graph::Call& calls() { return call_graph_; } - - const graph::Reference& references() const { return reference_graph_; } - const graph::Inheritance& inheritance() const { return inheritance_graph_; } - const graph::Call& calls() const { return call_graph_; } - - TypeSystem& type_system() { return type_system_; } - const TypeSystem& type_system() const { return type_system_; } - - NameResolver& name_resolver() { return *name_resolver_; } - const NameResolver& name_resolver() const { return *name_resolver_; } - - void RegisterUnitImports(const std::string& unit_name, UnitImportSet imports); - std::optional GetUnitImports(const std::string& unit_name) const; - - std::shared_ptr GetSymbolType(symbol::SymbolId symbol_id) const - { - return type_system_.GetSymbolType(symbol_id); - } - - void SetSymbolType(symbol::SymbolId symbol_id, std::shared_ptr type) - { - type_system_.RegisterSymbolType(symbol_id, std::move(type)); - } - - bool IsSubclassOf(symbol::SymbolId derived, symbol::SymbolId base) const - { - return inheritance_graph_.IsSubclassOf(derived, base); - } - - private: - const symbol::SymbolTable& symbol_table_; - - graph::Reference reference_graph_; - graph::Inheritance inheritance_graph_; - graph::Call call_graph_; - - TypeSystem type_system_; - - std::unique_ptr name_resolver_; - std::unordered_map unit_imports_; - }; - - class Analyzer : public ast::ASTVisitor - { - public: - explicit Analyzer(symbol::SymbolTable& symbol_table, SemanticModel& semantic_model); - - using ExternalSymbolProvider = std::function(const std::string&)>; - void SetExternalSymbolProvider(ExternalSymbolProvider provider); - - void Analyze(ast::ASTNode& root); - - void VisitProgram(ast::Program& node) override; - void VisitUnitDefinition(ast::UnitDefinition& node) override; - void VisitClassDefinition(ast::ClassDefinition& node) override; - void VisitFunctionDefinition(ast::FunctionDefinition& node) override; - void VisitFunctionDeclaration(ast::FunctionDeclaration& node) override; - void VisitMethodDeclaration(ast::MethodDeclaration& node) override; - void VisitExternalMethodDefinition(ast::ExternalMethodDefinition& node) override; - - void VisitVarDeclaration(ast::VarDeclaration& node) override; - void VisitStaticDeclaration(ast::StaticDeclaration& node) override; - void VisitGlobalDeclaration(ast::GlobalDeclaration& node) override; - void VisitConstDeclaration(ast::ConstDeclaration& node) override; - void VisitFieldDeclaration(ast::FieldDeclaration& node) override; - void VisitClassMember(ast::ClassMember& node) override; - void VisitPropertyDeclaration(ast::PropertyDeclaration& node) override; - - void VisitBlockStatement(ast::BlockStatement& node) override; - void VisitIfStatement(ast::IfStatement& node) override; - void VisitForInStatement(ast::ForInStatement& node) override; - void VisitForToStatement(ast::ForToStatement& node) override; - void VisitWhileStatement(ast::WhileStatement& node) override; - void VisitRepeatStatement(ast::RepeatStatement& node) override; - void VisitCaseStatement(ast::CaseStatement& node) override; - void VisitTryStatement(ast::TryStatement& node) override; - void VisitLabelStatement(ast::LabelStatement& node) override; - void VisitGotoStatement(ast::GotoStatement& node) override; - - void VisitUsesStatement(ast::UsesStatement& node) override; - - void VisitIdentifier(ast::Identifier& node) override; - void VisitCallExpression(ast::CallExpression& node) override; - void VisitAttributeExpression(ast::AttributeExpression& node) override; - void VisitAssignmentExpression(ast::AssignmentExpression& node) override; - - void VisitLiteral(ast::Literal& node) override; - void VisitBinaryExpression(ast::BinaryExpression& node) override; - void VisitTernaryExpression(ast::TernaryExpression& node) override; - void VisitSubscriptExpression(ast::SubscriptExpression& node) override; - void VisitArrayExpression(ast::ArrayExpression& node) override; - void VisitAnonymousFunctionExpression(ast::AnonymousFunctionExpression& node) override; - - void VisitUnaryPlusExpression(ast::UnaryPlusExpression& node) override; - void VisitUnaryMinusExpression(ast::UnaryMinusExpression& node) override; - void VisitPrefixIncrementExpression(ast::PrefixIncrementExpression& node) override; - void VisitPrefixDecrementExpression(ast::PrefixDecrementExpression& node) override; - void VisitPostfixIncrementExpression(ast::PostfixIncrementExpression& node) override; - void VisitPostfixDecrementExpression(ast::PostfixDecrementExpression& node) override; - void VisitLogicalNotExpression(ast::LogicalNotExpression& node) override; - void VisitBitwiseNotExpression(ast::BitwiseNotExpression& node) override; - void VisitDerivativeExpression(ast::DerivativeExpression& node) override; - void VisitMatrixTransposeExpression(ast::MatrixTransposeExpression& node) override; - void VisitExprOperatorExpression(ast::ExprOperatorExpression& node) override; - void VisitFunctionPointerExpression(ast::FunctionPointerExpression& node) override; - - void VisitNewExpression(ast::NewExpression& node) override; - void VisitEchoExpression(ast::EchoExpression& node) override; - void VisitRaiseExpression(ast::RaiseExpression& node) override; - void VisitInheritedExpression(ast::InheritedExpression& node) override; - void VisitRdoExpression(ast::RdoExpression& node) override; - void VisitParenthesizedExpression(ast::ParenthesizedExpression& node) override; - - void VisitExpressionStatement(ast::ExpressionStatement& node) override; - void VisitBreakStatement(ast::BreakStatement& node) override; - void VisitContinueStatement(ast::ContinueStatement& node) override; - void VisitReturnStatement(ast::ReturnStatement& node) override; - void VisitTSSQLExpression(ast::TSSQLExpression& node) override; - void VisitColumnReference(ast::ColumnReference& node) override; - void VisitUnpackPattern(ast::UnpackPattern& node) override; - void VisitMatrixIterationStatement(ast::MatrixIterationStatement& node) override; - - void VisitCompilerDirective(ast::CompilerDirective& node) override; - void VisitConditionalDirective(ast::ConditionalDirective& node) override; - void VisitConditionalBlock(ast::ConditionalBlock& node) override; - void VisitTSLXBlock(ast::TSLXBlock& node) override; - - void VisitParameter(ast::Parameter& node) override; - - private: - void VisitStatements(const std::vector& statements); - void VisitExpression(ast::Expression& expr); - void ProcessLValue(const ast::LValue& lvalue); - std::optional ResolveParentClass(const ast::ClassDefinition::ParentClass& parent); - std::optional ScopeAt(const ast::Location& location) const; - std::optional ResolveByName(const std::string& name); - std::optional FindMethodInClass(symbol::SymbolId class_id, const std::string& method_name) const; - std::optional FindScopeOwnedBy(symbol::SymbolId owner_id) const; - - std::optional ResolveIdentifier(const std::string& name, const ast::Location& location); - std::optional ResolveFromUses(const std::string& name); - void TrackReference(symbol::SymbolId symbol_id, const ast::Location& location, bool is_write = false); - void TrackCall(symbol::SymbolId callee, const ast::Location& location); - std::shared_ptr InferExpressionType(ast::Expression& expr); - std::shared_ptr GetDeclaredTypeForSymbol(symbol::SymbolId symbol_id); - std::optional ResolveClassSymbol(const std::string& name, const ast::Location& location); - std::optional ResolveLValueSymbol(const ast::LValue& lvalue); - void RegisterParameterTypes(symbol::SymbolId function_id, const std::vector>& parameters); - - private: - struct UnitContext - { - std::string unit_name; - std::vector interface_imports; - std::vector implementation_imports; - }; - - symbol::SymbolTable& symbol_table_; - SemanticModel& semantic_model_; - - std::optional current_function_id_; - std::optional current_class_id_; - std::optional current_unit_context_; - std::optional current_unit_section_; - std::vector file_imports_; - ExternalSymbolProvider external_symbol_provider_; - std::unordered_map imported_symbols_; - }; - - struct TokenInfo - { - ast::Location location; - std::string type; - std::string text; - bool is_named = false; - std::uint32_t depth = 0; - - std::string parent_type; - std::uint32_t child_index = 0; - std::uint32_t sibling_count = 0; - - bool is_error = false; - bool is_missing = false; - }; - - struct TokenCollectionOptions - { - bool include_anonymous = true; - bool include_comments = false; - bool include_whitespace = false; - bool include_errors = true; - bool include_missing = true; - - std::uint32_t max_depth = UINT32_MAX_VALUE; - std::uint32_t min_depth = 0; - - std::unordered_set include_types; - std::unordered_set exclude_types; - - bool skip_empty_text = false; - std::uint32_t min_text_length = 0; - std::uint32_t max_text_length = UINT32_MAX_VALUE; - - bool has_location_filter = false; - std::uint32_t filter_start_line = 0; - std::uint32_t filter_end_line = UINT32_MAX_VALUE; - - std::function custom_filter; - - static TokenCollectionOptions OnlyNamed() - { - TokenCollectionOptions opts; - opts.include_anonymous = false; - return opts; - } - - static TokenCollectionOptions WithComments() - { - TokenCollectionOptions opts; - opts.include_comments = true; - return opts; - } - - static TokenCollectionOptions OnlyTypes(const std::vector& types) - { - TokenCollectionOptions opts; - opts.include_types = std::unordered_set(types.begin(), types.end()); - return opts; - } - - static TokenCollectionOptions ExcludeTypes(const std::vector& types) - { - TokenCollectionOptions opts; - opts.exclude_types = std::unordered_set(types.begin(), types.end()); - return opts; - } - - static TokenCollectionOptions InRange(std::uint32_t start_line, std::uint32_t end_line) - { - TokenCollectionOptions opts; - opts.has_location_filter = true; - opts.filter_start_line = start_line; - opts.filter_end_line = end_line; - return opts; - } - }; - - class TokenCollector - { - public: - explicit TokenCollector(const TokenCollectionOptions& options = {}); - ~TokenCollector() = default; - TokenCollector(const TokenCollector&) = delete; - TokenCollector& operator=(const TokenCollector&) = delete; - TokenCollector(TokenCollector&&) = default; - TokenCollector& operator=(TokenCollector&&) = default; - - std::vector Collect(TSNode root, const std::string& source); - std::vector CollectByType(TSNode root, const std::string& source, const std::vector& types); - std::vector CollectInRange(TSNode root, const std::string& source, std::uint32_t start_line, std::uint32_t end_line); - std::vector CollectLeafNodes(TSNode root, const std::string& source); - - std::vector FindTokensAtPosition(TSNode root, const std::string& source, std::uint32_t line, std::uint32_t column); - std::vector FindTokensByText(TSNode root, const std::string& source, const std::string& text, bool exact_match = true); - std::vector FindTokensBy(TSNode root, const std::string& source, std::function predicate); - - std::uint32_t CountTokensByType(TSNode root, const std::string& source, const std::string& type); - std::vector GetUniqueTypes(TSNode root, const std::string& source); - - void SetOptions(const TokenCollectionOptions& options) { options_ = options; } - const TokenCollectionOptions& GetOptions() const { return options_; } - - private: - void CollectRecursive(TSNode node, const std::string& source, std::vector& tokens, std::uint32_t depth, TSNode parent); - bool ShouldCollectNode(const TokenInfo& info) const; - void FillTokenInfo(TokenInfo& info, TSNode node, const std::string& source, std::uint32_t depth, TSNode parent) const; - bool IsCommentNode(const std::string& type) const; - bool IsWhitespaceNode(const std::string& type) const; - bool IsLocationInRange(const ast::Location& loc) const; - - private: - TokenCollectionOptions options_; - }; - - inline std::vector CollectNamedTokens(TSNode root, const std::string& source) - { - return TokenCollector(TokenCollectionOptions::OnlyNamed()).Collect(root, source); - } - - inline std::vector CollectTokensOfType(TSNode root, const std::string& source, const std::string& type) - { - return TokenCollector().CollectByType(root, source, { type }); - } -} // namespace lsp::language::semantic +// Legacy placeholder kept for compatibility with existing imports. diff --git a/lsp-server/src/language/semantic/name_resolver.cppm b/lsp-server/src/language/semantic/name_resolver.cppm index aae1c5a..4870d1c 100644 --- a/lsp-server/src/language/semantic/name_resolver.cppm +++ b/lsp-server/src/language/semantic/name_resolver.cppm @@ -4,10 +4,131 @@ export module lsp.language.semantic:name_resolver; import std; -import :interface; +import :type_system; import lsp.language.ast; import lsp.language.symbol; +export namespace lsp::language::semantic +{ + struct NameResolutionResult + { + symbol::SymbolId symbol_id = symbol::kInvalidSymbolId; + bool is_ambiguous = false; + std::vector candidates; + + bool IsResolved() const + { + return symbol_id != symbol::kInvalidSymbolId; + } + + static NameResolutionResult Success(symbol::SymbolId id) + { + return { id, false, { id } }; + } + + static NameResolutionResult Ambiguous(std::vector symbols) + { + return { + symbols.empty() ? symbol::kInvalidSymbolId : symbols[0], + true, + std::move(symbols) + }; + } + + static NameResolutionResult NotFound() + { + return { symbol::kInvalidSymbolId, false, {} }; + } + }; + + struct OverloadCandidate + { + symbol::SymbolId symbol_id; + int match_score = 0; + std::vector arg_conversions; + + bool operator<(const OverloadCandidate& other) const + { + return match_score > other.match_score; + } + }; + + class NameResolver + { + public: + explicit NameResolver(const symbol::SymbolTable& symbol_table, + const TypeSystem& type_system) : symbol_table_(symbol_table), + type_system_(type_system) {} + + NameResolutionResult ResolveName( + const std::string& name, + symbol::ScopeId scope_id, + bool search_parent = true) const; + + NameResolutionResult ResolveNameAtLocation( + const std::string& name, + const ast::Location& location) const; + + NameResolutionResult ResolveMemberAccess( + symbol::SymbolId object_symbol_id, + const std::string& member_name) const; + + NameResolutionResult ResolveClassMember( + symbol::SymbolId class_id, + const std::string& member_name, + bool static_only = false) const; + + NameResolutionResult ResolveFunctionCall( + const std::string& function_name, + const std::vector>& arg_types, + symbol::ScopeId scope_id) const; + + NameResolutionResult ResolveMethodCall( + symbol::SymbolId object_symbol_id, + const std::string& method_name, + const std::vector>& arg_types) const; + + NameResolutionResult ResolveQualifiedName( + const std::string& qualifier, + const std::string& name, + symbol::ScopeId scope_id) const; + + std::optional GetSymbolScope( + [[maybe_unused]] symbol::SymbolId symbol_id) const; + + bool IsSymbolVisibleInScope( + [[maybe_unused]] symbol::SymbolId symbol_id, + [[maybe_unused]] symbol::ScopeId scope_id) const; + + private: + const symbol::SymbolTable& symbol_table_; + const TypeSystem& type_system_; + + std::vector SearchScopeChain( + const std::string& name, + symbol::ScopeId start_scope) const; + + std::optional FindScopeOwnedBy(symbol::SymbolId owner) const; + + NameResolutionResult SelectBestOverload( + const std::vector& candidates) const; + + OverloadCandidate CalculateOverloadScore( + symbol::SymbolId candidate_id, + const std::vector>& arg_types) const; + + std::vector> GetParameterTypes( + symbol::SymbolId symbol_id) const; + + std::optional GetOwnerClassId( + [[maybe_unused]] symbol::SymbolId symbol_id) const; + + bool CheckMemberAccessibility( + [[maybe_unused]] const symbol::Symbol& member, + [[maybe_unused]] symbol::ScopeId access_scope) const; + }; +} + namespace lsp::language::semantic { diff --git a/lsp-server/src/language/semantic/semantic.cppm b/lsp-server/src/language/semantic/semantic.cppm index 209e7e9..f945929 100644 --- a/lsp-server/src/language/semantic/semantic.cppm +++ b/lsp-server/src/language/semantic/semantic.cppm @@ -6,11 +6,13 @@ import std; // 聚合导出语义模块的各个分区 export import :interface; +export import :type_system.types; export import :type_system; export import :name_resolver; export import :semantic_model; export import :analyzer; export import :token_collector; +export import :graph.types; export import :graph.call; export import :graph.inheritance; export import :graph.reference; diff --git a/lsp-server/src/language/semantic/semantic_model.cppm b/lsp-server/src/language/semantic/semantic_model.cppm index 4bf3fec..ca46734 100644 --- a/lsp-server/src/language/semantic/semantic_model.cppm +++ b/lsp-server/src/language/semantic/semantic_model.cppm @@ -4,10 +4,80 @@ export module lsp.language.semantic:semantic_model; import std; -import :interface; +import :graph.call; +import :graph.inheritance; +import :graph.reference; +import :name_resolver; +import :type_system; import lsp.language.ast; import lsp.language.symbol; import lsp.protocol.types; +import lsp.utils.string; + +export namespace lsp::language::semantic +{ + class SemanticModel + { + public: + struct UnitImportSet + { + std::vector interface_imports; + std::vector implementation_imports; + }; + + explicit SemanticModel(const symbol::SymbolTable& symbol_table); + + void Clear(); + void OnSymbolRemoved(symbol::SymbolId id); + void AddReference(symbol::SymbolId symbol_id, const ast::Location& location, bool is_definition = false, bool is_write = false); + void AddInheritance(symbol::SymbolId derived, symbol::SymbolId base); + void AddCall(symbol::SymbolId caller, symbol::SymbolId callee, const ast::Location& location); + + graph::Reference& references() { return reference_graph_; } + graph::Inheritance& inheritance() { return inheritance_graph_; } + graph::Call& calls() { return call_graph_; } + + const graph::Reference& references() const { return reference_graph_; } + const graph::Inheritance& inheritance() const { return inheritance_graph_; } + const graph::Call& calls() const { return call_graph_; } + + TypeSystem& type_system() { return type_system_; } + const TypeSystem& type_system() const { return type_system_; } + + NameResolver& name_resolver() { return *name_resolver_; } + const NameResolver& name_resolver() const { return *name_resolver_; } + + void RegisterUnitImports(const std::string& unit_name, UnitImportSet imports); + std::optional GetUnitImports(const std::string& unit_name) const; + + std::shared_ptr GetSymbolType(symbol::SymbolId symbol_id) const + { + return type_system_.GetSymbolType(symbol_id); + } + + void SetSymbolType(symbol::SymbolId symbol_id, std::shared_ptr type) + { + type_system_.RegisterSymbolType(symbol_id, std::move(type)); + } + + bool IsSubclassOf(symbol::SymbolId derived, symbol::SymbolId base) const + { + return inheritance_graph_.IsSubclassOf(derived, base); + } + + private: + const symbol::SymbolTable& symbol_table_; + + graph::Reference reference_graph_; + graph::Inheritance inheritance_graph_; + graph::Call call_graph_; + + TypeSystem type_system_; + + std::unique_ptr name_resolver_; + std::unordered_map unit_imports_; + }; +} namespace lsp::language::semantic { diff --git a/lsp-server/src/language/semantic/token_collector.cppm b/lsp-server/src/language/semantic/token_collector.cppm index 7c287fa..1b3b959 100644 --- a/lsp-server/src/language/semantic/token_collector.cppm +++ b/lsp-server/src/language/semantic/token_collector.cppm @@ -6,9 +6,136 @@ import tree_sitter; import std; -import :interface; import lsp.language.ast; +export namespace lsp::language::semantic +{ + struct TokenInfo + { + ast::Location location; + std::string type; + std::string text; + bool is_named = false; + std::uint32_t depth = 0; + + std::string parent_type; + std::uint32_t child_index = 0; + std::uint32_t sibling_count = 0; + + bool is_error = false; + bool is_missing = false; + }; + + struct TokenCollectionOptions + { + bool include_anonymous = true; + bool include_comments = false; + bool include_whitespace = false; + bool include_errors = true; + bool include_missing = true; + + std::uint32_t max_depth = UINT32_MAX_VALUE; + std::uint32_t min_depth = 0; + + std::unordered_set include_types; + std::unordered_set exclude_types; + + bool skip_empty_text = false; + std::uint32_t min_text_length = 0; + std::uint32_t max_text_length = UINT32_MAX_VALUE; + + bool has_location_filter = false; + std::uint32_t filter_start_line = 0; + std::uint32_t filter_end_line = UINT32_MAX_VALUE; + + std::function custom_filter; + + static TokenCollectionOptions OnlyNamed() + { + TokenCollectionOptions opts; + opts.include_anonymous = false; + return opts; + } + + static TokenCollectionOptions WithComments() + { + TokenCollectionOptions opts; + opts.include_comments = true; + return opts; + } + + static TokenCollectionOptions OnlyTypes(const std::vector& types) + { + TokenCollectionOptions opts; + opts.include_types = std::unordered_set(types.begin(), types.end()); + return opts; + } + + static TokenCollectionOptions ExcludeTypes(const std::vector& types) + { + TokenCollectionOptions opts; + opts.exclude_types = std::unordered_set(types.begin(), types.end()); + return opts; + } + + static TokenCollectionOptions InRange(std::uint32_t start_line, std::uint32_t end_line) + { + TokenCollectionOptions opts; + opts.has_location_filter = true; + opts.filter_start_line = start_line; + opts.filter_end_line = end_line; + return opts; + } + }; + + class TokenCollector + { + public: + explicit TokenCollector(const TokenCollectionOptions& options = {}); + ~TokenCollector() = default; + TokenCollector(const TokenCollector&) = delete; + TokenCollector& operator=(const TokenCollector&) = delete; + TokenCollector(TokenCollector&&) = default; + TokenCollector& operator=(TokenCollector&&) = default; + + std::vector Collect(TSNode root, const std::string& source); + std::vector CollectByType(TSNode root, const std::string& source, const std::vector& types); + std::vector CollectInRange(TSNode root, const std::string& source, std::uint32_t start_line, std::uint32_t end_line); + std::vector CollectLeafNodes(TSNode root, const std::string& source); + + std::vector FindTokensAtPosition(TSNode root, const std::string& source, std::uint32_t line, std::uint32_t column); + std::vector FindTokensByText(TSNode root, const std::string& source, const std::string& text, bool exact_match = true); + std::vector FindTokensBy(TSNode root, const std::string& source, std::function predicate); + + std::uint32_t CountTokensByType(TSNode root, const std::string& source, const std::string& type); + std::vector GetUniqueTypes(TSNode root, const std::string& source); + + void SetOptions(const TokenCollectionOptions& options) { options_ = options; } + const TokenCollectionOptions& GetOptions() const { return options_; } + + private: + void CollectRecursive(TSNode node, const std::string& source, std::vector& tokens, std::uint32_t depth, TSNode parent); + bool ShouldCollectNode(const TokenInfo& info) const; + void FillTokenInfo(TokenInfo& info, TSNode node, const std::string& source, std::uint32_t depth, TSNode parent) const; + bool IsCommentNode(const std::string& type) const; + bool IsWhitespaceNode(const std::string& type) const; + bool IsLocationInRange(const ast::Location& loc) const; + + private: + TokenCollectionOptions options_; + }; + + inline std::vector CollectNamedTokens(TSNode root, const std::string& source) + { + return TokenCollector(TokenCollectionOptions::OnlyNamed()).Collect(root, source); + } + + inline std::vector CollectTokensOfType(TSNode root, const std::string& source, const std::string& type) + { + return TokenCollector().CollectByType(root, source, { type }); + } +} + namespace lsp::language::semantic { TokenCollector::TokenCollector(const TokenCollectionOptions& options) : diff --git a/lsp-server/src/language/semantic/type_system.cppm b/lsp-server/src/language/semantic/type_system.cppm index 12a1ec4..b7836aa 100644 --- a/lsp-server/src/language/semantic/type_system.cppm +++ b/lsp-server/src/language/semantic/type_system.cppm @@ -4,11 +4,99 @@ export module lsp.language.semantic:type_system; import std; -import :interface; +export import :type_system.types; import lsp.language.symbol; import lsp.utils.string; +export namespace lsp::language::semantic +{ + class TypeSystem + { + public: + TypeSystem(); + + std::shared_ptr GetIntType() const { return int_type_; } + std::shared_ptr GetFloatType() const { return float_type_; } + std::shared_ptr GetStringType() const { return string_type_; } + std::shared_ptr GetBoolType() const { return bool_type_; } + std::shared_ptr GetCharType() const { return char_type_; } + std::shared_ptr GetVoidType() const { return void_type_; } + std::shared_ptr GetUnknownType() const { return unknown_type_; } + std::shared_ptr GetErrorType() const { return error_type_; } + + std::shared_ptr CreateClassType(symbol::SymbolId class_id); + std::shared_ptr CreateArrayType(std::shared_ptr element_type); + std::shared_ptr CreateFunctionType( + std::vector> param_types, + std::shared_ptr return_type); + std::shared_ptr CreateOptionalType(std::shared_ptr inner_type); + + void RegisterClassType(const std::string& type_name, symbol::SymbolId class_id); + std::shared_ptr GetTypeByName(const std::string& type_name) const; + + std::shared_ptr GetSymbolType(symbol::SymbolId symbol_id) const; + + void RegisterSymbolType(symbol::SymbolId symbol_id, std::shared_ptr type); + + TypeCompatibility CheckCompatibility(const Type& from, const Type& to) const; + + bool IsAssignable(const Type& from, const Type& to) const; + + bool RequiresExplicitCast(const Type& from, const Type& to) const; + + std::shared_ptr InferBinaryExpressionType( + const Type& left, + const Type& right, + const std::string& op) const; + + std::shared_ptr InferUnaryExpressionType( + const Type& operand, + const std::string& op) const; + + std::shared_ptr InferLiteralType(const std::string& literal_value) const; + + void SetInheritanceChecker( + std::function checker) + { + is_subclass_of_ = std::move(checker); + } + + private: + std::shared_ptr int_type_; + std::shared_ptr float_type_; + std::shared_ptr string_type_; + std::shared_ptr bool_type_; + std::shared_ptr char_type_; + std::shared_ptr void_type_; + std::shared_ptr unknown_type_; + std::shared_ptr error_type_; + + std::unordered_map> type_by_name_; + + std::unordered_map> symbol_types_; + + std::function + is_subclass_of_; + + TypeCompatibility CheckPrimitiveCompatibility( + const PrimitiveType& from, + const PrimitiveType& to) const; + + TypeCompatibility CheckClassCompatibility( + const ClassType& from, + const ClassType& to) const; + + TypeCompatibility CheckArrayCompatibility( + const ArrayType& from, + const ArrayType& to) const; + + TypeCompatibility CheckFunctionCompatibility( + const FunctionType& from, + const FunctionType& to) const; + }; +} + namespace lsp::language::semantic { diff --git a/lsp-server/src/language/semantic/type_system.types.cppm b/lsp-server/src/language/semantic/type_system.types.cppm new file mode 100644 index 0000000..45d9bf4 --- /dev/null +++ b/lsp-server/src/language/semantic/type_system.types.cppm @@ -0,0 +1,206 @@ +module; + +export module lsp.language.semantic:type_system.types; + +import std; + +import lsp.language.symbol; + +export namespace lsp::language::semantic +{ + enum class TypeKind + { + kPrimitive, + kClass, + kArray, + kFunction, + kOptional, + kVoid, + kUnknown, + kError + }; + + enum class PrimitiveTypeKind + { + kInt, + kFloat, + kString, + kBool, + kChar + }; + + struct TypeCompatibility + { + bool is_compatible = false; + int conversion_cost = -1; + bool requires_cast = false; + + static TypeCompatibility Exact() + { + return { true, 0, false }; + } + + static TypeCompatibility Implicit(int cost) + { + return { true, cost, false }; + } + + static TypeCompatibility ExplicitCast(int cost) + { + return { true, cost, true }; + } + + static TypeCompatibility Incompatible() + { + return { false, -1, false }; + } + }; + + class Type; + + class PrimitiveType + { + public: + explicit PrimitiveType(PrimitiveTypeKind kind) : kind_(kind) {} + + PrimitiveTypeKind kind() const { return kind_; } + std::string ToString() const; + + bool operator==(const PrimitiveType& other) const + { + return kind_ == other.kind_; + } + + private: + PrimitiveTypeKind kind_; + }; + + class ClassType + { + public: + explicit ClassType(symbol::SymbolId class_id) : class_id_(class_id) {} + + symbol::SymbolId class_id() const { return class_id_; } + + bool operator==(const ClassType& other) const + { + return class_id_ == other.class_id_; + } + + private: + symbol::SymbolId class_id_; + }; + + class ArrayType + { + public: + explicit ArrayType(std::shared_ptr element_type) : element_type_(std::move(element_type)) {} + + const Type& element_type() const { return *element_type_; } + std::shared_ptr element_type_ptr() const { return element_type_; } + + private: + std::shared_ptr element_type_; + }; + + class FunctionType + { + public: + FunctionType(std::vector> param_types, + std::shared_ptr return_type) : param_types_(std::move(param_types)), + return_type_(std::move(return_type)) {} + + const std::vector>& param_types() const + { + return param_types_; + } + + const Type& return_type() const { return *return_type_; } + std::shared_ptr return_type_ptr() const { return return_type_; } + + private: + std::vector> param_types_; + std::shared_ptr return_type_; + }; + + class OptionalType + { + public: + explicit OptionalType(std::shared_ptr inner_type) : inner_type_(std::move(inner_type)) {} + + const Type& inner_type() const { return *inner_type_; } + std::shared_ptr inner_type_ptr() const { return inner_type_; } + + private: + std::shared_ptr inner_type_; + }; + + class VoidType + { + public: + bool operator==(const VoidType&) const { return true; } + }; + + class UnknownType + { + public: + bool operator==(const UnknownType&) const { return true; } + }; + + class ErrorType + { + public: + explicit ErrorType(std::string message = "") : message_(std::move(message)) {} + + const std::string& message() const { return message_; } + + bool operator==(const ErrorType&) const { return true; } + + private: + std::string message_; + }; + + using TypeData = std::variant< + PrimitiveType, + ClassType, + ArrayType, + FunctionType, + OptionalType, + VoidType, + UnknownType, + ErrorType>; + + class Type + { + public: + explicit Type(TypeData data) : data_(std::move(data)) {} + + template + bool Is() const + { + return std::holds_alternative(data_); + } + + template + const T* As() const + { + return std::get_if(&data_); + } + + template + T* As() + { + return std::get_if(&data_); + } + + TypeKind kind() const; + std::string ToString() const; + + bool Equals(const Type& other) const; + + const TypeData& data() const { return data_; } + + private: + TypeData data_; + }; +} diff --git a/lsp-server/test/test_provider/CMakeLists.txt b/lsp-server/test/test_provider/CMakeLists.txt index fc59810..6d0def0 100644 --- a/lsp-server/test/test_provider/CMakeLists.txt +++ b/lsp-server/test/test_provider/CMakeLists.txt @@ -83,11 +83,13 @@ target_sources( ../../src/language/symbol/symbol.cppm ../../src/language/semantic/interface.cppm ../../src/language/semantic/semantic.cppm + ../../src/language/semantic/type_system.types.cppm ../../src/language/semantic/analyzer.cppm ../../src/language/semantic/semantic_model.cppm ../../src/language/semantic/type_system.cppm ../../src/language/semantic/name_resolver.cppm ../../src/language/semantic/token_collector.cppm + ../../src/language/semantic/graph/types.cppm ../../src/language/semantic/graph/call.cppm ../../src/language/semantic/graph/reference.cppm ../../src/language/semantic/graph/inheritance.cppm diff --git a/lsp-server/test/test_semantic/CMakeLists.txt b/lsp-server/test/test_semantic/CMakeLists.txt index d62de96..2433483 100644 --- a/lsp-server/test/test_semantic/CMakeLists.txt +++ b/lsp-server/test/test_semantic/CMakeLists.txt @@ -27,10 +27,12 @@ set(SOURCES ../../src/language/symbol/index/location.cppm ../../src/language/symbol/index/scope.cppm ../../src/language/semantic/interface.cppm + ../../src/language/semantic/type_system.types.cppm ../../src/language/semantic/analyzer.cppm ../../src/language/semantic/semantic_model.cppm ../../src/language/semantic/type_system.cppm ../../src/language/semantic/name_resolver.cppm + ../../src/language/semantic/graph/types.cppm ../../src/language/semantic/graph/call.cppm ../../src/language/semantic/graph/reference.cppm ../../src/language/semantic/graph/inheritance.cppm @@ -93,11 +95,13 @@ target_sources( ../../src/language/symbol/symbol.cppm ../../src/language/semantic/interface.cppm ../../src/language/semantic/semantic.cppm + ../../src/language/semantic/type_system.types.cppm ../../src/language/semantic/type_system.cppm ../../src/language/semantic/name_resolver.cppm ../../src/language/semantic/semantic_model.cppm ../../src/language/semantic/analyzer.cppm ../../src/language/semantic/token_collector.cppm + ../../src/language/semantic/graph/types.cppm ../../src/language/semantic/graph/call.cppm ../../src/language/semantic/graph/inheritance.cppm ../../src/language/semantic/graph/reference.cppm