Uncertainty Estimation using TensorFlow Probability

In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
from tensorflow_probability.python.layers import DenseVariational, DenseReparameterization, DenseFlipout, Convolution2DFlipout, Convolution2DReparameterization
from tensorflow_probability.python.layers import DistributionLambda
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, Dense, Flatten, BatchNormalization, Activation, LeakyReLU
from tensorflow.keras.utils import plot_model
from tensorflow.keras.optimizers import *
tf.compat.v1.enable_eager_execution()

import numpy as np
from scipy.special import softmax
import matplotlib.pyplot as plt

%matplotlib inline

print('TensorFlow version:', tf.__version__)
print('TensorFlow Probability version:', tfp.__version__)
WARNING:tensorflow:

  TensorFlow's `tf-nightly` package will soon be updated to TensorFlow 2.0.

  Please upgrade your code to TensorFlow 2.0:
    * https://www.tensorflow.org/beta/guide/migration_guide

  Or install the latest stable TensorFlow 1.X release:
    * `pip install -U "tensorflow==1.*"`

  Otherwise your code may be broken by the change.

  
TensorFlow version: 1.15.0-dev20190821
TensorFlow Probability version: 0.8.0-dev20190828

Build the dataset for regression

In [2]:
def load_dataset(n, w0, b0, x_low, x_high):
    def s(x):
        g = (x - x_low) / (x_high - x_low)
        return 3 * (0.25 + g**2)
    def f(x, w, b):
        return w * x * (1. + np.sin(x)) + b
    x = (x_high - x_low) * np.random.rand(n) + x_low  # N(x_low, x_high)
    x = np.sort(x)
    eps = np.random.randn(n) * s(x)
    y = f(x, w0, b0) + eps
    return x, y
In [3]:
n_data = 500
n_train = 400
w0 = 0.125
b0 = 5.0
x_low, x_high = -20, 60

X, y = load_dataset(n_data, w0, b0, x_low, x_high)
X = np.expand_dims(X, 1)
y = np.expand_dims(y, 1)

idx_randperm = np.random.permutation(n_data)
idx_train = np.sort(idx_randperm[:n_train])
idx_test = np.sort(idx_randperm[n_train:])

X_train, y_train = X[idx_train], y[idx_train]
X_test = X[idx_test]

print("X_train.shape =", X_train.shape)
print("y_train.shape =", y_train.shape)
print("X_test.shape =", X_test.shape)

plt.scatter(X_train, y_train, marker='+', label='Training data')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Noisy training data and ground truth')
plt.legend()
X_train.shape = (400, 1)
y_train.shape = (400, 1)
X_test.shape = (100, 1)
Out[3]:
<matplotlib.legend.Legend at 0x22645118630>

Traditional point-estimate neural network

Define the loss function of negative log-likelihood (input as prediction)

In [4]:
def neg_log_likelihood_with_dist(y_true, y_pred):
    return -tf.reduce_mean(y_pred.log_prob(y_true))

Define and train the model

In [5]:
batch_size = 100
n_epochs = 3000
lr = 5e-3

def build_point_estimate_model(scale=1):
    model_in = Input(shape=(1,))
    x = Dense(16)(model_in)
    x = LeakyReLU(0.1)(x)
    x = Dense(64)(x)
    x = LeakyReLU(0.1)(x)
    x = Dense(16)(x)
    x = LeakyReLU(0.1)(x)
    x = Dense(1)(x)
    model_out = DistributionLambda(lambda t: tfd.Normal(loc=t, scale=scale))(x)
    model = Model(model_in, model_out)
    return model

pe_model = build_point_estimate_model()
pe_model.compile(loss=neg_log_likelihood_with_dist, optimizer=Adam(lr), metrics=['mse'])
pe_model.summary()
hist = pe_model.fit(X_train, y_train, batch_size=batch_size, epochs=n_epochs, verbose=0)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 1)]               0         
_________________________________________________________________
dense (Dense)                (None, 16)                32        
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                1088      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 64)                0         
_________________________________________________________________
dense_2 (Dense)              (None, 16)                1040      
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 16)                0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 17        
_________________________________________________________________
distribution_lambda (Distrib ((None, 1), (None, 1))    0         
=================================================================
Total params: 2,177
Trainable params: 2,177
Non-trainable params: 0
_________________________________________________________________

Plot the training loss and predict the test data

In [6]:
fig, ax = plt.subplots(2, 1, figsize=(15, 10))
ax[0].plot(range(n_epochs), hist.history['loss'])
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Training loss')
ax[1].plot(range(n_epochs), hist.history['mean_squared_error'])
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Mean squared error')
y_test_pred_pe = pe_model(X_test)

Plot the training and test data

In [7]:
plt.scatter(X_train, y_train, marker='+', label='Training data')
plt.plot(X_test, y_test_pred_pe.mean(), 'r-', marker='+', label='Test data')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Noisy training data and ground truth')
plt.legend()
Out[7]:
<matplotlib.legend.Legend at 0x22715cc0518>

Estimate aleatoric uncertainty

Define and train the model

In [8]:
def build_aleatoric_model():
    model_in = Input(shape=(1,))
    x = Dense(16)(model_in)
    x = LeakyReLU(0.1)(x)
    x = Dense(64)(x)
    x = LeakyReLU(0.1)(x)
    x = Dense(16)(x)
    x = LeakyReLU(0.1)(x)
    model_out_loc = Dense(1)(x)
    model_out_scale = Dense(1)(x)
    model_out = DistributionLambda(lambda t: tfd.Normal(loc=t[0],
                                                        scale=1e-7 + tf.math.softplus(1e-3 * t[1])))([model_out_loc,
                                                                                                      model_out_scale])
    model = Model(model_in, model_out)
    return model

al_model = build_aleatoric_model()
al_model.compile(loss=neg_log_likelihood_with_dist, optimizer=Adam(lr), metrics=['mse'])
al_model.summary()
hist = al_model.fit(X_train, y_train, batch_size=batch_size, epochs=n_epochs, verbose=0)
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 16)           32          input_2[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)       (None, 16)           0           dense_4[0][0]                    
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 64)           1088        leaky_re_lu_3[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)       (None, 64)           0           dense_5[0][0]                    
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 16)           1040        leaky_re_lu_4[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 16)           0           dense_6[0][0]                    
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 1)            17          leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
dense_8 (Dense)                 (None, 1)            17          leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
distribution_lambda_1 (Distribu ((None, 1), (None, 1 0           dense_7[0][0]                    
                                                                 dense_8[0][0]                    
==================================================================================================
Total params: 2,194
Trainable params: 2,194
Non-trainable params: 0
__________________________________________________________________________________________________

Plot the training loss and predict the test data

In [9]:
fig, ax = plt.subplots(2, 1, figsize=(15, 10))
ax[0].plot(range(n_epochs), hist.history['loss'])
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Training loss')
ax[1].plot(range(n_epochs), hist.history['mean_squared_error'])
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Mean squared error')
y_test_pred_al = al_model(X_test)
y_test_pred_al_mean = y_test_pred_al.mean()
y_test_pred_al_stddev = y_test_pred_al.stddev()

Plot the training and test data

In [10]:
plt.scatter(X_train, y_train, marker='+', label='Training data')
plt.plot(X_test, y_test_pred_al_mean, 'r-', marker='+', label='Test data')
plt.fill_between(np.squeeze(X_test), 
                 np.squeeze(y_test_pred_al_mean + 2 * y_test_pred_al_stddev),
                 np.squeeze(y_test_pred_al_mean - 2 * y_test_pred_al_stddev),
                 alpha=0.5, label='Aleatoric uncertainty')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Noisy training data and ground truth')
plt.legend()
Out[10]:
<matplotlib.legend.Legend at 0x2271887fba8>

Estimate epistemic uncertainty

In [11]:
n_epochs = 10000
lr = 5e-3
n_test = 10

Specify the surrogate posterior over kernel and bias of DenseVariational

In [12]:
def posterior_mean_field(kernel_size, bias_size=0, dtype=None):
    n = kernel_size + bias_size
    c = np.log(np.expm1(1.0))
    return Sequential([tfp.layers.VariableLayer(2 * n, dtype=dtype),
                       tfp.layers.DistributionLambda(lambda t: tfd.Independent(
                           tfd.Normal(loc=t[..., :n], scale=1e-7 + tf.nn.softplus(c + t[..., n:])),
                           reinterpreted_batch_ndims=1))
    ])

Specify the prior over kernel and bias of DenseVariational

In [13]:
def prior_trainable(kernel_size, bias_size=0, dtype=None):
    n = kernel_size + bias_size
    return Sequential([tfp.layers.VariableLayer(n, dtype=dtype),
                       tfp.layers.DistributionLambda(lambda t: tfd.Independent(
                           tfd.Normal(loc=t, scale=1.0), reinterpreted_batch_ndims=1)),
    ])

Define and train the model

In [14]:
def build_epistemic_model(train_size, scale=1):
    model_in = Input(shape=(1,))
    x = DenseVariational(16, posterior_mean_field, prior_trainable, kl_weight=1/train_size)(model_in)
    x = LeakyReLU(0.1)(x)
    x = DenseVariational(64, posterior_mean_field, prior_trainable, kl_weight=1/train_size)(x)
    x = LeakyReLU(0.1)(x)
    x = DenseVariational(16, posterior_mean_field, prior_trainable, kl_weight=1/train_size)(x)
    x = LeakyReLU(0.1)(x)
    x = DenseVariational(1, posterior_mean_field, prior_trainable, kl_weight=1/train_size)(x)
    model_out = DistributionLambda(lambda t: tfd.Normal(loc=t, scale=scale))(x)
    model = Model(model_in, model_out)
    return model

ep_model = build_epistemic_model(n_train)
ep_model.compile(loss=neg_log_likelihood_with_dist, optimizer=Adam(lr), metrics=['mse'])
ep_model.summary()
hist = ep_model.fit(X_train, y_train, batch_size=batch_size, epochs=n_epochs, verbose=0)
Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(None, 1)]               0         
_________________________________________________________________
dense_variational (DenseVari (None, 16)                96        
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 16)                0         
_________________________________________________________________
dense_variational_1 (DenseVa (None, 64)                3264      
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 64)                0         
_________________________________________________________________
dense_variational_2 (DenseVa (None, 16)                3120      
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 16)                0         
_________________________________________________________________
dense_variational_3 (DenseVa (None, 1)                 51        
_________________________________________________________________
distribution_lambda_2 (Distr ((None, 1), (None, 1))    0         
=================================================================
Total params: 6,531
Trainable params: 6,531
Non-trainable params: 0
_________________________________________________________________

Show all layer names, weight names and shapes

In [15]:
print({l.name: l.weights for l in ep_model.layers})
{'input_3': [], 'dense_variational': [<tf.Variable 'dense_variational/constant:0' shape=(64,) dtype=float32, numpy=
array([-0.04704537, -0.09531302, -0.05468397, -0.11580291, -0.08971833,
       -0.07729009, -0.05076814, -0.08324057, -0.12371637, -0.05294088,
       -0.08828061, -0.09085072, -0.14504395, -0.1380819 , -0.0870083 ,
       -0.27899066, -1.9161842 , -1.8248898 , -2.0795515 , -2.4224415 ,
       -2.1526618 , -2.125303  , -1.8890059 , -1.7628058 , -2.7586875 ,
       -1.9747292 , -1.9358826 , -1.620749  , -3.2749026 , -2.855188  ,
       -2.1074066 , -3.862741  , -4.304486  , -4.014289  , -4.6453576 ,
       -4.387102  , -4.175872  , -4.225246  , -3.7834823 , -4.292671  ,
       -4.809787  , -3.940326  , -4.010811  , -4.3495092 , -5.026194  ,
       -4.5197973 , -3.9998908 , -5.616435  , -0.9938686 , -0.58735746,
       -0.9358099 , -1.1768152 , -0.9848648 , -1.0890893 , -0.36814827,
       -0.8628556 , -1.5795509 , -0.64255416, -0.7174097 , -0.7457821 ,
       -1.5089829 , -1.5389388 , -0.77139425, -2.554386  ], dtype=float32)>, <tf.Variable 'dense_variational/constant:0' shape=(32,) dtype=float32, numpy=
array([-0.03552532, -0.09396114, -0.05786705, -0.10888591, -0.09632728,
       -0.07216997, -0.03780396, -0.0857416 , -0.11594326, -0.03910633,
       -0.07292157, -0.08665093, -0.15043956, -0.14847077, -0.09428117,
       -0.26843002, -1.9407692 , -1.8078051 , -2.1089375 , -2.379748  ,
       -2.1603942 , -2.217171  , -1.9734799 , -1.740873  , -2.7568405 ,
       -2.0162845 , -2.0360692 , -1.5466782 , -3.2942357 , -2.9302049 ,
       -2.0141025 , -3.8510332 ], dtype=float32)>], 'leaky_re_lu_6': [], 'dense_variational_1': [<tf.Variable 'dense_variational_1/constant:0' shape=(2176,) dtype=float32, numpy=
array([-0.35227278, -0.22739749,  0.00543146, ..., -0.15598762,
       -1.1102257 , -0.4004959 ], dtype=float32)>, <tf.Variable 'dense_variational_1/constant:0' shape=(1088,) dtype=float32, numpy=
array([-0.38646337, -0.14000496, -0.1035843 , ..., -2.9416277 ,
       -4.947252  , -3.6286163 ], dtype=float32)>], 'leaky_re_lu_7': [], 'dense_variational_2': [<tf.Variable 'dense_variational_2/constant:0' shape=(2080,) dtype=float32, numpy=
array([-0.20594649, -0.5097729 , -0.44125563, ..., -0.30167702,
       -0.05558572, -0.003965  ], dtype=float32)>, <tf.Variable 'dense_variational_2/constant:0' shape=(1040,) dtype=float32, numpy=
array([-0.17168942, -0.53252125, -0.35130063, ..., -2.4285772 ,
       -5.924332  , -6.627571  ], dtype=float32)>], 'leaky_re_lu_8': [], 'dense_variational_3': [<tf.Variable 'dense_variational_3/constant:0' shape=(34,) dtype=float32, numpy=
array([ 3.8052346e-03, -8.4818482e-02, -4.8954707e-02, -4.6195488e-02,
       -4.8472621e-02, -5.2426271e-02, -5.6273747e-02, -2.3019690e-02,
       -2.5252184e-02, -4.0765688e-02, -4.5673251e-02, -3.7789512e-02,
       -1.9367348e-02, -6.7534581e-02, -1.0069254e-02, -2.1955183e-02,
        6.0810575e+00, -4.9888983e+00, -5.2000089e+00, -4.8150182e+00,
       -5.1294537e+00, -5.0586605e+00, -5.9611487e+00, -5.3447561e+00,
       -4.8540487e+00, -5.1902680e+00, -4.9880152e+00, -5.2330332e+00,
       -4.8953261e+00, -4.9589343e+00, -5.5525970e+00, -4.8847165e+00,
       -5.0921049e+00, -3.6405253e+00], dtype=float32)>, <tf.Variable 'dense_variational_3/constant:0' shape=(17,) dtype=float32, numpy=
array([ 3.0475669e-03, -8.7187938e-02, -5.3479671e-02, -5.3049676e-02,
       -4.1727580e-02, -5.1039308e-02, -5.0959606e-02, -1.7846076e-02,
       -3.1419937e-02, -3.1960133e-02, -4.2406023e-02, -3.1950708e-02,
       -1.7996188e-02, -6.4437412e-02,  3.7976943e-03, -2.0693699e-02,
        6.0708361e+00], dtype=float32)>], 'distribution_lambda_2': ListWrapper([])}

Plot the training loss and predict the test data

In [16]:
fig, ax = plt.subplots(2, 1, figsize=(15, 10))
ax[0].plot(range(n_epochs), hist.history['loss'])
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Training loss')
ax[0].set_yscale('log')
ax[1].plot(range(n_epochs), hist.history['mean_squared_error'])
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Mean squared error')
ax[1].set_yscale('log')
y_test_pred_ep_list = [ep_model(X_test) for _ in range(n_test)]

Plot the training and test data

In [17]:
plt.scatter(X_train, y_train, marker='+', label='Training data')
avg_mean = np.zeros_like(X_test)
for i, y in enumerate(y_test_pred_ep_list):
    y_mean = y.mean()
    plt.plot(X_test, y_mean, 'r-', marker='+', label='Ensemble tests' if i == 0 else None, linewidth=0.5)
    avg_mean += y_mean
plt.plot(X_test, avg_mean/n_test, 'g-', marker='+', label='Averaged tests', linewidth=2)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Noisy training data and ground truth')
plt.legend()
Out[17]:
<matplotlib.legend.Legend at 0x2273848c320>

Estimate aleatoric + epistemic uncertainty

In [18]:
n_epochs = 30000

Define and train the model

In [19]:
def build_aleatoric_epistemic_model(train_size):
    model_in = Input(shape=(1,))
    x = DenseVariational(16, posterior_mean_field, prior_trainable, kl_weight=1/train_size)(model_in)
    x = LeakyReLU(0.1)(x)
    x = DenseVariational(64, posterior_mean_field, prior_trainable, kl_weight=1/train_size)(x)
    x = LeakyReLU(0.1)(x)
    x = DenseVariational(16, posterior_mean_field, prior_trainable, kl_weight=1/train_size)(x)
    x = LeakyReLU(0.1)(x)
    model_out_loc = DenseVariational(1, posterior_mean_field, prior_trainable, kl_weight=1/train_size)(x)
    model_out_scale = DenseVariational(1, posterior_mean_field, prior_trainable, kl_weight=1/train_size)(x)
    model_out = DistributionLambda(lambda t: tfd.Normal(loc=t[0],
                                                        scale=1e-7 + tf.math.softplus(1e-3 * t[1])))([model_out_loc,
                                                                                                      model_out_scale])
    model = Model(model_in, model_out)
    return model

ae_model = build_aleatoric_epistemic_model(n_train)
ae_model.compile(loss=neg_log_likelihood_with_dist, optimizer=Adam(lr), metrics=['mse'])
ae_model.summary()
hist = ae_model.fit(X_train, y_train, batch_size=batch_size, epochs=n_epochs, verbose=0)
Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
dense_variational_4 (DenseVaria (None, 16)           96          input_4[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, 16)           0           dense_variational_4[0][0]        
__________________________________________________________________________________________________
dense_variational_5 (DenseVaria (None, 64)           3264        leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)      (None, 64)           0           dense_variational_5[0][0]        
__________________________________________________________________________________________________
dense_variational_6 (DenseVaria (None, 16)           3120        leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)      (None, 16)           0           dense_variational_6[0][0]        
__________________________________________________________________________________________________
dense_variational_7 (DenseVaria (None, 1)            51          leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
dense_variational_8 (DenseVaria (None, 1)            51          leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
distribution_lambda_3 (Distribu ((None, 1), (None, 1 0           dense_variational_7[0][0]        
                                                                 dense_variational_8[0][0]        
==================================================================================================
Total params: 6,582
Trainable params: 6,582
Non-trainable params: 0
__________________________________________________________________________________________________

Show all layer names, weight names and shapes

In [20]:
print({l.name: l.weights for l in ae_model.layers})
{'input_4': [], 'dense_variational_4': [<tf.Variable 'dense_variational_4/constant:0' shape=(64,) dtype=float32, numpy=
array([-0.09862926, -0.11476649, -0.07899566, -0.04197057, -0.02217351,
       -0.11380268, -0.0481223 , -0.09360223, -0.03561402, -0.03081092,
       -0.33977365, -0.05825329, -0.06220264, -0.05583094, -0.08798434,
       -0.12000691, -2.5273845 , -2.4766982 , -2.3864813 , -2.483623  ,
       -2.381097  , -2.2686546 , -2.4497914 , -2.4368212 , -2.5325873 ,
       -2.433751  , -3.2118866 , -2.3995125 , -2.2658179 , -2.3294208 ,
       -2.441587  , -2.4014895 , -3.1474376 , -3.046749  , -3.3582273 ,
       -3.1774943 , -3.0754395 , -3.0821257 , -3.068309  , -3.1905553 ,
       -3.0303633 , -3.1168997 , -4.9372067 , -3.0782003 , -3.0511806 ,
       -3.1341414 , -3.0179408 , -3.1620212 , -0.41232404, -0.23841205,
       -0.47780365, -0.3329982 , -0.1601861 , -0.29951492, -0.15853305,
       -0.37995142, -0.05261213, -0.1921717 , -1.9523746 , -0.28087547,
       -0.24126856, -0.26046717, -0.19583614, -0.22576518], dtype=float32)>, <tf.Variable 'dense_variational_4/constant:0' shape=(32,) dtype=float32, numpy=
array([-0.09596992, -0.10569854, -0.09335621, -0.0595726 ,  0.0146834 ,
       -0.10611212, -0.08417413, -0.1043053 , -0.06240533, -0.05138741,
       -0.34233236, -0.11072929, -0.03926694, -0.0545993 , -0.09282467,
       -0.10566173, -2.4582586 , -2.4655597 , -2.385962  , -2.4510012 ,
       -2.3168685 , -2.3953304 , -2.327874  , -2.3989172 , -2.479299  ,
       -2.555362  , -3.2063603 , -2.354704  , -2.3181937 , -2.2886453 ,
       -2.490134  , -2.3655572 ], dtype=float32)>], 'leaky_re_lu_9': [], 'dense_variational_5': [<tf.Variable 'dense_variational_5/constant:0' shape=(2176,) dtype=float32, numpy=
array([-0.07682855, -0.10121486,  0.2486316 , ...,  0.00329416,
       -0.04051789, -0.11259051], dtype=float32)>, <tf.Variable 'dense_variational_5/constant:0' shape=(1088,) dtype=float32, numpy=
array([-0.03216608, -0.02258643,  0.3282548 , ..., -2.2870789 ,
       -2.5762296 , -1.3740698 ], dtype=float32)>], 'leaky_re_lu_10': [], 'dense_variational_6': [<tf.Variable 'dense_variational_6/constant:0' shape=(2080,) dtype=float32, numpy=
array([-0.137996  , -0.9982297 , -0.68139344, ..., -0.02119915,
        0.01147775, -0.01906822], dtype=float32)>, <tf.Variable 'dense_variational_6/constant:0' shape=(1040,) dtype=float32, numpy=
array([-0.07852437, -0.94465953, -0.7332415 , ..., -2.213003  ,
       -2.8716516 , -2.2289624 ], dtype=float32)>], 'leaky_re_lu_11': [], 'dense_variational_7': [<tf.Variable 'dense_variational_7/constant:0' shape=(34,) dtype=float32, numpy=
array([-0.00993999,  0.13302565, -0.05715121,  0.01200709, -0.03035622,
       -0.03412971, -0.01329443, -0.0100791 ,  0.02336478, -0.02682818,
       -0.01447157,  0.01115929,  0.00999235, -0.0528665 , -0.03448028,
       -0.03277865,  4.297745  , -3.8062973 , -5.5671377 , -3.8215764 ,
       -4.1935534 , -3.8484457 , -4.041553  , -4.1952405 , -4.2259684 ,
       -7.6016836 , -3.941422  , -3.8715405 , -3.911588  , -4.1214776 ,
       -3.9944503 , -3.907611  , -4.1094184 , -2.6508656 ], dtype=float32)>, <tf.Variable 'dense_variational_7/constant:0' shape=(17,) dtype=float32, numpy=
array([-0.01208552,  0.13188832, -0.05139474,  0.0077449 , -0.00782503,
       -0.03098016, -0.00436794, -0.00431121,  0.01966239, -0.04230103,
       -0.00754995, -0.0179634 ,  0.02300935, -0.03148894, -0.0397683 ,
       -0.02528351,  4.3014326 ], dtype=float32)>], 'dense_variational_8': [<tf.Variable 'dense_variational_8/constant:0' shape=(34,) dtype=float32, numpy=
array([-4.7227371e-01, -6.2047906e+00, -6.4698035e-01, -3.0392554e-01,
       -6.3928866e-01, -2.0529075e+00, -4.9356914e-01, -8.4995079e-01,
        1.5248514e+01, -5.4005849e-01, -2.3926420e+00, -9.0026522e-01,
       -1.1381882e+00, -1.5625995e+00, -4.9507833e-01, -9.1090596e-01,
        4.5232415e+00,  9.1445476e-02, -1.4099446e-02, -4.9302932e-02,
       -9.3029387e-02, -6.0380854e-02, -1.8556746e-02,  2.9846998e-03,
        3.8176846e-02, -8.9852488e-01, -3.5951409e-02, -4.6424832e-02,
       -4.9651567e-02, -6.4862065e-02,  5.8297817e-02,  3.0502148e-02,
        1.0854327e-02, -1.0304969e-02], dtype=float32)>, <tf.Variable 'dense_variational_8/constant:0' shape=(17,) dtype=float32, numpy=
array([-0.4323448 , -6.1359973 , -0.6771927 , -0.36630574, -0.68904454,
       -1.8776996 , -0.49424997, -0.9684594 , 15.205367  , -0.56479084,
       -2.3902519 , -1.0373361 , -1.1890415 , -1.5385239 , -0.66423213,
       -0.9370167 ,  4.464679  ], dtype=float32)>], 'distribution_lambda_3': ListWrapper([])}

Plot the training loss and predict the test data

In [21]:
fig, ax = plt.subplots(2, 1, figsize=(15, 10))
ax[0].plot(range(n_epochs), hist.history['loss'])
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Training loss')
ax[0].set_yscale('log')
ax[1].plot(range(n_epochs), hist.history['mean_squared_error'])
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Mean squared error')
ax[1].set_yscale('log')
y_test_pred_ae_list = [ae_model(X_test) for _ in range(n_test)]

Plot the training and test data

In [22]:
plt.scatter(X_train, y_train, marker='+', label='Training data')
avg_mean = np.zeros_like(X_test)
for i, y in enumerate(y_test_pred_ae_list):
    y_mean = y.mean()
    y_stddev = y.stddev()
    plt.plot(X_test, y_mean, 'r-', marker='+', label='Ensemble tests' if i == 0 else None, linewidth=0.5)
    plt.fill_between(np.squeeze(X_test), 
                     np.squeeze(y_mean + 2 * y_stddev),
                     np.squeeze(y_mean - 2 * y_stddev),
                     alpha=0.5, label='Aleatoric uncertainties' if i == 0 else None)
    avg_mean += y_mean
plt.xlabel('x')
plt.ylabel('y')
plt.title('Noisy training data and ground truth')
plt.legend()
Out[22]:
<matplotlib.legend.Legend at 0x22747ed3c88>

Uncertainties in CNNs

Define the loss function of negative log-likelihood (input as logits)

In [23]:
def neg_log_likelihood_with_logits(y_true, y_pred):
    y_pred_dist = tfp.distributions.Categorical(logits=y_pred)
    return -tf.reduce_mean(y_pred_dist.log_prob(tf.argmax(y_true, axis=-1)))

Load MNIST dataset

In [24]:
n_class = 10

batch_size = 128
n_epochs = 20
lr = 1e-3

print('Loading MNIST dataset')
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = np.expand_dims(X_train, -1)
n_train = X_train.shape[0]
X_test = np.expand_dims(X_test, -1)
y_train = tf.keras.utils.to_categorical(y_train, n_class)
y_test = tf.keras.utils.to_categorical(y_test, n_class)

# Normalize data
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255

print("X_train.shape =", X_train.shape)
print("y_train.shape =", y_train.shape)
print("X_test.shape =", X_test.shape)
print("y_test.shape =", y_test.shape)

plt.imshow(X_train[0, :, :, 0], cmap='gist_gray')
Loading MNIST dataset
X_train.shape = (60000, 28, 28, 1)
y_train.shape = (60000, 10)
X_test.shape = (10000, 28, 28, 1)
y_test.shape = (10000, 10)
Out[24]:
<matplotlib.image.AxesImage at 0x227480fc080>

Define the kernel divergence function that comes with a weight

In [25]:
def get_kernel_divergence_fn(train_size, w=1.0):
    """
    Get the kernel Kullback-Leibler divergence function

    # Arguments
        train_size (int): size of the training dataset for normalization
        w (float): weight to the function

    # Returns
        kernel_divergence_fn: kernel Kullback-Leibler divergence function
    """
    def kernel_divergence_fn(q, p, _):  # need the third ignorable argument
        kernel_divergence = tfp.distributions.kl_divergence(q, p) / tf.cast(train_size, tf.float32)
        return w * kernel_divergence
    return kernel_divergence_fn
In [26]:
def add_kl_weight(layer, train_size, w_value=1.0):
    w = layer.add_weight(name=layer.name+'/kl_loss_weight', shape=(),
                         initializer=tf.initializers.constant(w_value), trainable=False)
    layer.kernel_divergence_fn = get_kernel_divergence_fn(train_size, w)
    return layer

Build and train the Bayesian CNN model

In [27]:
def build_bayesian_bcnn_model(input_shape, train_size):
    model_in = Input(shape=input_shape)
    conv_1 = Convolution2DFlipout(32, kernel_size=(3, 3), padding="same", strides=2,
                                  kernel_divergence_fn=None)
    conv_1 = add_kl_weight(conv_1, train_size)
    x = conv_1(model_in)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    conv_2 = Convolution2DFlipout(64, kernel_size=(3, 3), padding="same", strides=2,
                                  kernel_divergence_fn=None)
    conv_2 = add_kl_weight(conv_2, train_size)
    x = conv_2(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Flatten()(x)
    dense_1 = DenseFlipout(512, activation='relu',
                           kernel_divergence_fn=None)
    dense_1 = add_kl_weight(dense_1, train_size)
    x = dense_1(x)
    dense_2 = DenseFlipout(10, activation=None,
                           kernel_divergence_fn=None)
    dense_2 = add_kl_weight(dense_2, train_size)
    model_out = dense_2(x)  # logits
    model = Model(model_in, model_out)
    return model
    
bcnn_model = build_bayesian_bcnn_model(X_train.shape[1:], n_train)
bcnn_model.compile(loss=neg_log_likelihood_with_logits, optimizer=Adam(lr), metrics=['acc'],
                   experimental_run_tf_function=False)
bcnn_model.summary()
hist = bcnn_model.fit(X_train, y_train, batch_size=batch_size, epochs=n_epochs, verbose=1, validation_split=0.1)
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\nightly\lib\site-packages\tensorflow_probability\python\layers\util.py:103: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_flipout (Conv2DFlipou (None, 14, 14, 32)        609       
_________________________________________________________________
batch_normalization (BatchNo (None, 14, 14, 32)        128       
_________________________________________________________________
activation (Activation)      (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_flipout_1 (Conv2DFlip (None, 7, 7, 64)          36929     
_________________________________________________________________
batch_normalization_1 (Batch (None, 7, 7, 64)          256       
_________________________________________________________________
activation_1 (Activation)    (None, 7, 7, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 3136)              0         
_________________________________________________________________
dense_flipout (DenseFlipout) (None, 512)               3211777   
_________________________________________________________________
dense_flipout_1 (DenseFlipou (None, 10)                10251     
=================================================================
Total params: 3,259,950
Trainable params: 3,259,754
Non-trainable params: 196
_________________________________________________________________
Train on 54000 samples, validate on 6000 samples
Epoch 1/20
54000/54000 [==============================] - 13s 244us/sample - loss: 67.6189 - acc: 0.8352 - val_loss: 64.8947 - val_acc: 0.9190
Epoch 2/20
54000/54000 [==============================] - 10s 186us/sample - loss: 62.1857 - acc: 0.9376 - val_loss: 59.3274 - val_acc: 0.9502a
Epoch 3/20
54000/54000 [==============================] - 10s 185us/sample - loss: 56.3240 - acc: 0.9561 - val_loss: 53.2225 - val_acc: 0.9647
Epoch 4/20
54000/54000 [==============================] - 10s 186us/sample - loss: 50.0751 - acc: 0.9654 - val_loss: 46.8906 - val_acc: 0.9678
Epoch 5/20
54000/54000 [==============================] - 10s 185us/sample - loss: 43.7388 - acc: 0.9709 - val_loss: 40.6024 - val_acc: 0.9720
Epoch 6/20
54000/54000 [==============================] - 10s 185us/sample - loss: 37.5904 - acc: 0.9739 - val_loss: 34.6494 - val_acc: 0.9740
Epoch 7/20
54000/54000 [==============================] - 10s 185us/sample - loss: 31.8904 - acc: 0.9759 - val_loss: 29.2451 - val_acc: 0.9780.9775 - a
Epoch 8/20
54000/54000 [==============================] - 10s 186us/sample - loss: 26.8405 - acc: 0.9766 - val_loss: 24.5636 - val_acc: 0.9758 1s - loss: 27.1942 - acc - ETA: 1s - loss: 27.1280 - - ETA: 0s - loss: 27. - ETA: 0s - loss: 26.8883 - acc: 0. - ETA: 0s - loss: 26.8613 - acc: 0.97 - ETA: 0s - loss: 26.8454 - acc: 0.97
Epoch 9/20
54000/54000 [==============================] - 10s 186us/sample - loss: 22.5185 - acc: 0.9778 - val_loss: 20.6052 - val_acc: 0.9783
Epoch 10/20
54000/54000 [==============================] - 10s 186us/sample - loss: 18.9320 - acc: 0.9789 - val_loss: 17.3798 - val_acc: 0.9785
Epoch 11/20
54000/54000 [==============================] - 10s 186us/sample - loss: 16.0087 - acc: 0.9801 - val_loss: 14.7672 - val_acc: 0.9767
Epoch 12/20
54000/54000 [==============================] - 10s 186us/sample - loss: 13.6396 - acc: 0.9793 - val_loss: 12.6266 - val_acc: 0.9780ETA: 1s - loss: 13.7973 - acc: 
Epoch 13/20
54000/54000 [==============================] - 10s 186us/sample - loss: 11.7202 - acc: 0.9798 - val_loss: 10.8952 - val_acc: 0.9788- a
Epoch 14/20
54000/54000 [==============================] - 10s 186us/sample - loss: 10.1429 - acc: 0.9803 - val_loss: 9.4645 - val_acc: 0.9793
Epoch 15/20
54000/54000 [==============================] - 10s 187us/sample - loss: 8.8226 - acc: 0.9805 - val_loss: 8.2442 - val_acc: 0.9792- ETA: 3s - loss: 9.0570 - acc: 0.981 - ETA: 3s - loss - ETA: 2s - loss: 8.9605 - acc: 0. - ETA: 2s - loss: 8.9449 - ETA: 1s - loss: 8.88
Epoch 16/20
54000/54000 [==============================] - 10s 186us/sample - loss: 7.7130 - acc: 0.9809 - val_loss: 7.2372 - val_acc: 0.97730733 - acc: 0. - ETA: 6s - loss: 8.0614 -  - ETA: 5s - loss: 8.0244 - -
Epoch 17/20
54000/54000 [==============================] - 10s 186us/sample - loss: 6.7763 - acc: 0.9811 - val_loss: 6.3775 - val_acc: 0.9790
Epoch 18/20
54000/54000 [==============================] - 10s 186us/sample - loss: 5.9771 - acc: 0.9811 - val_loss: 5.6293 - val_acc: 0.9788
Epoch 19/20
54000/54000 [==============================] - 10s 186us/sample - loss: 5.2980 - acc: 0.9804 - val_loss: 5.0170 - val_acc: 0.9783
Epoch 20/20
54000/54000 [==============================] - 10s 186us/sample - loss: 4.7217 - acc: 0.9806 - val_loss: 4.4656 - val_acc: 0.9820

Quantify the uncertainty in predictions

In [28]:
n_mc_run = 100
med_prob_thres = 0.2

y_pred_logits_list = [bcnn_model.predict(X_test) for _ in range(n_mc_run)]  # a list of predicted logits
y_pred_prob_all = np.concatenate([softmax(y, axis=-1)[:, :, np.newaxis] for y in y_pred_logits_list], axis=-1)
y_pred = [[int(np.median(y) >= med_prob_thres) for y in y_pred_prob] for y_pred_prob in y_pred_prob_all]
y_pred = np.array(y_pred)

idx_valid = [any(y) for y in y_pred]
print('Number of recognizable samples:', sum(idx_valid))

idx_invalid = [not any(y) for y in y_pred]
print('Unrecognizable samples:', np.where(idx_invalid)[0])

print('Test accuracy on MNIST (recognizable samples):',
      sum(np.equal(np.argmax(y_test[idx_valid], axis=-1), np.argmax(y_pred[idx_valid], axis=-1))) / len(y_test[idx_valid]))

print('Test accuracy on MNIST (unrecognizable samples):',
      sum(np.equal(np.argmax(y_test[idx_invalid], axis=-1), np.argmax(y_pred[idx_invalid], axis=-1))) / len(y_test[idx_invalid]))
Number of recognizable samples: 9992
Unrecognizable samples: [1039 1319 1790 2293 3808 4248 6572 6651]
Test accuracy on MNIST (recognizable samples): 0.9885908726981585
Test accuracy on MNIST (unrecognizable samples): 0.125

Define the function that plots the histogram of predicted probabilities across all possible classes

In [29]:
def plot_pred_hist(y_pred, n_class, n_mc_run, n_bins=30, med_prob_thres=0.2, n_subplot_rows=2, figsize=(25, 10)):
    bins = np.logspace(-n_bins, 0, n_bins+1)
    fig, ax = plt.subplots(n_subplot_rows, n_class // n_subplot_rows + 1, figsize=figsize)
    for i in range(n_subplot_rows):
        for j in range(n_class // n_subplot_rows + 1):
            idx = i * (n_class // n_subplot_rows + 1) + j
            if idx < n_class:
                ax[i, j].hist(y_pred[idx], bins)
                ax[i, j].set_xscale('log')
                ax[i, j].set_ylim([0, n_mc_run])
                ax[i, j].title.set_text("{} (median prob: {:.2f}) ({})".format(str(idx),
                                                                               np.median(y_pred[idx]),
                                                                               str(np.median(y_pred[idx]) >= med_prob_thres)))
            else:
                ax[i, j].axis('off')
    plt.show()

A recognizable example

In [30]:
idx = 0
plt.imshow(X_test[idx, :, :, 0], cmap='gist_gray')
print("True label of the test sample {}: {}".format(idx, np.argmax(y_test[idx], axis=-1)))

plot_pred_hist(y_pred_prob_all[idx], n_class, n_mc_run, med_prob_thres=med_prob_thres)

if any(y_pred[idx]):
    print("Predicted label of the test sample {}: {}".format(idx, np.argmax(y_pred[idx], axis=-1)))
else:
    print("I don't know!")
True label of the test sample 0: 7
Predicted label of the test sample 0: 7

Unrecognizable examples

In [31]:
for idx in np.where(idx_invalid)[0]:
    plt.imshow(X_test[idx, :, :, 0], cmap='gist_gray')
    print("True label of the test sample {}: {}".format(idx, np.argmax(y_test[idx], axis=-1)))

    plot_pred_hist(y_pred_prob_all[idx], n_class, n_mc_run, med_prob_thres=med_prob_thres)

    if any(y_pred[idx]):
        print("Predicted label of the test sample {}: {}".format(idx, np.argmax(y_pred[idx], axis=-1)))
    else:
        print("I don't know!")
True label of the test sample 1039: 7
I don't know!
True label of the test sample 1319: 8
I don't know!
True label of the test sample 1790: 2
I don't know!
True label of the test sample 2293: 9
I don't know!
True label of the test sample 3808: 7
I don't know!
True label of the test sample 4248: 2
I don't know!
True label of the test sample 6572: 1
I don't know!
True label of the test sample 6651: 0
I don't know!

Load Fashion-MNIST dataset

In [32]:
print('Loading Fashion-MNIST dataset')
_, (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

X_test = np.expand_dims(X_test, -1)
y_test = tf.keras.utils.to_categorical(y_test, n_class)

print("X_test.shape =", X_test.shape)
print("y_test.shape =", y_test.shape)
Loading Fashion-MNIST dataset
X_test.shape = (10000, 28, 28, 1)
y_test.shape = (10000, 10)

Quantify the uncertainty in predictions

In [33]:
y_pred_logits_list = [bcnn_model.predict(X_test) for _ in range(n_mc_run)]  # a list of predicted logits
y_pred_prob_all = np.concatenate([softmax(y, axis=-1)[:, :, np.newaxis] for y in y_pred_logits_list], axis=-1)
y_pred = [[int(np.median(y) >= med_prob_thres) for y in y_pred_prob] for y_pred_prob in y_pred_prob_all]
y_pred = np.array(y_pred)

idx_valid = [any(y) for y in y_pred]
print('Number of recognizable samples:', sum(idx_valid))

idx_invalid = [not any(y) for y in y_pred]
print('Unrecognizable samples:', np.where(idx_invalid)[0])

print('Test accuracy on MNIST (recognizable samples):',
      sum(np.equal(np.argmax(y_test[idx_valid], axis=-1), np.argmax(y_pred[idx_valid], axis=-1))) / len(y_test[idx_valid]))

print('Test accuracy on MNIST (unrecognizable samples):',
      sum(np.equal(np.argmax(y_test[idx_invalid], axis=-1), np.argmax(y_pred[idx_invalid], axis=-1))) / len(y_test[idx_invalid]))
Number of recognizable samples: 1055
Unrecognizable samples: [   1    2    3 ... 9997 9998 9999]
Test accuracy on MNIST (recognizable samples): 0.02085308056872038
Test accuracy on MNIST (unrecognizable samples): 0.1106763555058692

An unrecognizable example

In [34]:
idx = np.where(idx_invalid)[0][0]
plt.imshow(X_test[idx, :, :, 0], cmap='gist_gray')
print("True label of the test sample {}: {}".format(idx, np.argmax(y_test[idx], axis=-1)))

plot_pred_hist(y_pred_prob_all[idx], n_class, n_mc_run, n_bins=50, med_prob_thres=med_prob_thres)

if any(y_pred[idx]):
    print("Predicted label of the test sample {}: {}".format(idx, np.argmax(y_pred[idx], axis=-1)))
else:
    print("I don't know!")
True label of the test sample 1: 2
I don't know!

Generate an HTML version of this notebook

In [35]:
!!python -m nbconvert *.ipynb
Out[35]:
['[NbConvertApp] Converting notebook tfp_bnn.ipynb to html',
 '[NbConvertApp] Writing 1138460 bytes to tfp_bnn.html']