Source code for deepsphere.layers.chebyshev

"""Chebyshev convolution layer. For the moment taking as-is from Michaƫl Defferrard's implementation. For v0.15 we will rewrite parts of this layer.
"""
# pylint: disable=W0221

import math

import torch
from torch import nn


[docs]def cheb_conv(laplacian, inputs, weight): """Chebyshev convolution. Args: laplacian (:obj:`torch.sparse.Tensor`): The laplacian corresponding to the current sampling of the sphere. inputs (:obj:`torch.Tensor`): The current input data being forwarded. weight (:obj:`torch.Tensor`): The weights of the current layer. Returns: :obj:`torch.Tensor`: Inputs after applying Chebyshev convolution. """ B, V, Fin = inputs.shape K, Fin, Fout = weight.shape # B = batch size # V = nb vertices # Fin = nb input features # Fout = nb output features # K = order of Chebyshev polynomials # transform to Chebyshev basis x0 = inputs.permute(1, 2, 0).contiguous() # V x Fin x B x0 = x0.view([V, Fin * B]) # V x Fin*B inputs = x0.unsqueeze(0) # 1 x V x Fin*B if K > 0: x1 = torch.sparse.mm(laplacian, x0) # V x Fin*B inputs = torch.cat((inputs, x1.unsqueeze(0)), 0) # 2 x V x Fin*B for _ in range(1, K - 1): x2 = 2 * torch.sparse.mm(laplacian, x1) - x0 inputs = torch.cat((inputs, x2.unsqueeze(0)), 0) # M x Fin*B x0, x1 = x1, x2 inputs = inputs.view([K, V, Fin, B]) # K x V x Fin x B inputs = inputs.permute(3, 1, 2, 0).contiguous() # B x V x Fin x K inputs = inputs.view([B * V, Fin * K]) # B*V x Fin*K # Linearly compose Fin features to get Fout features weight = weight.view(Fin * K, Fout) inputs = inputs.matmul(weight) # B*V x Fout inputs = inputs.view([B, V, Fout]) # B x V x Fout return inputs
[docs]class ChebConv(torch.nn.Module): """Graph convolutional layer. """ def __init__(self, in_channels, out_channels, kernel_size, bias=True, conv=cheb_conv): """Initialize the Chebyshev layer. Args: in_channels (int): Number of channels/features in the input graph. out_channels (int): Number of channels/features in the output graph. kernel_size (int): Number of trainable parameters per filter, which is also the size of the convolutional kernel. The order of the Chebyshev polynomials is kernel_size - 1. bias (bool): Whether to add a bias term. conv (callable): Function which will perform the actual convolution. """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self._conv = conv shape = (kernel_size, in_channels, out_channels) self.weight = torch.nn.Parameter(torch.Tensor(*shape)) if bias: self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) else: self.register_parameter("bias", None) self.kaiming_initialization()
[docs] def kaiming_initialization(self): """Initialize weights and bias. """ std = math.sqrt(2 / (self.in_channels * self.kernel_size)) self.weight.data.normal_(0, std) if self.bias is not None: self.bias.data.fill_(0.01)
[docs] def forward(self, laplacian, inputs): """Forward graph convolution. Args: laplacian (:obj:`torch.sparse.Tensor`): The laplacian corresponding to the current sampling of the sphere. inputs (:obj:`torch.Tensor`): The current input data being forwarded. Returns: :obj:`torch.Tensor`: The convoluted inputs. """ outputs = self._conv(laplacian, inputs, self.weight) if self.bias is not None: outputs += self.bias return outputs
[docs]class SphericalChebConv(nn.Module): """Building Block with a Chebyshev Convolution. """ def __init__(self, in_channels, out_channels, lap, kernel_size): """Initialization. Args: in_channels (int): initial number of channels out_channels (int): output number of channels lap (:obj:`torch.sparse.FloatTensor`): laplacian kernel_size (int): polynomial degree. Defaults to 3. """ super().__init__() self.register_buffer("laplacian", lap) self.chebconv = ChebConv(in_channels, out_channels, kernel_size)
[docs] def state_dict(self, *args, **kwargs): """! WARNING ! This function overrides the state dict in order to be able to save the model. This can be removed as soon as saving sparse matrices has been added to Pytorch. """ state_dict = super().state_dict(*args, **kwargs) del_keys = [] for key in state_dict: if key.endswith("laplacian"): del_keys.append(key) for key in del_keys: del state_dict[key] return state_dict
[docs] def forward(self, x): """Forward pass. Args: x (:obj:`torch.tensor`): input [batch x vertices x channels/features] Returns: :obj:`torch.tensor`: output [batch x vertices x channels/features] """ x = self.chebconv(self.laplacian, x) return x