Administrator
发布于 2022-10-10 / 7 阅读
0

Self-Attention Graph Pooling(ICML2019)

Abstract

将深度学习的框架迁移到结构化的数据近期也是个热点,近期的一些学习都在将卷积+pool迁移到非结构化的数据中【模仿CNN,目前LSTM,GRU还没有人模仿】,将卷积操作迁移到Graph中已经被证明是有效的,但是在Graph中downsampling依旧是一个挑战性的问题。在这篇论文,我们提出了一种基于self-attention的graph pool方法,我们的pool方法包括node featur/graph topology(拓扑)两个特征,为了公平,我们使用一样的训练过程以及模型结构,实验结果证明我们的方法在Graph classification中表现的好。

Introduction

深度学习的方法在数据的识别和增强等方面有突飞猛进。特别,CNNs成功地挖掘了数据的特征,例如 欧几里得空间的images,speech,video。CNNs包括卷积层和下采样层,卷积和池化操作都有 shift-invariance特性(平移不变性)。因此CNNs只需要很小的参数就可以获得较好的结果。

在很多领域,然而,大量的数据,都是以非欧数据的形式储存的,例如 Social Network,Biological network,Molecular structure都是通过Graph中的Node以及Edge的形式表示的,因此很多人尝试将CNN迁移到非欧空间数据上。

在Graph Pool领域,方法远远少于Graph Convolution,之前的方法只考虑了Graph topology,还有其他的方法,希望获得一个更小点的Graph表示,最近,也有一些方法希望学习Graph的结构信息,这些方法允许GNNs用一种End2End的方法。然而,这些池化方法都还有提升空间,譬如需要立方级别的存储复杂度。

由此我们提出了一种 SAGPool模型,是一种 Self-Attention Graph Pooling method,我们的方法可以用一种End2End的方式学习结构层次信息,Self-attention结构可以区分哪些节点应该丢弃,哪些应该保留。因为Self-attention结构使用了Graph convolution来计算attention分数,Node features以及Graph topology都被考虑进去,简而言之,SAGPool继承了之前模型的优点,也是第一个将self-attention 加入Graph pooling中,实现了较高的准确度。

Proposed Method

Self-Attention Graph Pooling

SAGPool的key points是使用GNN来计算self-attention scores。

其实SAGPool结构和Graph U-Nets(ICML 2019)中使用的gpool【gPool使用一个可学习的向量 p 来计算projection scores,然后使用这个scores来选择top ranked nodes。Projection scores通过向量p以及所有节点的特征的dot product。这个scores表明节点的信息保留程度】结构几乎一样,两者的差别在于,SAGPool使用GNN来计算scores,与gpool仅利用node feature相比,SAGPool也将Graph topology融入模型之中。其模型如下图所示

通过初始化Graph convolution来获得self-attention scores,pooling的结果是基于graph features和topology结构。即使Graph很小【Node很少】依旧可以保持部分节点。Pooling比例 k∈(0,1] 是一个保持节点数量的超参数,根据Z的值来选择前[kN]个节点。

Model Architecture

借鉴了JK-net结构,使用一个readout layer用节点特征生成一个固定大小的表示,使用readout layer的输出如下所示:

其中N是节点数量, xi 是节点i的特征向量,||代表连接操作。

上面的结构图包括俩结构,从左到右分别用poolg和poolh代表

Global pooling结构包含三个卷积层,将它们的输出连接起来。Node features在readout layer+pooling layer之下流动,Graph feature representions之后传输到线形层做分类。

Hierarchical pooling architecture 在这个设置下,如Fig 2所示那样,做一次卷积,做一次pooling,最后将三次pooling的结果加起来使用MLP来分类。

Experimental

这边需要注意的是当前代码使用的gpool来计算节点重要性的对比实验

gPool 选择top-ranked 节点,这与我们的方法是类似的,区别在于我们的模型考虑了图的拓扑结构,这有助于提高在分类任务上的结果。

针对poolg和poolh很难说哪种结构更好,POOLg最小化了信息的loss,在节点数较少的数据集【NCI1,NCI109,FRANKENSTEIN】上表现更好,然而,POOLh在节点数多的数据集上表现更好【D&D,PROTEINS】,因为POOLh在提取large scale graphs信息方面更拿手。不过,SAGPool比其他的模型更好。