Seeing is believing

Using FlashTorch 🔦 to shine a light on what neural nets "see"


Misa Ogura

Follow along @ tinyurl.com/flashtorch-hopperx1

Hello, I'm Misa 👋


  • Originally from Tokyo, now based in London
  • Cancer Cell Biologist, turned Software Engineer
  • Currently at BBC R&D
  • Co-founder of Women Driven Development
  • Women in Data Science London Ambassador

Feature visualisation


Introducing FlashTorch 🔦


Image processing & CNNs 101


Kernel & convolution


Kernel: a small matrix used for edge detection, blurring, sharpening, embossing, etc.

Convolution: an operation to calculate weighted sum of neibouring pixels

Examples of convolution: detecting edges


Typical CNN architecture


  • Kernels weights are learnt during the training

  • Extract features that are relevant to the task at hand

Feature visualisation technique

Saliency maps


Saliency


  • A subjective quality in human visual perception

  • Makes certain items stand out and grabs our attention

Saliency maps in computer vision: indications of the most “salient” regions

Saliency maps for CNNs


  • First introduced in 2013

  • Gradients of target class w.r.t. input image via backpropagation

  • Pixels with positive gradients: some intuition of attention

  • Avaialble via flashtorch.saliency API

FlashTorch demo 1

Visualising saliency maps with backpropagation


Install FlashTorch & load an image



$ pip install flashtorch

...
In [2]:
from flashtorch.utils import load_image

image = load_image('../../examples/images/great_grey_owl.jpg')

plt.imshow(image)
plt.title('Great grey owl')
plt.axis('off');

Apply transformations


In [3]:
from flashtorch.utils import apply_transforms, denormalize, format_for_plotting

input_ = apply_transforms(image)

print(f'Before: {type(image)}')
print(f'After: {type(input_)}, {input_.shape}')

plt.imshow(format_for_plotting(denormalize(input_)))
plt.title('Input tensor')
plt.axis('off');
Before: <class 'PIL.Image.Image'>
After: <class 'torch.Tensor'>, torch.Size([1, 3, 224, 224])

Create a Backprop object with a pre-trained model


In [4]:
from flashtorch.saliency import Backprop

model = models.alexnet(pretrained=True)

backprop = Backprop(model)
  • Registers custom functions to model layers
  • Grabs intermidiate gradients out of the computational graph

To calculate gradiants:

Signature:

    backprop.calculate_gradients(input_, target_class=None, ...)

Calculate the gradients of target class w.r.t. input


In [5]:
from flashtorch.utils import ImageNetIndex 

imagenet = ImageNetIndex()
target_class = imagenet['great grey owl']

print(f'Traget class index: {target_class}')

# Ready to calculate gradients!

gradients = backprop.calculate_gradients(input_, target_class)

max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

print(type(gradients), gradients.shape)
print(type(max_gradients), max_gradients.shape)
Traget class index: 24
<class 'torch.Tensor'> torch.Size([3, 224, 224])
<class 'torch.Tensor'> torch.Size([1, 224, 224])

Let's visualise gradients


In [6]:
from flashtorch.utils import visualize

visualize(input_, gradients, max_gradients)
Pixels where the animal is present have the strongest positive effects.

But it's quite noisy...

FlashTorch demo 2

Visualising saliency maps with guided backpropagation


Guided backpropagation


  • Additional guidance from the higher layers during backprop

  • Masks out neurons that had no effect or negative effects on the prediction

  • Preventing the flow of such gradients: less noise

In [7]:
guided_gradients = backprop.calculate_gradients(input_, target_class, guided=True)

max_guided_gradients = backprop.calculate_gradients(input_, target_class, take_max=True, guided=True)

Visualise guided gradients


In [8]:
visualize(input_, guided_gradients, max_guided_gradients)
Now that's much less noisy!

Pixels around the head and eyes have the strongest positive effects.

What about other birds?

What makes peacock a peacock?


In [10]:
visualize(input_, guided_gradients, max_guided_gradients)

... or a toucan?


In [12]:
visualize(input_, guided_gradients, max_guided_gradients)
🤖 We've seen what the network has learnt.

FlashTorch can also help us understand how the its perception changes through training.

FlashTorch demo 3

Gaining insights on how neural nets learn


Transfer learning


  • A model developed for a task is reused as a starting point for another task

  • Often used in computer vision & natural language processing tasks

  • Save compute & time resources

Building a flower classifier


From: DenseNet model, pre-trained on ImageNet (1000 classes)

To: Flower classifier to recognise 102 species of flowers (dataset)

Pre-trained model - 0.1% test accuracy... why is it so bad?

In [13]:
plt.imshow(load_image('images/foxgloves.jpg'))
plt.title('Foxgloves')
plt.axis('off');

Pre-trained model - 0.1% test accuracy 😨


In [16]:
backprop = Backprop(pretrained_model)
guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
/Users/misao/Projects/personal/flashtorch/flashtorch/saliency/backprop.py:94: UserWarning: The predicted class index 82 does notequal the target class index 96. Calculatingthe gradient w.r.t. the predicted class.
  'the gradient w.r.t. the predicted class.'
Trained model achieved 98.7% test accuracy... but why?

Trained model - 98.7% test accuracy 💡


In [17]:
backprop = Backprop(trained_model)
guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
The trained model has learnt to shift focus on to the most distinguising pattern.

We tend to strive for test accuracy, but I believe this tool can be powerful in other ways!

Let's make neural nets more interpretable & explainable


With feature visualisation, we're better equipped to:

  • Diagnose what and why the network gets things wrong

  • Spot and correct biases in algorithms

  • Step forward from only looking at accuracy

  • Understand why the network behaves in the way it does

  • More focus on mechanisms of how neural nets learn

Thank you!


🌡 Like what you saw? Try out FlashTorch 🔦 on Google Colab

🙏 Questions and feedback on the talk: Pull Request

🤝 General suggestions & contribution: Submit issues

Slide deck @ tinyurl.com/flashtorch-hopperx1