[Estimated time of execution: 5 min]
This notebook shows how to train a multi-input dataset using MOGPTK.
Abalone is a dataset containing the measurements of 4177 abalones (a type of marine snail) for their lengths, weights, sex and age. Specifically, the following features are given:
Here we will use the length and weight features as independent variables, and the number of rings (i.e. their age) as the dependent variable.
import numpy as np
import mogptk
Load in the dataset and select the relevant columns. Then standardize the number of rings so that it has zero mean and unit variance.
columns = ["Sex","Length","Diameter","Height","Whole weight","Shucked weight","Viscera weight","Shell weight","Rings"]
data = mogptk.LoadCSV("data/abalone/abalone.data",
x_col=["Length","Diameter","Height","Whole weight","Shucked weight","Viscera weight","Shell weight"],
y_col=["Rings"], names=columns)
data.transform(mogptk.TransformNormalize())
Set up the model by using Titsias' sparse Gaussian process inference with a 100
inducing (sparse) points. The initial points will be spread randomly over the input space.
We will use the squared exponential kernel with our 7 input dimensions. We initialize the parameters to values that are close to the minimum to help training.
# Choose kernel
kernel = mogptk.gpr.SquaredExponentialKernel(input_dims=7)
# Choose inference
inference = mogptk.Titsias(inducing_points=100, init_inducing_points='random')
# Set up model
model = mogptk.Model(data, kernel, model=inference)
# Initialize parameters
model.gpr.kernel.lengthscale.assign(np.random.rand(7))
Training will be performed using the default Adam optimizer with a learning rate of 1.0
for a 100
iterations. We will additionally be measuring the error between the prediction and the target values using the mean absolute percentage error (MAPE), which is a relative error with respect to the target values. In our case, we are able to reach a 16% error relative to the target values.
model.train(iters=100, verbose=True, error='mape', lr=1.0)
model.plot_losses();
Starting optimization using Adam ‣ Channels: 1 ‣ Parameters: 709 ‣ Training points: 4177 ‣ Initial loss: 6126.56 ‣ Initial error: 47.3229 Start Adam: 0/100 0:00:00 loss= 6126.56 error= 47.3229 1/100 0:00:00 loss= 6120.71 error= 68.9888 2/100 0:00:01 loss= 5971.82 error= 69.3777 3/100 0:00:01 loss= 5804.55 error= 69.3452 4/100 0:00:02 loss= 5618.33 error= 66.8451 5/100 0:00:02 loss= 5474.17 error= 68.8705 6/100 0:00:03 loss= 5133.53 error= 50.8475 7/100 0:00:04 loss= 5049.64 error= 53.3974 8/100 0:00:05 loss= 4847.99 error= 49.6542 9/100 0:00:05 loss= 4266.26 error= 29.7818 10/100 0:00:06 loss= 4094.61 error= 23.752 11/100 0:00:07 loss= 3907.53 error= 24.8311 12/100 0:00:08 loss= 3211.78 error= 23.6807 13/100 0:00:08 loss= 3475.56 error= 29.3339 14/100 0:00:09 loss= 3108.79 error= 22.7633 15/100 0:00:09 loss= 2755.55 error= 21.255 16/100 0:00:10 loss= 2660.32 error= 20.5218 17/100 0:00:10 loss= 1991.66 error= 19.5937 18/100 0:00:10 loss= 1475.62 error= 19.4742 19/100 0:00:11 loss= 929.391 error= 18.5868 20/100 0:00:11 loss= 701.319 error= 19.5765 21/100 0:00:12 loss= 244.136 error= 20.5502 22/100 0:00:12 loss= -99.1879 error= 17.7463 23/100 0:00:12 loss= -551.284 error= 16.5445 24/100 0:00:13 loss= -630.891 error= 19.3101 25/100 0:00:13 loss= -913.641 error= 17.2445 26/100 0:00:14 loss= -1231.32 error= 16.772 27/100 0:00:14 loss= -1456.35 error= 16.4394 28/100 0:00:14 loss= -1410.82 error= 16.0906 29/100 0:00:15 loss= -1362.12 error= 15.9848 30/100 0:00:15 loss= -1388.66 error= 16.2976 31/100 0:00:15 loss= -1382.88 error= 16.2884 32/100 0:00:16 loss= -1450.98 error= 16.2159 33/100 0:00:16 loss= -1500.5 error= 16.2068 34/100 0:00:17 loss= -1529.53 error= 16.2635 35/100 0:00:17 loss= -1567.78 error= 16.2285 36/100 0:00:18 loss= -1594.77 error= 16.2156 37/100 0:00:18 loss= -1592.97 error= 16.2748 38/100 0:00:18 loss= -1573.42 error= 16.397 39/100 0:00:19 loss= -1554.17 error= 16.42 40/100 0:00:19 loss= -1537.76 error= 16.4927 41/100 0:00:20 loss= -1541.73 error= 16.4523 42/100 0:00:20 loss= -1550.85 error= 16.4409 43/100 0:00:20 loss= -1563.28 error= 16.479 44/100 0:00:21 loss= -1576.86 error= 16.479 45/100 0:00:21 loss= -1594.16 error= 16.4208 46/100 0:00:21 loss= -1613.18 error= 16.3559 47/100 0:00:22 loss= -1619.82 error= 16.3572 48/100 0:00:22 loss= -1616.14 error= 16.3165 49/100 0:00:23 loss= -1614.87 error= 16.3746 50/100 0:00:23 loss= -1625.01 error= 16.2991 51/100 0:00:23 loss= -1639.96 error= 16.2531 52/100 0:00:24 loss= -1647.48 error= 16.2403 53/100 0:00:24 loss= -1658.27 error= 16.2207 54/100 0:00:24 loss= -1669.21 error= 16.1348 55/100 0:00:25 loss= -1687.74 error= 16.1067 56/100 0:00:25 loss= -1692.51 error= 16.1007 57/100 0:00:26 loss= -1695.01 error= 16.1387 58/100 0:00:26 loss= -1704.65 error= 16.0803 59/100 0:00:26 loss= -1703.73 error= 16.0059 60/100 0:00:27 loss= -1712.71 error= 16.0312 61/100 0:00:27 loss= -1708.87 error= 16.0543 62/100 0:00:28 loss= -1670.39 error= 16.0188 63/100 0:00:28 loss= -1714.96 error= 16.029 64/100 0:00:29 loss= -1691.4 error= 16.1434 65/100 0:00:29 loss= -1688.04 error= 16.0785 66/100 0:00:29 loss= -1684.94 error= 16.0819 67/100 0:00:30 loss= -1689.45 error= 16.0612 68/100 0:00:30 loss= -1665.14 error= 16.1858 69/100 0:00:31 loss= -1692.01 error= 16.1951 70/100 0:00:31 loss= -1607.48 error= 15.9962 71/100 0:00:31 loss= -1487.53 error= 16.3899 72/100 0:00:32 loss= -1430.82 error= 16.2627 73/100 0:00:32 loss= -1550.05 error= 16.2665 74/100 0:00:33 loss= -1477.61 error= 16.34 75/100 0:00:33 loss= -1580.03 error= 16.3224 76/100 0:00:33 loss= -1456.46 error= 16.5551 77/100 0:00:34 loss= -1568.6 error= 16.2653 78/100 0:00:34 loss= -1574.67 error= 16.1926 79/100 0:00:35 loss= -1603.25 error= 16.207 80/100 0:00:35 loss= -1555.55 error= 16.4502 81/100 0:00:35 loss= -1567.64 error= 16.3196 82/100 0:00:36 loss= -1588.42 error= 16.2377 83/100 0:00:36 loss= -1595.23 error= 16.3114 84/100 0:00:36 loss= -1594.37 error= 16.4269 85/100 0:00:37 loss= -1565.5 error= 16.3261 86/100 0:00:37 loss= -1595.39 error= 16.2985 87/100 0:00:38 loss= -1604.5 error= 16.3282 88/100 0:00:38 loss= -1595.72 error= 16.2864 89/100 0:00:39 loss= -1609.97 error= 16.3143 90/100 0:00:39 loss= -1622.1 error= 16.3284 91/100 0:00:40 loss= -1631.72 error= 16.3173 92/100 0:00:40 loss= -1626.35 error= 16.3244 93/100 0:00:41 loss= -1636.14 error= 16.3022 94/100 0:00:41 loss= -1641.45 error= 16.3255 95/100 0:00:41 loss= -1649.17 error= 16.2843 96/100 0:00:42 loss= -1659.47 error= 16.2725 97/100 0:00:42 loss= -1674.35 error= 16.2014 98/100 0:00:43 loss= -1680.57 error= 16.2241 99/100 0:00:43 loss= -1686.02 error= 16.1437 100/100 0:00:44 loss= -1696.7 error= 16.183 Finished Optimization finished in 44.160 seconds ‣ Iterations: 100 ‣ Final loss: -1696.7 ‣ Final error: 16.183
The loss and error decline rapidly and consistently, confirming that training was succesfull. The following kernel and model parameters have been obtained.
model.print_parameters()
Name | Range | Value |
---|---|---|
SE.magnitude | [1e-08, ∞) | 0.06266966105625149 |
SE.lengthscale | [1e-08, ∞) | [8.5721219 2.8215086 6.23715182 3.07448305 1.40495237 6.43906609 1.46569563] |
Gaussian.scale | [1e-08, ∞) | 0.15722606101998848 |
induction_points | (-∞, ∞) | [[ 6.61802711 6.47048147 -5.08921644 -3.57544769 -4.73587658 -5.38632516 5.97771272] [ 6.33032587 6.31329965 -5.77576607 -4.4124692 6.41913084 -5.25870884 6.37386517] [ 1.80218202 1.90373382 0.91611634 3.23064571 0.38905421 -2.39662741 -0.1151143 ] [ -1.24920835 0.4153708 1.83085541 -1.15973424 -1.152318 -1.24259471 -1.05118522] [ -4.81916026 -4.9156052 -4.88441164 6.32737742 -4.96187574 5.62193111 5.8207054 ] [ 5.42367902 5.28490817 4.73242625 -3.96394128 4.03854014 4.7315908 -4.6876553 ] [ -5.40295007 -5.36749677 -5.59379197 -4.81502702 6.27931386 -5.55019005 -5.52429264] [ 6.23742908 6.44362768 -5.81783665 -3.66139268 6.81941351 -5.40527837 6.01257386] [ -1.03677572 0.51086156 -4.4540333 6.94791753 3.39084355 -2.75774956 -5.67272476] [ -0.15512309 -0.16822341 0.91674728 0.42048189 0.83927777 -1.56973629 0.90808607] [ -3.98464449 -3.02719252 -5.19120551 -2.78970955 -0.45313988 1.61304657 5.64553087] [ 0.46277921 0.4796228 0.52558705 0.9093049 -0.19621949 -0.37455812 0.36573013] [ -5.26418055 -5.45564464 -5.86499155 6.73464569 6.43907417 -5.60861426 -5.05202552] [ 6.37625132 6.35643813 -5.41298833 -4.87242461 6.16868597 -5.48276787 6.1317929 ] [ 1.93341083 2.3154003 -2.09876999 1.46998742 0.46290059 0.92879278 -0.32724924] [ 5.67487025 4.96433968 -4.61959004 5.49992719 -5.67058563 3.36221484 -5.84185852] [ 1.10435611 1.89872537 -10.23228089 6.87454898 -9.93308971 -3.57248164 5.67492503] [ -4.6229609 -4.51086612 -4.82619961 -3.41901637 6.15177486 5.35438272 5.65325611] [ -5.3311893 5.40158529 -5.47914053 5.5275653 -5.47109818 -5.29120084 -5.06855472] [ 6.16753724 6.26796416 -5.27753557 -5.5058128 6.03129946 -5.54085261 -5.30467099] [ 7.49170174 -2.27540065 -12.68160209 6.57582539 7.93059665 7.68063642 5.71124924] [ 6.44563587 6.53276274 -5.6841908 -3.77761462 6.70943778 6.17712503 6.23270465] [ -5.25692834 -5.64231881 6.02895105 4.6793944 7.20423644 -5.49423655 6.27682082] [ 10.23886449 -3.62230796 -10.26065685 -10.61365508 5.4737814 10.21115889 0.70221664] [ 1.54575609 1.28820356 5.05853037 -3.08186556 0.10786015 4.83892149 -4.20616671] [ -1.44073083 -4.49177409 -3.43769994 -1.68813373 5.7304344 -6.58124698 -5.39076775] [ -3.9647811 -4.23427742 5.49893695 0.95136425 1.8532288 3.50029717 -8.03195655] [ 3.9751144 5.46767384 -1.72835275 -7.83041292 -1.09963042 -1.92890754 -6.20553439] [ 1.9501396 -0.62039393 -2.80304112 1.32090419 0.30803303 3.82603812 -0.38070342] [ 2.45829508 2.90701232 2.51137216 -0.22501395 1.70924115 2.8644518 -6.10577681] [ 0.90152809 0.40851513 0.28342608 1.78027003 0.76533985 0.81507915 0.76922849] [ 5.82089516 0.08711939 1.06335858 1.98635293 0.4906126 0.84789967 1.24403454] [ -0.20514137 -1.26954153 4.36907797 7.84468961 -5.68424717 -1.9016879 -6.34595327] [ 0.78132314 0.42223676 0.03942347 1.09882011 0.41256056 0.44179873 0.27756618] [ 6.50333495 -5.55588618 -5.12370795 6.84978275 -5.47254504 6.15885161 -5.74575873] [ -6.85437715 -2.8964145 5.14766208 7.05689684 -4.16583091 -5.31644104 -5.1007049 ] [ -5.23147106 6.4957753 -5.62021223 7.36832282 -4.79757219 6.21769617 -5.32377532] [ 6.41094064 6.29769777 6.16345472 7.77181179 -4.93232416 6.3343712 -5.15964513] [ -4.17556719 -4.37203662 -3.35130153 -3.19537814 -3.50405461 -4.49734927 5.36786874] [ 6.03431739 -5.45303464 -5.17700974 6.15653011 -5.58881942 -5.52453024 -5.1946913 ] [ 1.72995244 2.17024948 -7.52276499 6.09661904 7.1590757 4.84583129 6.40933525] [ 6.35794416 6.18317877 -5.67543958 -5.04918861 6.27346563 -6.28383583 -5.04100431] [ 6.45719977 -5.38218752 6.07313075 -3.71999115 6.67256945 -5.40188496 6.79172664] [ 6.08720096 6.42992327 -5.48531207 7.48377732 -5.20269577 6.02432849 -5.10062475] [ 11.26301547 9.34147902 1.30916816 3.7398371 -11.91788446 -11.97288807 -8.18466697] [ 5.94943451 -4.79671474 -4.80285371 -3.22714878 5.7522071 -5.19179257 4.88332551] [ -5.46074307 6.27566281 -5.66611164 6.27366596 -5.75570808 -5.46030953 -5.04160857] [ -1.70653714 -1.50917671 1.0553611 0.59680504 -1.06294104 0.9972682 -2.83085916] [ -5.29226533 6.53795141 -5.44017467 7.08026413 -4.94413696 6.25696172 -5.59487932] [ -4.37883944 9.623168 -8.32266547 -3.58554155 5.89888512 -1.67255761 4.15429597] [ -3.84914838 0.21216138 4.23281255 8.42640835 2.09757777 -3.56731129 4.7028442 ] [ 6.15654896 -5.40394149 -5.738115 -3.99955638 -5.07711845 6.37408462 6.2669852 ] [ -3.53976161 -1.58517378 -1.53209439 6.57228736 -3.78585784 -3.89605312 -4.10042635] [ 0.65077106 -0.48141499 2.70841527 0.19221339 -1.71114859 0.43124329 4.34691051] [ 3.39375286 0.47145896 -2.181357 -2.38077593 1.58554717 1.81809284 0.53830334] [ 5.58077207 -4.5877737 -4.35183562 6.61772251 -4.35924357 -4.69268345 5.45551603] [ -3.78137812 5.85965479 -5.01866671 -3.24002004 5.45404331 5.39445987 5.13763332] [ 6.2952957 -5.44079606 -5.78192871 -4.45914586 6.18184312 -5.33905223 6.05099895] [ 6.4740828 6.37011637 -5.33120866 7.94875418 -4.87883646 -5.33528574 6.02029069] [ 3.1041986 -4.92717626 4.22361391 2.9559308 -3.86609032 -5.8351792 -0.66916177] [ -4.03958476 -4.13247927 3.79015529 -3.4655948 5.06982151 4.86698795 5.00912151] [ 6.38782548 -5.57317548 -5.60283165 -5.26756609 6.31618163 -5.74793075 -5.3853681 ] [ -5.45603408 6.23767435 -5.82732711 7.14567797 -5.15309274 -5.79476997 -5.90790297] [ -1.16149347 1.62709719 -0.51856388 4.0245873 2.64656277 -0.23537898 -0.29628835] [ -5.31604082 6.28303464 -5.15083014 -4.69071547 6.58542545 -5.63062024 6.1504005 ] [ 0.26490653 0.0665221 0.03765576 -0.23738541 0.06066124 -0.02248117 -0.04964459] [ 6.50861438 6.54475816 -5.69037571 -3.48224652 -4.71050667 -5.29755287 -5.25216452] [ 6.139225 -5.66132042 6.09138401 6.10903401 -5.52221132 -5.94119229 -5.44438525] [ -5.22370607 6.15078435 -5.46722848 -5.49805058 -5.79256562 -5.4153686 -5.76382923] [ 2.7772881 -0.06295554 -1.89851761 2.93014518 1.22978271 1.05882566 -0.49937926] [ 7.70974885 7.91293432 -7.33966515 -5.99732426 11.35202136 -3.12383367 -10.62976276] [ 6.02540918 -4.66614967 5.52878511 -3.47686595 6.1739871 5.81390326 -4.41636628] [ -5.36206963 6.4592948 6.13639694 -4.17273196 6.35302837 -5.47378464 6.18753149] [ 0.30847654 0.6689605 0.97222689 2.02982164 5.30081598 1.15750277 6.4172784 ] [ -5.54747991 6.06264173 -5.1960345 -5.78617448 -5.92778327 -5.76590498 -5.52775509] [ 6.09296203 -5.48520079 6.93752391 6.62775546 -5.38728772 -5.82461045 6.01131378] [ 6.73981027 6.32712478 -5.74562519 7.03080524 -5.25195234 -5.70723859 -5.23141757] [ -0.36725461 0.97413796 0.83773667 -0.64310114 0.15811825 -1.19152946 0.76372862] [ 6.55548112 -5.41891628 -5.52147157 -4.74895819 6.50052014 -5.50525058 -5.40967354] [ -9.10705935 -0.70565321 -1.510057 -3.81166567 9.20534517 4.22932435 -4.95362992] [ -1.6232858 2.07960332 -1.02247972 -2.7933715 2.14080178 0.84937094 -4.79755969] [ 4.63574862 5.05968814 -4.49002171 5.3175073 -4.4512494 -4.57594594 -1.16470318] [ 0.51807746 0.5777796 -0.42766339 0.608072 -1.87331397 0.39891602 3.21560745] [ 6.30317894 6.15389756 5.20734048 -6.02593563 -4.30971342 -5.02713209 2.2770625 ] [ -5.2005763 -5.43114167 -5.11467092 -3.82700302 6.90606367 6.44673419 -5.46855483] [ -5.56055682 -5.62330327 -4.93329121 -4.6307466 6.09548761 -5.46221477 -5.73368983] [ 3.61375408 -1.04931364 -0.62529218 3.06951291 -0.4717515 3.61877606 0.37127328] [ 14.21927117 -9.90281858 -9.94888361 -4.48322074 -3.97833022 -7.11948009 -11.06311209] [ 5.52115293 -4.58148124 -4.82135765 5.75497363 -4.51655449 5.20042001 -4.51530594] [ 6.3070379 6.2509505 -5.1614492 6.60308799 -5.49641349 -5.68212323 -5.85756125] [ -5.524322 -5.51930772 -5.02427024 -5.03707481 6.24244978 -5.62208858 6.32642836] [ 6.15885889 6.32142981 -5.73550174 -3.7282155 -4.9294399 -5.88224055 11.37183961] [ 1.52635012 2.53110275 -8.23339585 -2.82327244 -5.77566184 -5.86390685 -5.56672072] [ 6.4140891 6.74006632 -5.48274256 7.91234938 -4.56924729 6.1146253 -5.03910395] [ 5.81903702 -4.94746712 -5.15982777 -3.34867728 6.55393391 -4.86459469 -4.81633041] [ 5.57843234 5.6684662 -4.51404028 5.60527466 -4.86822101 5.3736594 -4.78002884] [ -5.32631193 -5.38694535 -5.65693489 6.68516782 6.10702773 -5.34779138 6.08909165] [ 6.31266915 6.42502513 6.12737119 7.08861213 -5.36245189 -5.58164971 -5.03466372] [ 13.11810867 -8.65529715 7.6845961 8.01355136 10.52833839 13.19652612 -1.95978409] [ 9.04978533 -5.1949991 -3.62739159 -4.24526529 -5.72755824 2.58923685 -5.45824353]] |