当前位置 博文首页 > 韦全敏的博客:图神经网络框架DGL教程-第3章:构建图神经网络(G

    韦全敏的博客:图神经网络框架DGL教程-第3章:构建图神经网络(G

    作者:[db:作者] 时间:2021-07-08 15:38

    更多图神经网络和深度学习内容请关注:
    在这里插入图片描述

    第3章:构建图神经网络(GNN)模块

    DGL NN模块是用户构建GNN模型的基本模块。根据DGL所使用的后端深度神经网络框架, DGL NN模块的父类取决于后端所使用的深度神经网络框架。对于PyTorch后端, 它应该继承 PyTorch的NN模块;对于MXNet后端,它应该继承 MXNet Gluon的NN块; 对于TensorFlow后端,它应该继承 Tensorflow的Keras层。 在DGL NN模块中,构造函数中的参数注册和前向传播函数中使用的张量操作与后端框架一样。这种方式使得DGL的代码可以无缝嵌入到后端框架的代码中。 DGL和这些深度神经网络框架的主要差异是其独有的消息传递操作。

    DGL已经集成了很多常用的 Conv Layers、 Dense Conv Layers、 Global Pooling Layers 和 Utility Modules。欢迎给DGL贡献更多的模块!

    本章将使用PyTorch作为后端,用 SAGEConv 作为例子来介绍如何构建用户自己的DGL NN模块。

    3.1 DGL NN模块的构造函数

    构造函数__init__完成以下几个任务:

    • 设置选项。
    • 注册可学习的参数或者子模块。
    • 初始化参数。
    import torch.nn as nn
    
    from dgl.utils import expand_as_pair
    
    class SAGEConv(nn.Module):
        def __init__(self,
                     in_feats,
                     out_feats,
                     aggregator_type,
                     bias=True,
                     norm=None,
                     activation=None):
            super(SAGEConv, self).__init__()
    
            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)#函数可以返回一个二维元组。
            self._out_feats = out_feats
            self._aggre_type = aggregator_type
            self.norm = norm
            self.activation = activation
    
    Using backend: pytorch
    

    在构造函数中,用户首先需要设置数据的维度。对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。 对于图神经网络输入维度可被分为源节点特征维度和目标节点特征维度

    除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。 常用的聚合类型包括 meansummaxmin。一些模块可能会使用更加复杂的聚合函数,比如 lstm

    上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化: h v = h v / ∥ h v ∥ 2 h_v = h_v / \lVert h_v \rVert_2 hv?=hv?/hv?2?

    # 聚合类型:mean、max_pool、lstm、gcn
    if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
        raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
    if aggregator_type == 'max_pool':
        self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
    if aggregator_type == 'lstm':
        self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
    if aggregator_type in ['mean', 'max_pool', 'lstm']:
        self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
    self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
    self.reset_parameters()
    

    注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linearnn.LSTM 等。 构造函数的最后调用了 reset_parameters() 进行权重初始化。

    def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'max_pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
    

    完整代码

    import torch.nn as nn
    
    from dgl.utils import expand_as_pair
    
    class SAGEConv(nn.Module):
        def __init__(self,
                     in_feats,
                     out_feats,
                     aggregator_type,
                     bias=True,
                     norm=None,
                     activation=None):
            super(SAGEConv, self).__init__()
    
            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
            self._out_feats = out_feats
            self._aggre_type = aggregator_type
            self.norm = norm
            self.activation = activation
            
            # 聚合类型:mean、max_pool、lstm、gcn
            if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
                raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
            if aggregator_type == 'max_pool':
                self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
            if aggregator_type == 'lstm':
                self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
            if aggregator_type in ['mean', 'max_pool', 'lstm']:
                self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
            self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
            self.reset_parameters()
        def reset_parameters(self):
            """重新初始化可学习的参数"""
            gain = nn.init.calculate_gain('relu')
            if self._aggre_type == 'max_pool':
                nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
            if self._aggre_type == 'lstm':
                self.lstm.reset_parameters()
            if self._aggre_type != 'gcn':
                nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
    

    3.2 编写DGL NN模块的forward函数

    在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比, DGL NN模块额外增加了1个参数 dgl.DGLGraphforward() 函数的内容一般可以分为3项操作:

    • 检测输入图对象是否符合规范。
    • 消息传递和聚合。
    • 聚合后,更新特征作为输出。

    下文展示了SAGEConv示例中的 forward() 函数。

    输入图对象的规范检测

    def forward(self, graph, feat):
        with graph.local_scope():
            # 指定图类型,然后根据图类型扩展输入特征
            feat_src, feat_dst = expand_as_pair(feat, graph)
    

    graph.local_scope():限定语句块内为局部作用域,对数据特征的操作不影响原始图的特征,常用于forward方法中,
    用法:

    def foo(g):
        with g.local_scope():
            g.edata['h'] = torch.ones((g.num_edges(), 3))
            g.edata['h2'] = torch.ones((g.num_edges(), 3))
            return g.edata['h']
    

    in-place操作会影响原始图数据
    如:

    def foo(g):
        with g.local_scope():
            # in-place operation
            g.edata['h'] += 1
            return g.edata['h']
    

    参考dgl API

    forward() 函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。 比如在 GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。 当1个节点入度为0时, mailbox 将为空,并且聚合函数的输出值全为0, 这可能会导致模型性能不佳。但是,在 SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来, forward() 函数的输出不会全为0。在这种情况下,无需进行此类检验。

    DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图(1.5 异构图)和子图块(第6章:在大图上的随机(批次)训练)。

    SAGEConv的数学公式如下:

    h N ( d s t ) ( l + 1 ) = a g g r e g a t e ( { h s r c l , ? s r c ∈ N ( d s t ) } ) h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate} \left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right)