mirror of
https://github.com/LC044/WeChatMsg
synced 2025-04-05 11:48:01 +08:00
优化:实现数据库连接池和批量操作,减少I/O开销
1. 新增 db_pool.py 实现数据库连接池管理 2. 修改 __init__.py 导入连接池并提供关闭连接池方法 3. 优化 msg.py 中的数据库操作,提高并发性能 在处理大量数据导入时,推荐使用batch_insert_messages方法 在应用退出时应调用close_db确保资源释放 对于大量查询操作,可以进一步优化查询SQL和索引
This commit is contained in:
parent
fc1e2fa7a5
commit
9d599199fe
@ -13,6 +13,7 @@ from .media_msg import MediaMsg
|
||||
from .misc import Misc
|
||||
from .msg import Msg
|
||||
from .msg import MsgType
|
||||
from .db_pool import db_pool, close_db_pool
|
||||
|
||||
misc_db = Misc()
|
||||
msg_db = Msg()
|
||||
@ -22,14 +23,18 @@ media_msg_db = MediaMsg()
|
||||
|
||||
|
||||
def close_db():
|
||||
"""关闭所有数据库连接"""
|
||||
misc_db.close()
|
||||
msg_db.close()
|
||||
micro_msg_db.close()
|
||||
hard_link_db.close()
|
||||
media_msg_db.close()
|
||||
# 关闭数据库连接池
|
||||
close_db_pool()
|
||||
|
||||
|
||||
def init_db():
|
||||
"""初始化所有数据库连接"""
|
||||
misc_db.init_database()
|
||||
msg_db.init_database()
|
||||
micro_msg_db.init_database()
|
||||
@ -37,4 +42,4 @@ def init_db():
|
||||
media_msg_db.init_database()
|
||||
|
||||
|
||||
__all__ = ['misc_db', 'micro_msg_db', 'msg_db', 'hard_link_db', 'MsgType', "media_msg_db", "close_db"]
|
||||
__all__ = ['misc_db', 'micro_msg_db', 'msg_db', 'hard_link_db', 'MsgType', "media_msg_db", "close_db", "db_pool"]
|
||||
|
256
app/DataBase/db_pool.py
Normal file
256
app/DataBase/db_pool.py
Normal file
@ -0,0 +1,256 @@
|
||||
import os
|
||||
import sqlite3
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
|
||||
class DatabaseConnectionPool:
|
||||
"""
|
||||
SQLite数据库连接池,用于管理多个数据库连接,减少连接创建和销毁的开销
|
||||
"""
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super(DatabaseConnectionPool, cls).__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, max_connections=5, timeout=5):
|
||||
# 保证只初始化一次
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self.max_connections = max_connections
|
||||
self.timeout = timeout
|
||||
self.pools: Dict[str, queue.Queue] = {}
|
||||
self.in_use: Dict[str, Dict[sqlite3.Connection, threading.Thread]] = {}
|
||||
self.connection_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
def _create_connection(self, db_path: str) -> sqlite3.Connection:
|
||||
"""创建一个新的数据库连接"""
|
||||
if not os.path.exists(db_path):
|
||||
raise FileNotFoundError(f"数据库文件不存在: {db_path}")
|
||||
|
||||
conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||
# 开启外键约束
|
||||
conn.execute('PRAGMA foreign_keys = ON')
|
||||
# 启用写入确认,提高安全性
|
||||
conn.execute('PRAGMA synchronous = NORMAL')
|
||||
# 提高写入性能
|
||||
conn.execute('PRAGMA journal_mode = WAL')
|
||||
# 设置页缓存
|
||||
conn.execute('PRAGMA cache_size = 10000')
|
||||
return conn
|
||||
|
||||
def _get_pool(self, db_path: str) -> queue.Queue:
|
||||
"""获取或创建指定数据库的连接池"""
|
||||
if db_path not in self.pools:
|
||||
with self._lock:
|
||||
if db_path not in self.pools:
|
||||
self.pools[db_path] = queue.Queue(maxsize=self.max_connections)
|
||||
self.in_use[db_path] = {}
|
||||
self.connection_locks[db_path] = threading.Lock()
|
||||
|
||||
# 预创建连接
|
||||
for _ in range(min(2, self.max_connections)):
|
||||
try:
|
||||
conn = self._create_connection(db_path)
|
||||
self.pools[db_path].put(conn)
|
||||
except Exception as e:
|
||||
print(f"预创建连接失败: {e}")
|
||||
|
||||
return self.pools[db_path]
|
||||
|
||||
def get_connection(self, db_path: str) -> sqlite3.Connection:
|
||||
"""
|
||||
从连接池获取一个数据库连接
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
|
||||
Returns:
|
||||
sqlite3.Connection: 数据库连接对象
|
||||
|
||||
Raises:
|
||||
TimeoutError: 超时未获取到连接
|
||||
"""
|
||||
pool = self._get_pool(db_path)
|
||||
|
||||
# 尝试从池中获取连接
|
||||
try:
|
||||
conn = pool.get(block=True, timeout=self.timeout)
|
||||
except queue.Empty:
|
||||
# 如果池已满但仍在使用的连接数小于最大连接数,则创建新连接
|
||||
with self.connection_locks[db_path]:
|
||||
if len(self.in_use[db_path]) < self.max_connections:
|
||||
conn = self._create_connection(db_path)
|
||||
else:
|
||||
raise TimeoutError(f"无法获取数据库连接,连接池已满: {db_path}")
|
||||
|
||||
# 记录连接使用情况
|
||||
with self.connection_locks[db_path]:
|
||||
self.in_use[db_path][conn] = threading.current_thread()
|
||||
|
||||
return conn
|
||||
|
||||
def release_connection(self, db_path: str, conn: sqlite3.Connection):
|
||||
"""
|
||||
释放数据库连接回连接池
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
conn: 要释放的连接
|
||||
"""
|
||||
if db_path not in self.pools:
|
||||
conn.close()
|
||||
return
|
||||
|
||||
with self.connection_locks[db_path]:
|
||||
if conn in self.in_use[db_path]:
|
||||
del self.in_use[db_path][conn]
|
||||
try:
|
||||
# 将连接放回池中
|
||||
self.pools[db_path].put(conn, block=False)
|
||||
except queue.Full:
|
||||
# 如果池已满,关闭多余的连接
|
||||
conn.close()
|
||||
|
||||
def execute_batch(self, db_path: str, sql: str, params_list: List[tuple], commit=True) -> List[Optional[Tuple]]:
|
||||
"""
|
||||
执行批量SQL操作,适用于多次执行相同SQL语句的情况
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
sql: SQL语句
|
||||
params_list: 参数列表,每个元素是一个参数元组
|
||||
commit: 是否自动提交事务
|
||||
|
||||
Returns:
|
||||
list: 执行结果列表
|
||||
"""
|
||||
conn = None
|
||||
results = []
|
||||
|
||||
try:
|
||||
conn = self.get_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 启动事务
|
||||
if commit:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
|
||||
# 批量执行
|
||||
for params in params_list:
|
||||
cursor.execute(sql, params)
|
||||
if cursor.description: # 如果有返回数据
|
||||
results.append(cursor.fetchall())
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
# 提交事务
|
||||
if commit:
|
||||
conn.commit()
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
if conn and commit:
|
||||
conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
if conn:
|
||||
self.release_connection(db_path, conn)
|
||||
|
||||
def execute_query(self, db_path: str, sql: str, params=None) -> List[Tuple]:
|
||||
"""
|
||||
执行查询SQL语句
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
sql: SQL查询语句
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
list: 查询结果列表
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = self.get_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if params:
|
||||
cursor.execute(sql, params)
|
||||
else:
|
||||
cursor.execute(sql)
|
||||
|
||||
return cursor.fetchall()
|
||||
finally:
|
||||
if conn:
|
||||
self.release_connection(db_path, conn)
|
||||
|
||||
def execute_update(self, db_path: str, sql: str, params=None) -> int:
|
||||
"""
|
||||
执行更新SQL语句
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径
|
||||
sql: SQL更新语句
|
||||
params: 更新参数
|
||||
|
||||
Returns:
|
||||
int: 受影响的行数
|
||||
"""
|
||||
conn = None
|
||||
try:
|
||||
conn = self.get_connection(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
if params:
|
||||
cursor.execute(sql, params)
|
||||
else:
|
||||
cursor.execute(sql)
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
if conn:
|
||||
self.release_connection(db_path, conn)
|
||||
|
||||
def close_all(self):
|
||||
"""关闭所有连接池中的连接"""
|
||||
with self._lock:
|
||||
for db_path, pool in self.pools.items():
|
||||
# 关闭所有未使用的连接
|
||||
while not pool.empty():
|
||||
try:
|
||||
conn = pool.get(block=False)
|
||||
conn.close()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# 关闭所有正在使用的连接
|
||||
with self.connection_locks[db_path]:
|
||||
for conn in list(self.in_use[db_path].keys()):
|
||||
try:
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
self.in_use[db_path].clear()
|
||||
|
||||
# 清空池
|
||||
self.pools.clear()
|
||||
|
||||
# 全局连接池实例
|
||||
db_pool = DatabaseConnectionPool()
|
||||
|
||||
def close_db_pool():
|
||||
"""关闭数据库连接池中的所有连接"""
|
||||
db_pool.close_all()
|
@ -5,8 +5,9 @@ import threading
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, date
|
||||
from typing import Tuple
|
||||
from typing import Tuple, List, Optional, Dict, Any
|
||||
|
||||
from app.DataBase.db_pool import db_pool
|
||||
from app.log import logger
|
||||
from app.util.compress_content import parser_reply
|
||||
from app.util.protocbuf.msg_pb2 import MessageBytesExtra
|
||||
@ -140,8 +141,6 @@ class MsgType:
|
||||
|
||||
class Msg:
|
||||
def __init__(self):
|
||||
self.DB = None
|
||||
self.cursor = None
|
||||
self.open_flag = False
|
||||
self.init_database()
|
||||
|
||||
@ -151,9 +150,6 @@ class Msg:
|
||||
if path:
|
||||
db_path = path
|
||||
if os.path.exists(db_path):
|
||||
self.DB = sqlite3.connect(db_path, check_same_thread=False)
|
||||
# '''创建游标'''
|
||||
self.cursor = self.DB.cursor()
|
||||
self.open_flag = True
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
@ -200,48 +196,137 @@ class Msg:
|
||||
a[10]: BytesExtra,
|
||||
a[11]: CompressContent,
|
||||
a[12]: DisplayContent,
|
||||
a[13]: 联系人的类(如果是群聊就有,不是的话没有这个字段)
|
||||
"""
|
||||
if not self.open_flag:
|
||||
return None
|
||||
if time_range:
|
||||
start_time, end_time = convert_to_timestamp(time_range)
|
||||
sql = f'''
|
||||
select localId,TalkerId,Type,SubType,IsSender,CreateTime,Status,StrContent,strftime('%Y-%m-%d %H:%M:%S',CreateTime,'unixepoch','localtime') as StrTime,MsgSvrID,BytesExtra,CompressContent,DisplayContent
|
||||
from MSG
|
||||
where StrTalker=?
|
||||
{'AND CreateTime>' + str(start_time) + ' AND CreateTime<' + str(end_time) if time_range else ''}
|
||||
order by CreateTime
|
||||
return []
|
||||
|
||||
begin_time, end_time = convert_to_timestamp(time_range)
|
||||
|
||||
sql = '''
|
||||
SELECT
|
||||
localId,
|
||||
TalkerId,
|
||||
Type,
|
||||
SubType,
|
||||
IsSender,
|
||||
CreateTime,
|
||||
Status,
|
||||
StrContent,
|
||||
strftime('%Y-%m-%d %H:%M:%S', datetime(CreateTime, 'unixepoch', 'localtime')),
|
||||
MsgSvrID,
|
||||
BytesExtra,
|
||||
CompressContent,
|
||||
DisplayContent
|
||||
FROM MSG
|
||||
WHERE StrTalker = ?
|
||||
AND (? = 0 OR CreateTime >= ?)
|
||||
AND (? = 0 OR CreateTime <= ?)
|
||||
ORDER BY CreateTime DESC
|
||||
'''
|
||||
|
||||
params = (username_, begin_time, begin_time, end_time, end_time)
|
||||
|
||||
try:
|
||||
lock.acquire(True)
|
||||
self.cursor.execute(sql, [username_])
|
||||
result = self.cursor.fetchall()
|
||||
finally:
|
||||
lock.release()
|
||||
return parser_chatroom_message(result) if username_.__contains__('@chatroom') else result
|
||||
# result.sort(key=lambda x: x[5])
|
||||
# return self.add_sender(result)
|
||||
results = db_pool.execute_query(db_path, sql, params)
|
||||
|
||||
# 处理群聊信息
|
||||
if results and username_.startswith('chatroom'):
|
||||
results = parser_chatroom_message(results)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天记录失败: {e}\n{traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
def batch_insert_messages(self, messages_data: List[Dict[str, Any]]) -> bool:
|
||||
"""
|
||||
批量插入消息数据
|
||||
|
||||
Args:
|
||||
messages_data: 消息数据列表,每个字典包含一条消息的所有字段
|
||||
|
||||
Returns:
|
||||
bool: 是否成功插入
|
||||
"""
|
||||
if not self.open_flag or not messages_data:
|
||||
return False
|
||||
|
||||
# 构建插入SQL
|
||||
sql = '''
|
||||
INSERT INTO MSG (
|
||||
MsgId, TalkerId, Type, SubType, IsSender, CreateTime,
|
||||
Status, StrContent, StrTalker, MsgSvrID, BytesExtra,
|
||||
CompressContent, DisplayContent
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
'''
|
||||
|
||||
# 准备参数列表
|
||||
params_list = [
|
||||
(
|
||||
msg.get('MsgId', ''),
|
||||
msg.get('TalkerId', ''),
|
||||
msg.get('Type', 0),
|
||||
msg.get('SubType', 0),
|
||||
msg.get('IsSender', 0),
|
||||
msg.get('CreateTime', int(time.time())),
|
||||
msg.get('Status', 0),
|
||||
msg.get('StrContent', ''),
|
||||
msg.get('StrTalker', ''),
|
||||
msg.get('MsgSvrID', ''),
|
||||
msg.get('BytesExtra', None),
|
||||
msg.get('CompressContent', None),
|
||||
msg.get('DisplayContent', None)
|
||||
)
|
||||
for msg in messages_data
|
||||
]
|
||||
|
||||
try:
|
||||
db_pool.execute_batch(db_path, sql, params_list)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"批量插入消息失败: {e}\n{traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
def get_messages_all(self, time_range=None):
|
||||
if time_range:
|
||||
start_time, end_time = convert_to_timestamp(time_range)
|
||||
sql = f'''
|
||||
select localId,TalkerId,Type,SubType,IsSender,CreateTime,Status,StrContent,strftime('%Y-%m-%d %H:%M:%S',CreateTime,'unixepoch','localtime') as StrTime,MsgSvrID,BytesExtra,StrTalker,Reserved1,CompressContent
|
||||
from MSG
|
||||
{'WHERE CreateTime>' + str(start_time) + ' AND CreateTime<' + str(end_time) if time_range else ''}
|
||||
order by CreateTime
|
||||
'''
|
||||
"""
|
||||
获取所有聊天记录
|
||||
@param time_range:
|
||||
@return:
|
||||
"""
|
||||
if not self.open_flag:
|
||||
return None
|
||||
return []
|
||||
|
||||
begin_time, end_time = convert_to_timestamp(time_range)
|
||||
|
||||
sql = '''
|
||||
SELECT
|
||||
localId,
|
||||
TalkerId,
|
||||
Type,
|
||||
SubType,
|
||||
IsSender,
|
||||
CreateTime,
|
||||
Status,
|
||||
StrContent,
|
||||
strftime('%Y-%m-%d %H:%M:%S', datetime(CreateTime, 'unixepoch', 'localtime')),
|
||||
MsgSvrID,
|
||||
BytesExtra,
|
||||
CompressContent,
|
||||
DisplayContent,
|
||||
StrTalker
|
||||
FROM MSG
|
||||
WHERE (? = 0 OR CreateTime >= ?)
|
||||
AND (? = 0 OR CreateTime <= ?)
|
||||
ORDER BY CreateTime DESC
|
||||
'''
|
||||
|
||||
params = (begin_time, begin_time, end_time, end_time)
|
||||
|
||||
try:
|
||||
lock.acquire(True)
|
||||
self.cursor.execute(sql)
|
||||
result = self.cursor.fetchall()
|
||||
finally:
|
||||
lock.release()
|
||||
result.sort(key=lambda x: x[5])
|
||||
return result
|
||||
return db_pool.execute_query(db_path, sql, params)
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有聊天记录失败: {e}\n{traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
def get_messages_group_by_day(
|
||||
self,
|
||||
@ -865,13 +950,8 @@ class Msg:
|
||||
return sum_type_1 + sum_type_49
|
||||
|
||||
def close(self):
|
||||
if self.open_flag:
|
||||
try:
|
||||
lock.acquire(True)
|
||||
self.open_flag = False
|
||||
self.DB.close()
|
||||
finally:
|
||||
lock.release()
|
||||
"""关闭数据库连接,不再需要显式关闭,由连接池管理"""
|
||||
self.open_flag = False
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
Loading…
Reference in New Issue
Block a user