--- title: Inference keywords: fastai sidebar: home_sidebar summary: "Classes for inference with model ensembles using Torchscript." description: "Classes for inference with model ensembles using Torchscript." nb_path: "nbs/04_inference.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %}

Helper functions

Gaussian weighting for merging different predictions.

{% raw %}

ScriptFunction object at 0x7f3b47cc23b0>[source]

ScriptFunction object at 0x7f3b47cc23b0>()

Returns a Gaussian window

{% endraw %} {% raw %}
{% endraw %} {% raw %}

ScriptFunction object at 0x7f3b47cc2f40>[source]

ScriptFunction object at 0x7f3b47cc2f40>()

Returns a 2D Gaussian kernel tensor.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
gw = gaussian_kernel_2d((256,256), sigma_scale=1/8)
test_eq(gw.max(), 1)
plt.imshow(gw);
{% endraw %}

Calculation of epistemic and aleatoric uncertainy

{% raw %}

ScriptFunction object at 0x7f3b47ccbdb0>[source]

ScriptFunction object at 0x7f3b47ccbdb0>()

{% endraw %} {% raw %}

ScriptFunction object at 0x7f3b47ccb9a0>[source]

ScriptFunction object at 0x7f3b47ccb9a0>()

{% endraw %} {% raw %}

ScriptFunction object at 0x7f3b47ccb860>[source]

ScriptFunction object at 0x7f3b47ccb860>()

{% endraw %} {% raw %}
{% endraw %}

Tiling

Functions and classes for tiling (slicing) images

{% raw %}

get_in_slices_1d[source]

get_in_slices_1d(center:Tensor, len_x:int, len_tile:int)

{% endraw %} {% raw %}

get_out_slices_1d[source]

get_out_slices_1d(center:Tensor, len_x:int, len_tile:int)

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class TileModule[source]

TileModule(tile_shape=(512, 512), scale:float=1.0, border_padding_factor:float=0.25, max_tile_shift:float=0.5) :: Module

Class for tiling data.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
ts = 540
os = ts/2

t = TileModule(tile_shape=(ts, ts))
inp = torch.from_numpy(image).float().unsqueeze_(-1)
out = t(inp, (os,os))[0,0]
test_close(out,image, eps=1e-04)
{% endraw %} {% raw %}
TS = [256, 512, 1024]
SCALES = [0.5, 1., 2.]
SHIFTS = [0.5, 0.9, 1.]
BPFACTORS = [0.25, 0.1, 0.]
inp = torch.from_numpy(image).float().unsqueeze_(-1)

for ts in TS:
    for scale in SCALES:
        for max_tile_shift in SHIFTS:
            for border_padding_factor in BPFACTORS:
                t = TileModule(tile_shape=(ts, ts), scale=scale, max_tile_shift=max_tile_shift, border_padding_factor=border_padding_factor)
                out = torch.zeros(*[int(x/t.scale) for x in inp.shape[:2]]).unsqueeze_(-1)
                in_slices, out_slices, center_points = t.get_slices_and_centers(inp.shape)
                assert len(center_points)!=0
                for i, cp in enumerate(center_points):
                    ix0, ix1, iy0, iy1 = in_slices[0][0][i], in_slices[0][1][i], in_slices[1][0][i], in_slices[1][1][i]
                    ox0, ox1, oy0, oy1 = out_slices[0][0][i], out_slices[0][1][i], out_slices[1][0][i], out_slices[1][1][i]
                    assert (ix1-ix0) == (ox1-ox0), 'Input/Output slices do not match'
                    assert (iy1-iy0) == (oy1-oy0), 'Input/Output slices do not match'
                    tile = t(inp, cp)
                    out[ox0:ox1, oy0:oy1, 0] += tile[0 ,0, ix0:ix1, iy0:iy1]
                    #plt.imshow(out)
                    #plt.show()
{% endraw %} {% raw %}
if torch.__version__.split(".") >= ["1", "12", "0"]:
    print(torch.__version__)
    t_scripted = torch.jit.script(t)
    ost = torch.tensor(os)
    path = Path('TileModule.onnx')
    _ = torch.onnx.export(t_scripted, (inp, torch.tensor([ost, ost])), path, verbose=True, opset_version=16)
    path.unlink()
{% endraw %}

Inference Ensemble

Scripable module for inference with multiple models

{% raw %}

class InferenceEnsemble[source]

InferenceEnsemble(models:List[Module], num_classes:int, in_channels:int, channel_means:List[float], channel_stds:List[float], tile_shape:Tuple[int, int]=(512, 512), use_gaussian:bool=True, gaussian_kernel_sigma_scale:float=0.125, use_tta:bool=True, border_padding_factor:float=0.25, max_tile_shift:float=0.9, scale:float=1.0, device:str='cpu') :: Module

Class for model ensemble inference

{% endraw %} {% raw %}
{% endraw %} {% raw %}
class DummyModule(torch.nn.Module):
    'Dummy Module for testing'
    def __init__(self, num_classes=2):
        super().__init__()
        self.num_classes = num_classes
    
    def forward(self, x):
        return x[:,:1].repeat(1, self.num_classes, 1,1)
{% endraw %} {% raw %}
CHANNELS = [1,3]
NUM_CLASSES = [2, 5]
SCALES = [0.5, 1., 4.]
IMG_SHAPES = [512]#[256, 512, 1024]

for ch in CHANNELS:
    channel_means = [0.]*ch
    channel_stds = [1.]*ch
    for num_classes in NUM_CLASSES:
        models = [DummyModule(num_classes=num_classes) for _ in range(2)]
        for sx in IMG_SHAPES:
            for sy in IMG_SHAPES:
                inp = torch.rand(sx,sy, ch)
                for scale in SCALES:
                    ensemble = InferenceEnsemble(models, 
                                                 num_classes=num_classes, 
                                                 in_channels=3, 
                                                 channel_means=channel_means,
                                                 channel_stds=channel_stds, 
                                                 scale=scale)
                    ensemble = torch.jit.script(ensemble)
                    outs = ensemble(inp)
                    test_eq(outs[0].shape, (sx, sy))
                    test_eq(outs[1].shape, (num_classes, sx, sy))
                    test_eq(outs[2].shape, (sx, sy))
{% endraw %} {% raw %}
path_pt = Path('ensemble.pt')
scripted_ensemble = torch.jit.script(ensemble)
scripted_ensemble.save(path_pt)
_ = torch.jit.load(path_pt)
path_pt.unlink()
{% endraw %} {% raw %}
if torch.__version__.split(".") >= ["1", "12", "0"]:
    print(torch.__version__)
    input_names=["inp"]
    output_names=["argmax", "softmax", "stdeviation"]
    dynamic_axes = {"inp": [0, 1], "argmax": [0, 1], "softmax": [0, 1, 2], "stdeviation": [0, 1]}
    path_onnx = Path('ensemble.onnx')
    torch.onnx.export(scripted_ensemble, inp, f=path_onnx, verbose=True, opset_version=16, 
                      input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
    
    #import onnxruntime
    #ort_session = onnxruntime.InferenceSession(path_onnx.as_posix())
{% endraw %}