Aims to understand how neural networks perceive images
Evolved in response to a desire to make neural nets more interpretable
For latest developments: brilliant series of articles on Distill
Open source feature visualisation toolkit
Supports torchvision models
Available to install via pip!
$ pip install flashtorch
First introduced in 2013
Gradients of target class w.r.t. input image via backpropagation
Pixels with positive gradients: some intuition of attention
Avaialble via flashtorch.saliency
API
from flashtorch.utils import load_image
image = load_image('../../examples/images/great_grey_owl.jpg')
plt.imshow(image)
plt.title('Great grey owl')
plt.axis('off');
from flashtorch.utils import apply_transforms, denormalize, format_for_plotting
input_ = apply_transforms(image)
print(f'Before: {type(image)}')
print(f'After: {type(input_)}, {input_.shape}')
plt.imshow(format_for_plotting(denormalize(input_)))
plt.title('Input tensor')
plt.axis('off');
Before: <class 'PIL.Image.Image'> After: <class 'torch.Tensor'>, torch.Size([1, 3, 224, 224])
from flashtorch.saliency import Backprop
model = models.alexnet(pretrained=True)
backprop = Backprop(model)
To calculate gradiants:
Signature:
backprop.calculate_gradients(input_, target_class=None, ...)
from flashtorch.utils import ImageNetIndex
imagenet = ImageNetIndex()
target_class = imagenet['great grey owl']
print(f'Traget class index: {target_class}')
# Ready to calculate gradients!
gradients = backprop.calculate_gradients(input_, target_class)
max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)
print(type(gradients), gradients.shape)
print(type(max_gradients), max_gradients.shape)
Traget class index: 24 <class 'torch.Tensor'> torch.Size([3, 224, 224]) <class 'torch.Tensor'> torch.Size([1, 224, 224])
from flashtorch.utils import visualize
visualize(input_, gradients, max_gradients)
guided_gradients = backprop.calculate_gradients(input_, target_class, guided=True)
max_guided_gradients = backprop.calculate_gradients(input_, target_class, take_max=True, guided=True)
visualize(input_, guided_gradients, max_guided_gradients)
visualize(input_, guided_gradients, max_guided_gradients)
visualize(input_, guided_gradients, max_guided_gradients)
A model developed for a task is reused as a starting point for another task
Often used in computer vision & natural language processing tasks
Save compute & time resources
plt.imshow(load_image('images/foxgloves.jpg'))
plt.title('Foxgloves')
plt.axis('off');
backprop = Backprop(pretrained_model)
guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)
visualize(input_, guided_gradients, guided_max_gradients)
/Users/misao/Projects/personal/flashtorch/flashtorch/saliency/backprop.py:94: UserWarning: The predicted class index 82 does notequal the target class index 96. Calculatingthe gradient w.r.t. the predicted class. 'the gradient w.r.t. the predicted class.'
backprop = Backprop(trained_model)
guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)
visualize(input_, guided_gradients, guided_max_gradients)
With feature visualisation, we're better equipped to:
Diagnose what and why the network gets things wrong
Spot and correct biases in algorithms
Step forward from only looking at accuracy
Understand why the network behaves in the way it does
More focus on mechanisms of how neural nets learn
🌡 Like what you saw? Try out FlashTorch 🔦 on Google Colab
🙏 Questions and feedback on the talk: Pull Request
🤝 General suggestions & contribution: Submit issues
Slide deck @ tinyurl.com/flashtorch-hopperx1