--- title: Inference keywords: fastai sidebar: home_sidebar summary: "Classes for inference with model ensembles using Torchscript." description: "Classes for inference with model ensembles using Torchscript." nb_path: "nbs/04_inference.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %}

Helper functions

Gaussian weighting for merging different predictions.

{% raw %}

ScriptFunction object at 0x7f1617396950>[source]

ScriptFunction object at 0x7f1617396950>()

Returns a Gaussian window

{% endraw %} {% raw %}
{% endraw %} {% raw %}

ScriptFunction object at 0x7f1617396900>[source]

ScriptFunction object at 0x7f1617396900>()

Returns a 2D Gaussian kernel tensor.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
gw = gaussian_kernel_2d((256,256), sigma_scale=1/8)
test_eq(gw.max(), 1)
plt.imshow(gw);
{% endraw %}

Calculation of epistemic and aleatoric uncertainy

{% raw %}

ScriptFunction object at 0x7f161738fcc0>[source]

ScriptFunction object at 0x7f161738fcc0>()

{% endraw %} {% raw %}

ScriptFunction object at 0x7f161738f630>[source]

ScriptFunction object at 0x7f161738f630>()

{% endraw %} {% raw %}

ScriptFunction object at 0x7f161738fea0>[source]

ScriptFunction object at 0x7f161738fea0>()

{% endraw %} {% raw %}
{% endraw %}

Tiling

Functions and classes for tiling (slicing) images

{% raw %}

get_in_slices_1d[source]

get_in_slices_1d(center:Tensor, len_x:int, len_tile:int)

{% endraw %} {% raw %}

get_out_slices_1d[source]

get_out_slices_1d(center:Tensor, len_x:int, len_tile:int)

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class TileModule[source]

TileModule(tile_shape=(512, 512), scale:float=1.0, border_padding_factor:float=0.25, max_tile_shift:float=0.5) :: Module

Class for tiling data.

{% endraw %} {% raw %}
{% endraw %} {% raw %}
ts = 540
os = ts/2

t = TileModule(tile_shape=(ts, ts))
inp = torch.from_numpy(image).float().unsqueeze_(-1)
out = t(inp, (os,os))[0,0]
test_close(out,image, eps=1e-04)
{% endraw %} {% raw %}
TS = [256, 512, 1024]
SCALES = [0.5, 1., 2.]
SHIFTS = [0.5, 0.9, 1.]
BPFACTORS = [0.25, 0.1, 0.]
inp = torch.from_numpy(image).float().unsqueeze_(-1)

for ts in TS:
    for scale in SCALES:
        for max_tile_shift in SHIFTS:
            for border_padding_factor in BPFACTORS:
                t = TileModule(tile_shape=(ts, ts), scale=scale, max_tile_shift=max_tile_shift, border_padding_factor=border_padding_factor)
                out = torch.zeros(*[int(x/t.scale) for x in inp.shape[:2]]).unsqueeze_(-1)
                in_slices, out_slices, center_points = t.get_slices_and_centers(inp.shape)
                assert len(center_points)!=0
                for i, cp in enumerate(center_points):
                    ix0, ix1, iy0, iy1 = in_slices[0][0][i], in_slices[0][1][i], in_slices[1][0][i], in_slices[1][1][i]
                    ox0, ox1, oy0, oy1 = out_slices[0][0][i], out_slices[0][1][i], out_slices[1][0][i], out_slices[1][1][i]
                    assert (ix1-ix0) == (ox1-ox0), 'Input/Output slices do not match'
                    assert (iy1-iy0) == (oy1-oy0), 'Input/Output slices do not match'
                    tile = t(inp, cp)
                    out[ox0:ox1, oy0:oy1, 0] += tile[0 ,0, ix0:ix1, iy0:iy1]
                    #plt.imshow(out)
                    #plt.show()
{% endraw %} {% raw %}
if torch.__version__.split(".") >= ["1", "12", "0"]:
    print(torch.__version__)
    t_scripted = torch.jit.script(t)
    ost = torch.tensor(os)
    path = Path('TileModule.onnx')
    _ = torch.onnx.export(t_scripted, (inp, torch.tensor([ost, ost])), path, verbose=True, opset_version=16)
    path.unlink()
1.12.0
Exported graph: graph(%x.1 : Float(540, 540, 1, strides=[540, 1, 1], requires_grad=0, device=cpu),
      %center.1 : Float(2, strides=[1], requires_grad=0, device=cpu)):
  %onnx::Gather_2 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name="Constant_0"]()
  %onnx::Gather_3 : Long(device=cpu) = onnx::Constant[value={1}, onnx_name="Constant_1"]()
  %onnx::Sub_4 : Float(device=cpu) = onnx::Gather[axis=0, onnx_name="Gather_2"](%center.1, %onnx::Gather_2) # /tmp/ipykernel_2119845/628550723.py:55:31
  %onnx::Sub_5 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={270}, onnx_name="Constant_3"]() # /tmp/ipykernel_2119845/628550723.py:55:41
  %onnx::Div_6 : Float(device=cpu) = onnx::Sub[onnx_name="Sub_4"](%onnx::Sub_4, %onnx::Sub_5) # /tmp/ipykernel_2119845/628550723.py:55:31
  %onnx::Div_7 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={270}, onnx_name="Constant_5"]() # /tmp/ipykernel_2119845/628550723.py:55:41
  %relative_center : Float(device=cpu) = onnx::Div[onnx_name="Div_6"](%onnx::Div_6, %onnx::Div_7) # /tmp/ipykernel_2119845/628550723.py:55:31
  %onnx::Add_9 : Float(1024, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = onnx::Constant[value=<Tensor>, onnx_name="Constant_7"]() # /tmp/ipykernel_2119845/628550723.py:56:22
  %coords : Float(1024, 1024, strides=[1024, 1], device=cpu) = onnx::Add[onnx_name="Add_8"](%onnx::Add_9, %relative_center) # /tmp/ipykernel_2119845/628550723.py:56:22
  %onnx::Unsqueeze_11 : Float(1024, 1024, strides=[1024, 1], device=cpu) = onnx::Cast[to=1, onnx_name="Cast_9"](%coords) # /tmp/ipykernel_2119845/628550723.py:57:25
  %onnx::Sub_12 : Float(device=cpu) = onnx::Gather[axis=0, onnx_name="Gather_10"](%center.1, %onnx::Gather_3) # /tmp/ipykernel_2119845/628550723.py:55:31
  %onnx::Sub_13 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={270}, onnx_name="Constant_11"]() # /tmp/ipykernel_2119845/628550723.py:55:41
  %onnx::Div_14 : Float(device=cpu) = onnx::Sub[onnx_name="Sub_12"](%onnx::Sub_12, %onnx::Sub_13) # /tmp/ipykernel_2119845/628550723.py:55:31
  %onnx::Div_15 : Float(requires_grad=0, device=cpu) = onnx::Constant[value={270}, onnx_name="Constant_13"]() # /tmp/ipykernel_2119845/628550723.py:55:41
  %relative_center.3 : Float(device=cpu) = onnx::Div[onnx_name="Div_14"](%onnx::Div_14, %onnx::Div_15) # /tmp/ipykernel_2119845/628550723.py:55:31
  %onnx::Add_17 : Float(1024, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = onnx::Constant[value=<Tensor>, onnx_name="Constant_15"]() # /tmp/ipykernel_2119845/628550723.py:56:22
  %coords.3 : Float(1024, 1024, strides=[1024, 1], device=cpu) = onnx::Add[onnx_name="Add_16"](%onnx::Add_17, %relative_center.3) # /tmp/ipykernel_2119845/628550723.py:56:22
  %onnx::Unsqueeze_19 : Float(1024, 1024, strides=[1024, 1], device=cpu) = onnx::Cast[to=1, onnx_name="Cast_17"](%coords.3) # /tmp/ipykernel_2119845/628550723.py:57:25
  %onnx::Unsqueeze_20 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={-1}, onnx_name="Constant_18"]() # /tmp/ipykernel_2119845/628550723.py:60:16
  %onnx::Concat_21 : Float(1024, 1024, 1, strides=[1024, 1, 1], device=cpu) = onnx::Unsqueeze[onnx_name="Unsqueeze_19"](%onnx::Unsqueeze_19, %onnx::Unsqueeze_20) # /tmp/ipykernel_2119845/628550723.py:60:16
  %onnx::Unsqueeze_22 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={-1}, onnx_name="Constant_20"]() # /tmp/ipykernel_2119845/628550723.py:60:16
  %onnx::Concat_23 : Float(1024, 1024, 1, strides=[1024, 1, 1], device=cpu) = onnx::Unsqueeze[onnx_name="Unsqueeze_21"](%onnx::Unsqueeze_11, %onnx::Unsqueeze_22) # /tmp/ipykernel_2119845/628550723.py:60:16
  %onnx::Cast_24 : Float(1024, 1024, 2, strides=[2048, 2, 1], device=cpu) = onnx::Concat[axis=-1, onnx_name="Concat_22"](%onnx::Concat_21, %onnx::Concat_23) # /tmp/ipykernel_2119845/628550723.py:60:16
  %onnx::Unsqueeze_25 : Float(1024, 1024, 2, strides=[2048, 2, 1], device=cpu) = onnx::Cast[to=1, onnx_name="Cast_23"](%onnx::Cast_24) # /tmp/ipykernel_2119845/628550723.py:60:16
  %onnx::Unsqueeze_26 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="Constant_24"]() # /tmp/ipykernel_2119845/628550723.py:60:16
  %onnx::GridSample_27 : Float(1, 1024, 1024, 2, strides=[2097152, 2048, 2, 1], device=cpu) = onnx::Unsqueeze[onnx_name="Unsqueeze_25"](%onnx::Unsqueeze_25, %onnx::Unsqueeze_26) # /tmp/ipykernel_2119845/628550723.py:60:16
  %onnx::Unsqueeze_28 : Float(1, 540, 540, strides=[291600, 540, 1], device=cpu) = onnx::Transpose[perm=[2, 0, 1], onnx_name="Transpose_26"](%x.1) # /tmp/ipykernel_2119845/628550723.py:63:12
  %onnx::Unsqueeze_29 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}, onnx_name="Constant_27"]() # /tmp/ipykernel_2119845/628550723.py:63:12
  %onnx::GridSample_30 : Float(1, 1, 540, 540, strides=[291600, 291600, 540, 1], device=cpu) = onnx::Unsqueeze[onnx_name="Unsqueeze_28"](%onnx::Unsqueeze_28, %onnx::Unsqueeze_29) # /tmp/ipykernel_2119845/628550723.py:63:12
  %x : Float(1, 1, 1024, 1024, strides=[1048576, 1048576, 1024, 1], requires_grad=0, device=cpu) = onnx::GridSample[align_corners=0, mode="nearest", padding_mode="reflection", onnx_name="GridSample_29"](%onnx::GridSample_30, %onnx::GridSample_27) # /home/magr/.conda/envs/fastai2/lib/python3.9/site-packages/torch/nn/functional.py:4223:11
  return (%x)

{% endraw %}

Inference Ensemble

Scripable module for inference with multiple models

{% raw %}

class InferenceEnsemble[source]

InferenceEnsemble(models:List[Module], num_classes:int, in_channels:int, channel_means:List[float], channel_stds:List[float], tile_shape:Tuple[int, int]=(512, 512), use_gaussian:bool=True, gaussian_kernel_sigma_scale:float=0.125, use_tta:bool=True, border_padding_factor:float=0.25, max_tile_shift:float=0.9, scale:float=1.0, device:str='cpu') :: Module

Class for model ensemble inference

{% endraw %} {% raw %}
{% endraw %} {% raw %}
class DummyModule(torch.nn.Module):
    'Dummy Module for testing'
    def __init__(self, num_classes=2):
        super().__init__()
        self.num_classes = num_classes
    
    def forward(self, x):
        return x[:,:1].repeat(1, self.num_classes, 1,1)
{% endraw %} {% raw %}
CHANNELS = [1,3]
NUM_CLASSES = [2, 5]
SCALES = [0.5, 1., 4.]
IMG_SHAPES = [512]#[256, 512, 1024]

for ch in CHANNELS:
    channel_means = [0.]*ch
    channel_stds = [1.]*ch
    for num_classes in NUM_CLASSES:
        models = [DummyModule(num_classes=num_classes) for _ in range(3)]
        for sx in IMG_SHAPES:
            for sy in IMG_SHAPES:
                inp = torch.rand(sx,sy, ch)
                for scale in SCALES:
                    ensemble = InferenceEnsemble(models, 
                                                 num_classes=num_classes, 
                                                 in_channels=3, 
                                                 channel_means=channel_means,
                                                 channel_stds=channel_stds, 
                                                 scale=scale)
                    ensemble = torch.jit.script(ensemble)
                    outs = ensemble(inp)
                    test_eq(outs[0].shape, (sx, sy))
                    test_eq(outs[1].shape, (num_classes, sx, sy))
                    test_eq(outs[2].shape, (sx, sy))
{% endraw %} {% raw %}
path_pt = Path('ensemble.pt')
scripted_ensemble = torch.jit.script(ensemble)
scripted_ensemble.save(path_pt)
_ = torch.jit.load(path_pt)
path_pt.unlink()
{% endraw %}

ONNX Export (experimental)

Not supported yet, needs further testing

{% raw %}
def onnx_export(model, example_input, f, verbose=True):
    'Export model to onnx format'
    if torch.__version__.split(".") >= ["1", "12", "0"]:
        model = torch.jit.script(model)
        input_names=["inp"]
        output_names=["argmax", "softmax", "stdeviation"]
        dynamic_axes = {"inp": {0: "heigth", 0: "width"}, 
                        "argmax": {0: "heigth", 0: "width"}, 
                        "softmax": {0: "num_classes", 1: "heigth", 2: "width"}, 
                        "stdeviation": {0: "heigth", 0: "width"}}
        torch.onnx.export(model, example_input.float(), f, verbose=verbose, opset_version=16, do_constant_folding=True,
                          input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
    else:
        print('Require pytorch version >= 1.12.0')
{% endraw %} {% raw %}
# inp = torch.rand(1024,1024,3).float()
# path_onnx = Path('ensemble.onnx')
# onnx_export(scripted_ensemble, inp, path_onnx)

#import onnxruntime
#ort_session = onnxruntime.InferenceSession(path_onnx.as_posix())
{% endraw %}