Module mogptk.plot
Expand source code Browse git
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
def plot_spectrum(means, scales, weights=None, nyquist=None, titles=None, show=True, filename=None, title=None):
if means.ndim == 2:
means = np.expand_dims(means, axis=2)
if scales.ndim == 2:
scales = np.expand_dims(scales, axis=2)
if isinstance(weights, np.ndarray) and weights.ndim == 1:
weights = np.expand_dims(weights, axis=1)
if isinstance(nyquist, np.ndarray) and nyquist.ndim == 1:
nyquist = np.expand_dims(nyquist, axis=1)
if means.ndim != 3:
raise Exception('means and scales must have shape (mixtures,input_dims,output_dims)')
if means.shape != scales.shape:
raise Exception('means and scales must have the same shape (mixtures,input_dims,output_dims)')
mixtures = means.shape[0]
input_dims = means.shape[1]
output_dims = means.shape[2]
if isinstance(weights, np.ndarray) and (weights.ndim != 2 or weights.shape[0] != mixtures or weights.shape[1] != output_dims):
raise Exception('weights must have shape (mixtures,output_dims)')
elif not isinstance(weights, np.ndarray):
weights = np.ones((mixtures,output_dims))
if isinstance(nyquist, np.ndarray) and (nyquist.ndim != 2 or nyquist.shape[0] != input_dims or nyquist.shape[1] != output_dims):
raise Exception('nyquist must have shape (input_dims,output_dims)')
fig, axes = plt.subplots(output_dims, input_dims, figsize=(20, output_dims*5), sharey=False, constrained_layout=True, squeeze=False)
if title != None:
fig.suptitle(title, fontsize=36)
for channel in range(output_dims):
for i in range(input_dims):
x_low = max(0.0, norm.ppf(0.01, loc=means[:,i,channel], scale=scales[:,i,channel]).min())
x_high = norm.ppf(0.99, loc=means[:,i,channel], scale=scales[:,i,channel]).max()
if isinstance(nyquist, np.ndarray):
x_high = min(x_high, nyquist[i,channel])
x = np.linspace(x_low, x_high, 1000)
psd = np.zeros(x.shape)
for q in range(mixtures):
single_psd = weights[q,channel] * norm.pdf(x, loc=means[q,i,channel], scale=scales[q,i,channel])
#single_psd = np.log(single_psd+1)
axes[channel,i].plot(x, single_psd, '--', c='r', zorder=2)
psd += single_psd
axes[channel,i].plot(x, psd, 'k-', zorder=1)
axes[channel,i].set_yticks([])
axes[channel,i].set_ylim(0, None)
if titles != None:
axes[channel,i].set_title(titles[channel])
axes[output_dims-1,i].set_xlabel('Frequency')
if filename != None:
plt.savefig(filename+'.pdf', dpi=300)
if show:
plt.show()
return fig, axes
Functions
def plot_spectrum(means, scales, weights=None, nyquist=None, titles=None, show=True, filename=None, title=None)
-
Expand source code Browse git
def plot_spectrum(means, scales, weights=None, nyquist=None, titles=None, show=True, filename=None, title=None): if means.ndim == 2: means = np.expand_dims(means, axis=2) if scales.ndim == 2: scales = np.expand_dims(scales, axis=2) if isinstance(weights, np.ndarray) and weights.ndim == 1: weights = np.expand_dims(weights, axis=1) if isinstance(nyquist, np.ndarray) and nyquist.ndim == 1: nyquist = np.expand_dims(nyquist, axis=1) if means.ndim != 3: raise Exception('means and scales must have shape (mixtures,input_dims,output_dims)') if means.shape != scales.shape: raise Exception('means and scales must have the same shape (mixtures,input_dims,output_dims)') mixtures = means.shape[0] input_dims = means.shape[1] output_dims = means.shape[2] if isinstance(weights, np.ndarray) and (weights.ndim != 2 or weights.shape[0] != mixtures or weights.shape[1] != output_dims): raise Exception('weights must have shape (mixtures,output_dims)') elif not isinstance(weights, np.ndarray): weights = np.ones((mixtures,output_dims)) if isinstance(nyquist, np.ndarray) and (nyquist.ndim != 2 or nyquist.shape[0] != input_dims or nyquist.shape[1] != output_dims): raise Exception('nyquist must have shape (input_dims,output_dims)') fig, axes = plt.subplots(output_dims, input_dims, figsize=(20, output_dims*5), sharey=False, constrained_layout=True, squeeze=False) if title != None: fig.suptitle(title, fontsize=36) for channel in range(output_dims): for i in range(input_dims): x_low = max(0.0, norm.ppf(0.01, loc=means[:,i,channel], scale=scales[:,i,channel]).min()) x_high = norm.ppf(0.99, loc=means[:,i,channel], scale=scales[:,i,channel]).max() if isinstance(nyquist, np.ndarray): x_high = min(x_high, nyquist[i,channel]) x = np.linspace(x_low, x_high, 1000) psd = np.zeros(x.shape) for q in range(mixtures): single_psd = weights[q,channel] * norm.pdf(x, loc=means[q,i,channel], scale=scales[q,i,channel]) #single_psd = np.log(single_psd+1) axes[channel,i].plot(x, single_psd, '--', c='r', zorder=2) psd += single_psd axes[channel,i].plot(x, psd, 'k-', zorder=1) axes[channel,i].set_yticks([]) axes[channel,i].set_ylim(0, None) if titles != None: axes[channel,i].set_title(titles[channel]) axes[output_dims-1,i].set_xlabel('Frequency') if filename != None: plt.savefig(filename+'.pdf', dpi=300) if show: plt.show() return fig, axes