mighty.trainer.MaskTrainer¶
- class mighty.trainer.MaskTrainer(image_shape, accuracy_measure=AccuracyArgmax(), learning_rate=0.1, show_progress=False)[source]¶
Interpretable Explanations of Black Boxes by Meaningful Perturbation [1].
Train an occlusion mask that shows where a neural network “looks at” in the input space.
- Parameters:
- accuracy_measureAccuracy
Accuracy estimator.
- image_shapetuple
The shape of an input image.
- learning_ratefloat, optional
Optimizer learning rate. Default: 0.1
- show_progressbool, optional
Show the training progress bar or not. Default: False
References
[1]Fong, R. C., & Vedaldi, A. (2017). Interpretable explanations of black boxes by meaningful perturbation. In Proceedings of the IEEE International Conference on Computer Vision (pp. 3429-3437).
Methods
__init__(image_shape[, accuracy_measure, ...])cpu()cuda()get_probability(outputs, label)Returns the probability of the label class of the outputs.
train_mask(model, image, label_true)Train a grayscale occlusion mask.
Attributes
l1_coeffmask_sizemax_iterationstv_betatv_coeff- get_probability(outputs, label)[source]¶
Returns the probability of the label class of the outputs.
- Parameters:
- outputstorch.Tensor
The output of a model.
- labelint
The true class label.
- Returns:
- float
The probability of
outputs[label].
- train_mask(model, image, label_true)[source]¶
Train a grayscale occlusion mask.
- Parameters:
- modelnn.Module
A neural network model.
- image(C, H, W) torch.Tensor
An input image.
- label_trueint
The true class label of the image.
- Returns:
- mask_upsampledtorch.Tensor
The occlusion mask.
- image_perturbedtorch.Tensor
The input image with the mask applied.
- loss_tracelist of float
A list of training losses.