mirror of
https://github.com/LC044/WeChatMsg
synced 2025-04-06 04:18:02 +08:00
优化:实现数据库连接池和批量操作,减少I/O开销
1. 新增 db_pool.py 实现数据库连接池管理 2. 修改 __init__.py 导入连接池并提供关闭连接池方法 3. 优化 msg.py 中的数据库操作,提高并发性能 在处理大量数据导入时,推荐使用batch_insert_messages方法 在应用退出时应调用close_db确保资源释放 对于大量查询操作,可以进一步优化查询SQL和索引
This commit is contained in:
parent
fc1e2fa7a5
commit
9d599199fe
@ -13,6 +13,7 @@ from .media_msg import MediaMsg
|
|||||||
from .misc import Misc
|
from .misc import Misc
|
||||||
from .msg import Msg
|
from .msg import Msg
|
||||||
from .msg import MsgType
|
from .msg import MsgType
|
||||||
|
from .db_pool import db_pool, close_db_pool
|
||||||
|
|
||||||
misc_db = Misc()
|
misc_db = Misc()
|
||||||
msg_db = Msg()
|
msg_db = Msg()
|
||||||
@ -22,14 +23,18 @@ media_msg_db = MediaMsg()
|
|||||||
|
|
||||||
|
|
||||||
def close_db():
|
def close_db():
|
||||||
|
"""关闭所有数据库连接"""
|
||||||
misc_db.close()
|
misc_db.close()
|
||||||
msg_db.close()
|
msg_db.close()
|
||||||
micro_msg_db.close()
|
micro_msg_db.close()
|
||||||
hard_link_db.close()
|
hard_link_db.close()
|
||||||
media_msg_db.close()
|
media_msg_db.close()
|
||||||
|
# 关闭数据库连接池
|
||||||
|
close_db_pool()
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
def init_db():
|
||||||
|
"""初始化所有数据库连接"""
|
||||||
misc_db.init_database()
|
misc_db.init_database()
|
||||||
msg_db.init_database()
|
msg_db.init_database()
|
||||||
micro_msg_db.init_database()
|
micro_msg_db.init_database()
|
||||||
@ -37,4 +42,4 @@ def init_db():
|
|||||||
media_msg_db.init_database()
|
media_msg_db.init_database()
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['misc_db', 'micro_msg_db', 'msg_db', 'hard_link_db', 'MsgType', "media_msg_db", "close_db"]
|
__all__ = ['misc_db', 'micro_msg_db', 'msg_db', 'hard_link_db', 'MsgType', "media_msg_db", "close_db", "db_pool"]
|
||||||
|
256
app/DataBase/db_pool.py
Normal file
256
app/DataBase/db_pool.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional, List, Tuple
|
||||||
|
|
||||||
|
class DatabaseConnectionPool:
|
||||||
|
"""
|
||||||
|
SQLite数据库连接池,用于管理多个数据库连接,减少连接创建和销毁的开销
|
||||||
|
"""
|
||||||
|
_instance = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super(DatabaseConnectionPool, cls).__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, max_connections=5, timeout=5):
|
||||||
|
# 保证只初始化一次
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
self.max_connections = max_connections
|
||||||
|
self.timeout = timeout
|
||||||
|
self.pools: Dict[str, queue.Queue] = {}
|
||||||
|
self.in_use: Dict[str, Dict[sqlite3.Connection, threading.Thread]] = {}
|
||||||
|
self.connection_locks: Dict[str, threading.Lock] = {}
|
||||||
|
|
||||||
|
def _create_connection(self, db_path: str) -> sqlite3.Connection:
|
||||||
|
"""创建一个新的数据库连接"""
|
||||||
|
if not os.path.exists(db_path):
|
||||||
|
raise FileNotFoundError(f"数据库文件不存在: {db_path}")
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||||
|
# 开启外键约束
|
||||||
|
conn.execute('PRAGMA foreign_keys = ON')
|
||||||
|
# 启用写入确认,提高安全性
|
||||||
|
conn.execute('PRAGMA synchronous = NORMAL')
|
||||||
|
# 提高写入性能
|
||||||
|
conn.execute('PRAGMA journal_mode = WAL')
|
||||||
|
# 设置页缓存
|
||||||
|
conn.execute('PRAGMA cache_size = 10000')
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _get_pool(self, db_path: str) -> queue.Queue:
|
||||||
|
"""获取或创建指定数据库的连接池"""
|
||||||
|
if db_path not in self.pools:
|
||||||
|
with self._lock:
|
||||||
|
if db_path not in self.pools:
|
||||||
|
self.pools[db_path] = queue.Queue(maxsize=self.max_connections)
|
||||||
|
self.in_use[db_path] = {}
|
||||||
|
self.connection_locks[db_path] = threading.Lock()
|
||||||
|
|
||||||
|
# 预创建连接
|
||||||
|
for _ in range(min(2, self.max_connections)):
|
||||||
|
try:
|
||||||
|
conn = self._create_connection(db_path)
|
||||||
|
self.pools[db_path].put(conn)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"预创建连接失败: {e}")
|
||||||
|
|
||||||
|
return self.pools[db_path]
|
||||||
|
|
||||||
|
def get_connection(self, db_path: str) -> sqlite3.Connection:
|
||||||
|
"""
|
||||||
|
从连接池获取一个数据库连接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: 数据库文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sqlite3.Connection: 数据库连接对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: 超时未获取到连接
|
||||||
|
"""
|
||||||
|
pool = self._get_pool(db_path)
|
||||||
|
|
||||||
|
# 尝试从池中获取连接
|
||||||
|
try:
|
||||||
|
conn = pool.get(block=True, timeout=self.timeout)
|
||||||
|
except queue.Empty:
|
||||||
|
# 如果池已满但仍在使用的连接数小于最大连接数,则创建新连接
|
||||||
|
with self.connection_locks[db_path]:
|
||||||
|
if len(self.in_use[db_path]) < self.max_connections:
|
||||||
|
conn = self._create_connection(db_path)
|
||||||
|
else:
|
||||||
|
raise TimeoutError(f"无法获取数据库连接,连接池已满: {db_path}")
|
||||||
|
|
||||||
|
# 记录连接使用情况
|
||||||
|
with self.connection_locks[db_path]:
|
||||||
|
self.in_use[db_path][conn] = threading.current_thread()
|
||||||
|
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def release_connection(self, db_path: str, conn: sqlite3.Connection):
|
||||||
|
"""
|
||||||
|
释放数据库连接回连接池
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: 数据库文件路径
|
||||||
|
conn: 要释放的连接
|
||||||
|
"""
|
||||||
|
if db_path not in self.pools:
|
||||||
|
conn.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
with self.connection_locks[db_path]:
|
||||||
|
if conn in self.in_use[db_path]:
|
||||||
|
del self.in_use[db_path][conn]
|
||||||
|
try:
|
||||||
|
# 将连接放回池中
|
||||||
|
self.pools[db_path].put(conn, block=False)
|
||||||
|
except queue.Full:
|
||||||
|
# 如果池已满,关闭多余的连接
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def execute_batch(self, db_path: str, sql: str, params_list: List[tuple], commit=True) -> List[Optional[Tuple]]:
|
||||||
|
"""
|
||||||
|
执行批量SQL操作,适用于多次执行相同SQL语句的情况
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: 数据库文件路径
|
||||||
|
sql: SQL语句
|
||||||
|
params_list: 参数列表,每个元素是一个参数元组
|
||||||
|
commit: 是否自动提交事务
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 执行结果列表
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
results = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = self.get_connection(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 启动事务
|
||||||
|
if commit:
|
||||||
|
conn.execute("BEGIN TRANSACTION")
|
||||||
|
|
||||||
|
# 批量执行
|
||||||
|
for params in params_list:
|
||||||
|
cursor.execute(sql, params)
|
||||||
|
if cursor.description: # 如果有返回数据
|
||||||
|
results.append(cursor.fetchall())
|
||||||
|
else:
|
||||||
|
results.append(None)
|
||||||
|
|
||||||
|
# 提交事务
|
||||||
|
if commit:
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
if conn and commit:
|
||||||
|
conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
self.release_connection(db_path, conn)
|
||||||
|
|
||||||
|
def execute_query(self, db_path: str, sql: str, params=None) -> List[Tuple]:
|
||||||
|
"""
|
||||||
|
执行查询SQL语句
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: 数据库文件路径
|
||||||
|
sql: SQL查询语句
|
||||||
|
params: 查询参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 查询结果列表
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = self.get_connection(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
if params:
|
||||||
|
cursor.execute(sql, params)
|
||||||
|
else:
|
||||||
|
cursor.execute(sql)
|
||||||
|
|
||||||
|
return cursor.fetchall()
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
self.release_connection(db_path, conn)
|
||||||
|
|
||||||
|
def execute_update(self, db_path: str, sql: str, params=None) -> int:
|
||||||
|
"""
|
||||||
|
执行更新SQL语句
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: 数据库文件路径
|
||||||
|
sql: SQL更新语句
|
||||||
|
params: 更新参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 受影响的行数
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = self.get_connection(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
if params:
|
||||||
|
cursor.execute(sql, params)
|
||||||
|
else:
|
||||||
|
cursor.execute(sql)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount
|
||||||
|
except Exception as e:
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
self.release_connection(db_path, conn)
|
||||||
|
|
||||||
|
def close_all(self):
|
||||||
|
"""关闭所有连接池中的连接"""
|
||||||
|
with self._lock:
|
||||||
|
for db_path, pool in self.pools.items():
|
||||||
|
# 关闭所有未使用的连接
|
||||||
|
while not pool.empty():
|
||||||
|
try:
|
||||||
|
conn = pool.get(block=False)
|
||||||
|
conn.close()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 关闭所有正在使用的连接
|
||||||
|
with self.connection_locks[db_path]:
|
||||||
|
for conn in list(self.in_use[db_path].keys()):
|
||||||
|
try:
|
||||||
|
conn.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
self.in_use[db_path].clear()
|
||||||
|
|
||||||
|
# 清空池
|
||||||
|
self.pools.clear()
|
||||||
|
|
||||||
|
# 全局连接池实例
|
||||||
|
db_pool = DatabaseConnectionPool()
|
||||||
|
|
||||||
|
def close_db_pool():
|
||||||
|
"""关闭数据库连接池中的所有连接"""
|
||||||
|
db_pool.close_all()
|
@ -5,8 +5,9 @@ import threading
|
|||||||
import traceback
|
import traceback
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, date
|
from datetime import datetime, date
|
||||||
from typing import Tuple
|
from typing import Tuple, List, Optional, Dict, Any
|
||||||
|
|
||||||
|
from app.DataBase.db_pool import db_pool
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.util.compress_content import parser_reply
|
from app.util.compress_content import parser_reply
|
||||||
from app.util.protocbuf.msg_pb2 import MessageBytesExtra
|
from app.util.protocbuf.msg_pb2 import MessageBytesExtra
|
||||||
@ -140,8 +141,6 @@ class MsgType:
|
|||||||
|
|
||||||
class Msg:
|
class Msg:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.DB = None
|
|
||||||
self.cursor = None
|
|
||||||
self.open_flag = False
|
self.open_flag = False
|
||||||
self.init_database()
|
self.init_database()
|
||||||
|
|
||||||
@ -151,9 +150,6 @@ class Msg:
|
|||||||
if path:
|
if path:
|
||||||
db_path = path
|
db_path = path
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
self.DB = sqlite3.connect(db_path, check_same_thread=False)
|
|
||||||
# '''创建游标'''
|
|
||||||
self.cursor = self.DB.cursor()
|
|
||||||
self.open_flag = True
|
self.open_flag = True
|
||||||
if lock.locked():
|
if lock.locked():
|
||||||
lock.release()
|
lock.release()
|
||||||
@ -200,48 +196,137 @@ class Msg:
|
|||||||
a[10]: BytesExtra,
|
a[10]: BytesExtra,
|
||||||
a[11]: CompressContent,
|
a[11]: CompressContent,
|
||||||
a[12]: DisplayContent,
|
a[12]: DisplayContent,
|
||||||
a[13]: 联系人的类(如果是群聊就有,不是的话没有这个字段)
|
|
||||||
"""
|
"""
|
||||||
if not self.open_flag:
|
if not self.open_flag:
|
||||||
return None
|
return []
|
||||||
if time_range:
|
|
||||||
start_time, end_time = convert_to_timestamp(time_range)
|
begin_time, end_time = convert_to_timestamp(time_range)
|
||||||
sql = f'''
|
|
||||||
select localId,TalkerId,Type,SubType,IsSender,CreateTime,Status,StrContent,strftime('%Y-%m-%d %H:%M:%S',CreateTime,'unixepoch','localtime') as StrTime,MsgSvrID,BytesExtra,CompressContent,DisplayContent
|
sql = '''
|
||||||
from MSG
|
SELECT
|
||||||
where StrTalker=?
|
localId,
|
||||||
{'AND CreateTime>' + str(start_time) + ' AND CreateTime<' + str(end_time) if time_range else ''}
|
TalkerId,
|
||||||
order by CreateTime
|
Type,
|
||||||
|
SubType,
|
||||||
|
IsSender,
|
||||||
|
CreateTime,
|
||||||
|
Status,
|
||||||
|
StrContent,
|
||||||
|
strftime('%Y-%m-%d %H:%M:%S', datetime(CreateTime, 'unixepoch', 'localtime')),
|
||||||
|
MsgSvrID,
|
||||||
|
BytesExtra,
|
||||||
|
CompressContent,
|
||||||
|
DisplayContent
|
||||||
|
FROM MSG
|
||||||
|
WHERE StrTalker = ?
|
||||||
|
AND (? = 0 OR CreateTime >= ?)
|
||||||
|
AND (? = 0 OR CreateTime <= ?)
|
||||||
|
ORDER BY CreateTime DESC
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
params = (username_, begin_time, begin_time, end_time, end_time)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lock.acquire(True)
|
results = db_pool.execute_query(db_path, sql, params)
|
||||||
self.cursor.execute(sql, [username_])
|
|
||||||
result = self.cursor.fetchall()
|
# 处理群聊信息
|
||||||
finally:
|
if results and username_.startswith('chatroom'):
|
||||||
lock.release()
|
results = parser_chatroom_message(results)
|
||||||
return parser_chatroom_message(result) if username_.__contains__('@chatroom') else result
|
|
||||||
# result.sort(key=lambda x: x[5])
|
return results
|
||||||
# return self.add_sender(result)
|
except Exception as e:
|
||||||
|
logger.error(f"获取聊天记录失败: {e}\n{traceback.format_exc()}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def batch_insert_messages(self, messages_data: List[Dict[str, Any]]) -> bool:
|
||||||
|
"""
|
||||||
|
批量插入消息数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages_data: 消息数据列表,每个字典包含一条消息的所有字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否成功插入
|
||||||
|
"""
|
||||||
|
if not self.open_flag or not messages_data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 构建插入SQL
|
||||||
|
sql = '''
|
||||||
|
INSERT INTO MSG (
|
||||||
|
MsgId, TalkerId, Type, SubType, IsSender, CreateTime,
|
||||||
|
Status, StrContent, StrTalker, MsgSvrID, BytesExtra,
|
||||||
|
CompressContent, DisplayContent
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
'''
|
||||||
|
|
||||||
|
# 准备参数列表
|
||||||
|
params_list = [
|
||||||
|
(
|
||||||
|
msg.get('MsgId', ''),
|
||||||
|
msg.get('TalkerId', ''),
|
||||||
|
msg.get('Type', 0),
|
||||||
|
msg.get('SubType', 0),
|
||||||
|
msg.get('IsSender', 0),
|
||||||
|
msg.get('CreateTime', int(time.time())),
|
||||||
|
msg.get('Status', 0),
|
||||||
|
msg.get('StrContent', ''),
|
||||||
|
msg.get('StrTalker', ''),
|
||||||
|
msg.get('MsgSvrID', ''),
|
||||||
|
msg.get('BytesExtra', None),
|
||||||
|
msg.get('CompressContent', None),
|
||||||
|
msg.get('DisplayContent', None)
|
||||||
|
)
|
||||||
|
for msg in messages_data
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
db_pool.execute_batch(db_path, sql, params_list)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量插入消息失败: {e}\n{traceback.format_exc()}")
|
||||||
|
return False
|
||||||
|
|
||||||
def get_messages_all(self, time_range=None):
|
def get_messages_all(self, time_range=None):
|
||||||
if time_range:
|
"""
|
||||||
start_time, end_time = convert_to_timestamp(time_range)
|
获取所有聊天记录
|
||||||
sql = f'''
|
@param time_range:
|
||||||
select localId,TalkerId,Type,SubType,IsSender,CreateTime,Status,StrContent,strftime('%Y-%m-%d %H:%M:%S',CreateTime,'unixepoch','localtime') as StrTime,MsgSvrID,BytesExtra,StrTalker,Reserved1,CompressContent
|
@return:
|
||||||
from MSG
|
"""
|
||||||
{'WHERE CreateTime>' + str(start_time) + ' AND CreateTime<' + str(end_time) if time_range else ''}
|
|
||||||
order by CreateTime
|
|
||||||
'''
|
|
||||||
if not self.open_flag:
|
if not self.open_flag:
|
||||||
return None
|
return []
|
||||||
|
|
||||||
|
begin_time, end_time = convert_to_timestamp(time_range)
|
||||||
|
|
||||||
|
sql = '''
|
||||||
|
SELECT
|
||||||
|
localId,
|
||||||
|
TalkerId,
|
||||||
|
Type,
|
||||||
|
SubType,
|
||||||
|
IsSender,
|
||||||
|
CreateTime,
|
||||||
|
Status,
|
||||||
|
StrContent,
|
||||||
|
strftime('%Y-%m-%d %H:%M:%S', datetime(CreateTime, 'unixepoch', 'localtime')),
|
||||||
|
MsgSvrID,
|
||||||
|
BytesExtra,
|
||||||
|
CompressContent,
|
||||||
|
DisplayContent,
|
||||||
|
StrTalker
|
||||||
|
FROM MSG
|
||||||
|
WHERE (? = 0 OR CreateTime >= ?)
|
||||||
|
AND (? = 0 OR CreateTime <= ?)
|
||||||
|
ORDER BY CreateTime DESC
|
||||||
|
'''
|
||||||
|
|
||||||
|
params = (begin_time, begin_time, end_time, end_time)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lock.acquire(True)
|
return db_pool.execute_query(db_path, sql, params)
|
||||||
self.cursor.execute(sql)
|
except Exception as e:
|
||||||
result = self.cursor.fetchall()
|
logger.error(f"获取所有聊天记录失败: {e}\n{traceback.format_exc()}")
|
||||||
finally:
|
return []
|
||||||
lock.release()
|
|
||||||
result.sort(key=lambda x: x[5])
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_messages_group_by_day(
|
def get_messages_group_by_day(
|
||||||
self,
|
self,
|
||||||
@ -865,13 +950,8 @@ class Msg:
|
|||||||
return sum_type_1 + sum_type_49
|
return sum_type_1 + sum_type_49
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.open_flag:
|
"""关闭数据库连接,不再需要显式关闭,由连接池管理"""
|
||||||
try:
|
|
||||||
lock.acquire(True)
|
|
||||||
self.open_flag = False
|
self.open_flag = False
|
||||||
self.DB.close()
|
|
||||||
finally:
|
|
||||||
lock.release()
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.close()
|
self.close()
|
||||||
|
Loading…
Reference in New Issue
Block a user