700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > 大语言模型-中文chatGLM-LLAMA微调

大语言模型-中文chatGLM-LLAMA微调

时间:2022-07-15 12:32:05

相关推荐

大语言模型-中文chatGLM-LLAMA微调

微调

大语言模型-ChatGLM-Tuning

大语言模型-微调chatglm6b

大语言模型-中文chatGLM-LLAMA微调

大语言模型-alpaca-lora

本地知识库

大语言模型2-document ai解读

大语言模型-DocumentSearch解读

大语言模型-中文Langchain

本文解读代码的地址:

/27182812/ChatGLM-LLaMA-chinese-insturct

中文instruct在chatGLM, LLAMA上的表现

数据

json的预处理

instructiontokenizer

相比大语言模型-ChatGLM-Tuning中,是两个函数都放在了dataprocess的一个类中进行,初步看起来需要改变的几乎相同

微调

对chatGLM,finetune.sh对LLAMA,test_llama1.py

对于chatGLM和之前文章几乎相同,这里主要关注一下LLAMA

数据

def generate_prompt(data_point):# sorry about the formatting disaster gotta move fastif data_point["input"]:return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.### Instruction:{data_point["instruction"]}### Input:{data_point["input"]}### Response:{data_point["output"]}"""else:return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:{data_point["instruction"]}### Response:{data_point["output"]}"""def tokenize(prompt):# there's probably a way to do this with the tokenizer settings# but again, gotta move fastresult = tokenizer(prompt,truncation=True,max_length=CUTOFF_LEN + 1,padding="max_length",)return {"input_ids": result["input_ids"][:-1],"attention_mask": result["attention_mask"][:-1],}

模型

model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf",load_in_8bit=True,device_map="auto",)tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf", add_eos_token=True)model = prepare_model_for_int8_training(model)config = LoraConfig(r=LORA_R,lora_alpha=LORA_ALPHA,target_modules=["q_proj", "v_proj"],lora_dropout=LORA_DROPOUT,bias="none",task_type="CAUSAL_LM",)model = get_peft_model(model, config)tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token

微调

data = data.shuffle().map(lambda x: tokenize(generate_prompt(x)))trainer = transformers.Trainer(model=model,train_dataset=data["train"],args=transformers.TrainingArguments(per_device_train_batch_size=MICRO_BATCH_SIZE,gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,warmup_steps=100,num_train_epochs=EPOCHS,learning_rate=LEARNING_RATE,fp16=True,logging_steps=20,output_dir="qys-alpaca-chinese",save_total_limit=3,),data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),)model.config.use_cache = Falsetrainer.train(resume_from_checkpoint=False)# trainer.train()model.save_pretrained("qys-alpaca-chinese")

推理

对chatGLM,infer.py对LLAMA,generate_llama1.py

推理代码

tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf",load_in_8bit=True,torch_dtype=torch.float16,device_map="auto",)model = PeftModel.from_pretrained( model, "./qys-alpaca-chinese", torch_dtype=torch.float16)def generate_prompt(instruction, input=None):if input:return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.### Instruction:{instruction}### Input:{input}### Response:"""else:return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.### Instruction:{instruction}### Response:"""instructions = json.load(open("data/zh-data01.json"))answers = []with torch.no_grad():for idx, item in enumerate(instructions[12:18]):feature = format_example(item)input_text = feature['context']print(input_text)inputs = tokenizer(input_text, return_tensors="pt")input_ids = inputs["input_ids"].cuda()generation_config = GenerationConfig(temperature=0.1,top_p=0.75,top_k=40,num_beams=4,)generation_output = model.generate(input_ids=input_ids,generation_config=generation_config,return_dict_in_generate=True,output_scores=True,max_new_tokens=256,)s = generation_output.sequences[0]output = tokenizer.decode(s)print(output.strip())print("--------------------------------------------")

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。