Source code for deepsphere.utils.initialization

"""Initializing device
"""


import torch
from torch import nn
from torchvision import transforms

from deepsphere.data.datasets.dataset import ARTCTemporaldataset
from deepsphere.data.transforms.transforms import Stack
from deepsphere.models.spherical_unet.unet_model import SphericalUNetTemporalConv, SphericalUNetTemporalLSTM


[docs]def init_device(device, unet): """Initialize device based on cpu/gpu and number of gpu Args: device (str): cpu or gpu ids (list of int or str): list of gpus that should be used unet (torch.Module): the model to place on the device(s) Raises: Exception: There is an error in configuring the cpu or gpu Returns: torch.Module, torch.device: the model placed on device, the device """ if device is None: device = torch.device("cpu") unet = unet.to(device) elif len(device) == 0: device = torch.device("cuda") unet = unet.to(device) unet = nn.DataParallel(unet) elif len(device) == 1: device = torch.device("cuda:{}".format(device[0])) unet = unet.to(device) elif len(device) > 1: ids = device device = torch.device("cuda:{}".format(ids[0])) unet = unet.to(device) unet = nn.DataParallel(unet, device_ids=[int(i) for i in ids]) else: raise Exception("Device set up impossible.") return unet, device
[docs]def init_unet_temp(parser): """Initialize UNet Args: parser (dict): parser arguments Returns: unet: the model """ pooling_class = parser.pooling_class n_pixels = parser.n_pixels depth = parser.depth laplacian_type = parser.laplacian_type sequence_length = parser.sequence_length kernel_size = parser.kernel_size if parser.type == "LSTM": unet = SphericalUNetTemporalLSTM(pooling_class, n_pixels, depth, laplacian_type, sequence_length, kernel_size) elif parser.type == "conv": unet = SphericalUNetTemporalConv(pooling_class, n_pixels, depth, laplacian_type, sequence_length, kernel_size) else: raise Exception("The first element after --temp must be either 'LSTM' or 'conv' to specify the type.") return unet
[docs]def init_dataset_temp(parser, indices, transform_image, transform_labels): """Initialize the dataset Args: parser (dict): parser arguments indices (list): The list of indices we want included in the dataset transform_image (list): The list of torchvision transforms we want to apply to the images transform_labels (list): The list of torchvision transforms we want to apply to the labels Returns: dataset: the dataset """ path_to_data = parser.path_to_data download = parser.download if parser.type == "LSTM": transform_sample = transforms.Compose([Stack()]) elif parser.type == "conv": transform_sample = transforms.Compose([transforms.Lambda(lambda item: torch.stack(item, dim=1).reshape(item[0].size(0), -1))]) else: raise Exception("Invalid temporality type.") dataset = ARTCTemporaldataset( path=path_to_data, download=download, sequence_length=parser.sequence_length, prediction_shift=parser.prediction_shift, indices=indices, transform_image=transform_image, transform_labels=transform_labels, transform_sample=transform_sample, ) return dataset