Module imodels.experimental.bartpy.diagnostics.trees
Expand source code
import numpy as np
from matplotlib import pyplot as plt
from imodels.experimental.bartpy.sklearnmodel import SklearnModel
def plot_tree_depth(model: SklearnModel, ax=None):
if ax is None:
_, ax = plt.subplots(1, 1)
min_depth, mean_depth, max_depth = [], [], []
for sample in model.model_samples:
model_depths = []
for tree in sample.trees:
model_depths += [x.depth for x in tree.nodes]
min_depth.append(np.min(model_depths))
mean_depth.append(np.mean(model_depths))
max_depth.append(np.max(model_depths))
ax.plot(min_depth, label="Min Depth")
ax.plot(mean_depth, label="Mean Depth")
ax.plot(max_depth, label="Max Depth")
ax.set_ylabel("Depth")
ax.set_xlabel("Iteration")
ax.legend()
ax.set_title("Tree Depth by Iteration")
return ax
Functions
def plot_tree_depth(model: SklearnModel, ax=None)
-
Expand source code
def plot_tree_depth(model: SklearnModel, ax=None): if ax is None: _, ax = plt.subplots(1, 1) min_depth, mean_depth, max_depth = [], [], [] for sample in model.model_samples: model_depths = [] for tree in sample.trees: model_depths += [x.depth for x in tree.nodes] min_depth.append(np.min(model_depths)) mean_depth.append(np.mean(model_depths)) max_depth.append(np.max(model_depths)) ax.plot(min_depth, label="Min Depth") ax.plot(mean_depth, label="Mean Depth") ax.plot(max_depth, label="Max Depth") ax.set_ylabel("Depth") ax.set_xlabel("Iteration") ax.legend() ax.set_title("Tree Depth by Iteration") return ax