267 lines
8.4 KiB
Python
267 lines
8.4 KiB
Python
import csv
|
||
import os
|
||
import time
|
||
from typing import List, Dict, Any, Optional
|
||
|
||
|
||
class CSVTool:
|
||
def __init__(self, csv_file_name: str, headers: List[str]):
|
||
"""
|
||
初始化CSV工具
|
||
|
||
Args:
|
||
csv_file_name (str): CSV文件名
|
||
headers (List[str]): 表头列表
|
||
"""
|
||
self.csv_file_name = csv_file_name
|
||
self.headers = headers
|
||
|
||
def init_csv_file(self):
|
||
"""
|
||
初始化CSV文件,如果文件不存在则创建并写入表头
|
||
"""
|
||
if not os.path.exists(self.csv_file_name):
|
||
with open(self.csv_file_name, 'w', encoding='utf-8', newline='') as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow(self.headers)
|
||
|
||
def get_existing_data(self, unique_titles: List[str]) -> set:
|
||
"""
|
||
读取现有数据,用于去重检查
|
||
|
||
Args:
|
||
unique_titles (List[str]): 用于唯一性检查的列名列表
|
||
|
||
Returns:
|
||
set: 包含唯一标识符元组的集合
|
||
"""
|
||
existing_data = set()
|
||
|
||
if not os.path.exists(self.csv_file_name):
|
||
return existing_data
|
||
|
||
try:
|
||
with open(self.csv_file_name, 'r', encoding='utf-8') as f:
|
||
reader = csv.reader(f)
|
||
header_row = next(reader, None) # 读取表头
|
||
|
||
if header_row is None:
|
||
return existing_data
|
||
|
||
# 获取唯一列的索引
|
||
unique_indices = []
|
||
for title in unique_titles:
|
||
try:
|
||
index = header_row.index(title)
|
||
unique_indices.append(index)
|
||
except ValueError:
|
||
print(f"警告: 表头中未找到列 '{title}'")
|
||
continue
|
||
|
||
# 读取数据行
|
||
for row in reader:
|
||
if len(row) >= len(header_row): # 确保行数据完整
|
||
# 提取唯一标识符
|
||
unique_values = tuple(row[i] if i < len(row) else "" for i in unique_indices)
|
||
existing_data.add(unique_values)
|
||
|
||
except Exception as e:
|
||
print(f"读取CSV文件时出错: {e}")
|
||
|
||
return existing_data
|
||
|
||
def save_data(self, data_list: List[Dict[str, Any]], unique_titles: List[str], create_time: bool = True) -> int:
|
||
"""
|
||
将数据保存到CSV文件中,自动去重
|
||
|
||
Args:
|
||
data_list (List[Dict[str, Any]]): 要保存的数据列表
|
||
unique_titles (List[str]): 用于唯一性检查的列名列表
|
||
create_time (bool): 是否自动添加创建时间,默认为True
|
||
|
||
Returns:
|
||
int: 实际写入的行数
|
||
"""
|
||
if not data_list:
|
||
print('数据列表为空,没有数据可写入')
|
||
return 0
|
||
|
||
# 初始化文件
|
||
self.init_csv_file()
|
||
|
||
# 获取现有数据用于去重
|
||
existing_data = self.get_existing_data(unique_titles)
|
||
|
||
# 准备写入的数据
|
||
rows_to_write = []
|
||
written_count = 0
|
||
|
||
for data_node in data_list:
|
||
# 构建唯一标识符元组
|
||
unique_values = tuple(data_node.get(title, "") for title in unique_titles)
|
||
|
||
# 检查是否已存在
|
||
if unique_values in existing_data:
|
||
continue # 跳过已存在的数据
|
||
|
||
# 构建行数据
|
||
row_data = []
|
||
for header in self.headers:
|
||
if header == 'create_time' and create_time:
|
||
# 自动添加创建时间
|
||
row_data.append(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||
else:
|
||
# 从数据节点获取对应值
|
||
row_data.append(data_node.get(header, ""))
|
||
|
||
rows_to_write.append(row_data)
|
||
existing_data.add(unique_values) # 添加到已存在数据集合中,避免本次写入中的重复
|
||
written_count += 1
|
||
|
||
# 写入数据
|
||
if rows_to_write:
|
||
with open(self.csv_file_name, 'a', encoding='utf-8', newline='') as f:
|
||
writer = csv.writer(f)
|
||
writer.writerows(rows_to_write)
|
||
|
||
print(f"成功写入 {written_count} 行数据到 {self.csv_file_name}")
|
||
else:
|
||
print("没有新数据需要写入")
|
||
|
||
return written_count
|
||
|
||
def query_data(self, filter_func=None) -> List[Dict[str, str]]:
|
||
"""
|
||
查询CSV文件中的数据
|
||
|
||
Args:
|
||
filter_func (callable, optional): 过滤函数,接受一行数据字典作为参数,返回True/False
|
||
|
||
Returns:
|
||
List[Dict[str, str]]: 查询结果列表
|
||
"""
|
||
if not os.path.exists(self.csv_file_name):
|
||
return []
|
||
|
||
result = []
|
||
|
||
try:
|
||
with open(self.csv_file_name, 'r', encoding='utf-8') as f:
|
||
reader = csv.reader(f)
|
||
header_row = next(reader, None) # 读取表头
|
||
|
||
if header_row is None:
|
||
return result
|
||
|
||
for row in reader:
|
||
if len(row) >= len(header_row): # 确保行数据完整
|
||
# 将行数据转换为字典
|
||
row_dict = dict(zip(header_row, row))
|
||
|
||
# 应用过滤条件
|
||
if filter_func is None or filter_func(row_dict):
|
||
result.append(row_dict)
|
||
|
||
except Exception as e:
|
||
print(f"查询CSV文件时出错: {e}")
|
||
|
||
return result
|
||
|
||
def query_by_conditions(self, **kwargs) -> List[Dict[str, str]]:
|
||
"""
|
||
根据条件查询数据
|
||
|
||
Args:
|
||
**kwargs: 查询条件,键值对形式
|
||
|
||
Returns:
|
||
List[Dict[str, str]]: 查询结果列表
|
||
"""
|
||
|
||
def filter_func(row_dict):
|
||
for key, value in kwargs.items():
|
||
if key in row_dict and row_dict[key] != value:
|
||
return False
|
||
return True
|
||
|
||
return self.query_data(filter_func)
|
||
|
||
def get_all_data(self) -> List[Dict[str, str]]:
|
||
"""
|
||
获取所有数据
|
||
|
||
Returns:
|
||
List[Dict[str, str]]: 所有数据列表
|
||
"""
|
||
return self.query_data()
|
||
|
||
|
||
# 保持向后兼容的函数
|
||
def save_to_csv(filter_list: List[Dict[str, Any]],
|
||
csv_file_name: str = 'company_search_result_data.csv',
|
||
headers: List[str] = None,
|
||
unique_titles: List[str] = None) -> bool:
|
||
"""
|
||
将结果追加写入csv文件中(向后兼容函数)
|
||
|
||
Args:
|
||
filter_list: 需要写入的数据列表
|
||
csv_file_name: CSV文件名
|
||
headers: 表头列表
|
||
unique_titles: 用于唯一性检查的列名列表
|
||
|
||
Returns:
|
||
bool: 是否成功写入
|
||
"""
|
||
if headers is None:
|
||
headers = ['title', 'url', 'web_site_type', 'request_url', 'company_name', 'create_time']
|
||
|
||
if unique_titles is None:
|
||
unique_titles = ['company_name', 'web_site_type']
|
||
|
||
try:
|
||
csv_tool = CSVTool(csv_file_name, headers)
|
||
written_count = csv_tool.save_data(filter_list, unique_titles)
|
||
return written_count > 0
|
||
except Exception as e:
|
||
print(f"保存CSV时出错: {e}")
|
||
return False
|
||
|
||
|
||
# 使用示例:
|
||
if __name__ == "__main__":
|
||
# 示例数据
|
||
sample_data = [
|
||
{
|
||
'title': '测试公司1',
|
||
'url': 'https://example.com/1',
|
||
'web_site_type': 'aiqicha',
|
||
'request_url': 'https://bing.com/search?q=测试公司1',
|
||
'company_name': '测试公司1'
|
||
},
|
||
{
|
||
'title': '测试公司2',
|
||
'url': 'https://example.com/2',
|
||
'web_site_type': 'qcc',
|
||
'request_url': 'https://bing.com/search?q=测试公司2',
|
||
'company_name': '测试公司2'
|
||
}
|
||
]
|
||
|
||
# 创建CSV工具实例
|
||
csv_tool = CSVTool(
|
||
csv_file_name='test_data.csv',
|
||
headers=['title', 'url', 'web_site_type', 'request_url', 'company_name', 'create_time']
|
||
)
|
||
|
||
# 保存数据
|
||
csv_tool.save_data(sample_data, unique_titles=['company_name', 'web_site_type'])
|
||
|
||
# 查询所有数据
|
||
all_data = csv_tool.get_all_data()
|
||
print("所有数据:", all_data)
|
||
|
||
# 根据条件查询
|
||
filtered_data = csv_tool.query_by_conditions(web_site_type='aiqicha')
|
||
print("查询结果:", filtered_data)
|