计算机系统应用教程网站

网站首页 > 技术文章 正文

使用TF2与Keras实现经典GNN的开源库——Spektral

btikc 2024-10-12 11:00:06 技术文章 3 ℃ 0 评论

机器之心报道

参与:Racoon

这里有一个简单但又不失灵活性的开源 GNN 库推荐给你。

Spektral 是一个基于 Keras API 和 TensorFlow 2,用于图深度学习的开源 Python 库。该项目的主要目的是提供一个简单但又不失灵活性的图神经网络(graph neural networks,GNNs) 框架。

我们可以使用 Spektral 来进行网络节点分类、预测分子特性、使用 GAN 生成新的拓扑图、节点聚类、预测链接以及其他任意数据是使用拓扑图来描述的任务。

Spektral 中实现了多种目前经典的图深度学习层:

Graph Convolutional Networks (GCN)

Chebyshev networks (ChebNets)

GraphSAGE

ARMA convolutions

Edge-Conditioned Convolutions (ECC)

Graph attention networks (GAT)

Approximated Personalized Propagation of Neural Predictions (APPNP)

Graph Isomorphism Networks (GIN)

也包含如下多种池化层:

DiffPool

MinCUT pooling

Top-K pooling

Self-Attention Graph (SAG) pooling

Global sum, average, and max pooling

Global gated attention pooling

项目地址:https://github.com/danielegrattarola/spektral/

效果展示

我们先来看一下这个项目的效果怎么样。以下是使用 Spektral 编写的图神经网络在 MNIST 数据集上的训练结果:

验证结果如下:

我们将网络权重可视化后,可得到下面这样的效果:

下图展示了两个图卷积层的可视化效果。我们可以此来观察图神经网络是否能够学习到,与传统卷积神经网络类似的特征。

上手实测

Spektral 是依据 Keras API 的指导准则设计的,为的是对初学者友好的同时为专家及研究人员提供较好的灵活性。layers.convolutional 和 layers.pooling 是 Spektral 中最重要的两个模块,里面提供了多种用于构建 GNN 的经典网络层。由于 Spektral 是作为 Keras 的一个扩展被设计出的,这使得我们能够将任意一个 Spektral 层加入现有的 Keras 模型中,而不用进行任何更改。

安装方法

Spektral 支持 Python 3.5 及以上的版本,并在 Ubuntu 16.04+与 MacOS 上进行了测试,暂时不支持 Windows(抱歉了)。这里我们以 Ubuntu 为例,安装相关依赖项:

sudo apt install graphviz libgraphviz-dev libcgraph6

安装 Spektral 最简单的方式是通过 PyPi 来进行安装:

pip install spektral

使用如下命令从源安装 Spektral:

git clone https://github.com/danielegrattarola/spektral.gitcd spektralpython setup.py install # Or 'pip install .‘

机器之心友情提示:由于 TensorFlow API 的变化是个迷,推荐使用 TensorFlow 2.2 版本,并从源安装 Spektral(不要问我是怎么知道的)。

使用 GNN 处理 Cora 数据集中的分类问题

我们以 2017 年的那篇 ICLR 论文「SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS」中所提出的图卷积神经网络(Graph Convolutional Network,GCN) 为例,为大家介绍如何使用 Spektral 简单、快速地搭建并训练图神经网络。

这里对 GCN 的训练问题属于转导推理(transductive learning),即在训练时将所有节点与边用作输入,但其中仅有一部分输入带有标签。训练的目标是让网络能够预测那些没有标签的样本。下图表示 GCN 的示意图。

左图为多层 GCN 示意图。右图为使用 t-SNE 对一个两层 GCN 中隐含层激活的可视化结果。

我们使用 Cora 数据集对 GCN 进行训练,该数据集由 7 个类别的机器学习领域论文构成,分别是:

Case_Based

Genetic_Algorithms

Neural_Networks

Probabilistic_Methods

Reinforcement_Learning

Rule_Learning

Theory

Cora 数据集总共包含 2708 篇论文,其中每篇论文至少引用了该数据集中另外一篇论文,或者被其他论文所引用。在消除停词以及除去文档频率小于 10 的词汇后,最终词汇表中共有 1433 个词汇。

使用 Spektral 中的 datasets.citation 模块,让我们能够方便地下载并读取如:Cora、Citeseer 和 Pubmed 这类引文数据集。以下代码展示了如何读取 Cora 数据集:

from spektral.datasets *import* citationA, X, y, train_mask, val_mask, test_mask = citation.load_data('cora')

N = A.shape[0]F = X.shape[-1]n_classes = y.shape[-1]

其中 A 为形状为 (N, N) 的网络邻接矩阵,X 为形状为 (N, F) 的节点特征,y 表示形状为 (N, n_classes) 的标签。

搭建 GNN

这里我们使用 GraphConv 网络层以及其他一些 Keras 的 API 来搭建 GCN:

from spektral.layers import GraphConvfrom tensorflow.keras.models import Modelfrom tensorflow.keras.layers import Input, Dropout

搭建 GCN 的方式与搭建其他 Keras 模型没有任何区别,只是需要注意 GraphConv 层的输入为 X 与 A:

X_in = Input(shape=(F, ))A_in = Input((N, ), sparse=*True*)

X_1 = GraphConv(16, 'relu')([X_in, A_in])X_1 = Dropout(0.5)(X_1)X_2 = GraphConv(n_classes, 'softmax')([X_1, A_in])

model = Model(inputs=[X_in, A_in], outputs=X_2)

至此,我们已经完成了 GCN 的搭建,是不是非常简单呢?

训练 GNN

在训练 GCN 之前,我们首先需要对邻接矩阵进行预处理,preprocess() 这一静态类方法提供了每一层需要的预处理方法。在这一 GCN 的例子中,我们使用如下方法进行预处理:

A = GraphConv.preprocess(A).astype('f4')

至此全部准备工作就绪,使用如下代码对模型进行编译:

model.compile(optimizer='adam', loss='categorical_crossentropy', weighted_metrics=['acc'])model.summary()

输出如下:

接下来我们就可以使用 Keras 中提供的 fit() 方法来训练模型了:

# Prepare dataX = X.toarray()A = A.astype('f4')validation_data = ([X, A], y, val_mask)# Train modelmodel.fit([X, A], y, epochs=200, sample_weight=train_mask, validation_data=validation_data, batch_size=N,shuffle=False)

以下为训练过程的输出:

验证模型

同样地,我们可以便捷地使用 Keras 中提供的方法对模型进行验证:

# Evaluate modeleval_results = model.evaluate([X, A], y, sample_weight=test_mask, batch_size=N)print('Done.\n''Test loss: {}\n''Test accuracy: {}'.format(eval_results))

结果如下:

下图为 GCN 论文中的分类结果:

可以看到论文中 GCN 在 Cora 数据集中的分类准确率为 81.5%,而我们训练的模型准确率为 74.9%。机器之心实测经过一些简单的超参数调整(如增加 epoch),几乎能达到与论文中一样的准确率,感兴趣的读者可自行测试一番。

参考连接:

https://www.kaggle.com/kmader/mnist-graph-deep-learning

https://zhuanlan.zhihu.com/p/78452993

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表