scripts.temporality package

Submodules

scripts.temporality.run_ar_tc module

Example script for running DeepSphere U-Net on reduced AR_TC dataset.

scripts.temporality.run_ar_tc.add_tensorboard(engine_train, optimizer, model, log_dir)[source]

Creates an ignite logger object and adds training elements such as weight and gradient histograms

Parameters
  • engine_train (ignite.engine) – the train engine to attach to the logger

  • optimizer (torch.optim) – the model’s optimizer

  • model (torch.nn.Module) – the model being trained

  • log_dir (string) – path to where tensorboard data should be saved

scripts.temporality.run_ar_tc.average_precision_compute_fn(y_pred, y_true)[source]

Attached function to the custom ignite metric AveragePrecisionMultiLabel

Parameters
Raises

RuntimeError – Indicates that sklearn should be installed by the user.

Returns

average precision vector.

Of the same length as the number of labels present in the data

Return type

numpy.array

scripts.temporality.run_ar_tc.get_dataloaders(parser_args)[source]

Creates the datasets and the corresponding dataloaders

Parameters

parser_args (dict) – parsed arguments

Returns

train, validation dataloaders

Return type

(torch.utils.data.dataloader, torch.utils.data.dataloader)

scripts.temporality.run_ar_tc.main(parser_args)[source]

Main function to create trainer engine, add handlers to train and validation engines. Then runs train engine to perform training and validation.

Parameters

parser_args (dict) – parsed arguments

scripts.temporality.run_ar_tc.validate_output_transform(x, y, y_pred)[source]

A transform to format the output of the supervised evaluator before calculating the metric

Parameters
Returns

model predictions and ground truths reformatted

Return type

(torch.Tensor, torch.Tensor)

Module contents