scripts package

Subpackages

Submodules

scripts.run_ar_tc module

Example script for running DeepSphere U-Net on reduced AR_TC dataset.

scripts.run_ar_tc.add_tensorboard(engine_train, optimizer, model, log_dir)[source]

Creates an ignite logger object and adds training elements such as weight and gradient histograms

Parameters
  • engine_train (ignite.engine) – the train engine to attach to the logger

  • optimizer (torch.optim) – the model’s optimizer

  • model (torch.nn.Module) – the model being trained

  • log_dir (string) – path to where tensorboard data should be saved

scripts.run_ar_tc.average_precision_compute_fn(y_pred, y_true)[source]

Attached function to the custom ignite metric AveragePrecisionMultiLabel

Parameters
Raises

RuntimeError – Indicates that sklearn should be installed by the user.

Returns

average precision vector.

Of the same length as the number of labels present in the data

Return type

numpy.array

scripts.run_ar_tc.get_dataloaders(parser_args)[source]

Creates the datasets and the corresponding dataloaders

Parameters

parser_args (dict) – parsed arguments

Returns

train, validation dataloaders

Return type

(torch.utils.data.dataloader, torch.utils.data.dataloader)

scripts.run_ar_tc.main(parser_args)[source]

Main function to create trainer engine, add handlers to train and validation engines. Then runs train engine to perform training and validation.

Parameters

parser_args (dict) – parsed arguments

scripts.run_ar_tc.validate_output_transform(x, y, y_pred)[source]

A transform to format the output of the supervised evaluator before calculating the metric

Parameters
Returns

model predictions and ground truths reformatted

Return type

(torch.Tensor, torch.Tensor)

Module contents

DeepSphere Base Documentation doc