Source code for deepsphere.models.spherical_unet.decoder

"""Decoder for Spherical UNet.
"""
# pylint: disable=W0221

import torch
from torch import nn

from deepsphere.layers.chebyshev import SphericalChebConv
from deepsphere.models.spherical_unet.utils import SphericalChebBN, SphericalChebBNPool


[docs]class SphericalChebBNPoolCheb(nn.Module): """Building Block calling a SphericalChebBNPool block then a SphericalCheb. """ def __init__(self, in_channels, middle_channels, out_channels, lap, pooling, kernel_size=3): """Initialization. Args: in_channels (int): initial number of channels. middle_channels (int): middle number of channels. out_channels (int): output number of channels. lap (:obj:`torch.sparse.FloatTensor`): laplacian. pooling (:obj:`torch.nn.Module`): pooling/unpooling module. kernel_size (int, optional): polynomial degree. Defaults to 3. """ super().__init__() self.spherical_cheb_bn_pool = SphericalChebBNPool(in_channels, middle_channels, lap, pooling, kernel_size) self.spherical_cheb = SphericalChebConv(middle_channels, out_channels, lap, kernel_size)
[docs] def forward(self, x): """Forward Pass. Args: x (:obj:`torch.Tensor`): input [batch x vertices x channels/features] Returns: :obj:`torch.Tensor`: output [batch x vertices x channels/features] """ x = self.spherical_cheb_bn_pool(x) x = self.spherical_cheb(x) return x
[docs]class SphericalChebBNPoolConcat(nn.Module): """Building Block calling a SphericalChebBNPool Block then concatenating the output with another tensor and calling a SphericalChebBN block. """ def __init__(self, in_channels, out_channels, lap, pooling, kernel_size=3): """Initialization. Args: in_channels (int): initial number of channels. out_channels (int): output number of channels. lap (:obj:`torch.sparse.FloatTensor`): laplacian. pooling (:obj:`torch.nn.Module`): pooling/unpooling module. kernel_size (int, optional): polynomial degree. Defaults to 3. """ super().__init__() self.spherical_cheb_bn_pool = SphericalChebBNPool(in_channels, out_channels, lap, pooling, kernel_size) self.spherical_cheb_bn = SphericalChebBN(in_channels + out_channels, out_channels, lap, kernel_size)
[docs] def forward(self, x, concat_data): """Forward Pass. Args: x (:obj:`torch.Tensor`): input [batch x vertices x channels/features] concat_data (:obj:`torch.Tensor`): encoder layer output [batch x vertices x channels/features] Returns: :obj:`torch.Tensor`: output [batch x vertices x channels/features] """ x = self.spherical_cheb_bn_pool(x) # pylint: disable=E1101 x = torch.cat((x, concat_data), dim=2) # pylint: enable=E1101 x = self.spherical_cheb_bn(x) return x
[docs]class Decoder(nn.Module): """The decoder of the Spherical UNet. """ def __init__(self, unpooling, laps): """Initialization. Args: unpooling (:obj:`torch.nn.Module`): The unpooling object. laps (list): List of laplacians. """ super().__init__() self.unpooling = unpooling self.dec_l1 = SphericalChebBNPoolConcat(512, 512, laps[1], self.unpooling) self.dec_l2 = SphericalChebBNPoolConcat(512, 256, laps[2], self.unpooling) self.dec_l3 = SphericalChebBNPoolConcat(256, 128, laps[3], self.unpooling) self.dec_l4 = SphericalChebBNPoolConcat(128, 64, laps[4], self.unpooling) self.dec_l5 = SphericalChebBNPoolCheb(64, 32, 3, laps[5], self.unpooling) # Switch from Logits to Probabilities if evaluating model self.softmax = nn.Softmax(dim=2)
[docs] def forward(self, x_enc0, x_enc1, x_enc2, x_enc3, x_enc4): """Forward Pass. Args: x_enc* (:obj:`torch.Tensor`): input tensors. Returns: :obj:`torch.Tensor`: output after forward pass. """ x = self.dec_l1(x_enc0, x_enc1) x = self.dec_l2(x, x_enc2) x = self.dec_l3(x, x_enc3) x = self.dec_l4(x, x_enc4) x = self.dec_l5(x) if not self.training: x = self.softmax(x) return x