Module imodels.experimental.bartpy.diagnostics.diagnostics

Expand source code
from matplotlib import pyplot as plt

from imodels.experimental.bartpy.diagnostics.residuals import plot_qq, plot_homoskedasity_diagnostics
from imodels.experimental.bartpy.diagnostics.sampling import plot_tree_mutation_acceptance_rate
from imodels.experimental.bartpy.diagnostics.sigma import plot_sigma_convergence
from imodels.experimental.bartpy.diagnostics.trees import plot_tree_depth
from imodels.experimental.bartpy.sklearnmodel import SklearnModel


def plot_diagnostics(model: SklearnModel):
    fig, ((ax1, ax2, ax3), (ax4, ax5, _)) = plt.subplots(2, 3, figsize=(10, 10))
    fig.suptitle("Diagnostics")
    plot_qq(model, ax1)
    plot_tree_depth(model, ax2)
    plot_sigma_convergence(model, ax3)
    plot_homoskedasity_diagnostics(model, ax4)
    plot_tree_mutation_acceptance_rate(model, ax5)
    plt.show()

Functions

def plot_diagnostics(model: SklearnModel)
Expand source code
def plot_diagnostics(model: SklearnModel):
    fig, ((ax1, ax2, ax3), (ax4, ax5, _)) = plt.subplots(2, 3, figsize=(10, 10))
    fig.suptitle("Diagnostics")
    plot_qq(model, ax1)
    plot_tree_depth(model, ax2)
    plot_sigma_convergence(model, ax3)
    plot_homoskedasity_diagnostics(model, ax4)
    plot_tree_mutation_acceptance_rate(model, ax5)
    plt.show()