keras多任务多loss回传的思考
如果有一个多任务多loss的网络,那么在训练时,loss是如何工作的呢?
比如下面:
model = Model(inputs = input, outputs = [y1, y2])
l1 = 0.5
l2 = 0.3
model.compile(loss = [loss1, loss2], loss_weights=[l1, l2], ...)
其实我们最终得到的loss为
final_loss = l1 * loss1 + l2 * loss2
我们最终的优化效果是最小化final_loss。
问题来了,在训练过程中,是否loss2只更新得到y2的网络通路,还是loss2会更新所有的网络层呢?
此问题的关键在梯度回传上,即反向传播算法。
对于x1参数的更新:
对于x2参数的更新:
对于x2参数的更新:
所以loss1只对x1和x2有影响,而loss2只对x1和x3有影响。
参考:https://*.com/questions/49404309/how-does-keras-handle-multiple-losses