上节课,我们做了一个简单的数据预处理,通过观察直方图,定义好了文本的长度参数。同时,如果对 Bert 不熟悉的同学,还需要看一下前面补充的内容。

现在,假设大家已经看了、并且掌握了前面 Huggingface 的内容,我们接着往下讲自定义 Dataset 和 Bert 分词的内容。

代码示例

1、新建文件

# utils.py
from torch.utils import data
from config import *
import torch
from transformers import BertTokenizer

from transformers import logging
logging.set_verbosity_error()

2、自定义Dataset类

class Dataset(data.Dataset):
    def __init__(self, type='train'):
        super().__init__()
        if type == 'train':
            sample_path = TRAIN_SAMPLE_PATH
        elif type == 'dev':
            sample_path = DEV_SAMPLE_PATH
        elif type == 'test':
            sample_path = TEST_SAMPLE_PATH

        self.lines = open(sample_path).readlines()
        self.tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, index):
        text, label = self.lines[index].split('\t')
        tokened = self.tokenizer(text)
        input_ids = tokened['input_ids']
        mask = tokened['attention_mask']
        if len(input_ids) < TEXT_LEN:
            pad_len = (TEXT_LEN - len(input_ids))
            input_ids += [BERT_PAD_ID] * pad_len
            mask += [0] * pad_len
        target = int(label)
        return torch.tensor(input_ids[:TEXT_LEN]), torch.tensor(mask[:TEXT_LEN]), torch.tensor(target)

3、调用测试

if __name__ == '__main__':
    dataset = Dataset()
    loader = data.DataLoader(dataset, batch_size=2)
    print(iter(loader).next())

本文链接:http://edu.ichenhua.cn/edu/note/504

版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!