07 Sparse Multi Input example using the Abalone dataset¶

[Estimated time of execution: 5 min]

This notebook shows how to train a multi-input dataset using MOGPTK.

Abalone is a dataset containing the measurements of 4177 abalones (a type of marine snail) for their lengths, weights, sex and age. Specifically, the following features are given:

  • Sex: M, F, and I (infant)
  • Length in mm: longest shell measurement
  • Diameter in mm: perpendicular to length
  • Height in mm: with meat in shell
  • Whole weight in grams: whole abalone
  • Shucked weight in grams: weight of meat
  • Viscera weight in grams: gut weight (after bleeding)
  • Shell weight in grams: after being dried
  • Rings: +1.5 gives the age in years

Here we will use the length and weight features as independent variables, and the number of rings (i.e. their age) as the dependent variable.

In [10]:
import numpy as np
import mogptk

Load in the dataset and select the relevant columns. Then standardize the number of rings so that it has zero mean and unit variance.

In [11]:
columns = ["Sex","Length","Diameter","Height","Whole weight","Shucked weight","Viscera weight","Shell weight","Rings"]

data = mogptk.LoadCSV("data/abalone/abalone.data",
                      x_col=["Length","Diameter","Height","Whole weight","Shucked weight","Viscera weight","Shell weight"],
                      y_col=["Rings"], names=columns)
data.transform(mogptk.TransformNormalize())

Set up the model by using Titsias' sparse Gaussian process inference with a 100 inducing (sparse) points. The initial points will be spread randomly over the input space.

We will use the squared exponential kernel with our 7 input dimensions. We initialize the parameters to values that are close to the minimum to help training.

In [13]:
# Choose kernel
kernel = mogptk.gpr.SquaredExponentialKernel(input_dims=7)

# Choose inference
inference = mogptk.Titsias(inducing_points=100, init_inducing_points='random')

# Set up model
model = mogptk.Model(data, kernel, model=inference)

# Initialize parameters
model.gpr.kernel.lengthscale.assign(np.random.rand(7))

Training will be performed using the default Adam optimizer with a learning rate of 1.0 for a 100 iterations. We will additionally be measuring the error between the prediction and the target values using the mean absolute percentage error (MAPE), which is a relative error with respect to the target values. In our case, we are able to reach a 16% error relative to the target values.

In [14]:
model.train(iters=100, verbose=True, error='mape', lr=1.0)
model.plot_losses();
Starting optimization using Adam
‣ Channels: 1
‣ Parameters: 709
‣ Training points: 4177
‣ Initial loss: 6126.56
‣ Initial error: 47.3229

Start Adam:
    0/100   0:00:00  loss=     6126.56  error=     47.3229
    1/100   0:00:00  loss=     6120.71  error=     68.9888
    2/100   0:00:01  loss=     5971.82  error=     69.3777
    3/100   0:00:01  loss=     5804.55  error=     69.3452
    4/100   0:00:02  loss=     5618.33  error=     66.8451
    5/100   0:00:02  loss=     5474.17  error=     68.8705
    6/100   0:00:03  loss=     5133.53  error=     50.8475
    7/100   0:00:04  loss=     5049.64  error=     53.3974
    8/100   0:00:05  loss=     4847.99  error=     49.6542
    9/100   0:00:05  loss=     4266.26  error=     29.7818
   10/100   0:00:06  loss=     4094.61  error=      23.752
   11/100   0:00:07  loss=     3907.53  error=     24.8311
   12/100   0:00:08  loss=     3211.78  error=     23.6807
   13/100   0:00:08  loss=     3475.56  error=     29.3339
   14/100   0:00:09  loss=     3108.79  error=     22.7633
   15/100   0:00:09  loss=     2755.55  error=      21.255
   16/100   0:00:10  loss=     2660.32  error=     20.5218
   17/100   0:00:10  loss=     1991.66  error=     19.5937
   18/100   0:00:10  loss=     1475.62  error=     19.4742
   19/100   0:00:11  loss=     929.391  error=     18.5868
   20/100   0:00:11  loss=     701.319  error=     19.5765
   21/100   0:00:12  loss=     244.136  error=     20.5502
   22/100   0:00:12  loss=    -99.1879  error=     17.7463
   23/100   0:00:12  loss=    -551.284  error=     16.5445
   24/100   0:00:13  loss=    -630.891  error=     19.3101
   25/100   0:00:13  loss=    -913.641  error=     17.2445
   26/100   0:00:14  loss=    -1231.32  error=      16.772
   27/100   0:00:14  loss=    -1456.35  error=     16.4394
   28/100   0:00:14  loss=    -1410.82  error=     16.0906
   29/100   0:00:15  loss=    -1362.12  error=     15.9848
   30/100   0:00:15  loss=    -1388.66  error=     16.2976
   31/100   0:00:15  loss=    -1382.88  error=     16.2884
   32/100   0:00:16  loss=    -1450.98  error=     16.2159
   33/100   0:00:16  loss=     -1500.5  error=     16.2068
   34/100   0:00:17  loss=    -1529.53  error=     16.2635
   35/100   0:00:17  loss=    -1567.78  error=     16.2285
   36/100   0:00:18  loss=    -1594.77  error=     16.2156
   37/100   0:00:18  loss=    -1592.97  error=     16.2748
   38/100   0:00:18  loss=    -1573.42  error=      16.397
   39/100   0:00:19  loss=    -1554.17  error=       16.42
   40/100   0:00:19  loss=    -1537.76  error=     16.4927
   41/100   0:00:20  loss=    -1541.73  error=     16.4523
   42/100   0:00:20  loss=    -1550.85  error=     16.4409
   43/100   0:00:20  loss=    -1563.28  error=      16.479
   44/100   0:00:21  loss=    -1576.86  error=      16.479
   45/100   0:00:21  loss=    -1594.16  error=     16.4208
   46/100   0:00:21  loss=    -1613.18  error=     16.3559
   47/100   0:00:22  loss=    -1619.82  error=     16.3572
   48/100   0:00:22  loss=    -1616.14  error=     16.3165
   49/100   0:00:23  loss=    -1614.87  error=     16.3746
   50/100   0:00:23  loss=    -1625.01  error=     16.2991
   51/100   0:00:23  loss=    -1639.96  error=     16.2531
   52/100   0:00:24  loss=    -1647.48  error=     16.2403
   53/100   0:00:24  loss=    -1658.27  error=     16.2207
   54/100   0:00:24  loss=    -1669.21  error=     16.1348
   55/100   0:00:25  loss=    -1687.74  error=     16.1067
   56/100   0:00:25  loss=    -1692.51  error=     16.1007
   57/100   0:00:26  loss=    -1695.01  error=     16.1387
   58/100   0:00:26  loss=    -1704.65  error=     16.0803
   59/100   0:00:26  loss=    -1703.73  error=     16.0059
   60/100   0:00:27  loss=    -1712.71  error=     16.0312
   61/100   0:00:27  loss=    -1708.87  error=     16.0543
   62/100   0:00:28  loss=    -1670.39  error=     16.0188
   63/100   0:00:28  loss=    -1714.96  error=      16.029
   64/100   0:00:29  loss=     -1691.4  error=     16.1434
   65/100   0:00:29  loss=    -1688.04  error=     16.0785
   66/100   0:00:29  loss=    -1684.94  error=     16.0819
   67/100   0:00:30  loss=    -1689.45  error=     16.0612
   68/100   0:00:30  loss=    -1665.14  error=     16.1858
   69/100   0:00:31  loss=    -1692.01  error=     16.1951
   70/100   0:00:31  loss=    -1607.48  error=     15.9962
   71/100   0:00:31  loss=    -1487.53  error=     16.3899
   72/100   0:00:32  loss=    -1430.82  error=     16.2627
   73/100   0:00:32  loss=    -1550.05  error=     16.2665
   74/100   0:00:33  loss=    -1477.61  error=       16.34
   75/100   0:00:33  loss=    -1580.03  error=     16.3224
   76/100   0:00:33  loss=    -1456.46  error=     16.5551
   77/100   0:00:34  loss=     -1568.6  error=     16.2653
   78/100   0:00:34  loss=    -1574.67  error=     16.1926
   79/100   0:00:35  loss=    -1603.25  error=      16.207
   80/100   0:00:35  loss=    -1555.55  error=     16.4502
   81/100   0:00:35  loss=    -1567.64  error=     16.3196
   82/100   0:00:36  loss=    -1588.42  error=     16.2377
   83/100   0:00:36  loss=    -1595.23  error=     16.3114
   84/100   0:00:36  loss=    -1594.37  error=     16.4269
   85/100   0:00:37  loss=     -1565.5  error=     16.3261
   86/100   0:00:37  loss=    -1595.39  error=     16.2985
   87/100   0:00:38  loss=     -1604.5  error=     16.3282
   88/100   0:00:38  loss=    -1595.72  error=     16.2864
   89/100   0:00:39  loss=    -1609.97  error=     16.3143
   90/100   0:00:39  loss=     -1622.1  error=     16.3284
   91/100   0:00:40  loss=    -1631.72  error=     16.3173
   92/100   0:00:40  loss=    -1626.35  error=     16.3244
   93/100   0:00:41  loss=    -1636.14  error=     16.3022
   94/100   0:00:41  loss=    -1641.45  error=     16.3255
   95/100   0:00:41  loss=    -1649.17  error=     16.2843
   96/100   0:00:42  loss=    -1659.47  error=     16.2725
   97/100   0:00:42  loss=    -1674.35  error=     16.2014
   98/100   0:00:43  loss=    -1680.57  error=     16.2241
   99/100   0:00:43  loss=    -1686.02  error=     16.1437
  100/100   0:00:44  loss=     -1696.7  error=      16.183
Finished

Optimization finished in 44.160 seconds
‣ Iterations: 100
‣ Final loss: -1696.7
‣ Final error: 16.183

The loss and error decline rapidly and consistently, confirming that training was succesfull. The following kernel and model parameters have been obtained.

In [15]:
model.print_parameters()
NameRangeValue
SE.magnitude[1e-08, ∞)0.06266966105625149
SE.lengthscale[1e-08, ∞)[8.5721219 2.8215086 6.23715182 3.07448305 1.40495237 6.43906609 1.46569563]
Gaussian.scale[1e-08, ∞)0.15722606101998848
induction_points(-∞, ∞)[[ 6.61802711 6.47048147 -5.08921644 -3.57544769 -4.73587658 -5.38632516 5.97771272] [ 6.33032587 6.31329965 -5.77576607 -4.4124692 6.41913084 -5.25870884 6.37386517] [ 1.80218202 1.90373382 0.91611634 3.23064571 0.38905421 -2.39662741 -0.1151143 ] [ -1.24920835 0.4153708 1.83085541 -1.15973424 -1.152318 -1.24259471 -1.05118522] [ -4.81916026 -4.9156052 -4.88441164 6.32737742 -4.96187574 5.62193111 5.8207054 ] [ 5.42367902 5.28490817 4.73242625 -3.96394128 4.03854014 4.7315908 -4.6876553 ] [ -5.40295007 -5.36749677 -5.59379197 -4.81502702 6.27931386 -5.55019005 -5.52429264] [ 6.23742908 6.44362768 -5.81783665 -3.66139268 6.81941351 -5.40527837 6.01257386] [ -1.03677572 0.51086156 -4.4540333 6.94791753 3.39084355 -2.75774956 -5.67272476] [ -0.15512309 -0.16822341 0.91674728 0.42048189 0.83927777 -1.56973629 0.90808607] [ -3.98464449 -3.02719252 -5.19120551 -2.78970955 -0.45313988 1.61304657 5.64553087] [ 0.46277921 0.4796228 0.52558705 0.9093049 -0.19621949 -0.37455812 0.36573013] [ -5.26418055 -5.45564464 -5.86499155 6.73464569 6.43907417 -5.60861426 -5.05202552] [ 6.37625132 6.35643813 -5.41298833 -4.87242461 6.16868597 -5.48276787 6.1317929 ] [ 1.93341083 2.3154003 -2.09876999 1.46998742 0.46290059 0.92879278 -0.32724924] [ 5.67487025 4.96433968 -4.61959004 5.49992719 -5.67058563 3.36221484 -5.84185852] [ 1.10435611 1.89872537 -10.23228089 6.87454898 -9.93308971 -3.57248164 5.67492503] [ -4.6229609 -4.51086612 -4.82619961 -3.41901637 6.15177486 5.35438272 5.65325611] [ -5.3311893 5.40158529 -5.47914053 5.5275653 -5.47109818 -5.29120084 -5.06855472] [ 6.16753724 6.26796416 -5.27753557 -5.5058128 6.03129946 -5.54085261 -5.30467099] [ 7.49170174 -2.27540065 -12.68160209 6.57582539 7.93059665 7.68063642 5.71124924] [ 6.44563587 6.53276274 -5.6841908 -3.77761462 6.70943778 6.17712503 6.23270465] [ -5.25692834 -5.64231881 6.02895105 4.6793944 7.20423644 -5.49423655 6.27682082] [ 10.23886449 -3.62230796 -10.26065685 -10.61365508 5.4737814 10.21115889 0.70221664] [ 1.54575609 1.28820356 5.05853037 -3.08186556 0.10786015 4.83892149 -4.20616671] [ -1.44073083 -4.49177409 -3.43769994 -1.68813373 5.7304344 -6.58124698 -5.39076775] [ -3.9647811 -4.23427742 5.49893695 0.95136425 1.8532288 3.50029717 -8.03195655] [ 3.9751144 5.46767384 -1.72835275 -7.83041292 -1.09963042 -1.92890754 -6.20553439] [ 1.9501396 -0.62039393 -2.80304112 1.32090419 0.30803303 3.82603812 -0.38070342] [ 2.45829508 2.90701232 2.51137216 -0.22501395 1.70924115 2.8644518 -6.10577681] [ 0.90152809 0.40851513 0.28342608 1.78027003 0.76533985 0.81507915 0.76922849] [ 5.82089516 0.08711939 1.06335858 1.98635293 0.4906126 0.84789967 1.24403454] [ -0.20514137 -1.26954153 4.36907797 7.84468961 -5.68424717 -1.9016879 -6.34595327] [ 0.78132314 0.42223676 0.03942347 1.09882011 0.41256056 0.44179873 0.27756618] [ 6.50333495 -5.55588618 -5.12370795 6.84978275 -5.47254504 6.15885161 -5.74575873] [ -6.85437715 -2.8964145 5.14766208 7.05689684 -4.16583091 -5.31644104 -5.1007049 ] [ -5.23147106 6.4957753 -5.62021223 7.36832282 -4.79757219 6.21769617 -5.32377532] [ 6.41094064 6.29769777 6.16345472 7.77181179 -4.93232416 6.3343712 -5.15964513] [ -4.17556719 -4.37203662 -3.35130153 -3.19537814 -3.50405461 -4.49734927 5.36786874] [ 6.03431739 -5.45303464 -5.17700974 6.15653011 -5.58881942 -5.52453024 -5.1946913 ] [ 1.72995244 2.17024948 -7.52276499 6.09661904 7.1590757 4.84583129 6.40933525] [ 6.35794416 6.18317877 -5.67543958 -5.04918861 6.27346563 -6.28383583 -5.04100431] [ 6.45719977 -5.38218752 6.07313075 -3.71999115 6.67256945 -5.40188496 6.79172664] [ 6.08720096 6.42992327 -5.48531207 7.48377732 -5.20269577 6.02432849 -5.10062475] [ 11.26301547 9.34147902 1.30916816 3.7398371 -11.91788446 -11.97288807 -8.18466697] [ 5.94943451 -4.79671474 -4.80285371 -3.22714878 5.7522071 -5.19179257 4.88332551] [ -5.46074307 6.27566281 -5.66611164 6.27366596 -5.75570808 -5.46030953 -5.04160857] [ -1.70653714 -1.50917671 1.0553611 0.59680504 -1.06294104 0.9972682 -2.83085916] [ -5.29226533 6.53795141 -5.44017467 7.08026413 -4.94413696 6.25696172 -5.59487932] [ -4.37883944 9.623168 -8.32266547 -3.58554155 5.89888512 -1.67255761 4.15429597] [ -3.84914838 0.21216138 4.23281255 8.42640835 2.09757777 -3.56731129 4.7028442 ] [ 6.15654896 -5.40394149 -5.738115 -3.99955638 -5.07711845 6.37408462 6.2669852 ] [ -3.53976161 -1.58517378 -1.53209439 6.57228736 -3.78585784 -3.89605312 -4.10042635] [ 0.65077106 -0.48141499 2.70841527 0.19221339 -1.71114859 0.43124329 4.34691051] [ 3.39375286 0.47145896 -2.181357 -2.38077593 1.58554717 1.81809284 0.53830334] [ 5.58077207 -4.5877737 -4.35183562 6.61772251 -4.35924357 -4.69268345 5.45551603] [ -3.78137812 5.85965479 -5.01866671 -3.24002004 5.45404331 5.39445987 5.13763332] [ 6.2952957 -5.44079606 -5.78192871 -4.45914586 6.18184312 -5.33905223 6.05099895] [ 6.4740828 6.37011637 -5.33120866 7.94875418 -4.87883646 -5.33528574 6.02029069] [ 3.1041986 -4.92717626 4.22361391 2.9559308 -3.86609032 -5.8351792 -0.66916177] [ -4.03958476 -4.13247927 3.79015529 -3.4655948 5.06982151 4.86698795 5.00912151] [ 6.38782548 -5.57317548 -5.60283165 -5.26756609 6.31618163 -5.74793075 -5.3853681 ] [ -5.45603408 6.23767435 -5.82732711 7.14567797 -5.15309274 -5.79476997 -5.90790297] [ -1.16149347 1.62709719 -0.51856388 4.0245873 2.64656277 -0.23537898 -0.29628835] [ -5.31604082 6.28303464 -5.15083014 -4.69071547 6.58542545 -5.63062024 6.1504005 ] [ 0.26490653 0.0665221 0.03765576 -0.23738541 0.06066124 -0.02248117 -0.04964459] [ 6.50861438 6.54475816 -5.69037571 -3.48224652 -4.71050667 -5.29755287 -5.25216452] [ 6.139225 -5.66132042 6.09138401 6.10903401 -5.52221132 -5.94119229 -5.44438525] [ -5.22370607 6.15078435 -5.46722848 -5.49805058 -5.79256562 -5.4153686 -5.76382923] [ 2.7772881 -0.06295554 -1.89851761 2.93014518 1.22978271 1.05882566 -0.49937926] [ 7.70974885 7.91293432 -7.33966515 -5.99732426 11.35202136 -3.12383367 -10.62976276] [ 6.02540918 -4.66614967 5.52878511 -3.47686595 6.1739871 5.81390326 -4.41636628] [ -5.36206963 6.4592948 6.13639694 -4.17273196 6.35302837 -5.47378464 6.18753149] [ 0.30847654 0.6689605 0.97222689 2.02982164 5.30081598 1.15750277 6.4172784 ] [ -5.54747991 6.06264173 -5.1960345 -5.78617448 -5.92778327 -5.76590498 -5.52775509] [ 6.09296203 -5.48520079 6.93752391 6.62775546 -5.38728772 -5.82461045 6.01131378] [ 6.73981027 6.32712478 -5.74562519 7.03080524 -5.25195234 -5.70723859 -5.23141757] [ -0.36725461 0.97413796 0.83773667 -0.64310114 0.15811825 -1.19152946 0.76372862] [ 6.55548112 -5.41891628 -5.52147157 -4.74895819 6.50052014 -5.50525058 -5.40967354] [ -9.10705935 -0.70565321 -1.510057 -3.81166567 9.20534517 4.22932435 -4.95362992] [ -1.6232858 2.07960332 -1.02247972 -2.7933715 2.14080178 0.84937094 -4.79755969] [ 4.63574862 5.05968814 -4.49002171 5.3175073 -4.4512494 -4.57594594 -1.16470318] [ 0.51807746 0.5777796 -0.42766339 0.608072 -1.87331397 0.39891602 3.21560745] [ 6.30317894 6.15389756 5.20734048 -6.02593563 -4.30971342 -5.02713209 2.2770625 ] [ -5.2005763 -5.43114167 -5.11467092 -3.82700302 6.90606367 6.44673419 -5.46855483] [ -5.56055682 -5.62330327 -4.93329121 -4.6307466 6.09548761 -5.46221477 -5.73368983] [ 3.61375408 -1.04931364 -0.62529218 3.06951291 -0.4717515 3.61877606 0.37127328] [ 14.21927117 -9.90281858 -9.94888361 -4.48322074 -3.97833022 -7.11948009 -11.06311209] [ 5.52115293 -4.58148124 -4.82135765 5.75497363 -4.51655449 5.20042001 -4.51530594] [ 6.3070379 6.2509505 -5.1614492 6.60308799 -5.49641349 -5.68212323 -5.85756125] [ -5.524322 -5.51930772 -5.02427024 -5.03707481 6.24244978 -5.62208858 6.32642836] [ 6.15885889 6.32142981 -5.73550174 -3.7282155 -4.9294399 -5.88224055 11.37183961] [ 1.52635012 2.53110275 -8.23339585 -2.82327244 -5.77566184 -5.86390685 -5.56672072] [ 6.4140891 6.74006632 -5.48274256 7.91234938 -4.56924729 6.1146253 -5.03910395] [ 5.81903702 -4.94746712 -5.15982777 -3.34867728 6.55393391 -4.86459469 -4.81633041] [ 5.57843234 5.6684662 -4.51404028 5.60527466 -4.86822101 5.3736594 -4.78002884] [ -5.32631193 -5.38694535 -5.65693489 6.68516782 6.10702773 -5.34779138 6.08909165] [ 6.31266915 6.42502513 6.12737119 7.08861213 -5.36245189 -5.58164971 -5.03466372] [ 13.11810867 -8.65529715 7.6845961 8.01355136 10.52833839 13.19652612 -1.95978409] [ 9.04978533 -5.1949991 -3.62739159 -4.24526529 -5.72755824 2.58923685 -5.45824353]]
In [ ]: