This dataset contains EEG recordings from human neonates. Multi-channel EEG was recorded from 79 term neonates admitted to the neonatal intensive care unit (NICU) at the Helsinki University Hospital. The median recording duration was 74 minutes (IQR: 64 to 96 minutes).
import numpy as np
import mogptk
import pandas as pd
import torch
np.random.seed(1)
We use eight of the twenty two sensor's values as channel, leaving us with eight channels to fit using multi output Gaussian processes in order to find cross correlations between the channels.
dataset_pd = pd.read_csv('data/eeg.csv',header=0,index_col=0)
cols = ['EEG Fp1-Ref','EEG Fp2-Ref','EEG Fz-Ref','EEG Cz-Ref','EEG T3-Ref','EEG T4-Ref','EEG O1-Ref','EEG O2-Ref']
t = dataset_pd['time'].values
y = dataset_pd[cols].values
data = mogptk.DataSet()
for i in range(len(cols)):
data.append(mogptk.Data(t, y[:, i], name=cols[i]))
for i, channel in enumerate(data):
channel.transform(mogptk.TransformNormalize())
channel.remove_randomly(pct=0.4)
if i not in [0,1, 2, 3, 5, 7]:
channel.remove_range(45, None)
# simulate sensor failure
data[0].remove_range(25, 35)
data[5].remove_range(None, 10)
data[7].remove_range(None, 10)
data.plot();
model = mogptk.MOHSM(data, Q=2, P=2)
model.init_parameters('BNSE')
model.train(method='Adam', lr=0.1, iters=400, verbose=True, error='MAE')
Starting optimization using Adam ‣ Model: MOHSM ‣ Channels: 8 ‣ Parameters: 204 ‣ Training points: 1265 ‣ Initial loss: 1249.46 ‣ Initial error: 2.88558e-05 Start Adam: 0/400 0:00:03 loss= 1249.46 error= 2.88558e-05 4/400 0:00:20 loss= 1204.49 error= 2.86722e-05 8/400 0:00:37 loss= 1161.65 error= 2.87232e-05 12/400 0:00:54 loss= 1119.39 error= 2.88007e-05 16/400 0:01:10 loss= 1077.42 error= 2.87922e-05 20/400 0:01:26 loss= 1036.37 error= 2.88113e-05 24/400 0:01:44 loss= 995.901 error= 2.89047e-05 28/400 0:02:00 loss= 956.301 error= 2.89434e-05 32/400 0:02:17 loss= 917.417 error= 2.8951e-05 36/400 0:02:33 loss= 879.38 error= 2.89157e-05 40/400 0:02:50 loss= 842.203 error= 2.88625e-05 44/400 0:03:07 loss= 806.025 error= 2.88363e-05 48/400 0:03:24 loss= 770.834 error= 2.88819e-05 52/400 0:03:40 loss= 736.62 error= 2.89305e-05 56/400 0:03:57 loss= 703.367 error= 2.89737e-05 60/400 0:04:13 loss= 671.035 error= 2.90739e-05 64/400 0:04:29 loss= 638.972 error= 2.91687e-05 ERROR: torch.linalg_cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1258 is not positive-definite).
Name | Range | Value |
---|---|---|
Mixture[0].MOHSM.weight | [1e-08, ∞) | [9.56417766 4.31908331 6.50562154 6.46525677 4.81299564 8.56451074 7.02364936 8.90889109] |
Mixture[0].MOHSM.mean | [1e-08, ∞) | [[0.02549104] [0.23747579] [0.0940852 ] [0.05668663] [0.02541778] [0.08369347] [0.20351463] [0.08385606]] |
Mixture[0].MOHSM.variance | [1e-08, ∞) | [[9.84772739e-05] [1.13710101e-04] [5.78241497e-04] [2.64049394e-04] [8.17668270e-05] [1.72471227e-03] [5.43362127e-04] [4.26498943e-03]] |
Mixture[0].MOHSM.lengthscale | [1e-08, ∞) | [0.00405632 0.00200596 0.00288502 0.00222919 0.00338237 0.00263217 0.00183108 0.00262867] |
Mixture[0].MOHSM.center | (-∞, ∞) | [5.69423539] |
Mixture[0].MOHSM.delay | (-∞, ∞) | [[ 3.40386784] [-0.23420072] [ 0.76156157] [ 0.64975614] [-1.16957554] [ 1.26906993] [ 0.23420072] [-1.17658926]] |
Mixture[0].MOHSM.phase | (-∞, ∞) | [ 0.89169945 -0.22572085 0.75435152 0.24842922 1.2647724 -0.36961333 0.22572085 -0.0476424 ] |
Mixture[1].MOHSM.weight | [1e-08, ∞) | [5.92660805 4.16041503 4.91548693 8.2709504 6.1945013 6.13877519 7.37456912 6.03555051] |
Mixture[1].MOHSM.mean | [1e-08, ∞) | [[0.07238981] [0.04048969] [0.07259066] [0.16766437] [0.07252619] [0.07217512] [0.34827757] [0.1049849 ]] |
Mixture[1].MOHSM.variance | [1e-08, ∞) | [[1.03102851e-04] [4.29638223e-05] [5.59670909e-05] [1.19854508e-04] [7.16697260e-05] [1.73783272e-05] [1.87286830e-04] [3.06911893e-05]] |
Mixture[1].MOHSM.lengthscale | [1e-08, ∞) | [0.00243043 0.00202046 0.00265784 0.00302527 0.00275688 0.00200365 0.00203173 0.00186393] |
Mixture[1].MOHSM.center | (-∞, ∞) | [0.51884715] |
Mixture[1].MOHSM.delay | (-∞, ∞) | [[ 2.15652911e-001] [ 2.37925580e-006] [-9.85059126e-002] [ 9.05378422e-077] [ 9.28101091e-001] [-9.91024401e-001] [-3.20607624e-279] [-1.06286387e-003]] |
Mixture[1].MOHSM.phase | (-∞, ∞) | [ 4.01139695e-001 7.56847724e-006 -1.87353904e-001 1.18830096e-076 8.18883066e-001 -9.59353347e-001 -2.18794704e-279 -1.98479618e-003] |
Mixture[2].MOHSM.weight | [1e-08, ∞) | [11.0782579 4.77409199 9.34273408 10.1551508 8.55126663 12.44845653 8.13061063 10.64343961] |
Mixture[2].MOHSM.mean | [1e-08, ∞) | [[0.02827339] [0.23754086] [0.08537728] [0.07348974] [0.02885545] [0.07864727] [0.2033961 ] [0.0759806 ]] |
Mixture[2].MOHSM.variance | [1e-08, ∞) | [[0.00012196] [0.00013199] [0.00065596] [0.00046723] [0.00010792] [0.00231396] [0.00067343] [0.00269152]] |
Mixture[2].MOHSM.lengthscale | [1e-08, ∞) | [0.00128426 0.00409839 0.00325203 0.00371448 0.00332959 0.0029388 0.00417381 0.00165126] |
Mixture[2].MOHSM.center | (-∞, ∞) | [995.33665453] |
Mixture[2].MOHSM.delay | (-∞, ∞) | [[ 2.96948983] [-0.232626 ] [ 0.55447587] [ 2.40706836] [-0.93859438] [ 1.65108312] [ 0.232626 ] [-1.83963697]] |
Mixture[2].MOHSM.phase | (-∞, ∞) | [ 0.98795431 -0.20902396 0.50547442 0.82532578 0.98712018 -0.71767182 0.20902396 0.20449908] |
Mixture[3].MOHSM.weight | [1e-08, ∞) | [ 9.49372105 4.67870156 5.34691321 7.84464581 4.68806397 10.29858104 8.046322 7.26677981] |
Mixture[3].MOHSM.mean | [1e-08, ∞) | [[0.05787629] [0.04050175] [0.06787368] [0.16762761] [0.0683894 ] [0.05785544] [0.34932979] [0.10671217]] |
Mixture[3].MOHSM.variance | [1e-08, ∞) | [[2.81373586e-04] [5.07293468e-05] [5.24319441e-05] [1.33508658e-04] [5.53518473e-05] [2.05231241e-05] [2.24503747e-04] [3.92696639e-05]] |
Mixture[3].MOHSM.lengthscale | [1e-08, ∞) | [0.0017158 0.00394903 0.00445108 0.00313999 0.00424557 0.00247263 0.00403102 0.0039983 ] |
Mixture[3].MOHSM.center | (-∞, ∞) | [998.07657853] |
Mixture[3].MOHSM.delay | (-∞, ∞) | [[ 4.66779554e-001] [-1.16572534e+000] [-6.79629878e-001] [ 9.08460483e-076] [ 5.72472211e-001] [-5.35775274e-001] [-7.88727431e-284] [-3.13114632e-005]] |
Mixture[3].MOHSM.phase | (-∞, ∞) | [ 5.02361011e-001 -1.38349666e+000 -7.40910394e-001 1.19227923e-075 5.49754361e-001 -5.20046184e-001 -5.33713933e-284 -5.86561149e-005] |
Gaussian.scale | [1e-08, ∞) | [0.54075483 0.54517939 0.54945548 0.54801568 0.55589353 0.54054611 0.54745335 0.53813097] |
--------------------------------------------------------------------------- _LinAlgError Traceback (most recent call last) ~/dev/cmm/mogp/mogptk/mogptk/gpr/model.py in _cholesky(self, K, add_jitter) 267 try: --> 268 return torch.linalg.cholesky(K) 269 except RuntimeError as e: _LinAlgError: torch.linalg_cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1258 is not positive-definite). During handling of the above exception, another exception occurred: CholeskyException Traceback (most recent call last) /tmp/ipykernel_38465/3647760507.py in <module> 1 model = mogptk.MOHSM(data, Q=2, P=2) 2 model.init_parameters('BNSE') ----> 3 model.train(method='Adam', lr=0.1, iters=400, verbose=True, error='MAE') ~/dev/cmm/mogp/mogptk/mogptk/model.py in train(self, method, iters, verbose, error, plot, **kwargs) 436 437 for i in range(iters): --> 438 progress(i, self.loss()) 439 optimizer.step() 440 progress(iters, self.loss()) ~/dev/cmm/mogp/mogptk/mogptk/model.py in loss(self) 264 >>> model.loss() 265 """ --> 266 return self.gpr.loss().detach().cpu().item() 267 268 def error(self, method='MAE', use_all_data=False): ~/dev/cmm/mogp/mogptk/mogptk/gpr/model.py in loss(self) 307 """ 308 self.zero_grad() --> 309 loss = -self.log_marginal_likelihood() - self.log_prior() 310 loss.backward() 311 return loss ~/dev/cmm/mogp/mogptk/mogptk/gpr/model.py in log_marginal_likelihood(self) 398 if self.data_variance is not None: 399 Kff += self.data_variance --> 400 L = self._cholesky(Kff, add_jitter=True) # NxN 401 402 if self.mean is not None: ~/dev/cmm/mogp/mogptk/mogptk/gpr/model.py in _cholesky(self, K, add_jitter) 275 self.print_parameters() 276 plot_gram(K) --> 277 raise CholeskyException(e.args[0], K, self) 278 279 def log_marginal_likelihood(self): CholeskyException: torch.linalg_cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 1258 is not positive-definite).
model.plot_prediction(transformed=True);