Source code for deepsphere.models.spherical_unet.utils
"""Layers used in both Encoder and Decoder.
"""
# pylint: disable=W0221
import torch.nn.functional as F
from torch import nn
from deepsphere.layers.chebyshev import SphericalChebConv
[docs]class SphericalChebBN(nn.Module):
"""Building Block with a Chebyshev Convolution, Batchnormalization, and ReLu activation.
"""
def __init__(self, in_channels, out_channels, lap, kernel_size=3):
"""Initialization.
Args:
in_channels (int): initial number of channels.
out_channels (int): output number of channels.
lap (:obj:`torch.sparse.FloatTensor`): laplacian.
kernel_size (int, optional): polynomial degree. Defaults to 3.
"""
super().__init__()
self.spherical_cheb = SphericalChebConv(in_channels, out_channels, lap, kernel_size)
self.batchnorm = nn.BatchNorm1d(out_channels)
[docs] def forward(self, x):
"""Forward Pass.
Args:
x (:obj:`torch.tensor`): input [batch x vertices x channels/features]
Returns:
:obj:`torch.tensor`: output [batch x vertices x channels/features]
"""
x = self.spherical_cheb(x)
x = self.batchnorm(x.permute(0, 2, 1))
x = F.relu(x.permute(0, 2, 1))
return x
[docs]class SphericalChebBNPool(nn.Module):
"""Building Block with a pooling/unpooling, a calling the SphericalChebBN block.
"""
def __init__(self, in_channels, out_channels, lap, pooling, kernel_size=3):
"""Initialization.
Args:
in_channels (int): initial number of channels.
out_channels (int): output number of channels.
lap (:obj:`torch.sparse.FloatTensor`): laplacian.
pooling (:obj:`torch.nn.Module`): pooling/unpooling module.
kernel_size (int, optional): polynomial degree. Defaults to 3.
"""
super().__init__()
self.pooling = pooling
self.spherical_cheb_bn = SphericalChebBN(in_channels, out_channels, lap, kernel_size)
[docs] def forward(self, x):
"""Forward Pass.
Args:
x (:obj:`torch.tensor`): input [batch x vertices x channels/features]
Returns:
:obj:`torch.tensor`: output [batch x vertices x channels/features]
"""
x = self.pooling(x)
x = self.spherical_cheb_bn(x)
return x