mirror of
https://github.com/LC044/WeChatMsg
synced 2025-04-08 05:28:03 +08:00
306 lines
11 KiB
Python
306 lines
11 KiB
Python
import json
|
||
import random
|
||
import os
|
||
|
||
from wxManager import Me, MessageType
|
||
from exporter.exporter import ExporterBase, remove_privacy_info, get_new_filename
|
||
|
||
|
||
class JsonStrategy:
|
||
SPLIT_BY_TIME = 0 # 距离第一条消息的时间范围
|
||
SPLIT_BY_INTERVALS = 1 # 相邻消息的时间间隔
|
||
SLIDING_WINDOW = 2 # 滑动窗口法分割
|
||
|
||
|
||
class AssistantUser:
|
||
SELF = 0 # 自己是ai助手
|
||
CONTACT = 1 # 好友是ai助手
|
||
|
||
|
||
class JsonConfig:
|
||
prompt: str = ''
|
||
shuffle: bool = True # 是否随机打乱数据
|
||
train_ratio: int = 80 # 训练集占比(百分比)
|
||
model: str = 'Alpaca' # 可选:GLM4,ChatGLM3
|
||
model_keys = {
|
||
'GLM4': 'messages',
|
||
'ChatGLM3': 'conversations'
|
||
}
|
||
strategy: int = JsonStrategy.SPLIT_BY_INTERVALS # json导出策略
|
||
intervals: int = 120 # 相邻两条消息的最大间隔时间
|
||
span: int = 300 # 第一条消息跟最后一条消息的间隔时间
|
||
window_size: int = 10 # 窗口大小
|
||
step: int = 3 # 步长
|
||
assistant = AssistantUser.SELF
|
||
|
||
def get_model_keys(self):
|
||
return self.model_keys.get(self.model, 'messages')
|
||
|
||
|
||
def modify(output, history):
|
||
return output
|
||
|
||
|
||
def merge_content(conversions_list) -> list:
|
||
"""
|
||
合并一组对话中连续发送的句子
|
||
@param conversions_list:
|
||
@return:
|
||
"""
|
||
merged_data = []
|
||
current_role = None
|
||
current_content = ""
|
||
str_time = ''
|
||
for item in conversions_list:
|
||
if 'str_time' in item:
|
||
str_time = item['str_time']
|
||
else:
|
||
str_time = ''
|
||
if current_role is None:
|
||
current_role = item["role"]
|
||
current_content = item["content"]
|
||
elif current_role == item["role"]:
|
||
current_content += "," + item["content"]
|
||
else:
|
||
# merged_data.append({"role": current_role, "content": current_content, 'str_time': str_time})
|
||
if len(current_content) < 3 and current_role == 'assistant':
|
||
current_content = modify(current_content, merged_data)
|
||
merged_data.append({"role": current_role, "content": current_content})
|
||
current_role = item["role"]
|
||
current_content = item["content"]
|
||
str_time = item.get('str_time')
|
||
|
||
# 处理最后一组
|
||
if current_role is not None:
|
||
# merged_data.append({"role": current_role, "content": current_content,'str_time': str_time})
|
||
merged_data.append({"role": current_role, "content": current_content})
|
||
return merged_data
|
||
|
||
|
||
def is_first_msg(conversions):
|
||
if not conversions:
|
||
return True
|
||
else:
|
||
return len(conversions) == 1 and conversions[0]['role'] == 'system'
|
||
|
||
|
||
def conversion_to_history(conversations):
|
||
res = []
|
||
has_system_prompt = conversations[0].get('role') == 'system'
|
||
s_index, e_index = (1, len(conversations) - 3) if has_system_prompt else (0, len(conversations) - 2)
|
||
for i in range(s_index, e_index, 2):
|
||
res.append(
|
||
[
|
||
conversations[i].get('content'), conversations[i + 1].get('content')
|
||
]
|
||
)
|
||
return res
|
||
|
||
|
||
class JsonExporter(ExporterBase):
|
||
def __init__(
|
||
self,
|
||
database,
|
||
contact,
|
||
output_dir,
|
||
type_, # 导出文件类型
|
||
message_types: set[MessageType] = None, # 导出的消息类型
|
||
time_range=None, # 导出的日期范围
|
||
group_members: set[str] = None, # 群聊中只导出这些人的聊天记录
|
||
progress_callback=None, # 进度回调函数,func(progress:float)
|
||
finish_callback=None, # 导出完成回调函数
|
||
json_config: JsonConfig = None
|
||
):
|
||
super().__init__(database, contact, output_dir, type_, message_types, time_range, group_members,
|
||
progress_callback, finish_callback) # 调用父类的构造函数
|
||
if json_config:
|
||
self.json_config: JsonConfig = json_config
|
||
else:
|
||
self.json_config = JsonConfig()
|
||
|
||
def is_user(self, is_send):
|
||
"""
|
||
判断一条消息是否是user角色发送的
|
||
@param is_send:
|
||
@return:
|
||
"""
|
||
return is_send ^ (self.json_config.assistant == AssistantUser.SELF)
|
||
|
||
def system_prompt(self):
|
||
system = {
|
||
"role": "system",
|
||
"content": self.json_config.prompt.replace(
|
||
'{{name}}', Me().name
|
||
).replace(
|
||
'{{remark}}', self.contact.remark
|
||
)
|
||
}
|
||
return system
|
||
|
||
def message_to_conversion(self, group):
|
||
conversions = [self.system_prompt()] if self.json_config.prompt else []
|
||
# 确保最后一条消息是assistant发出的
|
||
while len(group) and self.is_user(group[-1].is_sender):
|
||
group.pop()
|
||
for message in group:
|
||
is_send = message.is_sender
|
||
text = remove_privacy_info(message.content)
|
||
# 确保第一条消息必须是user发出的
|
||
if is_first_msg(conversions) and not self.is_user(is_send):
|
||
continue
|
||
if self.is_user(is_send):
|
||
json_msg = {
|
||
"role": "user",
|
||
"content": text
|
||
}
|
||
else:
|
||
json_msg = {
|
||
"role": "assistant",
|
||
"content": text
|
||
}
|
||
json_msg['str_time'] = message.str_time
|
||
conversions.append(json_msg)
|
||
if len(conversions) == 1:
|
||
return []
|
||
return merge_content(conversions)
|
||
|
||
def split_by_time(self, length=300):
|
||
"""
|
||
通过第一条消息和最后一条消息的时间间隔分割数据集
|
||
@param length:
|
||
@return:
|
||
"""
|
||
messages = self.database.get_messages_by_type(self.contact.wxid, type_=MessageType.Text,
|
||
time_range=self.time_range)
|
||
start_time = 0
|
||
res = []
|
||
i = 0
|
||
while i < len(messages):
|
||
message = messages[i]
|
||
timestamp = message.timestamp
|
||
is_send = message.is_sender
|
||
group = []
|
||
while i < len(messages) and timestamp - start_time < length:
|
||
group.append(message)
|
||
i += 1
|
||
if i >= len(messages):
|
||
break
|
||
message = messages[i]
|
||
timestamp = message.timestamp
|
||
is_send = message.is_sender
|
||
while not self.is_user(is_send):
|
||
group.append(message)
|
||
i += 1
|
||
if i >= len(messages):
|
||
break
|
||
message = messages[i]
|
||
timestamp = message.timestamp
|
||
is_send = message.is_sender
|
||
start_time = timestamp
|
||
if len(group) > 4:
|
||
res.append(group)
|
||
return res
|
||
|
||
def split_by_intervals(self, max_diff_seconds=300):
|
||
"""
|
||
通过相邻两条消息的时间间隔分割数据集
|
||
@param max_diff_seconds:
|
||
@return:
|
||
"""
|
||
messages = self.database.get_messages_by_type(self.contact.wxid, type_=MessageType.Text,
|
||
time_range=self.time_range)
|
||
res = []
|
||
i = 0
|
||
current_group = []
|
||
while i < len(messages):
|
||
message = messages[i]
|
||
timestamp = message.timestamp
|
||
is_send = message.is_sender
|
||
while not self.is_user(is_send) and i + 1 < len(messages):
|
||
i += 1
|
||
message = messages[i]
|
||
is_send = message.is_sender
|
||
current_group = [messages[i]]
|
||
i += 1
|
||
while i < len(messages) and messages[i].timestamp - current_group[-1].timestamp <= max_diff_seconds:
|
||
current_group.append(messages[i])
|
||
i += 1
|
||
while i < len(messages) and not self.is_user(messages[i].is_sender):
|
||
current_group.append(messages[i])
|
||
i += 1
|
||
if len(current_group) > 4:
|
||
res.append(current_group)
|
||
return res
|
||
|
||
def split_by_window(self, window_size=10, step=3):
|
||
"""
|
||
滑动窗口切分数据集
|
||
@param window_size:
|
||
@param step:
|
||
@return:
|
||
"""
|
||
messages = self.database.get_messages_by_type(self.contact.wxid, type_=MessageType.Text,
|
||
time_range=self.time_range)
|
||
res = []
|
||
i = 0
|
||
while i < len(messages):
|
||
message = messages[i]
|
||
timestamp = message.timestamp
|
||
is_send = message.is_sender
|
||
current_group = []
|
||
j = i
|
||
while not self.is_user(is_send) and j + 1 < len(messages) and j - i < window_size:
|
||
j += 1
|
||
message = messages[j]
|
||
is_send = message.is_sender
|
||
current_group = [messages[j]]
|
||
j += 1
|
||
while j < len(messages) and j - i < window_size:
|
||
current_group.append(messages[j])
|
||
j += 1
|
||
res.append(current_group)
|
||
i += step
|
||
return res
|
||
|
||
def export(self):
|
||
print(f"【开始导出 json {self.contact.remark}】")
|
||
origin_path = self.origin_path
|
||
filename = os.path.join(origin_path, f"{self.contact.remark}.json")
|
||
filename = get_new_filename(filename)
|
||
messages_groups = []
|
||
match self.json_config.strategy:
|
||
case JsonStrategy.SPLIT_BY_INTERVALS:
|
||
messages_groups = self.split_by_intervals(self.json_config.intervals)
|
||
case JsonStrategy.SPLIT_BY_TIME:
|
||
messages_groups = self.split_by_time(self.json_config.span)
|
||
case JsonStrategy.SLIDING_WINDOW:
|
||
messages_groups = self.split_by_window(self.json_config.window_size, self.json_config.step)
|
||
dataset = []
|
||
self.update_progress_callback(0.5)
|
||
for group in messages_groups:
|
||
conversations = self.message_to_conversion(group)
|
||
if conversations:
|
||
if self.json_config.model == 'Alpaca':
|
||
has_system_prompt = conversations[0].get('role') == 'system'
|
||
dataset.append(
|
||
{
|
||
'system': conversations[0].get('content') if has_system_prompt else '',
|
||
'instruction': conversations[-2].get('content'),
|
||
'input': '',
|
||
'output': conversations[-1].get('content'),
|
||
'history': conversion_to_history(conversations),
|
||
}
|
||
)
|
||
else:
|
||
dataset.append({
|
||
self.json_config.get_model_keys(): conversations
|
||
})
|
||
if self.json_config.shuffle:
|
||
# 打乱列表顺序
|
||
random.shuffle(dataset)
|
||
with open(filename, "w", encoding="utf-8") as f:
|
||
json.dump(dataset, f, ensure_ascii=False, indent=4)
|
||
print(f"【完成导出 json {self.contact.remark}】")
|
||
self.update_progress_callback(1)
|
||
self.finish_callback(self.exporter_id)
|