网站维护 网站建设属于什么,四川招标投标网,织梦后台如何做网站地图,网站快速收录平台文章目录 一、导入相关包二、加载数据集三、数据集预处理四、创建模型五、创建评估函数六、配置训练参数七、创建训练器八、模型训练九、模型预测 !pip install transformers datasets evaluate accelerate 一、导入相关包
import evaluate
from datasets import DatasetDict,… 文章目录 一、导入相关包二、加载数据集三、数据集预处理四、创建模型五、创建评估函数六、配置训练参数七、创建训练器八、模型训练九、模型预测 !pip install transformers datasets evaluate accelerate 一、导入相关包
import evaluate
from datasets import DatasetDict, load_dataset
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer二、加载数据集
# c3 DatasetDict.load_from_disk(./c3/) 从本地加载
# c3 load_from_disk(./c3/) 同上
c3 load_dataset(clue,c3)
c3DatasetDict({test: Dataset({features: [id, context, question, choice, answer],num_rows: 1625})train: Dataset({features: [id, context, question, choice, answer],num_rows: 11869})validation: Dataset({features: [id, context, question, choice, answer],num_rows: 3816})
})c3[train][:10]{id: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],context: [[男你今天晚上有时间吗?我们一起去看电影吧?, 女你喜欢恐怖片和爱情片但是我喜欢喜剧片科幻片一般。所以……],[男足球比赛是明天上午八点开始吧?, 女因为天气不好比赛改到后天下午三点了。],[女今天下午的讨论会开得怎么样?, 男我觉得发言的人太少了。],[男我记得你以前很爱吃巧克力最近怎么不吃了是在减肥吗?, 女是啊我希望自己能瘦一点儿。],[女过几天刘明就要从英国回来了。我还真有点儿想他了记得那年他是刚过完中秋节走的。,男可不是嘛!自从我去日本留学就再也没见过他算一算都五年了。,女从2000年我们在学校第一次见面到现在已经快十年了。我还真想看看刘明变成什么样了!,男你还别说刘明肯定跟英国绅士一样也许还能带回来一个英国女朋友呢。],[男好久不见了最近忙什么呢?,女最近我们单位要搞一个现代艺术展览正忙着准备呢。,男你们不是出版公司吗?为什么搞艺术展览?,女对啊这次展览是我们出版的一套艺术丛书的重要宣传活动。],[男会议结束后你记得把空调和灯都关了。, 女好的我知道了明天见。],[男你出国读书的事定了吗?, 女思前想后还拿不定主意呢。],[男这件衣服我要了在哪儿交钱?, 女前边右拐就有一个收银台可以交现金也可以刷卡。],[男小李啊你是我见过的最爱干净的学生。,女谢谢教授夸奖。不过您是怎么看出来的?,男不管我叫你做什么你总是推得干干净净。,女教授我……]],question: [女的最喜欢哪种电影?,根据对话可以知道什么?,关于这次讨论会我们可以知道什么?,女的为什么不吃巧克力了?,现在大概是哪一年?,女的的公司为什么要做现代艺术展览?,他们最可能是什么关系?,女的是什么意思?,他们最可能在什么地方?,教授认为小李怎么样?],choice: [[恐怖片, 爱情片, 喜剧片, 科幻片],[今天天气不好, 比赛时间变了, 校长忘了时间],[会是昨天开的, 男的没有参加, 讨论得不热烈, 参加的人很少],[刷牙了, 要减肥, 口渴了, 吃饱了],[2005年, 2010年, 2008年, 2009年],[传播文化, 宣传新书, 推广现代艺术, 体现企业文化],[同事, 司机和客人, 医生和病人],[不想出国, 出国太难, 还在犹豫, 不想决定],[医院, 迪厅, 商场, 饭馆],[卫生习惯非常好, 做事的能力不够, 找借口拒绝做事, 记不住该做的事]],answer: [喜剧片,比赛时间变了,讨论得不热烈,要减肥,2010年,宣传新书,同事,还在犹豫,商场,找借口拒绝做事]}# dataset本质上是一个字典删除test键
c3.pop(test) # 删除test数据集Dataset({features: [id, context, question, choice, answer],num_rows: 1625
})
...# 因为是字典下列操作也支持c3.keys()
c3.values()
c3.items()c3DatasetDict({train: Dataset({features: [id, context, question, choice, answer],num_rows: 11869})validation: Dataset({features: [id, context, question, choice, answer],num_rows: 3816})
})三、数据集预处理
tokenizer AutoTokenizer.from_pretrained(hfl/chinese-macbert-base)
tokenizerBertTokenizerFast(name_or_pathhfl/chinese-macbert-base, vocab_size21128, model_max_length1000000000000000019884624838656, is_fastTrue, padding_sideright, truncation_sideright, special_tokens{unk_token: [UNK], sep_token: [SEP], pad_token: [PAD], cls_token: [CLS], mask_token: [MASK]}, clean_up_tokenization_spacesTrue), added_tokens_decoder{0: AddedToken([PAD], rstripFalse, lstripFalse, single_wordFalse, normalizedFalse, specialTrue),100: AddedToken([UNK], rstripFalse, lstripFalse, single_wordFalse, normalizedFalse, specialTrue),101: AddedToken([CLS], rstripFalse, lstripFalse, single_wordFalse, normalizedFalse, specialTrue),102: AddedToken([SEP], rstripFalse, lstripFalse, single_wordFalse, normalizedFalse, specialTrue),103: AddedToken([MASK], rstripFalse, lstripFalse, single_wordFalse, normalizedFalse, specialTrue),
}def process_function(examples):# examples, dict, keys: [context, quesiton, choice, answer]# 假设examples有1000个context []question_choice []labels []for idx in range(len(examples[context])):ctx \n.join(examples[context][idx])question examples[question][idx]choices examples[choice][idx]for choice in choices:context.append(ctx)question_choice.append(question choice)# 不足四个选项补全四个选项if len(choices) 4:for _ in range(4 - len(choices)):context.append(ctx)question_choice.append(question 不知道)labels.append(choices.index(examples[answer][idx]))tokenized_examples tokenizer(context, question_choice, truncationonly_first, max_length256, paddingmax_length) # input_ids: 4000 * 256,tokenized_examples {k: [v[i: i 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()} # 1000 * 4 *256tokenized_examples[labels] labelsreturn tokenized_examplesres c3[train].select(range(10)).map(process_function, batchedTrue)
resDataset({features: [id, context, question, choice, answer, input_ids, token_type_ids, attention_mask, labels],num_rows: 10
})import numpy as np
np.array(res[input_ids]).shape(10, 4, 256)tokenized_c3 c3.map(process_function, batchedTrue)
tokenized_c3DatasetDict({train: Dataset({features: [id, context, question, choice, answer, input_ids, token_type_ids, attention_mask, labels],num_rows: 11869})validation: Dataset({features: [id, context, question, choice, answer, input_ids, token_type_ids, attention_mask, labels],num_rows: 3816})
})四、创建模型
model AutoModelForMultipleChoice.from_pretrained(hfl/chinese-macbert-base)五、创建评估函数
import numpy as np # 切记这里predictions是np数组
accuracy evaluate.load(accuracy)def compute_metric(pred):predictions, labels predpredictions np.argmax(predictions, axis-1)return accuracy.compute(predictionspredictions, referenceslabels)六、配置训练参数
fp16True用混合精度训练 混合精度训练需要 GPU 支持特别是 NVIDIA 的 Volta 和 Turing 架构以及更高版本的 GPU。如果您在没有这些硬件的环境中启用了混合精度训练可能会遇到错误。好处更少的显存、更快的训练速度坏处损失精度
args TrainingArguments(output_dir./muliple_choice,per_device_train_batch_size16,per_device_eval_batch_size16,num_train_epochs3,logging_steps50,evaluation_strategyepoch,save_strategyepoch,load_best_model_at_endTrue,fp16True # 用混合精度训练可以加速训练
)七、创建训练器
trainer Trainer(modelmodel,argsargs,train_datasettokenized_c3[train],eval_datasettokenized_c3[validation],compute_metricscompute_metric
)八、模型训练
trainer.train()九、模型预测
多项选择任务 pipeline并没有现成的封装需要自己写推理
from typing import Any
import torchclass MultipleChoicePipeline:def __init__(self, model, tokenizer) - None:self.model modelself.tokenizer tokenizerself.device model.devicedef preprocess(self, context, quesiton, choices):cs, qcs [], []for choice in choices:cs.append(context)qcs.append(quesiton choice)return tokenizer(cs, qcs, truncationonly_first, max_length256, return_tensorspt)def predict(self, inputs):# inputs扩充一个batch维度inputs {k: v.unsqueeze(0).to(self.device) for k, v in inputs.items()}return self.model(**inputs).logitsdef postprocess(self, logits, choices):predition torch.argmax(logits, dim-1).cpu().item()return choices[predition]def __call__(self, context, question, choices) - Any:inputs self.preprocess(context, question, choices)logits self.predict(inputs)result self.postprocess(logits, choices)return result单条预测
pipe MultipleChoicePipeline(model, tokenizer)注意这里不限于选项的个数训练的时候限制了 4 个推理的时候可以任意个数
pipe(小明在北京上班, 小明在哪里上班, [北京, 上海, 河北, 海南, 河北, 海南])北京pipe(小明在北京上班, 小明在哪里上班, [北京, 上海])北京