Files
SearchCompany/tool/csv_tool.py
2025-10-04 01:18:05 +08:00

301 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import csv
import os
import time
from typing import List, Dict, Any, Optional
class CSVTool:
def __init__(self, csv_file_name: str, headers: List[str]):
"""
初始化CSV工具
Args:
csv_file_name (str): CSV文件名
headers (List[str]): 表头列表
"""
self.csv_file_name = csv_file_name
self.headers = headers
def init_csv_file(self):
"""
初始化CSV文件如果文件不存在则创建并写入表头
"""
if not os.path.exists(self.csv_file_name):
with open(self.csv_file_name, 'w', encoding='utf-8', newline='') as f:
writer = csv.writer(f)
writer.writerow(self.headers)
# 文件存在,检查是否为空或只有空行
try:
with open(self.csv_file_name, 'r', encoding='utf-8') as f:
content = f.read()
# 如果文件内容为空或只包含空白字符(空格、换行符等)
if not content.strip():
with open(self.csv_file_name, 'w', encoding='utf-8', newline='') as f:
writer = csv.writer(f)
writer.writerow(self.headers)
return
# 文件有内容,检查第一行是否为表头
with open(self.csv_file_name, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
first_row = next(reader, None)
# 如果第一行不是预期的表头,则重新写入表头和原有内容
if first_row != self.headers:
lines = content.strip().split('\n')
# 过滤掉空行和只包含空白字符的行
non_empty_lines = [line for line in lines if line.strip()]
with open(self.csv_file_name, 'w', encoding='utf-8', newline='') as f:
writer = csv.writer(f)
writer.writerow(self.headers)
# 如果有非空内容,则写入
if non_empty_lines:
f.write('\n'.join(non_empty_lines) + '\n')
except Exception as e:
print(f"检查/更新表头时出错: {e}")
def get_existing_data(self, unique_titles: List[str]) -> set:
"""
读取现有数据,用于去重检查
Args:
unique_titles (List[str]): 用于唯一性检查的列名列表
Returns:
set: 包含唯一标识符元组的集合
"""
existing_data = set()
if not os.path.exists(self.csv_file_name):
return existing_data
try:
with open(self.csv_file_name, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
header_row = next(reader, None) # 读取表头
if header_row is None:
return existing_data
# 获取唯一列的索引
unique_indices = []
for title in unique_titles:
try:
index = header_row.index(title)
unique_indices.append(index)
except ValueError:
print(f"警告: 表头中未找到列 '{title}'")
continue
# 读取数据行
for row in reader:
if len(row) >= len(header_row): # 确保行数据完整
# 提取唯一标识符
unique_values = tuple(row[i] if i < len(row) else "" for i in unique_indices)
existing_data.add(unique_values)
except Exception as e:
print(f"读取CSV文件时出错: {e}")
return existing_data
def save_data(self, data_list: List[Dict[str, Any]], unique_titles: List[str], create_time: bool = True) -> int:
"""
将数据保存到CSV文件中自动去重
Args:
data_list (List[Dict[str, Any]]): 要保存的数据列表
unique_titles (List[str]): 用于唯一性检查的列名列表
create_time (bool): 是否自动添加创建时间默认为True
Returns:
int: 实际写入的行数
"""
if not data_list:
print('数据列表为空,没有数据可写入')
return 0
# 初始化文件
self.init_csv_file()
# 获取现有数据用于去重
existing_data = self.get_existing_data(unique_titles)
# 准备写入的数据
rows_to_write = []
written_count = 0
for data_node in data_list:
# 构建唯一标识符元组
unique_values = tuple(data_node.get(title, "") for title in unique_titles)
# 检查是否已存在
if unique_values in existing_data:
continue # 跳过已存在的数据
# 构建行数据
row_data = []
for header in self.headers:
if header == 'create_time' and create_time:
# 自动添加创建时间
row_data.append(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
else:
# 从数据节点获取对应值
row_data.append(data_node.get(header, ""))
rows_to_write.append(row_data)
existing_data.add(unique_values) # 添加到已存在数据集合中,避免本次写入中的重复
written_count += 1
# 写入数据
if rows_to_write:
with open(self.csv_file_name, 'a', encoding='utf-8', newline='') as f:
writer = csv.writer(f)
writer.writerows(rows_to_write)
print(f"成功写入 {written_count} 行数据到 {self.csv_file_name}")
else:
print("没有新数据需要写入")
return written_count
def query_data(self, filter_func=None) -> List[Dict[str, str]]:
"""
查询CSV文件中的数据
Args:
filter_func (callable, optional): 过滤函数接受一行数据字典作为参数返回True/False
Returns:
List[Dict[str, str]]: 查询结果列表
"""
if not os.path.exists(self.csv_file_name):
return []
result = []
try:
with open(self.csv_file_name, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
header_row = next(reader, None) # 读取表头
if header_row is None:
return result
for row in reader:
if len(row) >= len(header_row): # 确保行数据完整
# 将行数据转换为字典
row_dict = dict(zip(header_row, row))
# 应用过滤条件
if filter_func is None or filter_func(row_dict):
result.append(row_dict)
except Exception as e:
print(f"查询CSV文件时出错: {e}")
return result
def query_by_conditions(self, **kwargs) -> List[Dict[str, str]]:
"""
根据条件查询数据
Args:
**kwargs: 查询条件,键值对形式
Returns:
List[Dict[str, str]]: 查询结果列表
"""
def filter_func(row_dict):
for key, value in kwargs.items():
if key in row_dict and row_dict[key] != value:
return False
return True
return self.query_data(filter_func)
def get_all_data(self) -> List[Dict[str, str]]:
"""
获取所有数据
Returns:
List[Dict[str, str]]: 所有数据列表
"""
return self.query_data()
# 保持向后兼容的函数
def save_to_csv(filter_list: List[Dict[str, Any]],
csv_file_name: str = 'company_search_result_data.csv',
headers: List[str] = None,
unique_titles: List[str] = None) -> bool:
"""
将结果追加写入csv文件中向后兼容函数
Args:
filter_list: 需要写入的数据列表
csv_file_name: CSV文件名
headers: 表头列表
unique_titles: 用于唯一性检查的列名列表
Returns:
bool: 是否成功写入
"""
if headers is None:
headers = ['title', 'url', 'web_site_type', 'request_url', 'company_name', 'create_time']
if unique_titles is None:
unique_titles = ['company_name', 'web_site_type']
try:
csv_tool = CSVTool(csv_file_name, headers)
written_count = csv_tool.save_data(filter_list, unique_titles)
return written_count > 0
except Exception as e:
print(f"保存CSV时出错: {e}")
return False
# 使用示例:
if __name__ == "__main__":
# 示例数据
sample_data = [
{
'title': '测试公司1',
'url': 'https://example.com/1',
'web_site_type': 'aiqicha',
'request_url': 'https://bing.com/search?q=测试公司1',
'company_name': '测试公司1'
},
{
'title': '测试公司2',
'url': 'https://example.com/2',
'web_site_type': 'qcc',
'request_url': 'https://bing.com/search?q=测试公司2',
'company_name': '测试公司2'
}
]
# 创建CSV工具实例
csv_tool = CSVTool(
csv_file_name='test_data.csv',
headers=['title', 'url', 'web_site_type', 'request_url', 'company_name', 'create_time']
)
# 保存数据
csv_tool.save_data(sample_data, unique_titles=['company_name', 'web_site_type'])
# 查询所有数据
all_data = csv_tool.get_all_data()
print("所有数据:", all_data)
# 根据条件查询
filtered_data = csv_tool.query_by_conditions(web_site_type='aiqicha')
print("查询结果:", filtered_data)