TextCNN项目 P2 训练数据导入和句子长度统计
上节课,给大家介绍了 TextCNN 的模型结构,这节课就正式进入代码部分。本节课有两个任务,一是导入数据集,二是要统计待分类的文本长度,因为 TextCNN 在卷积之后,要做批量最大池化操作,所以要求文本长度一致,不够的填充PAD,太长的要进行截取。
代码示例
1、添加配置项
# config.py TRAIN_SAMPLE_PATH = './data/input/train.txt' DEV_SAMPLE_PATH = './data/input/dev.txt' TEST_SAMPLE_PATH = './data/input/test.txt' LABEL_PATH = './data/input/class.txt' BERT_PAD_ID = 0 TEXT_LEN = 35
2、统计句子长度
from config import * import matplotlib.pyplot as plt def count_text_len(): text_len = [] with open(TRAIN_SAMPLE_PATH) as f: for line in f.readlines(): text, _ = line.split('\t') text_len.append(len(text)) plt.hist(text_len) plt.show() print(max(text_len)) if __name__ == '__main__': count_text_len()
做完简单的数据预处理之后,下一步就要定义 Dataset 类,加载数据了。但是数据加载,涉及到 Bert 分词,可能有的同学对 Bert 的使用还不熟悉,所以下面两节课,我们对 Bert 的基本使用做一个简单介绍,已经会用的同学可以直接跳过。
本文链接:http://edu.ichenhua.cn/edu/note/503
版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!