Source code for deepsphere.models.spherical_unet.unet_model

"""Spherical Graph Convolutional Neural Network with UNet autoencoder architecture.
"""

# pylint: disable=W0221


from torch import nn

from deepsphere.layers.samplings.equiangular_pool_unpool import Equiangular
from deepsphere.layers.samplings.healpix_pool_unpool import Healpix
from deepsphere.layers.samplings.icosahedron_pool_unpool import Icosahedron
from deepsphere.models.spherical_unet.decoder import Decoder
from deepsphere.models.spherical_unet.encoder import Encoder
from deepsphere.utils.laplacian_funcs import get_equiangular_laplacians, get_healpix_laplacians, get_icosahedron_laplacians


[docs]class SphericalUNet(nn.Module): """Spherical GCNN Autoencoder. """ def __init__(self, pooling_class, N, depth, laplacian_type, ratio=1): """Initialization. Args: pooling_class (obj): One of three classes of pooling methods N (int): Number of pixels in the input image depth (int): The depth of the UNet, which is bounded by the N and the type of pooling ratio (float): Parameter for equiangular sampling """ super().__init__() self.pooling_class = pooling_class self.ratio = ratio if isinstance(self.pooling_class, Icosahedron): self.laps = get_icosahedron_laplacians(N, depth, laplacian_type) elif isinstance(self.pooling_class, Healpix): self.laps = get_healpix_laplacians(N, depth, laplacian_type) elif isinstance(self.pooling_class, Equiangular): self.laps = get_equiangular_laplacians(N, depth, self.ratio, laplacian_type) else: raise ValueError("Error: sampling method unknown. Please use icosahedron, healpix or equiangular.") self.encoder = Encoder(self.pooling_class.pooling, self.laps) self.decoder = Decoder(self.pooling_class.unpooling, self.laps)
[docs] def forward(self, x): """Forward Pass. Args: x (:obj:`torch.Tensor`): input to be forwarded. Returns: :obj:`torch.Tensor`: output """ x_encoder = self.encoder(x) output = self.decoder(*x_encoder) return output