Source code for deepsphere.utils.stats_extractor

"""Get Means and Standard deviations for all features of a dataset.
"""
import numpy as np
import torch


[docs]def stats_extractor(dataset): """Iterates over a dataset object It is iterated over so as to calculate the mean and standard deviation. Args: dataset (:obj:`torch.utils.data.dataloader`): dataset object to iterate over Returns: :obj:numpy.array, :obj:numpy.array : computed means and standard deviation """ F, V = torch.Tensor(dataset[0][0]).shape summing = torch.zeros(F) square_summing = torch.zeros(F) total = 0 for item in dataset: item = torch.Tensor(item[0]) summing += torch.sum(item, dim=1) total += V means = torch.unsqueeze(summing / total, dim=1) for item in dataset: item = torch.Tensor(item[0]) square_summing += torch.sum((item - means) ** 2, dim=1) stds = np.sqrt(square_summing / (total - 1)) return torch.squeeze(means, dim=1).numpy(), stds.numpy()