Source code for deepsphere.data.datasets.dataset

"""Dataset for reduced atmospheric river and tropical cyclone detection dataset.
"""
import os

import numpy as np
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_and_extract_archive


[docs]class ARTCDataset(Dataset): """Dataset for reduced atmospheric river and tropical cyclone dataset. """ resource = "http://island.me.berkeley.edu/ugscnn/data/climate_sphere_l5.zip" def __init__(self, path, indices=None, transform_data=None, transform_labels=None, download=False): """Initialization. Args: path (str): Path to the data or desired place the data will be downloaded to. indices (list): List of indices representing the subset of the data used for the current dataset. transform_data (:obj:`transform.Compose`): List of torchvision transforms for the data. transform_labels (:obj:`transform.Compose`): List of torchvision transforms for the labels. download (bool): Flag to decide if data should be downloaded or not. """ self.path = path if download: self.download() self.files = indices if indices is not None else os.listdir(self.path) self.transform_data = transform_data self.transform_labels = transform_labels @property def indices(self): """Get files. Returns: list: List of strings, which represent the files contained in the dataset. """ return self.files def __len__(self): """Get length of dataset. Returns: int: Number of files contained in the dataset. """ return len(self.files) def __getitem__(self, idx): """Get an item from the dataset. Args: idx (int): The index of the desired datapoint. Returns: obj, obj: The data and labels corresponding to the desired index. The type depends on the applied transforms. """ item = np.load(os.path.join(self.path, self.files[idx])) data, labels = item["data"], item["labels"] if self.transform_data: data = self.transform_data(data) if self.transform_labels: labels = self.transform_labels(labels) return data, labels
[docs] def get_runs(self, runs): """Get datapoints corresponding to specific runs. Args: runs (list): List of desired runs. Returns: list: List of strings, which represents the files in the dataset, which belong to one of the desired runs. """ files = [] for file in self.files: for i in runs: if file.endswith("{}-mesh.npz".format(i)): files.append(file) return files
[docs] def download(self): """Download the dataset if it doesn't already exist. """ if not self.check_exists(): download_and_extract_archive(self.resource, download_root=os.path.split(self.path)[0]) else: print("Data already exists")
[docs] def check_exists(self): """Check if dataset already exists. """ return os.path.exists(self.path)