计算机系统应用教程网站

网站首页 > 技术文章 正文

「五分钟机器学习」 神经网络的基本介绍

btikc 2024-10-11 11:26:33 技术文章 4 ℃ 0 评论

大家好,我是爱讲故事的某某某。 欢迎来到今天的[五分钟机器学习] 神经网络的基本介绍

本期内容将继续视频内容【五分钟机器学习】神经网络——一个小人国投票的故事,并给出全连接神经网络模型的数学推导。还没有看过的小伙伴欢迎去补番。

本期专栏的主要内容如下:

  1. 神经网络的前向传播(Forward Propagation)
  2. 神经网络的反向传播(Backward Propagation)

神经网络的训练主要分为两个部分:

  1. 前向传播生成基于当前试验参数的预测值
  2. 根据LOSS反向传播,计算梯度并更新模型参数

为了理解神经网络中的这两个主要步骤,请看下面的数学推导实例。在这个例子中,我们随机定义了一个全连接神经网络,且网络中包含两层Hidden Layer。

神经网络的前向传播(Forward Propagation)

前向传播的主要目的在于基于将输入X通过各种层的计算(weights和Bias等)得出最终的预测结果,并计算Loss。对于Fig1 中的神经网络,如果想要生成最终的预测结果Y_hat。总共分为以下几个步骤:

假定这个神经网络是用于RegressionTask, 那么SSE Loss为

如果是一个Classification Task,那么Cross-entropy Loss (其中C表示数据集中类别的数量):


神经网络的反向传播(Backward Propagation)

当我们有了基于当前模型参数的前向传播的结果,我们需要利用LOSS和Y_hat进行反向传播,来更新模型内部的参数使LOSS更低。通常来说,为了更新参数,我们需要用到Gradient Descent这个方法,计算Loss和待更新参数之间的偏导数。比如如果你要更新W2,你需要计算dL/dW2。但是问题在于,在神经网络中,前项的参数往往和你的LOSS不能直接关联,也就是你不能直接计算他们的偏导数。比如你不能直接找到dL/dW1这个项。为此,我们需要利用反向传播,通过中间项找到他们之间的关系。

我们以dL/dW2这项为例。要找到这组关联偏导数,我们需要从Eq5 入手。在这个公式中,Y表示Ground Truth,Y_hat表示为模型的输出。而为了找到L和W2之间的关系,可以看到Y是个无关项,我们需要从Y_hat 入手。

而为了找到Y_hat 和W2之间的关系,我们来到Eq4。在这个公式中,delta表示激活函数,是一个固定的计算逻辑,并不和W2相关,所以他是无关项。而Z2的结果将受到W2的影响,所以这里面我们关注Z2

现在,我们通过Y_hat作为中间相,找到了L和Z2的关系。我们继续这个流程,来到Eq3。

可以看到在这个公式中Z2和W2是直接相关的。所以到此我们已经找到了L和W2的关系,他表示为:

对于Eq6这个公式,我们为了找到L和W2的关系,引入了两个中间项Y_hat (也就是A2)和Z2。然后分别通过中间项之间的关联及Z2和W2之间的关联,表示了最终的偏导数。在这个公式中,dL/dA2表示为Loss的偏导数。以Cross-Entropy Loss 为例,这一项的结果为:

而dA2/dZ2这一项,表示的位Activation的偏导数。常见的Activation有很多,比如Sigmoid,Tanh,和ReLU。 对于Sigmoid函数,其偏导表示为:

最后一项dZ2/dW2,回顾Eq3,这两项直接相关,也就是说:

到此,我们已经完成了对于Eq7中每一项的计算,求得了W2的导数。最后我们立刻用梯度下降的方法,更新W2的参数值:

同样的你也可以利用上面的逻辑更新其他参数比如W1:

当你有了这些偏导数之后,就可以利用Eq11中的公式去更新所有的参数了。


以上就是今天的 [五分钟机器学习] 神经网络的基本介绍 的主要内容了。

如果你觉得本期内容有所帮助,欢迎素质三连。

您的支持将是我继续发电的最大动力~

我是某某某

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表