plot_learning_curves
plot_learning_curves(X_train, y_train, X_test, y_test, clf, train_marker='o', test_marker='^', scoring='misclassification error', suppress_plot=False, print_model=True, style='fivethirtyeight', legend_loc='best')
Plots learning curves of a classifier.
Parameters
-
X_train
: array-like, shape = [n_samples, n_features]Feature matrix of the training dataset.
-
y_train
: array-like, shape = [n_samples]True class labels of the training dataset.
-
X_test
: array-like, shape = [n_samples, n_features]Feature matrix of the test dataset.
-
y_test
: array-like, shape = [n_samples]True class labels of the test dataset.
-
clf
: Classifier object. Must have a .predict .fit method. -
train_marker
: str (default: 'o')Marker for the training set line plot.
-
test_marker
: str (default: '^')Marker for the test set line plot.
-
scoring
: str (default: 'misclassification error')If not 'misclassification error', accepts the following metrics (from scikit-learn): {'accuracy', 'average_precision', 'f1_micro', 'f1_macro', 'f1_weighted', 'f1_samples', 'log_loss', 'precision', 'recall', 'roc_auc', 'adjusted_rand_score', 'mean_absolute_error', 'mean_squared_error', 'median_absolute_error', 'r2'}
-
suppress_plot=False
: bool (default: False)Suppress matplotlib plots if True. Recommended for testing purposes.
-
print_model
: bool (default: True)Print model parameters in plot title if True.
-
style
: str (default: 'fivethirtyeight')Matplotlib style
-
legend_loc
: str (default: 'best')Where to place the plot legend: {'best', 'upper left', 'upper right', 'lower left', 'lower right'}
Returns
errors
: (training_error, test_error): tuple of lists
Examples
For usage examples, please see http://rasbt.github.io/mlxtend/user_guide/plotting/plot_learning_curves/