refactor ast

This commit is contained in:
csh 2025-10-25 14:02:55 +08:00
parent 57f182380e
commit ddca9769e5
11 changed files with 4419 additions and 1840 deletions

View File

@ -1,101 +0,0 @@
#include "./cache.hpp"
#include "./tree_sitter_utils.hpp"
namespace lsp::language::ast
{
NodeCache::NodeCache(size_t max_size) :
max_size_(max_size)
{
}
const ASTNode* NodeCache::Find(const NodeKey& key)
{
auto it = cache_.find(key);
if (it != cache_.end())
{
// 缓存命中更新LRU
Touch(key, it->second.lru_iter);
++hit_count_;
return &it->second.node;
}
++miss_count_;
return nullptr;
}
void NodeCache::Insert(const NodeKey& key, const ASTNode& node)
{
auto it = cache_.find(key);
if (it != cache_.end())
{
// 键已存在更新值和LRU
it->second.node = node;
Touch(key, it->second.lru_iter);
return;
}
// 检查是否需要淘汰
if (cache_.size() >= max_size_)
{
EvictLRU();
}
// 插入新条目
lru_list_.push_front(key);
cache_[key] = CacheEntry{ node, lru_list_.begin() };
}
void NodeCache::Clear()
{
cache_.clear();
lru_list_.clear();
ResetStats();
}
void NodeCache::EvictLRU()
{
if (lru_list_.empty())
return;
// 移除最久未使用的(列表尾部)
const NodeKey& evict_key = lru_list_.back();
cache_.erase(evict_key);
lru_list_.pop_back();
}
void NodeCache::Touch([[maybe_unused]] const NodeKey& key, KeyIter iter)
{
// 将元素移到列表前面(最近使用)
lru_list_.splice(lru_list_.begin(), lru_list_, iter);
}
NodeKey NodeCache::MakeKey(const ASTNode& node)
{
return std::visit([](auto&& arg) -> NodeKey {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::monostate>)
{
return NodeKey{ 0, 0, "" };
}
else
{
return NodeKey{
arg.location.start_byte,
arg.location.end_byte,
arg.node_type
};
}
},
node);
}
NodeKey NodeCache::MakeKey(TSNode ts_node)
{
return NodeKey{
ts_node_start_byte(ts_node),
ts_node_end_byte(ts_node),
std::string(ts::Type(ts_node))
};
}
}

View File

@ -1,94 +0,0 @@
#pragma once
#include <unordered_map>
#include <list>
#include <cstddef>
#include "./types.hpp"
extern "C" {
#include <tree_sitter/api.h>
}
namespace lsp::language::ast
{
// ==================== 缓存键 ====================
struct NodeKey
{
uint32_t start_byte;
uint32_t end_byte;
std::string node_type;
inline bool operator==(const NodeKey& other) const
{
return start_byte == other.start_byte &&
end_byte == other.end_byte &&
node_type == other.node_type;
}
};
struct NodeKeyHash
{
inline size_t operator()(const NodeKey& key) const
{
return std::hash<uint32_t>()(key.start_byte) ^
(std::hash<uint32_t>()(key.end_byte) << 1) ^
(std::hash<std::string>()(key.node_type) << 2);
}
};
// ==================== LRU 缓存(重构 1====================
class NodeCache
{
public:
explicit NodeCache(size_t max_size = 1000);
const ASTNode* Find(const NodeKey& key);
void Insert(const NodeKey& key, const ASTNode& node);
void Clear();
inline size_t Size() const { return cache_.size(); }
inline void SetMaxSize(size_t max_size) { max_size_ = max_size; }
inline size_t GetMaxSize() const { return max_size_; }
// 缓存统计
inline size_t HitCount() const { return hit_count_; }
inline size_t MissCount() const { return miss_count_; }
inline double HitRate() const
{
size_t total = hit_count_ + miss_count_;
return total > 0 ? static_cast<double>(hit_count_) / total : 0.0;
}
// 重置统计
inline void ResetStats()
{
hit_count_ = 0;
miss_count_ = 0;
}
static NodeKey MakeKey(const ASTNode& node);
static NodeKey MakeKey(TSNode ts_node);
private:
// LRU 列表:最近使用的在前面
using KeyList = std::list<NodeKey>;
using KeyIter = KeyList::iterator;
struct CacheEntry
{
ASTNode node;
KeyIter lru_iter;
};
void EvictLRU();
void Touch(const NodeKey& key, KeyIter iter);
size_t max_size_;
KeyList lru_list_;
std::unordered_map<NodeKey, CacheEntry, NodeKeyHash> cache_;
// 统计信息
size_t hit_count_ = 0;
size_t miss_count_ = 0;
};
}

File diff suppressed because it is too large Load Diff

View File

@ -1,133 +1,86 @@
#pragma once #pragma once
#include <functional>
#include <unordered_map> #include <vector>
#include <memory>
#include "./types.hpp" #include "./types.hpp"
#include "./tree_sitter_utils.hpp"
#include "./cache.hpp" extern "C" {
#include <tree_sitter/api.h>
}
namespace lsp::language::ast namespace lsp::language::ast
{ {
// ==================== 增量解析结果 ==================== // ===== 解析错误 =====
enum class ErrorSeverity
struct IncrementalParseResult
{ {
ParseResult result; Warning,
std::vector<size_t> changed_indices; Error,
std::vector<size_t> reused_indices; Fatal
};
inline size_t ChangedCount() const { return changed_indices.size(); } struct ParseError
inline size_t ReusedCount() const { return reused_indices.size(); } {
inline size_t TotalCount() const { return result.nodes.size(); } Location location;
inline double ReuseRate() const std::string node_type;
std::string message;
ErrorSeverity severity = ErrorSeverity::Error;
static ParseError Create(const Location& loc, const std::string& type, const std::string& msg)
{ {
return TotalCount() > 0 ? static_cast<double>(ReusedCount()) / TotalCount() : 0.0; return { loc, type, msg, ErrorSeverity::Error };
}
static ParseError Missing(const Location& loc, const std::string& type)
{
return Create(loc, type, "Syntax error: missing " + type);
}
static ParseError Unexpected(const Location& loc, const std::string& type, const std::string& context = "")
{
std::string msg = "Syntax error: unexpected token";
if (!context.empty())
msg += " in " + context;
return Create(loc, type, msg);
} }
}; };
// ===== 解析结果 =====
struct ParseResult
{
std::vector<StatementPtr> statements;
std::vector<ParseError> errors;
bool HasErrors() const { return !errors.empty(); }
bool IsSuccess() const { return errors.empty(); }
size_t ErrorCount() const { return errors.size(); }
};
// ===== 增量解析结果 =====
struct IncrementalParseResult
{
ParseResult result;
size_t nodes_parsed = 0; // 重新解析的节点数
size_t nodes_unchanged = 0; // 未变化的节点数
size_t TotalNodes() const { return nodes_parsed + nodes_unchanged; }
double ChangeRate() const
{
return TotalNodes() > 0 ? static_cast<double>(nodes_parsed) / TotalNodes() : 0.0;
}
};
// ===== Deserializer 类 =====
class Deserializer class Deserializer
{ {
public: public:
Deserializer(); Deserializer();
~Deserializer(); ~Deserializer() = default;
Deserializer(const Deserializer&) = delete;
Deserializer& operator=(const Deserializer&) = delete;
// 基础解析(不使用缓存)
ParseResult Parse(TSNode root, const std::string& source); ParseResult Parse(TSNode root, const std::string& source);
// 增量解析(自动复用缓存)
IncrementalParseResult ParseIncremental(TSNode root, const std::string& source); IncrementalParseResult ParseIncremental(TSNode root, const std::string& source);
// 缓存管理 static std::vector<ParseError> DiagnoseSyntax(TSNode root, const std::string& source);
void ClearCache();
size_t CacheSize() const;
void SetCacheMaxSize(size_t max_size);
// 缓存统计
double CacheHitRate() const;
void ResetCacheStats();
private:
std::unique_ptr<NodeCache> cache_;
void ParseChildWithCache(TSNode child, const std::string& source, IncrementalParseResult& result);
}; };
namespace detail
{
using ParseFunc = std::function<Result<ASTNode>(TSNode, const std::string&)>;
// 获取解析函数映射表
const std::unordered_map<std::string, ParseFunc>& GetParseFuncMap();
// 辅助函数
template<typename T>
inline void InitNode(T& node, TSNode ts_node)
{
node.location = ts::NodeLocation(ts_node);
node.node_type = std::string(ts::Type(ts_node));
}
template<typename T>
inline bool ExtractNameAndLocation(T& node, TSNode ts_node, const std::string& source)
{
TSNode name_node = ts::FieldChild(ts_node, "name");
if (ts::IsNull(name_node))
return false;
node.name = ts::Text(name_node, source);
node.name_location = ts::NodeLocation(name_node);
return !node.name.empty();
}
Result<std::string> GetRequiredField(TSNode node, std::string_view field, const std::string& source, const std::string& error_msg);
Result<std::string> ParseType(TSNode node, const std::string& source);
Result<Signature> ParseSignature(TSNode node, const std::string& source);
Result<std::vector<Parameter>> ParseParameters(TSNode node, const std::string& source);
Result<Block> ParseBlock(TSNode node, const std::string& source);
Access ParseAccess(std::string_view text);
ParseError MakeError(TSNode node, const std::string& message, ErrorSeverity severity = ErrorSeverity::Error);
template<typename T>
inline std::vector<Result<T>> SplitVariableDeclaration(
TSNode node,
const std::string& source,
std::function<Result<T>(TSNode, const std::string&, const std::string&, TSNode)> parse_func)
{
std::vector<Result<T>> results;
auto names = ts::FieldSummaries(node, "name", source);
for (const auto& name_summary : names)
{
results.push_back(parse_func(node, source, name_summary.text, name_summary.node));
}
return results;
}
void ParseClassMembers(TSNode body_node, const std::string& source, ClassDefinition& def);
void ParseClassVariableMember(TSNode node, const std::string& source, Access current_access, ClassDefinition& def);
void ParseClassMethodMember(TSNode node, const std::string& source, Access current_access, ClassDefinition& def);
void ParseClassPropertyMember(TSNode node, const std::string& source, Access current_access, ClassDefinition& def);
ParseResult ParseStatements(TSNode node, const std::string& source);
ParseResult ParseChildren(const std::vector<TSNode>& children, const std::string& source);
}
// ==================== 顶层解析 API ====================
Result<ASTNode> ParseNode(TSNode node, const std::string& source);
ParseResult ParseRoot(TSNode root, const std::string& source);
// ==================== 具体类型解析函数 ====================
Result<VarDeclaration> ParseVarDeclaration(TSNode node, const std::string& source, const std::string& var_name, TSNode name_node);
Result<StaticDeclaration> ParseStaticDeclaration(TSNode node, const std::string& source, const std::string& var_name, TSNode name_node);
Result<GlobalDeclaration> ParseGlobalDeclaration(TSNode node, const std::string& source, const std::string& var_name, TSNode name_node);
Result<ConstDeclaration> ParseConstDeclaration(TSNode node, const std::string& source);
Result<AssignmentStatement> ParseAssignmentStatement(TSNode node, const std::string& source);
Result<FunctionDefinition> ParseFunctionDefinition(TSNode node, const std::string& source);
Result<ClassDefinition> ParseClassDefinition(TSNode node, const std::string& source);
Result<Method> ParseMethod(TSNode node, const std::string& source);
Result<ExternalMethod> ParseExternalMethod(TSNode node, const std::string& source);
Result<Property> ParseProperty(TSNode node, const std::string& source);
Result<UnitDefinition> ParseUnitDefinition(TSNode node, const std::string& source);
Result<UsesClause> ParseUsesClause(TSNode node, const std::string& source);
} }

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,206 @@
#pragma once
#include "./types.hpp"
#include "./deserializer.hpp"
#include <functional>
#include <unordered_map>
extern "C" {
#include <tree_sitter/api.h>
}
namespace lsp::language::ast::detail
{
// ===== 统一的解析上下文 =====
class ParseContext
{
public:
ParseContext(const std::string& source, std::vector<ParseError>& errors) :
source_(source), errors_(errors) {}
const std::string& Source() const { return source_; }
std::vector<ParseError>& Errors() { return errors_; }
const std::vector<ParseError>& Errors() const { return errors_; }
void RecordError(const Location& loc, const std::string& node_type, const std::string& message, ErrorSeverity severity = ErrorSeverity::Error)
{
errors_.push_back(ParseError{ loc, node_type, message, severity });
}
void RecordMissing(const Location& loc, const std::string& node_type)
{
RecordError(loc, node_type, "Syntax error: missing " + node_type);
}
void RecordUnexpected(const Location& loc, const std::string& node_type, const std::string& context = "")
{
std::string msg = "Syntax error: unexpected token";
if (!context.empty())
msg += " in " + context;
RecordError(loc, node_type, msg);
}
void RecordWarning(const Location& loc, const std::string& node_type, const std::string& message)
{
RecordError(loc, node_type, message, ErrorSeverity::Warning);
}
template<typename T = Expression>
std::unique_ptr<T> ReportNullNode(const Location& loc, const std::string& expected_type)
{
RecordMissing(loc, expected_type);
return nullptr;
}
template<typename T = Expression>
std::unique_ptr<T> ReportParseFailed(const Location& loc, const std::string& node_type, const std::string& reason = "")
{
std::string msg = "Failed to parse " + node_type;
if (!reason.empty())
msg += ": " + reason;
RecordError(loc, node_type, msg);
return nullptr;
}
private:
const std::string& source_;
std::vector<ParseError>& errors_;
};
// ===== 类成员解析结果 =====
struct ClassMemberParseResult
{
std::unique_ptr<ClassMember> member;
std::optional<AccessModifier> access_modifier;
std::optional<ReferenceModifier> reference_modifier;
};
// ===== 语句解析器注册表 =====
class StatementParserRegistry
{
public:
using ParserFunc = std::function<StatementPtr(TSNode, ParseContext&)>;
static StatementParserRegistry& Instance()
{
static StatementParserRegistry instance;
return instance;
}
void Register(const std::string& type, ParserFunc parser)
{
parsers_[type] = parser;
}
ParserFunc Get(const std::string& type) const
{
auto it = parsers_.find(type);
return it != parsers_.end() ? it->second : nullptr;
}
private:
StatementParserRegistry() = default;
StatementParserRegistry(const StatementParserRegistry&) = delete;
StatementParserRegistry& operator=(const StatementParserRegistry&) = delete;
private:
std::unordered_map<std::string, ParserFunc> parsers_;
};
// ===== 初始化 =====
void RegisterStatementParsers();
// ===== 操作符映射 =====
BinaryOperator ParseBinaryOperator(const std::string& op_text);
UnaryOperator ParseUnaryOperator(const std::string& op_text);
AssignmentOperator ParseAssignmentOperator(const std::string& op_text);
// ===== 基础辅助函数 =====
TSNode FindChildByType(TSNode parent, const std::string& type);
bool IsSyntaxErrorNode(TSNode node);
std::vector<ParseError> CollectSyntaxErrors(TSNode node, const std::string& source);
// 访问修饰符和方法修饰符解析
AccessModifier ParseAccessModifier(TSNode node, ParseContext& ctx);
MethodModifier ParseMethodModifier(TSNode node, ParseContext& ctx);
ReferenceModifier ParseReferenceModifier(TSNode node, ParseContext& ctx);
// 方法和属性声明解析
std::unique_ptr<MethodDeclaration> ParseMethodDeclaration(TSNode node, ParseContext& ctx);
std::unique_ptr<PropertyDeclaration> ParsePropertyDeclaration(TSNode node, ParseContext& ctx);
// 类成员解析
std::vector<ClassMemberParseResult> ParseClassMember(TSNode node, ParseContext& ctx, AccessModifier current_access, ReferenceModifier current_reference);
// ===== 通用解析函数 =====
Signature ParseSignature(TSNode node, ParseContext& ctx);
std::vector<Parameter> ParseParameters(TSNode params_node, ParseContext& ctx);
// ===== 表达式解析 =====
ExpressionPtr ParseExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParsePrimaryExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseBinaryExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseUnaryExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseCallExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseAttributeExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseSubscriptExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseArrayExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseTernaryExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseAssignmentExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseAnonymousFunctionExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseTSSQLExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseEchoExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseRaiseExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseNewExpression(TSNode node, ParseContext& ctx);
ExpressionPtr ParseInheritedExpression(TSNode node, ParseContext& ctx);
// ===== 左值解析 =====
LeftHandSide ParseLeftHandSide(TSNode node, ParseContext& ctx);
std::unique_ptr<UnpackPattern> ParseUnpackPattern(TSNode node, ParseContext& ctx);
// ===== 语句解析 =====
StatementPtr ParseUnitDefinition(TSNode node, ParseContext& ctx);
StatementPtr ParseStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseSingleSuite(TSNode node, ParseContext& ctx);
StatementPtr ParseBlockSuite(TSNode node, ParseContext& ctx);
StatementPtr ParseExpressionStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseAssignmentStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseVarStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseStaticStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseGlobalStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseConstStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseIfStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseForInStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseForToStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseWhileStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseRepeatStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseCaseStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseTryStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseBreakStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseContinueStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseReturnStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseUsesStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseInheritedStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseFunctionDefinition(TSNode node, ParseContext& ctx);
StatementPtr ParseFunctionDeclaration(TSNode node, ParseContext& ctx);
StatementPtr ParseClassDefinition(TSNode node, ParseContext& ctx);
StatementPtr ParseExternalMethodStatement(TSNode node, ParseContext& ctx);
StatementPtr ParseMatrixIterationStatement(TSNode node, ParseContext& ctx);
// ===== TSLX 模板解析 =====
StatementPtr ParseTSLXOpenTag(TSNode node, ParseContext& ctx);
StatementPtr ParseTSLXCloseTag(TSNode node, ParseContext& ctx);
StatementPtr ParseTSLXOutputTag(TSNode node, ParseContext& ctx);
StatementPtr ParseTSLXExpressionTag(TSNode node, ParseContext& ctx);
StatementPtr ParseHTMLSelfClosingTag(TSNode node, ParseContext& ctx);
StatementPtr ParseHTMLPairedTag(TSNode node, ParseContext& ctx);
StatementPtr ParseHTMLComment(TSNode node, ParseContext& ctx);
// ===== 通用声明解析模板 =====
template<typename DeclType, typename StmtType>
StatementPtr ParseDeclarationStatement(TSNode node, ParseContext& ctx, NodeKind decl_kind, NodeKind stmt_kind, const std::string& keyword);
}

View File

@ -0,0 +1,312 @@
#include "./token_collector.hpp"
#include "./tree_sitter_utils.hpp"
namespace lsp::language::ast
{
// ===== TokenCollector 实现 =====
TokenCollector::TokenCollector(const TokenCollectionOptions& options) :
options_(options)
{
}
std::vector<TokenInfo> TokenCollector::Collect(TSNode root, const std::string& source)
{
std::vector<TokenInfo> tokens;
TSNode null_parent = { 0 };
CollectRecursive(root, source, tokens, 0, null_parent);
return tokens;
}
std::vector<TokenInfo> TokenCollector::CollectByType(
TSNode root,
const std::string& source,
const std::vector<std::string>& types)
{
// 临时修改选项以仅包含指定类型
TokenCollectionOptions original_options = options_;
options_.include_types = std::unordered_set<std::string>(types.begin(), types.end());
auto result = Collect(root, source);
// 恢复原始选项
options_ = original_options;
return result;
}
std::vector<TokenInfo> TokenCollector::CollectInRange(
TSNode root,
const std::string& source,
uint32_t start_line,
uint32_t end_line)
{
// 临时修改选项以包含位置过滤
TokenCollectionOptions original_options = options_;
options_.has_location_filter = true;
options_.filter_start_line = start_line;
options_.filter_end_line = end_line;
auto result = Collect(root, source);
// 恢复原始选项
options_ = original_options;
return result;
}
std::vector<TokenInfo> TokenCollector::CollectLeafNodes(TSNode root, const std::string& source)
{
auto all_tokens = Collect(root, source);
std::vector<TokenInfo> leaf_tokens;
for (const auto& token : all_tokens)
{
// 叶子节点的特征:文本长度等于字节范围
// 或者可以通过重新检查 TSNode 来确定
leaf_tokens.push_back(token);
}
return leaf_tokens;
}
std::vector<TokenInfo> TokenCollector::FindTokensAtPosition(TSNode root, const std::string& source, uint32_t line, uint32_t column)
{
auto all_tokens = Collect(root, source);
std::vector<TokenInfo> result;
for (const auto& token : all_tokens)
{
const auto& loc = token.location;
// 检查位置是否在 token 范围内
if (loc.start_line <= line && line <= loc.end_line)
{
if (line == loc.start_line && line == loc.end_line)
{
if (loc.start_column <= column && column <= loc.end_column)
result.push_back(token);
}
else if (line == loc.start_line)
{
if (column >= loc.start_column)
result.push_back(token);
}
else if (line == loc.end_line)
{
if (column <= loc.end_column)
result.push_back(token);
}
else
{
result.push_back(token);
}
}
}
return result;
}
std::vector<TokenInfo> TokenCollector::FindTokensByText(TSNode root, const std::string& source, const std::string& text, bool exact_match)
{
auto all_tokens = Collect(root, source);
std::vector<TokenInfo> result;
for (const auto& token : all_tokens)
{
if (exact_match)
{
if (token.text == text)
result.push_back(token);
}
else
{
if (token.text.find(text) != std::string::npos)
result.push_back(token);
}
}
return result;
}
std::vector<TokenInfo> TokenCollector::FindTokensBy(TSNode root, const std::string& source, std::function<bool(const TokenInfo&)> predicate)
{
auto all_tokens = Collect(root, source);
std::vector<TokenInfo> result;
for (const auto& token : all_tokens)
{
if (predicate(token))
result.push_back(token);
}
return result;
}
uint32_t TokenCollector::CountTokensByType(TSNode root, const std::string& source, const std::string& type)
{
auto tokens = CollectByType(root, source, { type });
return tokens.size();
}
std::vector<std::string> TokenCollector::GetUniqueTypes(TSNode root, const std::string& source)
{
auto tokens = Collect(root, source);
std::unordered_set<std::string> unique_types;
for (const auto& token : tokens)
{
unique_types.insert(token.type);
}
return std::vector<std::string>(unique_types.begin(), unique_types.end());
}
// ===== 私有方法实现 =====
void TokenCollector::CollectRecursive(TSNode node, const std::string& source, std::vector<TokenInfo>& tokens, uint32_t depth, TSNode parent)
{
// 检查深度限制
if (depth > options_.max_depth || depth < options_.min_depth)
return;
// 创建 TokenInfo
TokenInfo info;
FillTokenInfo(info, node, source, depth, parent);
// 检查是否应该收集此节点
if (ShouldCollectNode(info))
{
tokens.push_back(info);
}
// 递归处理子节点
uint32_t child_count = ts_node_child_count(node);
for (uint32_t i = 0; i < child_count; ++i)
{
TSNode child = ts_node_child(node, i);
CollectRecursive(child, source, tokens, depth + 1, node);
}
}
bool TokenCollector::ShouldCollectNode(const TokenInfo& info) const
{
// 检查命名节点过滤
if (!options_.include_anonymous && !info.is_named)
return false;
// 检查注释过滤
if (!options_.include_comments && IsCommentNode(info.type))
return false;
// 检查空白过滤
if (!options_.include_whitespace && IsWhitespaceNode(info.type))
return false;
// 检查错误节点过滤
if (!options_.include_errors && info.is_error)
return false;
// 检查缺失节点过滤
if (!options_.include_missing && info.is_missing)
return false;
// 检查类型包含过滤
if (!options_.include_types.empty())
{
if (options_.include_types.find(info.type) == options_.include_types.end())
return false;
}
// 检查类型排除过滤
if (!options_.exclude_types.empty())
{
if (options_.exclude_types.find(info.type) != options_.exclude_types.end())
return false;
}
// 检查空文本过滤
if (options_.skip_empty_text && info.text.empty())
return false;
// 检查文本长度过滤
uint32_t text_len = info.text.length();
if (text_len < options_.min_text_length || text_len > options_.max_text_length)
return false;
// 检查位置过滤
if (options_.has_location_filter && !IsLocationInRange(info.location))
return false;
// 检查自定义过滤器
if (options_.custom_filter && !options_.custom_filter(info))
return false;
return true;
}
void TokenCollector::FillTokenInfo(TokenInfo& info, TSNode node, const std::string& source, uint32_t depth, TSNode parent) const
{
// 基本信息
info.location = ts::NodeLocation(node);
info.type = ts_node_type(node);
info.text = ts::Text(node, source);
info.is_named = ts_node_is_named(node);
info.depth = depth;
// 错误和缺失标记
info.is_error = ts_node_is_error(node);
info.is_missing = ts_node_is_missing(node);
// 父节点信息
if (!ts_node_is_null(parent))
{
info.parent_type = ts_node_type(parent);
// 查找当前节点在父节点中的索引
uint32_t sibling_count = ts_node_child_count(parent);
info.sibling_count = sibling_count;
for (uint32_t i = 0; i < sibling_count; ++i)
{
TSNode sibling = ts_node_child(parent, i);
if (ts_node_eq(sibling, node))
{
info.child_index = i;
break;
}
}
}
}
bool TokenCollector::IsCommentNode(const std::string& type) const
{
return type == "line_comment" ||
type == "block_comment" ||
type == "nested_comment" ||
type == "comment";
}
bool TokenCollector::IsWhitespaceNode(const std::string& type) const
{
return type == "whitespace" ||
type == " " ||
type == "\n" ||
type == "\r" ||
type == "\t";
}
bool TokenCollector::IsLocationInRange(const Location& loc) const
{
if (!options_.has_location_filter)
return true;
// 检查范围是否有重叠
if (loc.end_line < options_.filter_start_line)
return false;
if (loc.start_line > options_.filter_end_line)
return false;
return true;
}
}

View File

@ -0,0 +1,153 @@
#pragma once
#include <string>
#include <vector>
#include <unordered_set>
#include <functional>
#include "./types.hpp"
extern "C" {
#include <tree_sitter/api.h>
}
namespace lsp::language::ast
{
struct TokenInfo
{
Location location;
std::string type; // 节点类型 (如 "identifier", "number")
std::string text; // Token 文本内容
bool is_named = false;
uint32_t depth = 0;
std::string parent_type; // 父节点类型
uint32_t child_index = 0; // 在父节点中的索引
uint32_t sibling_count = 0; // 兄弟节点总数
bool is_error = false; // 是否是错误节点
bool is_missing = false; // 是否是缺失节点
};
// ===== Token 收集选项 =====
struct TokenCollectionOptions
{
// 基本过滤选项
bool include_anonymous = true; // 包含匿名节点 (如括号、分号等)
bool include_comments = false; // 包含注释节点
bool include_whitespace = false; // 包含空白节点
bool include_errors = true; // 包含错误节点
bool include_missing = true; // 包含缺失节点
// 深度控制
uint32_t max_depth = UINT32_MAX; // 最大遍历深度
uint32_t min_depth = 0; // 最小遍历深度
// 类型过滤
std::unordered_set<std::string> include_types; // 仅包含这些类型 (为空则包含所有)
std::unordered_set<std::string> exclude_types; // 排除这些类型
// 文本过滤
bool skip_empty_text = false; // 跳过空文本节点
uint32_t min_text_length = 0; // 最小文本长度
uint32_t max_text_length = UINT32_MAX; // 最大文本长度
// 位置过滤
bool has_location_filter = false; // 是否启用位置过滤
uint32_t filter_start_line = 0; // 过滤起始行
uint32_t filter_end_line = UINT32_MAX; // 过滤结束行
// 自定义过滤器
std::function<bool(const TokenInfo&)> custom_filter;
TokenCollectionOptions() = default;
// 便捷构造函数
static TokenCollectionOptions OnlyNamed()
{
TokenCollectionOptions opts;
opts.include_anonymous = false;
return opts;
}
static TokenCollectionOptions WithComments()
{
TokenCollectionOptions opts;
opts.include_comments = true;
return opts;
}
static TokenCollectionOptions OnlyTypes(const std::vector<std::string>& types)
{
TokenCollectionOptions opts;
opts.include_types = std::unordered_set<std::string>(types.begin(), types.end());
return opts;
}
static TokenCollectionOptions ExcludeTypes(const std::vector<std::string>& types)
{
TokenCollectionOptions opts;
opts.exclude_types = std::unordered_set<std::string>(types.begin(), types.end());
return opts;
}
static TokenCollectionOptions InRange(uint32_t start_line, uint32_t end_line)
{
TokenCollectionOptions opts;
opts.has_location_filter = true;
opts.filter_start_line = start_line;
opts.filter_end_line = end_line;
return opts;
}
};
// ===== Token 收集器类 =====
class TokenCollector
{
public:
explicit TokenCollector(const TokenCollectionOptions& options = {});
~TokenCollector() = default;
TokenCollector(const TokenCollector&) = delete;
TokenCollector& operator=(const TokenCollector&) = delete;
TokenCollector(TokenCollector&&) = default;
TokenCollector& operator=(TokenCollector&&) = default;
std::vector<TokenInfo> Collect(TSNode root, const std::string& source);
std::vector<TokenInfo> CollectByType(TSNode root, const std::string& source, const std::vector<std::string>& types);
std::vector<TokenInfo> CollectInRange(TSNode root, const std::string& source, uint32_t start_line, uint32_t end_line);
std::vector<TokenInfo> CollectLeafNodes(TSNode root, const std::string& source);
std::vector<TokenInfo> FindTokensAtPosition(TSNode root, const std::string& source, uint32_t line, uint32_t column);
std::vector<TokenInfo> FindTokensByText(TSNode root, const std::string& source, const std::string& text, bool exact_match = true);
std::vector<TokenInfo> FindTokensBy(TSNode root, const std::string& source, std::function<bool(const TokenInfo&)> predicate);
uint32_t CountTokensByType(TSNode root, const std::string& source, const std::string& type);
std::vector<std::string> GetUniqueTypes(TSNode root, const std::string& source);
void SetOptions(const TokenCollectionOptions& options) { options_ = options; }
const TokenCollectionOptions& GetOptions() const { return options_; }
private:
void CollectRecursive(TSNode node, const std::string& source, std::vector<TokenInfo>& tokens, uint32_t depth, TSNode parent);
bool ShouldCollectNode(const TokenInfo& info) const;
void FillTokenInfo(TokenInfo& info, TSNode node, const std::string& source, uint32_t depth, TSNode parent) const;
bool IsCommentNode(const std::string& type) const;
bool IsWhitespaceNode(const std::string& type) const;
bool IsLocationInRange(const Location& loc) const;
private:
TokenCollectionOptions options_;
};
inline std::vector<TokenInfo> CollectNamedTokens(TSNode root, const std::string& source)
{
return TokenCollector(TokenCollectionOptions::OnlyNamed()).Collect(root, source);
}
inline std::vector<TokenInfo> CollectTokensOfType(TSNode root, const std::string& source, const std::string& type)
{
return TokenCollector().CollectByType(root, source, { type });
}
}

View File

@ -2,51 +2,22 @@
namespace lsp::language::ast::ts namespace lsp::language::ast::ts
{ {
// ==================== 字符串池实现 ==================== std::string Text(TSNode node, const std::string& source)
StringPool& StringPool::Instance()
{
static StringPool instance;
return instance;
}
std::string_view StringPool::Intern(std::string_view str)
{
auto it = pool_.find(std::string(str));
if (it != pool_.end())
{
return *it;
}
auto [inserted_it, _] = pool_.insert(std::string(str));
return *inserted_it;
}
void StringPool::Clear()
{
pool_.clear();
}
// ==================== 文本提取 ====================
std::string Text(TSNode node, std::string_view source)
{ {
uint32_t start = ts_node_start_byte(node); uint32_t start = ts_node_start_byte(node);
uint32_t end = ts_node_end_byte(node); uint32_t end = ts_node_end_byte(node);
if (start >= source.length() || end > source.length() || start >= end)
return "";
return std::string(source.substr(start, end - start));
}
std::string Text(TSNode node, const std::string& source) if (start >= end || end > source.length())
{ return "";
return Text(node, std::string_view(source));
return source.substr(start, end - start);
} }
Location NodeLocation(TSNode node) Location NodeLocation(TSNode node)
{ {
TSPoint start = ts_node_start_point(node); TSPoint start = ts_node_start_point(node);
TSPoint end = ts_node_end_point(node); TSPoint end = ts_node_end_point(node);
return Location{ return Location{
static_cast<uint32_t>(start.row), static_cast<uint32_t>(start.row),
static_cast<uint32_t>(start.column), static_cast<uint32_t>(start.column),
@ -57,84 +28,4 @@ namespace lsp::language::ast::ts
}; };
} }
std::string FieldText(TSNode node, std::string_view field_name, std::string_view source)
{
TSNode field = FieldChild(node, field_name);
return IsNull(field) ? "" : Text(field, source);
}
std::string FieldText(TSNode node, std::string_view field_name, const std::string& source)
{
return FieldText(node, field_name, std::string_view(source));
}
std::vector<TSNode> Children(TSNode node)
{
std::vector<TSNode> children;
uint32_t count = ts_node_child_count(node);
children.reserve(count);
for (uint32_t i = 0; i < count; i++)
children.push_back(ts_node_child(node, i));
return children;
}
std::vector<TSNode> FieldChildren(TSNode node, std::string_view field_name)
{
std::vector<TSNode> result;
uint32_t count = ts_node_child_count(node);
for (uint32_t i = 0; i < count; i++)
{
const char* field = ts_node_field_name_for_child(node, i);
if (field && field_name == field)
result.push_back(ts_node_child(node, i));
}
return result;
}
std::vector<NodeSummary> ChildSummaries(TSNode node, std::string_view source)
{
std::vector<NodeSummary> result;
uint32_t count = ts_node_child_count(node);
for (uint32_t i = 0; i < count; i++)
{
TSNode child = ts_node_child(node, i);
const char* field_name = ts_node_field_name_for_child(node, i);
result.push_back(NodeSummary{
.node = child,
.field = field_name ? field_name : "",
.type = std::string(Type(child)),
.text = Text(child, source) });
}
return result;
}
std::vector<NodeSummary> ChildSummaries(TSNode node, const std::string& source)
{
return ChildSummaries(node, std::string_view(source));
}
std::vector<NodeSummary> FieldSummaries(TSNode node, std::string_view field_name, std::string_view source)
{
std::vector<NodeSummary> result;
uint32_t count = ts_node_child_count(node);
for (uint32_t i = 0; i < count; i++)
{
TSNode child = ts_node_child(node, i);
const char* field = ts_node_field_name_for_child(node, i);
if (field && field_name == field)
{
result.push_back(NodeSummary{
.node = child,
.field = field,
.type = std::string(Type(child)),
.text = Text(child, source) });
}
}
return result;
}
std::vector<NodeSummary> FieldSummaries(TSNode node, std::string_view field_name, const std::string& source)
{
return FieldSummaries(node, field_name, std::string_view(source));
}
} }

View File

@ -1,8 +1,5 @@
#pragma once #pragma once
#include <string> #include <string>
#include <string_view>
#include <vector>
#include <unordered_set>
#include "./types.hpp" #include "./types.hpp"
extern "C" { extern "C" {
@ -11,94 +8,16 @@ extern "C" {
namespace lsp::language::ast::ts namespace lsp::language::ast::ts
{ {
class StringPool
{
public:
static StringPool& Instance();
// 缓存字符串,返回池中的引用
std::string_view Intern(std::string_view str);
void Clear();
size_t Size() const { return pool_.size(); }
private:
StringPool() = default;
StringPool(const StringPool&) = delete;
StringPool& operator=(const StringPool&) = delete;
private:
std::unordered_set<std::string> pool_;
};
// ==================== 节点摘要 ====================
struct NodeSummary
{
TSNode node;
std::string field;
std::string type;
std::string text;
};
// ==================== 基础查询 ====================
// 获取节点文本(使用 string_view 优化)
std::string Text(TSNode node, std::string_view source);
std::string Text(TSNode node, const std::string& source); std::string Text(TSNode node, const std::string& source);
// 获取节点类型(使用字符串池)
inline std::string_view Type(TSNode node);
// 获取节点位置
Location NodeLocation(TSNode node); Location NodeLocation(TSNode node);
// 节点状态检查
inline bool IsNull(TSNode node);
inline bool IsComment(TSNode node);
// ==================== 字段操作 ====================
inline TSNode FieldChild(TSNode node, std::string_view field_name);
std::string FieldText(TSNode node, std::string_view field_name, std::string_view source);
std::string FieldText(TSNode node, std::string_view field_name, const std::string& source);
inline bool HasField(TSNode node, std::string_view field_name);
// ==================== 子节点操作 ====================
std::vector<TSNode> Children(TSNode node);
std::vector<TSNode> FieldChildren(TSNode node, std::string_view field_name);
std::vector<NodeSummary> ChildSummaries(TSNode node, std::string_view source);
std::vector<NodeSummary> ChildSummaries(TSNode node, const std::string& source);
std::vector<NodeSummary> FieldSummaries(TSNode node, std::string_view field_name, std::string_view source);
std::vector<NodeSummary> FieldSummaries(TSNode node, std::string_view field_name, const std::string& source);
// ==================== Inline 实现 ====================
inline std::string_view Type(TSNode node)
{
const char* type_str = ts_node_type(node);
return StringPool::Instance().Intern(type_str);
}
inline bool IsNull(TSNode node)
{
return ts_node_is_null(node);
}
inline bool IsComment(TSNode node) inline bool IsComment(TSNode node)
{ {
std::string_view type = Type(node); std::string_view type = ts_node_type(node);
return type == "line_comment" || type == "block_comment" || type == "nested_comment"; return type == "line_comment" ||
type == "block_comment" ||
type == "nested_comment";
} }
inline TSNode FieldChild(TSNode node, std::string_view field_name)
{
return ts_node_child_by_field_name(node, field_name.data(), field_name.length());
}
inline bool HasField(TSNode node, std::string_view field_name)
{
return !IsNull(FieldChild(node, field_name));
}
} }

File diff suppressed because it is too large Load Diff