Electroencephalogram example¶

This dataset contains EEG recordings from human neonates. Multi-channel EEG was recorded from 79 term neonates admitted to the neonatal intensive care unit (NICU) at the Helsinki University Hospital. The median recording duration was 74 minutes (IQR: 64 to 96 minutes).

In [1]:
import numpy as np
import mogptk
import pandas as pd
import torch

np.random.seed(1)

MOGP prediction on the sensor values¶

We use eight of the twenty two sensor's values as channel, leaving us with eight channels to fit using multi output Gaussian processes in order to find cross correlations between the channels.

In [2]:
dataset_pd = pd.read_csv('data/eeg.csv',header=0,index_col=0)
cols = ['EEG Fp1-Ref','EEG Fp2-Ref','EEG Fz-Ref','EEG Cz-Ref','EEG T3-Ref','EEG T4-Ref','EEG O1-Ref','EEG O2-Ref']
t = dataset_pd['time'].values
y = dataset_pd[cols].values

data = mogptk.DataSet()
for i in range(len(cols)):
    data.append(mogptk.Data(t, y[:, i], name=cols[i]))

for i, channel in enumerate(data):
    channel.transform(mogptk.TransformNormalize())
    channel.remove_randomly(pct=0.4)
    
    if i not in [0,1, 2, 3, 5, 7]:
        channel.remove_range(45, None)

# simulate sensor failure
data[0].remove_range(25, 35)
data[5].remove_range(None, 10)
data[7].remove_range(None, 10)        
In [3]:
data.plot();

Model training¶

In [6]:
model = mogptk.MOHSM(data, Q=2, P=2)
model.init_parameters('BNSE')
model.train(method='Adam', lr=0.1, iters=400, verbose=True, error='MAE')
Starting optimization using Adam
‣ Model: MOHSM
‣ Channels: 8
‣ Parameters: 204
‣ Training points: 1265
‣ Initial loss: 1249.46
‣ Initial error: 2.88558e-05

Start Adam:
    0/400   0:00:03  loss=     1249.46  error= 2.88558e-05
    4/400   0:00:20  loss=     1204.49  error= 2.86722e-05
    8/400   0:00:37  loss=     1161.65  error= 2.87232e-05
   12/400   0:00:54  loss=     1119.39  error= 2.88007e-05
   16/400   0:01:10  loss=     1077.42  error= 2.87922e-05
   20/400   0:01:26  loss=     1036.37  error= 2.88113e-05
   24/400   0:01:44  loss=     995.901  error= 2.89047e-05
   28/400   0:02:00  loss=     956.301  error= 2.89434e-05
   32/400   0:02:17  loss=     917.417  error=  2.8951e-05
   36/400   0:02:33  loss=      879.38  error= 2.89157e-05
   40/400   0:02:50  loss=     842.203  error= 2.88625e-05
   44/400   0:03:07  loss=     806.025  error= 2.88363e-05
   48/400   0:03:24  loss=     770.834  error= 2.88819e-05
   52/400   0:03:40  loss=      736.62  error= 2.89305e-05
   56/400   0:03:57  loss=     703.367  error= 2.89737e-05
   60/400   0:04:13  loss=     671.035  error= 2.90739e-05
   64/400   0:04:29  loss=     638.972  error= 2.91687e-05
ERROR: torch.linalg_cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1258 is not positive-definite).
NameRangeValue
Mixture[0].MOHSM.weight[1e-08, ∞)[9.56417766 4.31908331 6.50562154 6.46525677 4.81299564 8.56451074 7.02364936 8.90889109]
Mixture[0].MOHSM.mean[1e-08, ∞)[[0.02549104] [0.23747579] [0.0940852 ] [0.05668663] [0.02541778] [0.08369347] [0.20351463] [0.08385606]]
Mixture[0].MOHSM.variance[1e-08, ∞)[[9.84772739e-05] [1.13710101e-04] [5.78241497e-04] [2.64049394e-04] [8.17668270e-05] [1.72471227e-03] [5.43362127e-04] [4.26498943e-03]]
Mixture[0].MOHSM.lengthscale[1e-08, ∞)[0.00405632 0.00200596 0.00288502 0.00222919 0.00338237 0.00263217 0.00183108 0.00262867]
Mixture[0].MOHSM.center(-∞, ∞)[5.69423539]
Mixture[0].MOHSM.delay(-∞, ∞)[[ 3.40386784] [-0.23420072] [ 0.76156157] [ 0.64975614] [-1.16957554] [ 1.26906993] [ 0.23420072] [-1.17658926]]
Mixture[0].MOHSM.phase(-∞, ∞)[ 0.89169945 -0.22572085 0.75435152 0.24842922 1.2647724 -0.36961333 0.22572085 -0.0476424 ]
Mixture[1].MOHSM.weight[1e-08, ∞)[5.92660805 4.16041503 4.91548693 8.2709504 6.1945013 6.13877519 7.37456912 6.03555051]
Mixture[1].MOHSM.mean[1e-08, ∞)[[0.07238981] [0.04048969] [0.07259066] [0.16766437] [0.07252619] [0.07217512] [0.34827757] [0.1049849 ]]
Mixture[1].MOHSM.variance[1e-08, ∞)[[1.03102851e-04] [4.29638223e-05] [5.59670909e-05] [1.19854508e-04] [7.16697260e-05] [1.73783272e-05] [1.87286830e-04] [3.06911893e-05]]
Mixture[1].MOHSM.lengthscale[1e-08, ∞)[0.00243043 0.00202046 0.00265784 0.00302527 0.00275688 0.00200365 0.00203173 0.00186393]
Mixture[1].MOHSM.center(-∞, ∞)[0.51884715]
Mixture[1].MOHSM.delay(-∞, ∞)[[ 2.15652911e-001] [ 2.37925580e-006] [-9.85059126e-002] [ 9.05378422e-077] [ 9.28101091e-001] [-9.91024401e-001] [-3.20607624e-279] [-1.06286387e-003]]
Mixture[1].MOHSM.phase(-∞, ∞)[ 4.01139695e-001 7.56847724e-006 -1.87353904e-001 1.18830096e-076 8.18883066e-001 -9.59353347e-001 -2.18794704e-279 -1.98479618e-003]
Mixture[2].MOHSM.weight[1e-08, ∞)[11.0782579 4.77409199 9.34273408 10.1551508 8.55126663 12.44845653 8.13061063 10.64343961]
Mixture[2].MOHSM.mean[1e-08, ∞)[[0.02827339] [0.23754086] [0.08537728] [0.07348974] [0.02885545] [0.07864727] [0.2033961 ] [0.0759806 ]]
Mixture[2].MOHSM.variance[1e-08, ∞)[[0.00012196] [0.00013199] [0.00065596] [0.00046723] [0.00010792] [0.00231396] [0.00067343] [0.00269152]]
Mixture[2].MOHSM.lengthscale[1e-08, ∞)[0.00128426 0.00409839 0.00325203 0.00371448 0.00332959 0.0029388 0.00417381 0.00165126]
Mixture[2].MOHSM.center(-∞, ∞)[995.33665453]
Mixture[2].MOHSM.delay(-∞, ∞)[[ 2.96948983] [-0.232626 ] [ 0.55447587] [ 2.40706836] [-0.93859438] [ 1.65108312] [ 0.232626 ] [-1.83963697]]
Mixture[2].MOHSM.phase(-∞, ∞)[ 0.98795431 -0.20902396 0.50547442 0.82532578 0.98712018 -0.71767182 0.20902396 0.20449908]
Mixture[3].MOHSM.weight[1e-08, ∞)[ 9.49372105 4.67870156 5.34691321 7.84464581 4.68806397 10.29858104 8.046322 7.26677981]
Mixture[3].MOHSM.mean[1e-08, ∞)[[0.05787629] [0.04050175] [0.06787368] [0.16762761] [0.0683894 ] [0.05785544] [0.34932979] [0.10671217]]
Mixture[3].MOHSM.variance[1e-08, ∞)[[2.81373586e-04] [5.07293468e-05] [5.24319441e-05] [1.33508658e-04] [5.53518473e-05] [2.05231241e-05] [2.24503747e-04] [3.92696639e-05]]
Mixture[3].MOHSM.lengthscale[1e-08, ∞)[0.0017158 0.00394903 0.00445108 0.00313999 0.00424557 0.00247263 0.00403102 0.0039983 ]
Mixture[3].MOHSM.center(-∞, ∞)[998.07657853]
Mixture[3].MOHSM.delay(-∞, ∞)[[ 4.66779554e-001] [-1.16572534e+000] [-6.79629878e-001] [ 9.08460483e-076] [ 5.72472211e-001] [-5.35775274e-001] [-7.88727431e-284] [-3.13114632e-005]]
Mixture[3].MOHSM.phase(-∞, ∞)[ 5.02361011e-001 -1.38349666e+000 -7.40910394e-001 1.19227923e-075 5.49754361e-001 -5.20046184e-001 -5.33713933e-284 -5.86561149e-005]
Gaussian.scale[1e-08, ∞)[0.54075483 0.54517939 0.54945548 0.54801568 0.55589353 0.54054611 0.54745335 0.53813097]
---------------------------------------------------------------------------
_LinAlgError                              Traceback (most recent call last)
~/dev/cmm/mogp/mogptk/mogptk/gpr/model.py in _cholesky(self, K, add_jitter)
    267         try:
--> 268             return torch.linalg.cholesky(K)
    269         except RuntimeError as e:

_LinAlgError: torch.linalg_cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1258 is not positive-definite).

During handling of the above exception, another exception occurred:

CholeskyException                         Traceback (most recent call last)
/tmp/ipykernel_38465/3647760507.py in <module>
      1 model = mogptk.MOHSM(data, Q=2, P=2)
      2 model.init_parameters('BNSE')
----> 3 model.train(method='Adam', lr=0.1, iters=400, verbose=True, error='MAE')

~/dev/cmm/mogp/mogptk/mogptk/model.py in train(self, method, iters, verbose, error, plot, **kwargs)
    436 
    437             for i in range(iters):
--> 438                 progress(i, self.loss())
    439                 optimizer.step()
    440         progress(iters, self.loss())

~/dev/cmm/mogp/mogptk/mogptk/model.py in loss(self)
    264             >>> model.loss()
    265         """
--> 266         return self.gpr.loss().detach().cpu().item()
    267 
    268     def error(self, method='MAE', use_all_data=False):

~/dev/cmm/mogp/mogptk/mogptk/gpr/model.py in loss(self)
    307         """
    308         self.zero_grad()
--> 309         loss = -self.log_marginal_likelihood() - self.log_prior()
    310         loss.backward()
    311         return loss

~/dev/cmm/mogp/mogptk/mogptk/gpr/model.py in log_marginal_likelihood(self)
    398         if self.data_variance is not None:
    399             Kff += self.data_variance
--> 400         L = self._cholesky(Kff, add_jitter=True)  # NxN
    401 
    402         if self.mean is not None:

~/dev/cmm/mogp/mogptk/mogptk/gpr/model.py in _cholesky(self, K, add_jitter)
    275             self.print_parameters()
    276             plot_gram(K)
--> 277             raise CholeskyException(e.args[0], K, self)
    278 
    279     def log_marginal_likelihood(self):

CholeskyException: torch.linalg_cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1258 is not positive-definite).

Prediction¶

In [ ]:
model.plot_prediction(transformed=True);