优化:实现数据库连接池和批量操作,减少I/O开销

1. 新增 db_pool.py 实现数据库连接池管理
2. 修改 __init__.py 导入连接池并提供关闭连接池方法
3. 优化 msg.py 中的数据库操作,提高并发性能

在处理大量数据导入时,推荐使用batch_insert_messages方法
在应用退出时应调用close_db确保资源释放
对于大量查询操作,可以进一步优化查询SQL和索引
This commit is contained in:
Lecheeel 2025-03-27 00:02:33 +08:00
parent fc1e2fa7a5
commit 9d599199fe
3 changed files with 389 additions and 48 deletions

View File

@ -13,6 +13,7 @@ from .media_msg import MediaMsg
from .misc import Misc from .misc import Misc
from .msg import Msg from .msg import Msg
from .msg import MsgType from .msg import MsgType
from .db_pool import db_pool, close_db_pool
misc_db = Misc() misc_db = Misc()
msg_db = Msg() msg_db = Msg()
@ -22,14 +23,18 @@ media_msg_db = MediaMsg()
def close_db(): def close_db():
"""关闭所有数据库连接"""
misc_db.close() misc_db.close()
msg_db.close() msg_db.close()
micro_msg_db.close() micro_msg_db.close()
hard_link_db.close() hard_link_db.close()
media_msg_db.close() media_msg_db.close()
# 关闭数据库连接池
close_db_pool()
def init_db(): def init_db():
"""初始化所有数据库连接"""
misc_db.init_database() misc_db.init_database()
msg_db.init_database() msg_db.init_database()
micro_msg_db.init_database() micro_msg_db.init_database()
@ -37,4 +42,4 @@ def init_db():
media_msg_db.init_database() media_msg_db.init_database()
__all__ = ['misc_db', 'micro_msg_db', 'msg_db', 'hard_link_db', 'MsgType', "media_msg_db", "close_db"] __all__ = ['misc_db', 'micro_msg_db', 'msg_db', 'hard_link_db', 'MsgType', "media_msg_db", "close_db", "db_pool"]

256
app/DataBase/db_pool.py Normal file
View File

@ -0,0 +1,256 @@
import os
import sqlite3
import threading
import queue
import time
from typing import Dict, Optional, List, Tuple
class DatabaseConnectionPool:
"""
SQLite数据库连接池用于管理多个数据库连接减少连接创建和销毁的开销
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
with cls._lock:
if cls._instance is None:
cls._instance = super(DatabaseConnectionPool, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, max_connections=5, timeout=5):
# 保证只初始化一次
if self._initialized:
return
self._initialized = True
self.max_connections = max_connections
self.timeout = timeout
self.pools: Dict[str, queue.Queue] = {}
self.in_use: Dict[str, Dict[sqlite3.Connection, threading.Thread]] = {}
self.connection_locks: Dict[str, threading.Lock] = {}
def _create_connection(self, db_path: str) -> sqlite3.Connection:
"""创建一个新的数据库连接"""
if not os.path.exists(db_path):
raise FileNotFoundError(f"数据库文件不存在: {db_path}")
conn = sqlite3.connect(db_path, check_same_thread=False)
# 开启外键约束
conn.execute('PRAGMA foreign_keys = ON')
# 启用写入确认,提高安全性
conn.execute('PRAGMA synchronous = NORMAL')
# 提高写入性能
conn.execute('PRAGMA journal_mode = WAL')
# 设置页缓存
conn.execute('PRAGMA cache_size = 10000')
return conn
def _get_pool(self, db_path: str) -> queue.Queue:
"""获取或创建指定数据库的连接池"""
if db_path not in self.pools:
with self._lock:
if db_path not in self.pools:
self.pools[db_path] = queue.Queue(maxsize=self.max_connections)
self.in_use[db_path] = {}
self.connection_locks[db_path] = threading.Lock()
# 预创建连接
for _ in range(min(2, self.max_connections)):
try:
conn = self._create_connection(db_path)
self.pools[db_path].put(conn)
except Exception as e:
print(f"预创建连接失败: {e}")
return self.pools[db_path]
def get_connection(self, db_path: str) -> sqlite3.Connection:
"""
从连接池获取一个数据库连接
Args:
db_path: 数据库文件路径
Returns:
sqlite3.Connection: 数据库连接对象
Raises:
TimeoutError: 超时未获取到连接
"""
pool = self._get_pool(db_path)
# 尝试从池中获取连接
try:
conn = pool.get(block=True, timeout=self.timeout)
except queue.Empty:
# 如果池已满但仍在使用的连接数小于最大连接数,则创建新连接
with self.connection_locks[db_path]:
if len(self.in_use[db_path]) < self.max_connections:
conn = self._create_connection(db_path)
else:
raise TimeoutError(f"无法获取数据库连接,连接池已满: {db_path}")
# 记录连接使用情况
with self.connection_locks[db_path]:
self.in_use[db_path][conn] = threading.current_thread()
return conn
def release_connection(self, db_path: str, conn: sqlite3.Connection):
"""
释放数据库连接回连接池
Args:
db_path: 数据库文件路径
conn: 要释放的连接
"""
if db_path not in self.pools:
conn.close()
return
with self.connection_locks[db_path]:
if conn in self.in_use[db_path]:
del self.in_use[db_path][conn]
try:
# 将连接放回池中
self.pools[db_path].put(conn, block=False)
except queue.Full:
# 如果池已满,关闭多余的连接
conn.close()
def execute_batch(self, db_path: str, sql: str, params_list: List[tuple], commit=True) -> List[Optional[Tuple]]:
"""
执行批量SQL操作适用于多次执行相同SQL语句的情况
Args:
db_path: 数据库文件路径
sql: SQL语句
params_list: 参数列表每个元素是一个参数元组
commit: 是否自动提交事务
Returns:
list: 执行结果列表
"""
conn = None
results = []
try:
conn = self.get_connection(db_path)
cursor = conn.cursor()
# 启动事务
if commit:
conn.execute("BEGIN TRANSACTION")
# 批量执行
for params in params_list:
cursor.execute(sql, params)
if cursor.description: # 如果有返回数据
results.append(cursor.fetchall())
else:
results.append(None)
# 提交事务
if commit:
conn.commit()
return results
except Exception as e:
if conn and commit:
conn.rollback()
raise e
finally:
if conn:
self.release_connection(db_path, conn)
def execute_query(self, db_path: str, sql: str, params=None) -> List[Tuple]:
"""
执行查询SQL语句
Args:
db_path: 数据库文件路径
sql: SQL查询语句
params: 查询参数
Returns:
list: 查询结果列表
"""
conn = None
try:
conn = self.get_connection(db_path)
cursor = conn.cursor()
if params:
cursor.execute(sql, params)
else:
cursor.execute(sql)
return cursor.fetchall()
finally:
if conn:
self.release_connection(db_path, conn)
def execute_update(self, db_path: str, sql: str, params=None) -> int:
"""
执行更新SQL语句
Args:
db_path: 数据库文件路径
sql: SQL更新语句
params: 更新参数
Returns:
int: 受影响的行数
"""
conn = None
try:
conn = self.get_connection(db_path)
cursor = conn.cursor()
if params:
cursor.execute(sql, params)
else:
cursor.execute(sql)
conn.commit()
return cursor.rowcount
except Exception as e:
if conn:
conn.rollback()
raise e
finally:
if conn:
self.release_connection(db_path, conn)
def close_all(self):
"""关闭所有连接池中的连接"""
with self._lock:
for db_path, pool in self.pools.items():
# 关闭所有未使用的连接
while not pool.empty():
try:
conn = pool.get(block=False)
conn.close()
except queue.Empty:
break
# 关闭所有正在使用的连接
with self.connection_locks[db_path]:
for conn in list(self.in_use[db_path].keys()):
try:
conn.close()
except:
pass
self.in_use[db_path].clear()
# 清空池
self.pools.clear()
# 全局连接池实例
db_pool = DatabaseConnectionPool()
def close_db_pool():
"""关闭数据库连接池中的所有连接"""
db_pool.close_all()

View File

@ -5,8 +5,9 @@ import threading
import traceback import traceback
from collections import defaultdict from collections import defaultdict
from datetime import datetime, date from datetime import datetime, date
from typing import Tuple from typing import Tuple, List, Optional, Dict, Any
from app.DataBase.db_pool import db_pool
from app.log import logger from app.log import logger
from app.util.compress_content import parser_reply from app.util.compress_content import parser_reply
from app.util.protocbuf.msg_pb2 import MessageBytesExtra from app.util.protocbuf.msg_pb2 import MessageBytesExtra
@ -140,8 +141,6 @@ class MsgType:
class Msg: class Msg:
def __init__(self): def __init__(self):
self.DB = None
self.cursor = None
self.open_flag = False self.open_flag = False
self.init_database() self.init_database()
@ -151,9 +150,6 @@ class Msg:
if path: if path:
db_path = path db_path = path
if os.path.exists(db_path): if os.path.exists(db_path):
self.DB = sqlite3.connect(db_path, check_same_thread=False)
# '''创建游标'''
self.cursor = self.DB.cursor()
self.open_flag = True self.open_flag = True
if lock.locked(): if lock.locked():
lock.release() lock.release()
@ -200,48 +196,137 @@ class Msg:
a[10]: BytesExtra, a[10]: BytesExtra,
a[11]: CompressContent, a[11]: CompressContent,
a[12]: DisplayContent, a[12]: DisplayContent,
a[13]: 联系人的类如果是群聊就有不是的话没有这个字段
""" """
if not self.open_flag: if not self.open_flag:
return None return []
if time_range:
start_time, end_time = convert_to_timestamp(time_range) begin_time, end_time = convert_to_timestamp(time_range)
sql = f'''
select localId,TalkerId,Type,SubType,IsSender,CreateTime,Status,StrContent,strftime('%Y-%m-%d %H:%M:%S',CreateTime,'unixepoch','localtime') as StrTime,MsgSvrID,BytesExtra,CompressContent,DisplayContent sql = '''
from MSG SELECT
where StrTalker=? localId,
{'AND CreateTime>' + str(start_time) + ' AND CreateTime<' + str(end_time) if time_range else ''} TalkerId,
order by CreateTime Type,
SubType,
IsSender,
CreateTime,
Status,
StrContent,
strftime('%Y-%m-%d %H:%M:%S', datetime(CreateTime, 'unixepoch', 'localtime')),
MsgSvrID,
BytesExtra,
CompressContent,
DisplayContent
FROM MSG
WHERE StrTalker = ?
AND (? = 0 OR CreateTime >= ?)
AND (? = 0 OR CreateTime <= ?)
ORDER BY CreateTime DESC
''' '''
params = (username_, begin_time, begin_time, end_time, end_time)
try: try:
lock.acquire(True) results = db_pool.execute_query(db_path, sql, params)
self.cursor.execute(sql, [username_])
result = self.cursor.fetchall() # 处理群聊信息
finally: if results and username_.startswith('chatroom'):
lock.release() results = parser_chatroom_message(results)
return parser_chatroom_message(result) if username_.__contains__('@chatroom') else result
# result.sort(key=lambda x: x[5]) return results
# return self.add_sender(result) except Exception as e:
logger.error(f"获取聊天记录失败: {e}\n{traceback.format_exc()}")
return []
def batch_insert_messages(self, messages_data: List[Dict[str, Any]]) -> bool:
"""
批量插入消息数据
Args:
messages_data: 消息数据列表每个字典包含一条消息的所有字段
Returns:
bool: 是否成功插入
"""
if not self.open_flag or not messages_data:
return False
# 构建插入SQL
sql = '''
INSERT INTO MSG (
MsgId, TalkerId, Type, SubType, IsSender, CreateTime,
Status, StrContent, StrTalker, MsgSvrID, BytesExtra,
CompressContent, DisplayContent
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
'''
# 准备参数列表
params_list = [
(
msg.get('MsgId', ''),
msg.get('TalkerId', ''),
msg.get('Type', 0),
msg.get('SubType', 0),
msg.get('IsSender', 0),
msg.get('CreateTime', int(time.time())),
msg.get('Status', 0),
msg.get('StrContent', ''),
msg.get('StrTalker', ''),
msg.get('MsgSvrID', ''),
msg.get('BytesExtra', None),
msg.get('CompressContent', None),
msg.get('DisplayContent', None)
)
for msg in messages_data
]
try:
db_pool.execute_batch(db_path, sql, params_list)
return True
except Exception as e:
logger.error(f"批量插入消息失败: {e}\n{traceback.format_exc()}")
return False
def get_messages_all(self, time_range=None): def get_messages_all(self, time_range=None):
if time_range: """
start_time, end_time = convert_to_timestamp(time_range) 获取所有聊天记录
sql = f''' @param time_range:
select localId,TalkerId,Type,SubType,IsSender,CreateTime,Status,StrContent,strftime('%Y-%m-%d %H:%M:%S',CreateTime,'unixepoch','localtime') as StrTime,MsgSvrID,BytesExtra,StrTalker,Reserved1,CompressContent @return:
from MSG """
{'WHERE CreateTime>' + str(start_time) + ' AND CreateTime<' + str(end_time) if time_range else ''}
order by CreateTime
'''
if not self.open_flag: if not self.open_flag:
return None return []
begin_time, end_time = convert_to_timestamp(time_range)
sql = '''
SELECT
localId,
TalkerId,
Type,
SubType,
IsSender,
CreateTime,
Status,
StrContent,
strftime('%Y-%m-%d %H:%M:%S', datetime(CreateTime, 'unixepoch', 'localtime')),
MsgSvrID,
BytesExtra,
CompressContent,
DisplayContent,
StrTalker
FROM MSG
WHERE (? = 0 OR CreateTime >= ?)
AND (? = 0 OR CreateTime <= ?)
ORDER BY CreateTime DESC
'''
params = (begin_time, begin_time, end_time, end_time)
try: try:
lock.acquire(True) return db_pool.execute_query(db_path, sql, params)
self.cursor.execute(sql) except Exception as e:
result = self.cursor.fetchall() logger.error(f"获取所有聊天记录失败: {e}\n{traceback.format_exc()}")
finally: return []
lock.release()
result.sort(key=lambda x: x[5])
return result
def get_messages_group_by_day( def get_messages_group_by_day(
self, self,
@ -865,13 +950,8 @@ class Msg:
return sum_type_1 + sum_type_49 return sum_type_1 + sum_type_49
def close(self): def close(self):
if self.open_flag: """关闭数据库连接,不再需要显式关闭,由连接池管理"""
try:
lock.acquire(True)
self.open_flag = False self.open_flag = False
self.DB.close()
finally:
lock.release()
def __del__(self): def __del__(self):
self.close() self.close()