scripts package¶
Subpackages¶
Submodules¶
scripts.run_ar_tc module¶
Example script for running DeepSphere U-Net on reduced AR_TC dataset.
-
scripts.run_ar_tc.
add_tensorboard
(engine_train, optimizer, model, log_dir)[source]¶ Creates an ignite logger object and adds training elements such as weight and gradient histograms
- Parameters
engine_train (
ignite.engine
) – the train engine to attach to the loggeroptimizer (
torch.optim
) – the model’s optimizermodel (
torch.nn.Module
) – the model being trainedlog_dir (string) – path to where tensorboard data should be saved
-
scripts.run_ar_tc.
average_precision_compute_fn
(y_pred, y_true)[source]¶ Attached function to the custom ignite metric AveragePrecisionMultiLabel
- Parameters
y_pred (
torch.Tensor
) – model predictionsy_true (
torch.Tensor
) – ground truths
- Raises
RuntimeError – Indicates that sklearn should be installed by the user.
- Returns
- average precision vector.
Of the same length as the number of labels present in the data
- Return type
-
scripts.run_ar_tc.
get_dataloaders
(parser_args)[source]¶ Creates the datasets and the corresponding dataloaders
- Parameters
parser_args (dict) – parsed arguments
- Returns
train, validation dataloaders
- Return type
(
torch.utils.data.dataloader
,torch.utils.data.dataloader
)
-
scripts.run_ar_tc.
main
(parser_args)[source]¶ Main function to create trainer engine, add handlers to train and validation engines. Then runs train engine to perform training and validation.
- Parameters
parser_args (dict) – parsed arguments
-
scripts.run_ar_tc.
validate_output_transform
(x, y, y_pred)[source]¶ A transform to format the output of the supervised evaluator before calculating the metric
- Parameters
x (
torch.Tensor
) – the input to the modely (
torch.Tensor
) – the output of the modely_pred (
torch.Tensor
) – the ground truth labels
- Returns
model predictions and ground truths reformatted
- Return type
Module contents¶
DeepSphere Base Documentation doc