model.fit结合dataset实现输入数据的正样本加权

最近写代码想实现给我的loss正样本加权,因为我的数据集中正样本占比只有15.79%,为了保证tensorflow的IO性能,我采取使用dataset构建高效的数据输入流水线。
正在我忧愁要不要自己写weight_loss函数的时候开心地发现model.fit提供了参数class_weight,只需要输入字典结构的class_weight就可自动实现对已有的loss进行分类别的加权。
(不要以为weight_loss普通的loss加个系数就够了,loss的定义是当前epoch到目前为止所有batch的loss,而不是当前这个batch的loss,详情见之前的一篇博客
注意区分与sample_weight的关系,sample_weight表示对不同的样本的权重,比如要对一个个的batch样本的loss分别加权。class_weight和sample_weight不仅是用在loss中,在metrics中也有这两个参数,不过一般不怎么对metrics使用。

由于我的label是用True和False标记的,所以我在实验过程中传入class_weight={True:4}结果发现遇到这个错误:
model.fit结合dataset实现输入数据的正样本加权
没看懂这是什么意思,于是google了不少答案,却发现没什么人遇到过我这个问题,只发现githu上的一个issue很相似,大概就是说tf.keras中model.fit的class_weight对dataset的支持似乎有问题,并且作者提交了PR,下面一堆人期待官方尽快解决这个bug。github该问题

本来以为我只能自己写weight_loss函数了 ???? 。但是看到作者提交的PR已经被merge到tensorflow了,按道理这个bug不应该存在。于是,在仔细看看了作者的bug,发现和我的不太一样,这时候明白了肯定是我自己的问题。于是沿着报错的地方,打开源码看了看,果然有问题,出错地方的源码如下:
model.fit结合dataset实现输入数据的正样本加权python里面说的是999行出错,逐行代码测试发现,原来由于我的label是True和False,这样999处weight__vector[True]=4肯定有问题。于是果断在dataset解析函数里面改了label类型为int8。还有就是上面只设置了正样本是不合适的,因为从998行可以看出没有设置的class默认为nan,这是有问题的,也就是官方强制要求我们为每个class指明权重,于是改了class_weight如下:
class_weight={1:4,0:1}
完美,问题解决!

总结:
1.不要在没必要的地方标新立异,这里label其实大家都是用0-1存储,我用了True-False,可能tensorflow官方没料想到有人这样做哈哈;
2.有些错误是共性的,可以借鉴,有些错误是个性的,该看源码得自己看。