Module mogptk.plot

Expand source code Browse git
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

def plot_spectrum(means, scales, weights=None, nyquist=None, titles=None, show=True, filename=None, title=None):
    """
    Plot spectral Gaussians of given means, scales and weights.
    """
    if means.ndim == 2:
        means = np.expand_dims(means, axis=2)
    if scales.ndim == 2:
        scales = np.expand_dims(scales, axis=2)
    if isinstance(weights, np.ndarray) and weights.ndim == 1:
        weights = np.expand_dims(weights, axis=1)
    if isinstance(nyquist, np.ndarray) and nyquist.ndim == 1:
        nyquist = np.expand_dims(nyquist, axis=1)

    if means.ndim != 3:
        raise ValueError('means and scales must have shape (mixtures,output_dims,input_dims)')
    if means.shape != scales.shape:
        raise ValueError('means and scales must have the same shape (mixtures,output_dims,input_dims)')

    mixtures = means.shape[0]
    output_dims = means.shape[1]
    input_dims = means.shape[2]

    if isinstance(weights, np.ndarray) and (weights.ndim != 2 or weights.shape[0] != mixtures or weights.shape[1] != output_dims):
        raise ValueError('weights must have shape (mixtures,output_dims)')
    elif not isinstance(weights, np.ndarray):
        weights = np.ones((mixtures,output_dims))
    if isinstance(nyquist, np.ndarray) and (nyquist.ndim != 2 or nyquist.shape[0] != output_dims or nyquist.shape[1] != input_dims):
        raise ValueError('nyquist must have shape (output_dims,input_dims)')

    h = 3.0*output_dims
    fig, axes = plt.subplots(output_dims, input_dims, figsize=(12,h), squeeze=False, constrained_layout=True)
    if title is not None:
        fig.suptitle(title, y=(h+0.8)/h, fontsize=18)
    
    for j in range(output_dims):
        for i in range(input_dims):
            x_low = max(0.0, norm.ppf(0.01, loc=means[:,j,i], scale=scales[:,j,i]).min())
            x_high = norm.ppf(0.99, loc=means[:,j,i], scale=scales[:,j,i]).max()
            if isinstance(nyquist, np.ndarray):
                x_high = min(x_high, nyquist[j,i])

            x = np.linspace(x_low, x_high, 1000)
            psd = np.zeros(x.shape)

            for q in range(mixtures):
                single_psd = weights[q,j] * norm.pdf(x, loc=means[q,j,i], scale=scales[q,j,i])
                #single_psd = np.log(single_psd+1)
                axes[j,i].plot(x, single_psd, ls='--', c='k', zorder=2)
                axes[j,i].axvline(means[q,j,i], ymin=0.001, ymax=0.1, lw=2, color='silver')
                psd += single_psd
           
            axes[j,i].plot(x, psd, ls='-', c='k', zorder=1)
            axes[j,i].set_yticks([])
            axes[j,i].set_ylim(0, None)
            if titles is not None:
                axes[j,i].set_title(titles[j])

    axes[output_dims-1,i].set_xlabel('Frequency')

    legends = []
    legends.append(plt.Line2D([0], [0], ls='-', color='k', label='Total'))
    legends.append(plt.Line2D([0], [0], ls='--', color='k', label='Mixture'))
    legends.append(plt.Line2D([0], [0], ls='-', lw=2, color='silver', label='Peak location'))
    fig.legend(handles=legends, loc="upper center", bbox_to_anchor=(0.5,(h+0.4)/h), ncol=3)

    if filename is not None:
        plt.savefig(filename+'.pdf', dpi=300)
    if show:
        plt.show()
    return fig, axes

Functions

def plot_spectrum(means, scales, weights=None, nyquist=None, titles=None, show=True, filename=None, title=None)

Plot spectral Gaussians of given means, scales and weights.

Expand source code Browse git
def plot_spectrum(means, scales, weights=None, nyquist=None, titles=None, show=True, filename=None, title=None):
    """
    Plot spectral Gaussians of given means, scales and weights.
    """
    if means.ndim == 2:
        means = np.expand_dims(means, axis=2)
    if scales.ndim == 2:
        scales = np.expand_dims(scales, axis=2)
    if isinstance(weights, np.ndarray) and weights.ndim == 1:
        weights = np.expand_dims(weights, axis=1)
    if isinstance(nyquist, np.ndarray) and nyquist.ndim == 1:
        nyquist = np.expand_dims(nyquist, axis=1)

    if means.ndim != 3:
        raise ValueError('means and scales must have shape (mixtures,output_dims,input_dims)')
    if means.shape != scales.shape:
        raise ValueError('means and scales must have the same shape (mixtures,output_dims,input_dims)')

    mixtures = means.shape[0]
    output_dims = means.shape[1]
    input_dims = means.shape[2]

    if isinstance(weights, np.ndarray) and (weights.ndim != 2 or weights.shape[0] != mixtures or weights.shape[1] != output_dims):
        raise ValueError('weights must have shape (mixtures,output_dims)')
    elif not isinstance(weights, np.ndarray):
        weights = np.ones((mixtures,output_dims))
    if isinstance(nyquist, np.ndarray) and (nyquist.ndim != 2 or nyquist.shape[0] != output_dims or nyquist.shape[1] != input_dims):
        raise ValueError('nyquist must have shape (output_dims,input_dims)')

    h = 3.0*output_dims
    fig, axes = plt.subplots(output_dims, input_dims, figsize=(12,h), squeeze=False, constrained_layout=True)
    if title is not None:
        fig.suptitle(title, y=(h+0.8)/h, fontsize=18)
    
    for j in range(output_dims):
        for i in range(input_dims):
            x_low = max(0.0, norm.ppf(0.01, loc=means[:,j,i], scale=scales[:,j,i]).min())
            x_high = norm.ppf(0.99, loc=means[:,j,i], scale=scales[:,j,i]).max()
            if isinstance(nyquist, np.ndarray):
                x_high = min(x_high, nyquist[j,i])

            x = np.linspace(x_low, x_high, 1000)
            psd = np.zeros(x.shape)

            for q in range(mixtures):
                single_psd = weights[q,j] * norm.pdf(x, loc=means[q,j,i], scale=scales[q,j,i])
                #single_psd = np.log(single_psd+1)
                axes[j,i].plot(x, single_psd, ls='--', c='k', zorder=2)
                axes[j,i].axvline(means[q,j,i], ymin=0.001, ymax=0.1, lw=2, color='silver')
                psd += single_psd
           
            axes[j,i].plot(x, psd, ls='-', c='k', zorder=1)
            axes[j,i].set_yticks([])
            axes[j,i].set_ylim(0, None)
            if titles is not None:
                axes[j,i].set_title(titles[j])

    axes[output_dims-1,i].set_xlabel('Frequency')

    legends = []
    legends.append(plt.Line2D([0], [0], ls='-', color='k', label='Total'))
    legends.append(plt.Line2D([0], [0], ls='--', color='k', label='Mixture'))
    legends.append(plt.Line2D([0], [0], ls='-', lw=2, color='silver', label='Peak location'))
    fig.legend(handles=legends, loc="upper center", bbox_to_anchor=(0.5,(h+0.4)/h), ncol=3)

    if filename is not None:
        plt.savefig(filename+'.pdf', dpi=300)
    if show:
        plt.show()
    return fig, axes