scripts package¶
Submodules¶
scripts.run_ar_tc_ignite module¶
Example script for running DeepSphere U-Net on reduced AR_TC dataset.
-
scripts.run_ar_tc_ignite.
add_tensorboard
(engine_train, optimizer, model, log_dir)[source]¶ Creates an ignite logger object and adds training elements such as weight and gradient histograms
Parameters: - engine_train (
ignite.engine
) – the train engine to attach to the logger - optimizer (
torch.optim
) – the model’s optimizer - model (
torch.nn.Module
) – the model being trained - log_dir (string) – path to where tensorboard data should be saved
- engine_train (
-
scripts.run_ar_tc_ignite.
average_precision_compute_fn
(y_pred, y_true)[source]¶ Attached function to the custom ignite metric AveragePrecisionMultiLabel
Parameters: - y_pred (
torch.Tensor
) – model predictions - y_true (
torch.Tensor
) – ground truths
Raises: RuntimeError
– Indicates that sklearn should be installed by the user.Returns: - average precision vector.
Of the same length as the number of labels present in the data
Return type: - y_pred (
-
scripts.run_ar_tc_ignite.
get_dataloaders
(parser_args)[source]¶ Creates the datasets and the corresponding dataloaders
Parameters: parser_args (dict) – parsed arguments Returns: train, validation dataloaders Return type: ( torch.utils.data.dataloader
,torch.utils.data.dataloader
)
-
scripts.run_ar_tc_ignite.
main
(parser_args)[source]¶ Main function to create trainer engine, add handlers to train and validation engines. Then runs train engine to perform training and validation.
Parameters: parser_args (dict) – parsed arguments
-
scripts.run_ar_tc_ignite.
validate_output_transform
(x, y, y_pred)[source]¶ A transform to format the output of the supervised evaluator before calculating the metric
Parameters: - x (
torch.Tensor
) – the input to the model - y (
torch.Tensor
) – the output of the model - y_pred (
torch.Tensor
) – the ground truth labels
Returns: model predictions and ground truths reformatted
Return type: - x (
Module contents¶
DeepSphere Base Documentation doc