WeChatMsg/MemoAI/qwen2-0.5b/train.ipynb
2024-06-14 10:58:52 +00:00

420 lines
11 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "de53995b-32ed-4722-8cac-ba104c8efacb",
"metadata": {},
"source": [
"# 导入环境"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52fac949-4150-4091-b0c3-2968ab5e385c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from datasets import Dataset\n",
"import pandas as pd\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e098d9eb",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"df = pd.read_json('train.json')\n",
"ds = Dataset.from_pandas(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8ac92d42-efae-49b1-a00e-ccaa75b98938",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"ds[:3]"
]
},
{
"cell_type": "markdown",
"id": "380d9f69-9e98-4d2d-b044-1d608a057b0b",
"metadata": {},
"source": [
"# 下载模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "312d6439-1932-44a3-b592-9adbdb7ab702",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from modelscope import snapshot_download\n",
"model_dir = snapshot_download('qwen/Qwen2-0.5B-Instruct', cache_dir='qwen2-0.5b/')"
]
},
{
"cell_type": "markdown",
"id": "51d05e5d-d14e-4f03-92be-9a9677d41918",
"metadata": {},
"source": [
"# 处理数据集"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74ee5a67-2e55-4974-b90e-cbf492de500a",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained('./qwen2-0.5b/qwen/Qwen2-0___5B-Instruct/', use_fast=False, trust_remote_code=True)\n",
"tokenizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2503a5fa-9621-4495-9035-8e7ef6525691",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def process_func(example):\n",
" MAX_LENGTH = 384 # Llama分词器会将一个中文字切分为多个token因此需要放开一些最大长度保证数据的完整性\n",
" input_ids, attention_mask, labels = [], [], []\n",
" instruction = tokenizer(f\"<|im_start|>system\\n现在你需要扮演我,和我的微信好友快乐聊天!<|im_end|>\\n<|im_start|>user\\n{example['instruction'] + example['input']}<|im_end|>\\n<|im_start|>assistant\\n\", add_special_tokens=False)\n",
" response = tokenizer(f\"{example['output']}\", add_special_tokens=False)\n",
" input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n",
" attention_mask = instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1] # 因为eos token咱们也是要关注的所以 补充为1\n",
" labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id] \n",
" if len(input_ids) > MAX_LENGTH: # 做一个截断\n",
" input_ids = input_ids[:MAX_LENGTH]\n",
" attention_mask = attention_mask[:MAX_LENGTH]\n",
" labels = labels[:MAX_LENGTH]\n",
" return {\n",
" \"input_ids\": input_ids,\n",
" \"attention_mask\": attention_mask,\n",
" \"labels\": labels\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84f870d6-73a9-4b0f-8abf-687b32224ad8",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n",
"tokenized_id"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f7e15a0-4d9a-4935-9861-00cc472654b1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tokenizer.decode(tokenized_id[0]['input_ids'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97f16f66-324a-454f-8cc3-ef23b100ecff",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[1][\"labels\"])))"
]
},
{
"cell_type": "markdown",
"id": "424823a8-ed0d-4309-83c8-3f6b1cdf274c",
"metadata": {},
"source": [
"# 创建模型"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "170764e5-d899-4ef4-8c53-36f6dec0d198",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained('./qwen2-0.5b/qwen/Qwen2-0___5B-Instruct', device_map=\"auto\",torch_dtype=torch.bfloat16)\n",
"model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2323eac7-37d5-4288-8bc5-79fac7113402",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model.enable_input_require_grads()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f808b05c-f2cb-48cf-a80d-0c42be6051c7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model.dtype"
]
},
{
"cell_type": "markdown",
"id": "13d71257-3c1c-4303-8ff8-af161ebc2cf1",
"metadata": {},
"source": [
"# lora "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d304ae2-ab60-4080-a80d-19cac2e3ade3",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from peft import LoraConfig, TaskType, get_peft_model\n",
"\n",
"config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM, \n",
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
" inference_mode=False, # 训练模式\n",
" r=8, # Lora 秩\n",
" lora_alpha=32, # Lora alaph具体作用参见 Lora 原理\n",
" lora_dropout=0.1# Dropout 比例\n",
")\n",
"config"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c2489c5-eaab-4e1f-b06a-c3f914b4bf8e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model = get_peft_model(model, config)\n",
"config"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ebf5482b-fab9-4eb3-ad88-c116def4be12",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"model.print_trainable_parameters()"
]
},
{
"cell_type": "markdown",
"id": "ca055683-837f-4865-9c57-9164ba60c00f",
"metadata": {},
"source": [
"# 配置训练参数"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7e76bbff-15fd-4995-a61d-8364dc5e9ea0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"args = TrainingArguments(\n",
" output_dir=\"./output/\",\n",
" per_device_train_batch_size=4,\n",
" gradient_accumulation_steps=4,\n",
" logging_steps=10,\n",
" num_train_epochs=3,\n",
" learning_rate=1e-4,\n",
" gradient_checkpointing=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f142cb9c-ad99-48e6-ba86-6df198f9ed96",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model,\n",
" args=args,\n",
" train_dataset=tokenized_id,\n",
" data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aec9bc36-b297-45af-99e1-d4c4d82be081",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "8abb2327-458e-4e96-ac98-2141b5b97c8e",
"metadata": {},
"source": [
"# 合并加载模型,这里的路径可能有点不太一样,lora_path填写为Output的最后的checkpoint"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd2a415a-a9ad-49ea-877f-243558a83bfc",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"import torch\n",
"from peft import PeftModel\n",
"\n",
"mode_path = './qwen2-0.5b/qwen/Qwen2-0___5B-Instruct'\n",
"lora_path = './output/checkpoint-10' #修改这里\n",
"# 加载tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)\n",
"\n",
"# 加载模型\n",
"model = AutoModelForCausalLM.from_pretrained(mode_path, device_map=\"auto\",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()\n",
"\n",
"# 加载lora权重\n",
"model = PeftModel.from_pretrained(model, model_id=lora_path)\n",
"\n",
"prompt = \"在干啥呢?\"\n",
"inputs = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": \"现在你需要扮演我,和我的微信好友快乐聊天!\"},{\"role\": \"user\", \"content\": prompt}],\n",
" add_generation_prompt=True,\n",
" tokenize=True,\n",
" return_tensors=\"pt\",\n",
" return_dict=True\n",
" ).to('cuda')\n",
"\n",
"\n",
"gen_kwargs = {\"max_length\": 2500, \"do_sample\": True, \"top_k\": 1}\n",
"with torch.no_grad():\n",
" outputs = model.generate(**inputs, **gen_kwargs)\n",
" outputs = outputs[:, inputs['input_ids'].shape[1]:]\n",
" print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
"\n",
"# 保存合并后的模型和tokenizer\n",
"save_directory = './model_merge'\n",
"\n",
"# 保存模型\n",
"\n",
"model.save_pretrained(save_directory)\n",
"\n",
"# 保存tokenizer\n",
"tokenizer.save_pretrained(save_directory)"
]
},
{
"cell_type": "markdown",
"id": "b67e5e0a-2566-4483-9bce-92b5be8b4b34",
"metadata": {},
"source": [
"# 然后把模型上传到modelscope开始下一步"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dafe4f24-af5c-407e-abbc-eefd9d44cb15",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}