scripts.temporality package¶
Submodules¶
scripts.temporality.run_ar_tc module¶
Example script for running DeepSphere U-Net on reduced AR_TC dataset.
-
scripts.temporality.run_ar_tc.add_tensorboard(engine_train, optimizer, model, log_dir)[source]¶ Creates an ignite logger object and adds training elements such as weight and gradient histograms
- Parameters
engine_train (
ignite.engine) – the train engine to attach to the loggeroptimizer (
torch.optim) – the model’s optimizermodel (
torch.nn.Module) – the model being trainedlog_dir (string) – path to where tensorboard data should be saved
-
scripts.temporality.run_ar_tc.average_precision_compute_fn(y_pred, y_true)[source]¶ Attached function to the custom ignite metric AveragePrecisionMultiLabel
- Parameters
y_pred (
torch.Tensor) – model predictionsy_true (
torch.Tensor) – ground truths
- Raises
RuntimeError – Indicates that sklearn should be installed by the user.
- Returns
- average precision vector.
Of the same length as the number of labels present in the data
- Return type
-
scripts.temporality.run_ar_tc.get_dataloaders(parser_args)[source]¶ Creates the datasets and the corresponding dataloaders
- Parameters
parser_args (dict) – parsed arguments
- Returns
train, validation dataloaders
- Return type
(
torch.utils.data.dataloader,torch.utils.data.dataloader)
-
scripts.temporality.run_ar_tc.main(parser_args)[source]¶ Main function to create trainer engine, add handlers to train and validation engines. Then runs train engine to perform training and validation.
- Parameters
parser_args (dict) – parsed arguments
-
scripts.temporality.run_ar_tc.validate_output_transform(x, y, y_pred)[source]¶ A transform to format the output of the supervised evaluator before calculating the metric
- Parameters
x (
torch.Tensor) – the input to the modely (
torch.Tensor) – the output of the modely_pred (
torch.Tensor) – the ground truth labels
- Returns
model predictions and ground truths reformatted
- Return type