--- title: Inference keywords: fastai sidebar: home_sidebar summary: "Classes for inference with model ensembles using Torchscript." description: "Classes for inference with model ensembles using Torchscript." nb_path: "nbs/04_inference.ipynb" ---
gw = gaussian_kernel_2d((256,256), sigma_scale=1/8)
test_eq(gw.max(), 1)
plt.imshow(gw);
Calculation of epistemic and aleatoric uncertainy
ts = 540
os = ts/2
t = TileModule(tile_shape=(ts, ts))
inp = torch.from_numpy(image).float().unsqueeze_(-1)
out = t(inp, (os,os))[0,0]
test_close(out,image, eps=1e-04)
TS = [256, 512, 1024]
SCALES = [0.5, 1., 2.]
SHIFTS = [0.5, 0.9, 1.]
BPFACTORS = [0.25, 0.1, 0.]
inp = torch.from_numpy(image).float().unsqueeze_(-1)
for ts in TS:
for scale in SCALES:
for max_tile_shift in SHIFTS:
for border_padding_factor in BPFACTORS:
t = TileModule(tile_shape=(ts, ts), scale=scale, max_tile_shift=max_tile_shift, border_padding_factor=border_padding_factor)
out = torch.zeros(*[int(x/t.scale) for x in inp.shape[:2]]).unsqueeze_(-1)
in_slices, out_slices, center_points = t.get_slices_and_centers(inp.shape)
assert len(center_points)!=0
for i, cp in enumerate(center_points):
ix0, ix1, iy0, iy1 = in_slices[0][0][i], in_slices[0][1][i], in_slices[1][0][i], in_slices[1][1][i]
ox0, ox1, oy0, oy1 = out_slices[0][0][i], out_slices[0][1][i], out_slices[1][0][i], out_slices[1][1][i]
assert (ix1-ix0) == (ox1-ox0), 'Input/Output slices do not match'
assert (iy1-iy0) == (oy1-oy0), 'Input/Output slices do not match'
tile = t(inp, cp)
out[ox0:ox1, oy0:oy1, 0] += tile[0 ,0, ix0:ix1, iy0:iy1]
#plt.imshow(out)
#plt.show()
if torch.__version__.split(".") >= ["1", "12", "0"]:
print(torch.__version__)
t_scripted = torch.jit.script(t)
ost = torch.tensor(os)
path = Path('TileModule.onnx')
_ = torch.onnx.export(t_scripted, (inp, torch.tensor([ost, ost])), path, verbose=True, opset_version=16)
path.unlink()
class DummyModule(torch.nn.Module):
'Dummy Module for testing'
def __init__(self, num_classes=2):
super().__init__()
self.num_classes = num_classes
def forward(self, x):
return x[:,:1].repeat(1, self.num_classes, 1,1)
CHANNELS = [1,3]
NUM_CLASSES = [2, 5]
SCALES = [0.5, 1., 4.]
IMG_SHAPES = [512]#[256, 512, 1024]
for ch in CHANNELS:
channel_means = [0.]*ch
channel_stds = [1.]*ch
for num_classes in NUM_CLASSES:
models = [DummyModule(num_classes=num_classes) for _ in range(2)]
for sx in IMG_SHAPES:
for sy in IMG_SHAPES:
inp = torch.rand(sx,sy, ch)
for scale in SCALES:
ensemble = InferenceEnsemble(models,
num_classes=num_classes,
in_channels=3,
channel_means=channel_means,
channel_stds=channel_stds,
scale=scale)
ensemble = torch.jit.script(ensemble)
outs = ensemble(inp)
test_eq(outs[0].shape, (sx, sy))
test_eq(outs[1].shape, (num_classes, sx, sy))
test_eq(outs[2].shape, (sx, sy))
path_pt = Path('ensemble.pt')
scripted_ensemble = torch.jit.script(ensemble)
scripted_ensemble.save(path_pt)
_ = torch.jit.load(path_pt)
path_pt.unlink()
if torch.__version__.split(".") >= ["1", "12", "0"]:
print(torch.__version__)
input_names=["inp"]
output_names=["argmax", "softmax", "stdeviation"]
dynamic_axes = {"inp": [0, 1], "argmax": [0, 1], "softmax": [0, 1, 2], "stdeviation": [0, 1]}
path_onnx = Path('ensemble.onnx')
torch.onnx.export(scripted_ensemble, inp, f=path_onnx, verbose=True, opset_version=16,
input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
#import onnxruntime
#ort_session = onnxruntime.InferenceSession(path_onnx.as_posix())