优化:实现数据库连接池和批量操作,减少I/O开销

1. 新增 db_pool.py 实现数据库连接池管理
2. 修改 __init__.py 导入连接池并提供关闭连接池方法
3. 优化 msg.py 中的数据库操作,提高并发性能

在处理大量数据导入时,推荐使用batch_insert_messages方法
在应用退出时应调用close_db确保资源释放
对于大量查询操作,可以进一步优化查询SQL和索引
This commit is contained in:
Lecheeel 2025-03-27 00:02:33 +08:00
parent fc1e2fa7a5
commit 9d599199fe
3 changed files with 389 additions and 48 deletions

View File

@ -13,6 +13,7 @@ from .media_msg import MediaMsg
from .misc import Misc
from .msg import Msg
from .msg import MsgType
from .db_pool import db_pool, close_db_pool
misc_db = Misc()
msg_db = Msg()
@ -22,14 +23,18 @@ media_msg_db = MediaMsg()
def close_db():
"""关闭所有数据库连接"""
misc_db.close()
msg_db.close()
micro_msg_db.close()
hard_link_db.close()
media_msg_db.close()
# 关闭数据库连接池
close_db_pool()
def init_db():
"""初始化所有数据库连接"""
misc_db.init_database()
msg_db.init_database()
micro_msg_db.init_database()
@ -37,4 +42,4 @@ def init_db():
media_msg_db.init_database()
__all__ = ['misc_db', 'micro_msg_db', 'msg_db', 'hard_link_db', 'MsgType', "media_msg_db", "close_db"]
__all__ = ['misc_db', 'micro_msg_db', 'msg_db', 'hard_link_db', 'MsgType', "media_msg_db", "close_db", "db_pool"]

256
app/DataBase/db_pool.py Normal file
View File

@ -0,0 +1,256 @@
import os
import sqlite3
import threading
import queue
import time
from typing import Dict, Optional, List, Tuple
class DatabaseConnectionPool:
"""
SQLite数据库连接池用于管理多个数据库连接减少连接创建和销毁的开销
"""
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
with cls._lock:
if cls._instance is None:
cls._instance = super(DatabaseConnectionPool, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self, max_connections=5, timeout=5):
# 保证只初始化一次
if self._initialized:
return
self._initialized = True
self.max_connections = max_connections
self.timeout = timeout
self.pools: Dict[str, queue.Queue] = {}
self.in_use: Dict[str, Dict[sqlite3.Connection, threading.Thread]] = {}
self.connection_locks: Dict[str, threading.Lock] = {}
def _create_connection(self, db_path: str) -> sqlite3.Connection:
"""创建一个新的数据库连接"""
if not os.path.exists(db_path):
raise FileNotFoundError(f"数据库文件不存在: {db_path}")
conn = sqlite3.connect(db_path, check_same_thread=False)
# 开启外键约束
conn.execute('PRAGMA foreign_keys = ON')
# 启用写入确认,提高安全性
conn.execute('PRAGMA synchronous = NORMAL')
# 提高写入性能
conn.execute('PRAGMA journal_mode = WAL')
# 设置页缓存
conn.execute('PRAGMA cache_size = 10000')
return conn
def _get_pool(self, db_path: str) -> queue.Queue:
"""获取或创建指定数据库的连接池"""
if db_path not in self.pools:
with self._lock:
if db_path not in self.pools:
self.pools[db_path] = queue.Queue(maxsize=self.max_connections)
self.in_use[db_path] = {}
self.connection_locks[db_path] = threading.Lock()
# 预创建连接
for _ in range(min(2, self.max_connections)):
try:
conn = self._create_connection(db_path)
self.pools[db_path].put(conn)
except Exception as e:
print(f"预创建连接失败: {e}")
return self.pools[db_path]
def get_connection(self, db_path: str) -> sqlite3.Connection:
"""
从连接池获取一个数据库连接
Args:
db_path: 数据库文件路径
Returns:
sqlite3.Connection: 数据库连接对象
Raises:
TimeoutError: 超时未获取到连接
"""
pool = self._get_pool(db_path)
# 尝试从池中获取连接
try:
conn = pool.get(block=True, timeout=self.timeout)
except queue.Empty:
# 如果池已满但仍在使用的连接数小于最大连接数,则创建新连接
with self.connection_locks[db_path]:
if len(self.in_use[db_path]) < self.max_connections:
conn = self._create_connection(db_path)
else:
raise TimeoutError(f"无法获取数据库连接,连接池已满: {db_path}")
# 记录连接使用情况
with self.connection_locks[db_path]:
self.in_use[db_path][conn] = threading.current_thread()
return conn
def release_connection(self, db_path: str, conn: sqlite3.Connection):
"""
释放数据库连接回连接池
Args:
db_path: 数据库文件路径
conn: 要释放的连接
"""
if db_path not in self.pools:
conn.close()
return
with self.connection_locks[db_path]:
if conn in self.in_use[db_path]:
del self.in_use[db_path][conn]
try:
# 将连接放回池中
self.pools[db_path].put(conn, block=False)
except queue.Full:
# 如果池已满,关闭多余的连接
conn.close()
def execute_batch(self, db_path: str, sql: str, params_list: List[tuple], commit=True) -> List[Optional[Tuple]]:
"""
执行批量SQL操作适用于多次执行相同SQL语句的情况
Args:
db_path: 数据库文件路径
sql: SQL语句
params_list: 参数列表每个元素是一个参数元组
commit: 是否自动提交事务
Returns:
list: 执行结果列表
"""
conn = None
results = []
try:
conn = self.get_connection(db_path)
cursor = conn.cursor()
# 启动事务
if commit:
conn.execute("BEGIN TRANSACTION")
# 批量执行
for params in params_list:
cursor.execute(sql, params)
if cursor.description: # 如果有返回数据
results.append(cursor.fetchall())
else:
results.append(None)
# 提交事务
if commit:
conn.commit()
return results
except Exception as e:
if conn and commit:
conn.rollback()
raise e
finally:
if conn:
self.release_connection(db_path, conn)
def execute_query(self, db_path: str, sql: str, params=None) -> List[Tuple]:
"""
执行查询SQL语句
Args:
db_path: 数据库文件路径
sql: SQL查询语句
params: 查询参数
Returns:
list: 查询结果列表
"""
conn = None
try:
conn = self.get_connection(db_path)
cursor = conn.cursor()
if params:
cursor.execute(sql, params)
else:
cursor.execute(sql)
return cursor.fetchall()
finally:
if conn:
self.release_connection(db_path, conn)
def execute_update(self, db_path: str, sql: str, params=None) -> int:
"""
执行更新SQL语句
Args:
db_path: 数据库文件路径
sql: SQL更新语句
params: 更新参数
Returns:
int: 受影响的行数
"""
conn = None
try:
conn = self.get_connection(db_path)
cursor = conn.cursor()
if params:
cursor.execute(sql, params)
else:
cursor.execute(sql)
conn.commit()
return cursor.rowcount
except Exception as e:
if conn:
conn.rollback()
raise e
finally:
if conn:
self.release_connection(db_path, conn)
def close_all(self):
"""关闭所有连接池中的连接"""
with self._lock:
for db_path, pool in self.pools.items():
# 关闭所有未使用的连接
while not pool.empty():
try:
conn = pool.get(block=False)
conn.close()
except queue.Empty:
break
# 关闭所有正在使用的连接
with self.connection_locks[db_path]:
for conn in list(self.in_use[db_path].keys()):
try:
conn.close()
except:
pass
self.in_use[db_path].clear()
# 清空池
self.pools.clear()
# 全局连接池实例
db_pool = DatabaseConnectionPool()
def close_db_pool():
"""关闭数据库连接池中的所有连接"""
db_pool.close_all()

View File

@ -5,8 +5,9 @@ import threading
import traceback
from collections import defaultdict
from datetime import datetime, date
from typing import Tuple
from typing import Tuple, List, Optional, Dict, Any
from app.DataBase.db_pool import db_pool
from app.log import logger
from app.util.compress_content import parser_reply
from app.util.protocbuf.msg_pb2 import MessageBytesExtra
@ -140,8 +141,6 @@ class MsgType:
class Msg:
def __init__(self):
self.DB = None
self.cursor = None
self.open_flag = False
self.init_database()
@ -151,9 +150,6 @@ class Msg:
if path:
db_path = path
if os.path.exists(db_path):
self.DB = sqlite3.connect(db_path, check_same_thread=False)
# '''创建游标'''
self.cursor = self.DB.cursor()
self.open_flag = True
if lock.locked():
lock.release()
@ -200,48 +196,137 @@ class Msg:
a[10]: BytesExtra,
a[11]: CompressContent,
a[12]: DisplayContent,
a[13]: 联系人的类如果是群聊就有不是的话没有这个字段
"""
if not self.open_flag:
return None
if time_range:
start_time, end_time = convert_to_timestamp(time_range)
sql = f'''
select localId,TalkerId,Type,SubType,IsSender,CreateTime,Status,StrContent,strftime('%Y-%m-%d %H:%M:%S',CreateTime,'unixepoch','localtime') as StrTime,MsgSvrID,BytesExtra,CompressContent,DisplayContent
from MSG
where StrTalker=?
{'AND CreateTime>' + str(start_time) + ' AND CreateTime<' + str(end_time) if time_range else ''}
order by CreateTime
return []
begin_time, end_time = convert_to_timestamp(time_range)
sql = '''
SELECT
localId,
TalkerId,
Type,
SubType,
IsSender,
CreateTime,
Status,
StrContent,
strftime('%Y-%m-%d %H:%M:%S', datetime(CreateTime, 'unixepoch', 'localtime')),
MsgSvrID,
BytesExtra,
CompressContent,
DisplayContent
FROM MSG
WHERE StrTalker = ?
AND (? = 0 OR CreateTime >= ?)
AND (? = 0 OR CreateTime <= ?)
ORDER BY CreateTime DESC
'''
params = (username_, begin_time, begin_time, end_time, end_time)
try:
lock.acquire(True)
self.cursor.execute(sql, [username_])
result = self.cursor.fetchall()
finally:
lock.release()
return parser_chatroom_message(result) if username_.__contains__('@chatroom') else result
# result.sort(key=lambda x: x[5])
# return self.add_sender(result)
results = db_pool.execute_query(db_path, sql, params)
# 处理群聊信息
if results and username_.startswith('chatroom'):
results = parser_chatroom_message(results)
return results
except Exception as e:
logger.error(f"获取聊天记录失败: {e}\n{traceback.format_exc()}")
return []
def batch_insert_messages(self, messages_data: List[Dict[str, Any]]) -> bool:
"""
批量插入消息数据
Args:
messages_data: 消息数据列表每个字典包含一条消息的所有字段
Returns:
bool: 是否成功插入
"""
if not self.open_flag or not messages_data:
return False
# 构建插入SQL
sql = '''
INSERT INTO MSG (
MsgId, TalkerId, Type, SubType, IsSender, CreateTime,
Status, StrContent, StrTalker, MsgSvrID, BytesExtra,
CompressContent, DisplayContent
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
'''
# 准备参数列表
params_list = [
(
msg.get('MsgId', ''),
msg.get('TalkerId', ''),
msg.get('Type', 0),
msg.get('SubType', 0),
msg.get('IsSender', 0),
msg.get('CreateTime', int(time.time())),
msg.get('Status', 0),
msg.get('StrContent', ''),
msg.get('StrTalker', ''),
msg.get('MsgSvrID', ''),
msg.get('BytesExtra', None),
msg.get('CompressContent', None),
msg.get('DisplayContent', None)
)
for msg in messages_data
]
try:
db_pool.execute_batch(db_path, sql, params_list)
return True
except Exception as e:
logger.error(f"批量插入消息失败: {e}\n{traceback.format_exc()}")
return False
def get_messages_all(self, time_range=None):
if time_range:
start_time, end_time = convert_to_timestamp(time_range)
sql = f'''
select localId,TalkerId,Type,SubType,IsSender,CreateTime,Status,StrContent,strftime('%Y-%m-%d %H:%M:%S',CreateTime,'unixepoch','localtime') as StrTime,MsgSvrID,BytesExtra,StrTalker,Reserved1,CompressContent
from MSG
{'WHERE CreateTime>' + str(start_time) + ' AND CreateTime<' + str(end_time) if time_range else ''}
order by CreateTime
'''
"""
获取所有聊天记录
@param time_range:
@return:
"""
if not self.open_flag:
return None
return []
begin_time, end_time = convert_to_timestamp(time_range)
sql = '''
SELECT
localId,
TalkerId,
Type,
SubType,
IsSender,
CreateTime,
Status,
StrContent,
strftime('%Y-%m-%d %H:%M:%S', datetime(CreateTime, 'unixepoch', 'localtime')),
MsgSvrID,
BytesExtra,
CompressContent,
DisplayContent,
StrTalker
FROM MSG
WHERE (? = 0 OR CreateTime >= ?)
AND (? = 0 OR CreateTime <= ?)
ORDER BY CreateTime DESC
'''
params = (begin_time, begin_time, end_time, end_time)
try:
lock.acquire(True)
self.cursor.execute(sql)
result = self.cursor.fetchall()
finally:
lock.release()
result.sort(key=lambda x: x[5])
return result
return db_pool.execute_query(db_path, sql, params)
except Exception as e:
logger.error(f"获取所有聊天记录失败: {e}\n{traceback.format_exc()}")
return []
def get_messages_group_by_day(
self,
@ -865,13 +950,8 @@ class Msg:
return sum_type_1 + sum_type_49
def close(self):
if self.open_flag:
try:
lock.acquire(True)
self.open_flag = False
self.DB.close()
finally:
lock.release()
"""关闭数据库连接,不再需要显式关闭,由连接池管理"""
self.open_flag = False
def __del__(self):
self.close()