tsl-devkit/lsp-server/test/test_tree_sitter/test/extract_sql.py

259 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
递归读取文件,提取 SQL 语句:
- select/vselect/sselect ... end; (不匹配字符串中的关键字select和end之间不能有分号)
- update ... end; (update和end之间不能有分号)
- delete ... ; (delete到分号)
- insert ... ; (insert到分号)
- 按文件分组输出
"""
import os
import re
from pathlib import Path
def extract_sql_statements(content):
"""
从文本内容中提取 SQL 语句:
- select/vselect/sselect ... end; (select和end之间不能有分号)
- update ... end; (update和end之间不能有分号)
- delete ... ; (delete到分号)
- insert ... ; (insert到分号)
不会匹配字符串中的关键字
Args:
content: 文件内容字符串
Returns:
按类型分类的SQL语句字典
"""
# 将所有字符串替换为占位符(保持长度不变)
def replace_string_content(match):
quote = match.group(0)[0]
length = len(match.group(0))
return quote + ' ' * (length - 2) + quote
# 将注释替换为空格(保持长度不变)
def replace_comment(match):
return ' ' * len(match.group(0))
# 替换单引号字符串
content_cleaned = re.sub(r"'(?:[^'\\]|\\.)*?'", replace_string_content, content)
# 替换双引号字符串
content_cleaned = re.sub(r'"(?:[^"\\]|\\.)*?"', replace_string_content, content_cleaned)
# 替换单行注释 //
content_cleaned = re.sub(r'//.*?$', replace_comment, content_cleaned, flags=re.MULTILINE)
# 替换多行注释 /* */
content_cleaned = re.sub(r'/\*.*?\*/', replace_comment, content_cleaned, flags=re.DOTALL)
results = {
'select': [],
'sselect': [],
'vselect': [],
'update': [],
'delete': [],
'insert': []
}
# 在清理后的内容中匹配,但从原始内容中提取
# 匹配 vselect ... end; (中间不能有分号)
vselect_pattern = r'(?i)\bvselect\b(?!\s*\()[^;]*?end\s*;'
for match in re.finditer(vselect_pattern, content_cleaned, re.DOTALL):
original_text = content[match.start():match.end()].strip()
results['vselect'].append(original_text)
# 匹配 sselect ... end; (中间不能有分号)
sselect_pattern = r'(?i)\bsselect\b(?!\s*\()[^;]*?end\s*;'
for match in re.finditer(sselect_pattern, content_cleaned, re.DOTALL):
original_text = content[match.start():match.end()].strip()
results['sselect'].append(original_text)
# 匹配 select ... end; (但不包括 vselect 和 sselect中间不能有分号)
select_pattern = r'(?i)(?<![vs])\bselect\b(?!\s*\()[^;]*?end\s*;'
for match in re.finditer(select_pattern, content_cleaned, re.DOTALL):
original_text = content[match.start():match.end()].strip()
results['select'].append(original_text)
# 匹配 update ... end; (中间不能有分号)
update_pattern = r'(?i)\bupdate\b(?!\s*\()[^;]*?end\s*;'
for match in re.finditer(update_pattern, content_cleaned, re.DOTALL):
original_text = content[match.start():match.end()].strip()
results['update'].append(original_text)
# 匹配 delete ... ; (到第一个分号为止)
delete_pattern = r'(?i)\bdelete\b(?!\s*\()(?:\s+from\b)[^;]*?;'
for match in re.finditer(delete_pattern, content_cleaned, re.DOTALL):
original_text = content[match.start():match.end()].strip()
results['delete'].append(original_text)
# 匹配 insert ... ; (到第一个分号为止)
insert_pattern = r'(?i)\binsert\b(?!\s*\()[^;]*?;'
for match in re.finditer(insert_pattern, content_cleaned, re.DOTALL):
original_text = content[match.start():match.end()].strip()
results['insert'].append(original_text)
return results
def read_files_recursively(directory, output_file, file_extensions=None):
"""
递归读取目录下的所有文件提取SQL语句并按文件分组输出
Args:
directory: 要扫描的目录路径
output_file: 输出文件路径
file_extensions: 要处理的文件扩展名列表,如 ['.sql', '.txt']None表示所有文件
"""
# 按文件存储SQL语句
# 结构: {file_path: {'select': [...], 'sselect': [...], 'vselect': [...]}}
files_statements = {}
processed_files = 0
error_files = []
# 递归遍历目录
for root, dirs, files in os.walk(directory, followlinks=True):
for filename in files:
file_path = os.path.join(root, filename)
# 跳过输出文件本身
if os.path.abspath(file_path) == os.path.abspath(output_file):
continue
# 如果指定了文件扩展名,只处理匹配的文件
if file_extensions:
if not any(filename.lower().endswith(ext.lower()) for ext in file_extensions):
continue
try:
# 尝试多种编码读取文件
encodings = ['utf-8', 'gbk', 'gb2312', 'latin-1']
content = None
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as f:
content = f.read()
break
except UnicodeDecodeError:
continue
if content is None:
print(f"警告: 无法读取文件 {file_path} (编码问题)")
error_files.append(file_path)
continue
# 提取SQL语句按类型分类
matches = extract_sql_statements(content)
# 统计该文件中的语句数量
total_matches = sum(len(stmts) for stmts in matches.values())
if total_matches > 0:
files_statements[file_path] = matches
print(f"{file_path}: 找到 {total_matches} 条语句 "
f"(select: {len(matches['select'])}, "
f"sselect: {len(matches['sselect'])}, "
f"vselect: {len(matches['vselect'])}, "
f"update: {len(matches['update'])}, "
f"delete: {len(matches['delete'])}, "
f"insert: {len(matches['insert'])})")
processed_files += 1
except Exception as e:
print(f"错误: 处理文件 {file_path} 时出错: {str(e)}")
error_files.append(file_path)
# 计算总数
total_select = sum(len(stmts['select']) for stmts in files_statements.values())
total_sselect = sum(len(stmts['sselect']) for stmts in files_statements.values())
total_vselect = sum(len(stmts['vselect']) for stmts in files_statements.values())
total_update = sum(len(stmts['update']) for stmts in files_statements.values())
total_delete = sum(len(stmts['delete']) for stmts in files_statements.values())
total_insert = sum(len(stmts['insert']) for stmts in files_statements.values())
total_statements = total_select + total_sselect + total_vselect + total_update + total_delete + total_insert
# 写入输出文件(按文件分组输出)
if total_statements > 0:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(f"// 共提取到 {total_statements} 条SQL语句\n")
f.write(f"// 处理了 {processed_files} 个文件\n")
f.write(f"// SELECT: {total_select}\n")
f.write(f"// SSELECT: {total_sselect}\n")
f.write(f"// VSELECT: {total_vselect}\n")
f.write(f"// UPDATE: {total_update}\n")
f.write(f"// DELETE: {total_delete}\n")
f.write(f"// INSERT: {total_insert}\n")
f.write("//" + "=" * 80 + "\n")
# 按文件输出
for file_path, statements in sorted(files_statements.items()):
file_total = sum(len(stmts) for stmts in statements.values())
f.write("\n")
f.write("//" + "=" * 80 + "\n")
f.write(f"// 文件: {file_path}\n")
f.write(f"// 共 {file_total} 条语句 "
f"(SELECT: {len(statements['select'])}, "
f"SSELECT: {len(statements['sselect'])}, "
f"VSELECT: {len(statements['vselect'])}, "
f"UPDATE: {len(statements['update'])}, "
f"DELETE: {len(statements['delete'])}, "
f"INSERT: {len(statements['insert'])})\n")
f.write("//" + "=" * 80 + "\n")
# 按顺序输出select -> sselect -> vselect -> update -> delete -> insert
for stmt_type in ['select', 'sselect', 'vselect', 'update', 'delete', 'insert']:
stmts = statements[stmt_type]
if stmts:
for idx, stmt in enumerate(stmts, 1):
f.write(stmt)
f.write("\n")
print(f"\n✓ 成功提取 {total_statements} 条SQL语句")
print(f" - SELECT: {total_select}")
print(f" - SSELECT: {total_sselect}")
print(f" - VSELECT: {total_vselect}")
print(f" - UPDATE: {total_update}")
print(f" - DELETE: {total_delete}")
print(f" - INSERT: {total_insert}")
print(f"✓ 结果已保存到: {output_file}")
else:
print(f"\n! 没有找到匹配的SQL语句")
if error_files:
print(f"\n! 有 {len(error_files)} 个文件处理失败")
def main():
"""主函数"""
# 配置参数
source_directory = "/mnt/c/Programs/Tinysoft/TSLGen2/funcext" # 当前目录,可以修改为你的目标目录
output_file = "test_sql_statements.tsf"
# 可以指定只处理特定扩展名的文件,如 ['.sql', '.txt']
# 设为 None 则处理所有文件
file_extensions = ['.tsf']
print(f"开始扫描目录: {os.path.abspath(source_directory)}")
print(f"输出文件: {output_file}")
print(f"文件类型: {file_extensions if file_extensions else '所有文件'}")
print("-" * 80)
# 检查目录是否存在
if not os.path.exists(source_directory):
print(f"错误: 目录不存在: {source_directory}")
return
# 执行提取
read_files_recursively(source_directory, output_file, file_extensions)
if __name__ == "__main__":
main()