Source code for deepsphere.models.spherical_unet.encoder

"""Encoder for Spherical UNet.
"""
# pylint: disable=W0221
from torch import nn

from deepsphere.layers.chebyshev import SphericalChebConv
from deepsphere.models.spherical_unet.utils import SphericalChebBN, SphericalChebBNPool


[docs]class SphericalChebBN2(nn.Module): """Building Block made of 2 Building Blocks (convolution, batchnorm, activation). """ def __init__(self, in_channels, middle_channels, out_channels, lap, kernel_size=3): """Initialization. Args: in_channels (int): initial number of channels. middle_channels (int): middle number of channels. out_channels (int): output number of channels. lap (:obj:`torch.sparse.FloatTensor`): laplacian. kernel_size (int, optional): polynomial degree. Defaults to 3. """ super().__init__() self.spherical_cheb_bn_1 = SphericalChebBN(in_channels, middle_channels, lap, kernel_size) self.spherical_cheb_bn_2 = SphericalChebBN(middle_channels, out_channels, lap, kernel_size)
[docs] def forward(self, x): """Forward Pass. Args: x (:obj:`torch.Tensor`): input [batch x vertices x channels/features] Returns: :obj:`torch.Tensor`: output [batch x vertices x channels/features] """ x = self.spherical_cheb_bn_1(x) x = self.spherical_cheb_bn_2(x) return x
[docs]class SphericalChebPool(nn.Module): """Building Block with a pooling/unpooling and a Chebyshev Convolution. """ def __init__(self, in_channels, out_channels, lap, pooling, kernel_size=3): """Initialization. Args: in_channels (int): initial number of channels. out_channels (int): output number of channels. lap (:obj:`torch.sparse.FloatTensor`): laplacian. pooling (:obj:`torch.nn.Module`): pooling/unpooling module. kernel_size (int, optional): polynomial degree. Defaults to 3. """ super().__init__() self.pooling = pooling self.spherical_cheb = SphericalChebConv(in_channels, out_channels, lap, kernel_size)
[docs] def forward(self, x): """Forward Pass. Args: x (:obj:`torch.Tensor`): input [batch x vertices x channels/features] Returns: :obj:`torch.Tensor`: output [batch x vertices x channels/features] """ x = self.pooling(x) x = self.spherical_cheb(x) return x
[docs]class Encoder(nn.Module): """Encoder for the Spherical UNet. """ def __init__(self, pooling, laps): """Initialization. Args: pooling (:obj:`torch.nn.Module`): pooling layer. laps (list): List of laplacians. """ super().__init__() self.pooling = pooling self.enc_l5 = SphericalChebBN2(16, 32, 64, laps[5]) self.enc_l4 = SphericalChebBNPool(64, 128, laps[4], self.pooling) self.enc_l3 = SphericalChebBNPool(128, 256, laps[3], self.pooling) self.enc_l2 = SphericalChebBNPool(256, 512, laps[2], self.pooling) self.enc_l1 = SphericalChebBNPool(512, 512, laps[1], self.pooling) self.enc_l0 = SphericalChebPool(512, 512, laps[0], self.pooling)
[docs] def forward(self, x): """Forward Pass. Args: x (:obj:`torch.Tensor`): input [batch x vertices x channels/features] Returns: x_enc* :obj: `torch.Tensor`: output [batch x vertices x channels/features] """ x_enc5 = self.enc_l5(x) x_enc4 = self.enc_l4(x_enc5) x_enc3 = self.enc_l3(x_enc4) x_enc2 = self.enc_l2(x_enc3) x_enc1 = self.enc_l1(x_enc2) x_enc0 = self.enc_l0(x_enc1) return x_enc0, x_enc1, x_enc2, x_enc3, x_enc4