scripts package

Submodules

scripts.run_ar_tc_ignite module

Example script for running DeepSphere U-Net on reduced AR_TC dataset.

scripts.run_ar_tc_ignite.add_tensorboard(engine_train, optimizer, model, log_dir)[source]

Creates an ignite logger object and adds training elements such as weight and gradient histograms

Parameters:
  • engine_train (ignite.engine) – the train engine to attach to the logger
  • optimizer (torch.optim) – the model’s optimizer
  • model (torch.nn.Module) – the model being trained
  • log_dir (string) – path to where tensorboard data should be saved
scripts.run_ar_tc_ignite.average_precision_compute_fn(y_pred, y_true)[source]

Attached function to the custom ignite metric AveragePrecisionMultiLabel

Parameters:
Raises:

RuntimeError – Indicates that sklearn should be installed by the user.

Returns:

average precision vector.

Of the same length as the number of labels present in the data

Return type:

numpy.array

scripts.run_ar_tc_ignite.get_dataloaders(parser_args)[source]

Creates the datasets and the corresponding dataloaders

Parameters:parser_args (dict) – parsed arguments
Returns:train, validation dataloaders
Return type:(torch.utils.data.dataloader, torch.utils.data.dataloader)
scripts.run_ar_tc_ignite.main(parser_args)[source]

Main function to create trainer engine, add handlers to train and validation engines. Then runs train engine to perform training and validation.

Parameters:parser_args (dict) – parsed arguments
scripts.run_ar_tc_ignite.validate_output_transform(x, y, y_pred)[source]

A transform to format the output of the supervised evaluator before calculating the metric

Parameters:
Returns:

model predictions and ground truths reformatted

Return type:

(torch.Tensor, torch.Tensor)

Module contents

DeepSphere Base Documentation doc