网站首页 > 技术文章 正文
机器之心报道
参与:Racoon
这里有一个简单但又不失灵活性的开源 GNN 库推荐给你。
Spektral 是一个基于 Keras API 和 TensorFlow 2,用于图深度学习的开源 Python 库。该项目的主要目的是提供一个简单但又不失灵活性的图神经网络(graph neural networks,GNNs) 框架。
我们可以使用 Spektral 来进行网络节点分类、预测分子特性、使用 GAN 生成新的拓扑图、节点聚类、预测链接以及其他任意数据是使用拓扑图来描述的任务。
Spektral 中实现了多种目前经典的图深度学习层:
Graph Convolutional Networks (GCN)
Chebyshev networks (ChebNets)
GraphSAGE
ARMA convolutions
Edge-Conditioned Convolutions (ECC)
Graph attention networks (GAT)
Approximated Personalized Propagation of Neural Predictions (APPNP)
Graph Isomorphism Networks (GIN)
也包含如下多种池化层:
DiffPool
MinCUT pooling
Top-K pooling
Self-Attention Graph (SAG) pooling
Global sum, average, and max pooling
Global gated attention pooling
项目地址:https://github.com/danielegrattarola/spektral/
效果展示
我们先来看一下这个项目的效果怎么样。以下是使用 Spektral 编写的图神经网络在 MNIST 数据集上的训练结果:
验证结果如下:
我们将网络权重可视化后,可得到下面这样的效果:
下图展示了两个图卷积层的可视化效果。我们可以此来观察图神经网络是否能够学习到,与传统卷积神经网络类似的特征。
上手实测
Spektral 是依据 Keras API 的指导准则设计的,为的是对初学者友好的同时为专家及研究人员提供较好的灵活性。layers.convolutional 和 layers.pooling 是 Spektral 中最重要的两个模块,里面提供了多种用于构建 GNN 的经典网络层。由于 Spektral 是作为 Keras 的一个扩展被设计出的,这使得我们能够将任意一个 Spektral 层加入现有的 Keras 模型中,而不用进行任何更改。
安装方法
Spektral 支持 Python 3.5 及以上的版本,并在 Ubuntu 16.04+与 MacOS 上进行了测试,暂时不支持 Windows(抱歉了)。这里我们以 Ubuntu 为例,安装相关依赖项:
sudo apt install graphviz libgraphviz-dev libcgraph6
安装 Spektral 最简单的方式是通过 PyPi 来进行安装:
pip install spektral
使用如下命令从源安装 Spektral:
git clone https://github.com/danielegrattarola/spektral.gitcd spektralpython setup.py install # Or 'pip install .‘
机器之心友情提示:由于 TensorFlow API 的变化是个迷,推荐使用 TensorFlow 2.2 版本,并从源安装 Spektral(不要问我是怎么知道的)。
使用 GNN 处理 Cora 数据集中的分类问题
我们以 2017 年的那篇 ICLR 论文「SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS」中所提出的图卷积神经网络(Graph Convolutional Network,GCN) 为例,为大家介绍如何使用 Spektral 简单、快速地搭建并训练图神经网络。
这里对 GCN 的训练问题属于转导推理(transductive learning),即在训练时将所有节点与边用作输入,但其中仅有一部分输入带有标签。训练的目标是让网络能够预测那些没有标签的样本。下图表示 GCN 的示意图。
左图为多层 GCN 示意图。右图为使用 t-SNE 对一个两层 GCN 中隐含层激活的可视化结果。
我们使用 Cora 数据集对 GCN 进行训练,该数据集由 7 个类别的机器学习领域论文构成,分别是:
Case_Based
Genetic_Algorithms
Neural_Networks
Probabilistic_Methods
Reinforcement_Learning
Rule_Learning
Theory
Cora 数据集总共包含 2708 篇论文,其中每篇论文至少引用了该数据集中另外一篇论文,或者被其他论文所引用。在消除停词以及除去文档频率小于 10 的词汇后,最终词汇表中共有 1433 个词汇。
使用 Spektral 中的 datasets.citation 模块,让我们能够方便地下载并读取如:Cora、Citeseer 和 Pubmed 这类引文数据集。以下代码展示了如何读取 Cora 数据集:
from spektral.datasets *import* citationA, X, y, train_mask, val_mask, test_mask = citation.load_data('cora')
N = A.shape[0]F = X.shape[-1]n_classes = y.shape[-1]
其中 A 为形状为 (N, N) 的网络邻接矩阵,X 为形状为 (N, F) 的节点特征,y 表示形状为 (N, n_classes) 的标签。
搭建 GNN
这里我们使用 GraphConv 网络层以及其他一些 Keras 的 API 来搭建 GCN:
from spektral.layers import GraphConvfrom tensorflow.keras.models import Modelfrom tensorflow.keras.layers import Input, Dropout
搭建 GCN 的方式与搭建其他 Keras 模型没有任何区别,只是需要注意 GraphConv 层的输入为 X 与 A:
X_in = Input(shape=(F, ))A_in = Input((N, ), sparse=*True*)
X_1 = GraphConv(16, 'relu')([X_in, A_in])X_1 = Dropout(0.5)(X_1)X_2 = GraphConv(n_classes, 'softmax')([X_1, A_in])
model = Model(inputs=[X_in, A_in], outputs=X_2)
至此,我们已经完成了 GCN 的搭建,是不是非常简单呢?
训练 GNN
在训练 GCN 之前,我们首先需要对邻接矩阵进行预处理,preprocess() 这一静态类方法提供了每一层需要的预处理方法。在这一 GCN 的例子中,我们使用如下方法进行预处理:
A = GraphConv.preprocess(A).astype('f4')
至此全部准备工作就绪,使用如下代码对模型进行编译:
model.compile(optimizer='adam', loss='categorical_crossentropy', weighted_metrics=['acc'])model.summary()
输出如下:
接下来我们就可以使用 Keras 中提供的 fit() 方法来训练模型了:
# Prepare dataX = X.toarray()A = A.astype('f4')validation_data = ([X, A], y, val_mask)# Train modelmodel.fit([X, A], y, epochs=200, sample_weight=train_mask, validation_data=validation_data, batch_size=N,shuffle=False)
以下为训练过程的输出:
验证模型
同样地,我们可以便捷地使用 Keras 中提供的方法对模型进行验证:
# Evaluate modeleval_results = model.evaluate([X, A], y, sample_weight=test_mask, batch_size=N)print('Done.\n''Test loss: {}\n''Test accuracy: {}'.format(eval_results))
结果如下:
下图为 GCN 论文中的分类结果:
可以看到论文中 GCN 在 Cora 数据集中的分类准确率为 81.5%,而我们训练的模型准确率为 74.9%。机器之心实测经过一些简单的超参数调整(如增加 epoch),几乎能达到与论文中一样的准确率,感兴趣的读者可自行测试一番。
参考连接:
https://www.kaggle.com/kmader/mnist-graph-deep-learning
https://zhuanlan.zhihu.com/p/78452993
猜你喜欢
- 2024-10-12 「AAAI oral」阿里北大提出新attention建模框架
- 2024-10-12 CVPR 2020 | 港中文、上交大、商汤联合提出两种轨迹预测新方法
- 2024-10-12 东北石油大学研究者提出电能质量扰动识别的新方法
- 2024-10-12 「独家解读」谷歌会议app背景模糊和替换技术解析
- 2024-10-12 如何在深度学习模型内部做特征选择?
- 2024-10-12 深度时空网络、记忆网络与特征表达学习在 CTR 预估中的应用
- 2024-10-12 揭秘 BERT 火爆背后的功臣——Attention
- 2024-10-12 MViT:性能杠杠的多尺度ViT | ICCV 2021
- 2024-10-12 CTR预估系列(5)–阿里Deep Interest Network理论
- 2024-10-12 「论文阅读」 Residual Attention: Multi-Label Recognition
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)