计算机系统应用教程网站

网站首页 > 技术文章 正文

(8) 超越记忆:长短期记忆网络(LSTM)

btikc 2024-11-18 09:08:52 技术文章 27 ℃ 0 评论

简介

你还记得我们上一篇文章中提到的循环神经网络(RNN)吗?(回顾一下:(7) 记忆世界:循环神经网络(RNN))它们非常擅长处理序列数据,但是当我们需要处理的信息跨度很大时,RNN就显得力不从心了。这时候,长短期记忆网络(LSTM)就派上用场了。

LSTM是RNN的一种,它的特点是能够学习长期依赖信息。LSTM由Hochreiter和Schmidhuber在1997年首次提出,并在后续的工作中得到了许多人的改进和推广。LSTM在各种问题上都表现得非常好,现在已经被广泛使用。

LSTM的工作原理

LSTM的关键是细胞状态,这是一种在网络中流动的信息。LSTM有能力向细胞状态中添加或删除信息,这个过程是由所谓的门控制的。

门是一种让信息有选择地通过的方式。它由一个sigmoid神经网络层和一个逐点乘法操作组成。sigmoid层输出0到1之间的数值,描述了每个组件应该让多少信息通过。0表示“不让任何信息通过”,1表示“让所有信息通过”。

LSTM有三个这样的门,用来保护和控制细胞状态。

遗忘门

第一步是决定我们要从细胞状态中丢弃什么信息。这个决定由一个名为“遗忘门”的sigmoid层来做。它查看上一个隐藏状态和当前输入,然后对细胞状态进行逐元素的乘法。这样,如果遗忘门的输出是1,那么我们就保留那个元素。如果遗忘门的输出是0,那么我们就忘记那个元素。

输入门

下一步是决定我们要在细胞状态中存储什么新信息。这也是由一个sigmoid层来完成的,我们称之为“输入门”。同时,一个tanh层会创建新的候选值向量,这些值可能会被添加到状态中。

更新细胞状态

现在是时候更新旧的细胞状态了。我们将旧状态与遗忘门的输出相乘,丢弃我们决定忘记的任何信息。然后,我们添加输入门的输出与新的候选值的乘积。这就是我们新的细胞状态。

输出门

最后,我们需要决定输出什么。这个输出将会基于我们的细胞状态,但是会是一个过滤后的版本。我们首先运行一个sigmoid层来确定我们将输出细胞状态的哪些部分,然后我们将细胞状态通过tanh函数(将值压缩到-1和1之间)并乘以sigmoid门的输出,这样我们只输出我们决定输出的部分。

LSTM的代码实现

在PyTorch中,我们可以很容易地创建一个LSTM网络。下面是一个简单的例子:

在这个例子中,我们首先定义了一个LSTM网络,它有一个LSTM层和一个线性输出层。在前向传播函数中,我们将输入数据reshape成LSTM需要的格式,然后通过LSTM层和输出层得到最终的输出。

这只是一个非常基础的LSTM网络,实际使用中,我们可能会根据需要添加更多的层或者使用更复杂的结构。

基于LSTM的时间序列实例

在这个例子中,我们将使用PyTorch库和LSTM(长短期记忆)网络来预测时间序列数据。我们的目标是预测未来的乘客数量。

首先,我们需要导入必要的库,并加载我们的数据集。在这个例子中,我们将使用Python的Seaborn库内置的飞行数据集。

然后,我们需要对数据进行预处理。这包括将乘客数量列的类型更改为浮点数,将数据集划分为训练集和测试集,以及对数据进行归一化处理。

接下来,我们需要定义我们的LSTM模型。在这个模型中,我们将使用一个隐藏层,该层包含100个神经元。

然后,我们可以开始训练我们的模型。

最后,我们可以使用训练好的模型进行预测,并将预测结果与实际结果进行比较。

这个例子中的代码是基于LSTM的时间序列预测的完整实现。现在,让我们详细解释一下这个过程。

首先,我们导入了必要的库,并加载了数据集。这个数据集是一个包含了每月乘客数量的时间序列数据。

然后,我们对数据进行了预处理。这包括将乘客数量列的类型更改为浮点数,将数据集划分为训练集和测试集,以及对数据进行归一化处理。归一化是一个重要的步骤,因为它可以帮助我们的模型更好地学习和预测数据。

接下来,我们定义了我们的LSTM模型。这个模型有一个隐藏层,包含100个神经元。在模型的前向传播过程中,我们将输入序列传递给LSTM层,然后将LSTM层的输出传递给线性层,最后返回线性层的输出。

然后,我们开始训练我们的模型。我们使用了均方误差损失函数和Adam优化器。在每个训练周期中,我们都会遍历训练数据,将序列和对应的标签传递给模型,然后计算预测值和真实值之间的损失。我们使用这个损失来更新模型的参数。

最后,我们使用训练好的模型进行预测。我们首先从训练数据中获取最后12个月的数据,然后使用这些数据作为输入,让模型预测下一个月的乘客数量。我们将预测的结果添加到输入数据中,然后再次进行预测,这样循环12次,得到未来12个月的乘客数量预测。最后,我们将预测的结果从归一化的形式转换回原始的乘客数量。

这个例子展示了如何使用LSTM进行时间序列预测。虽然这个例子中的预测结果可能并不完全准确,但是它成功地捕捉到了数据的整体趋势。

总结

LSTM是一种强大的工具,它解决了RNN在处理长期依赖问题时的困难。通过使用门控制信息的流动,LSTM能够在保持长期信息的同时,忘记不再需要的信息。这使得LSTM在许多需要处理序列数据的任务中,比如语音识别、语言模型、翻译等,都取得了非常好的效果。

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表