--- title: Patches for the `fastai` Learner keywords: fastai sidebar: home_sidebar summary: "Imlements functions necessary to build `Learner` suitable for bioimgage segmentation" description: "Imlements functions necessary to build `Learner` suitable for bioimgage segmentation" nb_path: "nbs/00_learner.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %}

%nbdev_export import numpy as np import torch from torch import nn import torch.nn.functional as F from fastai2.learner import Learner from fastprogress.fastprogress import progress_bar

{% raw %}
{% endraw %} {% raw %}

Learner.predict_tiles[source]

Learner.predict_tiles(ds_idx=1, dl=None, mc_dropout=False, n_times=1)

Make predictions with dropout applied.

{% endraw %} {% raw %}
{% endraw %} {% raw %}

Learner.apply_dropout[source]

Learner.apply_dropout()

If a module contains 'dropout', it will be switched to .train() mode.

{% endraw %} {% raw %}
@patch
def get_mc_dropout_results(self, plot=True, dl=None, tile_ds:TileDataset=None, 
                           max_n=9, n_times=20, figsize=(15,15), **kwargs):
    "Get results with MC Dropout enabled. Plot results is enabled by default."
    if dl is None:
        dl = self.dls.valid
    if tile_ds is None:
        tile_ds = self.dls.valid_ds    
    
    smxs, segs, std_devs = self.predict_tiles_with_mc_dropout(dl, tile_ds, n_times)
    entrp = {tile_ds.files[i]:std_devs[i] for i in range(len(tile_ds.files))}
    
    if plot==True:
        imgs = tile_ds.get_images()
        for i, path in enumerate(tile_ds.files):
            img = imgs[i]
            msk = tile_ds.lbl_wgt_pdf[path.name][0] if path.name in tile_ds.lbl_wgt_pdf else np.ones_like(imgs)
            pred = segs[i]
            std_dev = std_devs[i]
            entr = entropy(std_dev[...,1]).mean()
            ser_tmp = pd.Series({'File' : path.name, 'Entropy': entr})
            fig, axs = plt.subplots(nrows=1, ncols=4, figsize=figsize)
                        
            
            axs[0].imshow(imgs[i], cmap='binary_r')
            axs[0].set_axis_off()
            axs[0].set_title('Image {}'.format(path.name))
            
            axs[1].imshow(msk, cmap='binary_r')
            axs[1].set_axis_off()
            axs[1].set_title('Target')

            
            axs[3].set_title('Std ({} Entropy)'.format(np.round(entrop,2)))
    
    return smxs, segs, std_devs
{% endraw %}