TextCNN项目 P14-1 自定义损失函数和模型评估函数(2)
上节课中,我们把训练数据的目标值做了修改,接下来就需要修改训练逻辑这部分。主要是两块,一个是损失函数和一个分类有点区别,另外一个是评估函数也有点不同。接下来,我们就从这两个部分进行修改。
代码示例
1、配置项
首先是损失函数,因为多标签,我们预测的其实是一个01序列,而且0比1要多。所以我们对0的位置,做一个降权。
另外预测值,是一组 (0,1) 范围内的数值,我们需要设置一个阈值,大于阈值的取1,小于阈值的取0。
内容不可见,请联系管理员开通权限。
2、自定义损失函数
内容不可见,请联系管理员开通权限。
3、模型评估函数
评估函数的修改其实很简单,就是把预测和真实序列解析出来,完全一样的就算预测正确。
内容不可见,请联系管理员开通权限。
4、Kaggle训练
以上流程修改完后,也需要把代码复制到Kaggle上进行训练,然后把缓存模型下载下来备用。这一步已经讲过很多次,课上就不演示了,大家课后自己完成。
内容不可见,请联系管理员开通权限。
5、模型测试
训练完成后,可以将模型下载到本地,也可将直接在Kaggle上测试。
内容不可见,请联系管理员开通权限。
6、模型预测
训练的过程大家自己完成,模型预测部分,我们也同样需要把序列解析出来,然后找到对应的 label 标签即可。
内容不可见,请联系管理员开通权限。
到目前为止,我们多标签的改造就全部完成了。从上面的例子可以看出,多标签的预测虽然更符合业务需求,但还存在一个问题,就是可能会出现为空的情况。对于这种情况,我们的处理方案是,如果多标签预测为空,就再接一个单标签分类做兜底。
另外,这两个模型我们是单独训练的,但是在实际业务场景中,我们需要把他们整合起来,封装成一个模块,供业务方调用,这个封装过程,我们就放到下节课再给大家介绍。
本文链接:http://edu.ichenhua.cn/edu/note/566
版权声明:本文为「陈华编程」原创课程讲义,请给与知识创作者起码的尊重,未经许可不得传播或转售!