--- title: Test Time Augmentation keywords: fastai sidebar: home_sidebar summary: "Code adapted from https://github.com/qubvel/ttach." description: "Code adapted from https://github.com/qubvel/ttach." nb_path: "nbs/07_tta.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %}

Functional

{% raw %}

rot90[source]

rot90(x, k=1)

rotate batch of images by 90 degrees k times

{% endraw %} {% raw %}

hflip[source]

hflip(x)

flip batch of images horizontally

{% endraw %} {% raw %}

vflip[source]

vflip(x)

flip batch of images vertically

{% endraw %} {% raw %}
{% endraw %}

Base Classes

{% raw %}

class BaseTransform[source]

BaseTransform(pname:str, params:Union[list, tuple])

{% endraw %} {% raw %}

class Chain[source]

Chain(functions:List[callable])

{% endraw %} {% raw %}

class Transformer[source]

Transformer(image_pipeline:Chain, mask_pipeline:Chain)

{% endraw %} {% raw %}

class Compose[source]

Compose(aug_transforms:List[BaseTransform])

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class Merger[source]

Merger()

{% endraw %} {% raw %}
{% endraw %} {% raw %}
imgs = TensorImage(torch.randn(4, 2, 356, 356))
for t in ['mean', 'max', 'std', 'uncertainty', 'entropy', 'aleatoric_uncertainty', 'epistemic_uncertainty']:
    m = Merger()
    for _ in range(10): m.append(imgs)    
    test_eq(imgs.shape, m.result(t).shape)
{% endraw %}

Transform Classes

{% raw %}

class HorizontalFlip[source]

HorizontalFlip() :: BaseTransform

Flip images horizontally (left->right)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
t = HorizontalFlip()
aug = t.apply_aug_image(imgs)
deaug = t.apply_deaug_mask(aug)
test_eq(imgs, deaug)
{% endraw %} {% raw %}

class VerticalFlip[source]

VerticalFlip() :: BaseTransform

Flip images vertically (up->down)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
t = VerticalFlip()
aug = t.apply_aug_image(imgs)
deaug = t.apply_deaug_mask(aug)
test_eq(imgs, deaug)
{% endraw %} {% raw %}

class Rotate90[source]

Rotate90(angles:List[int]) :: BaseTransform

Rotate images 0/90/180/270 degrees (angles)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
t = Rotate90([180])
aug = t.apply_aug_image(imgs)
deaug = t.apply_deaug_mask(aug)
test_eq(imgs, deaug)
{% endraw %}

Pipeline Test

{% raw %}
tfms=[HorizontalFlip(),VerticalFlip(), Rotate90(angles=[90,180,270])]
c = Compose(tfms)
m = Merger()
for t in c:
    aug = t.augment_image(imgs)
    deaug = t.deaugment_mask(aug)
    test_eq(imgs, deaug)
    m.append(deaug)
test_close(imgs, m.result())
{% endraw %}