scripts package¶
Submodules¶
scripts.run_ar_tc_ignite module¶
Example script for running DeepSphere U-Net on reduced AR_TC dataset.
-
scripts.run_ar_tc_ignite.add_tensorboard(engine_train, optimizer, model, log_dir)[source]¶ Creates an ignite logger object and adds training elements such as weight and gradient histograms
Parameters: - engine_train (
ignite.engine) – the train engine to attach to the logger - optimizer (
torch.optim) – the model’s optimizer - model (
torch.nn.Module) – the model being trained - log_dir (string) – path to where tensorboard data should be saved
- engine_train (
-
scripts.run_ar_tc_ignite.average_precision_compute_fn(y_pred, y_true)[source]¶ Attached function to the custom ignite metric AveragePrecisionMultiLabel
Parameters: - y_pred (
torch.Tensor) – model predictions - y_true (
torch.Tensor) – ground truths
Raises: RuntimeError– Indicates that sklearn should be installed by the user.Returns: - average precision vector.
Of the same length as the number of labels present in the data
Return type: - y_pred (
-
scripts.run_ar_tc_ignite.get_dataloaders(parser_args)[source]¶ Creates the datasets and the corresponding dataloaders
Parameters: parser_args (dict) – parsed arguments Returns: train, validation dataloaders Return type: ( torch.utils.data.dataloader,torch.utils.data.dataloader)
-
scripts.run_ar_tc_ignite.main(parser_args)[source]¶ Main function to create trainer engine, add handlers to train and validation engines. Then runs train engine to perform training and validation.
Parameters: parser_args (dict) – parsed arguments
-
scripts.run_ar_tc_ignite.validate_output_transform(x, y, y_pred)[source]¶ A transform to format the output of the supervised evaluator before calculating the metric
Parameters: - x (
torch.Tensor) – the input to the model - y (
torch.Tensor) – the output of the model - y_pred (
torch.Tensor) – the ground truth labels
Returns: model predictions and ground truths reformatted
Return type: - x (
Module contents¶
DeepSphere Base Documentation doc