mirror of
https://github.com/LC044/WeChatMsg
synced 2025-02-22 19:02:17 +08:00
194 lines
6.4 KiB
Python
194 lines
6.4 KiB
Python
import json
|
|
import random
|
|
import os
|
|
|
|
from app.DataBase import msg_db
|
|
from app.person import Me
|
|
from .exporter import ExporterBase
|
|
|
|
|
|
def merge_content(conversions_list) -> list:
|
|
"""
|
|
合并一组对话中连续发送的句子
|
|
@param conversions_list:
|
|
@return:
|
|
"""
|
|
merged_data = []
|
|
current_role = None
|
|
current_content = ""
|
|
str_time = ''
|
|
for item in conversions_list:
|
|
if 'str_time' in item:
|
|
str_time = item['str_time']
|
|
else:
|
|
str_time = ''
|
|
if current_role is None:
|
|
current_role = item["role"]
|
|
current_content = item["content"]
|
|
elif current_role == item["role"]:
|
|
current_content += "\n" + item["content"]
|
|
else:
|
|
# merged_data.append({"role": current_role, "content": current_content, 'str_time': str_time})
|
|
merged_data.append({"role": current_role, "content": current_content})
|
|
current_role = item["role"]
|
|
current_content = item["content"]
|
|
str_time = item.get('str_time')
|
|
|
|
# 处理最后一组
|
|
if current_role is not None:
|
|
# merged_data.append({"role": current_role, "content": current_content,'str_time': str_time})
|
|
merged_data.append({"role": current_role, "content": current_content})
|
|
return merged_data
|
|
|
|
|
|
def system_prompt():
|
|
system = {
|
|
"role": "system",
|
|
# "content": f"你是{Me().name},一个聪明、热情、善良的男大学生,后面的对话来自{self.contact.remark}(!!!注意:对方的身份十分重要,你务必记住对方的身份,因为跟不同的人对话要用不同的态度、语气),你要认真地回答他"
|
|
"content": f"你是{Me().name},一个聪明、热情、善良的人,后面的对话来自你的朋友,你要认真地回答他"
|
|
}
|
|
return system
|
|
|
|
|
|
def message_to_conversion(group):
|
|
conversions = [system_prompt()]
|
|
while len(group) and group[-1][4] == 0:
|
|
group.pop()
|
|
for message in group:
|
|
is_send = message[4]
|
|
if len(conversions) == 1 and is_send:
|
|
continue
|
|
if is_send:
|
|
json_msg = {
|
|
"role": "assistant",
|
|
"content": message[7]
|
|
}
|
|
else:
|
|
json_msg = {
|
|
"role": "user",
|
|
"content": message[7]
|
|
}
|
|
json_msg['str_time'] = message[8]
|
|
conversions.append(json_msg)
|
|
if len(conversions) == 1:
|
|
return []
|
|
return merge_content(conversions)
|
|
|
|
|
|
class JsonExporter(ExporterBase):
|
|
def split_by_time(self, length=300):
|
|
messages = msg_db.get_messages_by_type(self.contact.wxid, type_=1, time_range=self.time_range)
|
|
start_time = 0
|
|
res = []
|
|
i = 0
|
|
while i < len(messages):
|
|
message = messages[i]
|
|
timestamp = message[5]
|
|
is_send = message[4]
|
|
group = [
|
|
system_prompt()
|
|
]
|
|
while i < len(messages) and timestamp - start_time < length:
|
|
if is_send:
|
|
json_msg = {
|
|
"role": "assistant",
|
|
"content": message[7]
|
|
}
|
|
else:
|
|
json_msg = {
|
|
"role": "user",
|
|
"content": message[7]
|
|
}
|
|
group.append(json_msg)
|
|
i += 1
|
|
if i >= len(messages):
|
|
break
|
|
message = messages[i]
|
|
timestamp = message[5]
|
|
is_send = message[4]
|
|
while is_send:
|
|
json_msg = {
|
|
"role": "assistant",
|
|
"content": message[7]
|
|
}
|
|
group.append(json_msg)
|
|
i += 1
|
|
if i >= len(messages):
|
|
break
|
|
message = messages[i]
|
|
timestamp = message[5]
|
|
is_send = message[4]
|
|
start_time = timestamp
|
|
res.append(
|
|
{
|
|
"conversations": group
|
|
}
|
|
)
|
|
res_ = []
|
|
for item in res:
|
|
conversations = item['conversations']
|
|
res_.append({
|
|
'conversations': merge_content(conversations)
|
|
})
|
|
return res_
|
|
|
|
def split_by_intervals(self, max_diff_seconds=300):
|
|
messages = msg_db.get_messages_by_type(self.contact.wxid, type_=1, time_range=self.time_range)
|
|
res = []
|
|
i = 0
|
|
current_group = []
|
|
while i < len(messages):
|
|
message = messages[i]
|
|
timestamp = message[5]
|
|
is_send = message[4]
|
|
while is_send and i + 1 < len(messages):
|
|
i += 1
|
|
message = messages[i]
|
|
is_send = message[4]
|
|
current_group = [messages[i]]
|
|
i += 1
|
|
while i < len(messages) and messages[i][5] - current_group[-1][5] <= max_diff_seconds:
|
|
current_group.append(messages[i])
|
|
i += 1
|
|
while i < len(messages) and messages[i][4]:
|
|
current_group.append(messages[i])
|
|
i += 1
|
|
res.append(current_group)
|
|
res_ = []
|
|
for group in res:
|
|
conversations = message_to_conversion(group)
|
|
if conversations:
|
|
res_.append({
|
|
'conversations': conversations
|
|
})
|
|
return res_
|
|
|
|
def to_json(self):
|
|
print(f"【开始导出 json {self.contact.remark}】")
|
|
origin_path = self.origin_path
|
|
os.makedirs(origin_path, exist_ok=True)
|
|
filename = os.path.join(origin_path, f"{self.contact.remark}")
|
|
|
|
# res = self.split_by_time()
|
|
res = self.split_by_intervals(60)
|
|
# 打乱列表顺序
|
|
random.shuffle(res)
|
|
|
|
# 计算切分比例
|
|
split_ratio = 0.2 # 20% for the second list
|
|
|
|
# 计算切分点
|
|
split_point = int(len(res) * split_ratio)
|
|
|
|
# 分割列表
|
|
train_data = res[split_point:]
|
|
dev_data = res[:split_point]
|
|
with open(f'{filename}_train.json', "w", encoding="utf-8") as f:
|
|
json.dump(train_data, f, ensure_ascii=False, indent=4)
|
|
with open(f'{filename}_dev.json', "w", encoding="utf-8") as f:
|
|
json.dump(dev_data, f, ensure_ascii=False, indent=4)
|
|
self.okSignal.emit(1)
|
|
|
|
def run(self):
|
|
self.to_json()
|