大家好,今天要讲的内容是多分类中的交叉熵损失函数。
交叉熵误差,cross entropy error,用来评估模型输出的概率分布和真实概率分布的差异情况,一般用于解决分类问题。它有两种定义形式,分别对应二分类与多分类问题。
在二分类问题中,E=- [y * log(p) + (1 - y) * log(1 - p)],其中y是样本的真实标记,p是模型的预测概率。
在多分类问题中,E=- Σ [y_i * log(p_i)],y_i和p_i对应第i个类别的真实标记与预测概率,n是类别个数。
多分类中的交叉熵误差
在多分类问题中,如果每个类别之间的定义是互斥的,那么任何样本都只能被标记为一种类别。
我们使用向量y来表示样本的标记值,如果有n种类别,那么y就是一个n乘1的列向量,向量中只有1个元素是1,其余元素是0。
定义m个样本、n种类别的交叉熵误差E。它等于m个样本的平均交叉熵误差。
在公式中,y-i-k和p-i-k代表了,第i个样本,第k个类别的真实标记和第k个类别的模型预测概率。
如果第i个样本被标记为第k个类别,那么在向量y-i中,第k个元素y-i-k就等于1,其余均是0。
多分类问题的交叉熵损失,只与真实类别对应的模型预测概率有关。这里我们单独来看第i个样本的误差Ei。
如果该样本的真实类别是第k个类别,那么在西格玛求和的过程中,只有第k项是存在的,其余项均是0。因此Ei=-(y-i-k)*log(p-i-k)。
上述计算说明了,在交叉熵损失函数中,只有真实类别对应的那一项会被计算在内,其他类别的项,在求和过程中均为0。
所以即便模型对其他类别的预测概率不准确,但只要对真实类别的预测概率较高,损失函数的值仍然较低。
均方误差与交叉熵误差的举例对比
下面是一个图片分类的例子,图片的类别有猫、狗、牛三种情况。
我们训练出了两个模型a和b,使用a和b,对3个测试样本进行预测,得到了相应的预测结果。
我们要基于这个结果,来说明均方误差与交叉熵误差,两种损失函数,在多分类问题上的表现区别。
观察模型的预测结果可以发现,对于每个样本,a和b两个模型,都会输出三种类别的3个概率。每个概率会对应标记值0或1。
例如,模型a预测第1个样本,得到3个概率是0.3、0.3和0.4,它们对应的标记是0、0、1。模型预测正确。
从数据来看,a和b两个模型,都正确的预测了样本1和样本2,错误预测了样本3。
但是模型a在预测样本1和2时,三种类别的概率差异并不明显,只是很微弱的判断正确,并且样本3的错误还非常明显。
对于模型b,在判断前两个样本时,正确类别对应的预测概率很大,都是0.7,同时第3个样本的错误也不明显。下面就使用均方误差和交叉熵误差,分别对这两个模型进行衡量。
首先使用均方误差,计算a和b两个模型的差异。对于每个样本,都需要计算所有类别的预测概率和标记值差的平方和,再求平均。
例如,模型a对于三个样本的误差是0.54、0.54和1.34,平均误差是0.81,模型b对三个样本的误差是0.14、0.14和0.74,平均误差0.34,可以得出模型b更好。
接下来使用交叉熵误差,计算两个模型的区别。实际上,对于每个样本,我们只关注该样本的真实标记类别的预测概率。
例如,模型a预测第1个样本时,类别1和2的交叉熵都是0,类别3是负log0.4,因此得到结果0.91。
按照这种方式,计算a和b两个模型对于3个样本的误差,然后求出平均值。模型a的平均交叉熵误差是1.37,b是0.63,同样会得到模型b比a好的结论。
比较两种方法的计算结果,可以发现,a和b两个模型的均方误差的差异是0.81-0.34=0.47。交叉熵误差的差异是1.37-0.63=0.74。
因此,基于交叉熵误差比较两个模型,可以得到更大的差异,这也就说明了,在该多分类问题中,如果使用交叉熵误差,可以更好的区分出模型的差异。
那么到这里,多分类中的交叉熵损失函数就讲完了,感谢大家的观看,我们下节课再会。
本文暂时没有评论,来添加一个吧(●'◡'●)