网站首页 > 技术文章 正文
Keras是用于Python编程语言的神经网络库,能够与Theano,R或TensorFlow等许多深度学习工具一起运行,并允许快速迭代以进行神经网络的实验或原型设计。
无论您是在Keras中对神经网络模型进行原型设计以了解其将如何执行所需任务,还是对已构建和测试的模型进行微调,都需要为机器学习模型考虑许多参数。这些机器学习模型参数称为超参数。在层中使用的激活函数就是超参数的的示例。机器学习模型中的层数,每层神经元数或卷积神经网络中核的大小都可以视为超级参数。
超参数没有固定公式,不同的问题将需要不同的方法。更改模型的每个参数可能会影响其性能,只有实验才能确定哪种组合最适合您的模型和数据。
在本文中,我们将研究使用机器学习库Scikit-Learn执行超参数调整以优化Keras模型所需的步骤。我们将构建一个简单的神经网络,并使用Scikit-Learn库中的RandomizedSearchCV对象寻找最佳优化器、批量大小和激活。
准备工作
我们将在示例中使用的库是TensorFlow,其中包括Keras和Scikit Learn。
from tensorflow.keras import Sequential from tensorflow.keras.layers import Dense,Flatten from tensorflow.keras.datasets import mnist from tensorflow.keras.wrappers.scikit_learn import KerasClassifier from sklearn.model_selection import RandomizedSearchCV
我们还将使用numpy和matplotlib库:
import numpy as np import matplotlib.pyplot as plt
准备数据
首先,让我们使用一个数据集,对其进行格式化并构建我们的机器学习模型。在这里,我们进行归一化并打印其形状以确保我们为模型使用正确的输入:
(X_train, y_trn), (X_test, y_tst) = mnist.load_data() X_trn = X_train[..., np.newaxis].astype(np.float32) / 255. X_tst = X_test[..., np.newaxis].astype(np.float32) / 255. print(X_train.shape,y_trn.shape) print(X_test.shape,y_tst.shape)
mnist数据集是一组28x28像素的手写数字图片。
我们的数据如下所示:
def preview(data,result): """Shows 12 elements of picture dataset""" fig = plt.figure() for i in range(12): plt.subplot(2,6,i+1) plt.imshow(data[i], interpolation='none') plt.title("label:{}".format(result[i])) plt.xticks([]) plt.yticks([]) preview(X_train[12:],y_trn[12:])
建立模型
为了使用scikit-learn调整Keras模型的参数,我们需要能够使用不同的参数重建模型。为此,我们创建一个函数来基于我们的超参数构建模型:
def build_model(var_activation='relu',var_optimizer='adam'): """ Uses arguments to build Keras model. """ model = Sequential() model.add(Flatten(input_shape=[28, 28, 1])) model.add(Dense(64,activation=var_activation)) model.add(Dense(32,activation=var_activation)) model.add(Dense(16,activation=var_activation)) model.add(Dense(10,activation='softmax')) model.compile(loss="sparse_categorical_crossentropy", optimizer=var_optimizer, metrics=["accuracy"]) return model
这是我们的模型在默认参数下的样子:
model_default = build_model() model_default.summary()
设置变量
我们想使用Adam算法和随机梯度下降来测试模型的性能,并测试不同层的激活函数和批量大小来训练模型。让我们创建参数列表并将它们存储为字典。字典中的键是在我们的模型中使用的变量的名称:
_activations=['tanh','relu','selu'] _optimizers=['sgd','adam'] _batch_size=[16,32,64] params=dict(var_activation=_activations, var_optimizer=_optimizers, batch_size=_batch_size) print(params)
注意,' batch_size '不是build_model函数中的变量,而是.fit()调用中稍后将使用的变量,以训练我们创建的模型。
根据Keras模型创建scikit学习估算器
现在我们有了数据,构建模型的功能以及要测试的参数,我们可以使用sklearn库根据我们的函数和超参数测试不同的模型。我们可以使用sklearn.model_selection模块中的GridSearchCV或RandomizedSearchCV对象来迭代超参数的不同组合,并输出得分最高的模型。GridSearchCV对象将遍历超参数的所有可能组合,而RandomizedSearchCV对象将随机采样许多可能的组合以训练模型。尽管使用随机搜索可能并不总是提供最佳的可能模型,但由随机搜索要快得多,资源消耗也少得多,这使得随机模型搜索对于测试和原型设计非常有用。要使用RandomizedSearchCV,我们首先需要使我们的Keras模型与sklearn库兼容,我们将对scikitlearn使用Keras包装器:KerasClassifier。
model = KerasClassifier(build_fn=build_model,epochs=4,batch_size=16)
在拟合我们的RandomizedSearch对象之前,我们使用numpy.random.seed()设置随机种子。将种子设置为随机数生成器将使我们的模型权重初始化与每次迭代相同,从而使我们的搜索更有意义。但是,如果我们的超参数包含层数或层中节点数,则将无济于事,因为我们将比较完全不同的模型。
np.random.seed(42)
使用RandomizedSearchCV
创建KerasClassifier后,我们将创建RandomizedSearchCV对象,并使用.fit()方法开始搜索最佳模型。RandomizedSearchCV允许我们使用参数n_iter明确控制尝试的组合数量。
rscv = RandomizedSearchCV(model, param_distributions=params, cv=3, n_iter=10) rscv_results = rscv.fit(X_trn,y_trn)
这是我们搜索的结果:
print('Best score is: {} using {}'.format(rscv_results.best_score_, rscv_results.best_params_))
结论
超参数调优可用于微调所选模型,或搜索最适合该任务的模型。它还可以帮助评估模型的学习速度。上面的方法可以进一步扩展,包括使用来自scikit-learn库的GridSearchCV对象进行更详尽的搜索,或者为模型的结构添加参数,如层数。可以添加回调以防止过拟合测试模型。
猜你喜欢
- 2024-10-12 神经网络调试:梯度可视化 神经网络 梯度
- 2024-10-12 2021年4月下旬,百度机器学习/数据挖掘/NLP算法工程师面试8道题
- 2024-10-12 PyTorch 0.2发布:更多NumPy特性,高阶梯度等
- 2024-10-12 Tensorflow中的卷积神经网络 tensorflow 卷积神经网络
- 2024-10-12 深度学习中的激活函数总结 激活函数原理
- 2024-10-12 「周末AI课堂」SELU和ResNet(代码篇)机器学习你会遇到的“坑”
- 2024-10-12 想降低云服务的花销?或许深度强化学习能帮到你|论文
- 2024-10-12 SELU和ResNet(代码篇)|机器学习你会遇到的“坑”
- 2024-10-12 深度强化学习还能帮你省钱!这项研究要用RL控制云服务开销
- 2024-10-12 SELU和ResNet(理论篇)机器学习你会遇到的“坑”
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)