Как построить точность и отзывчивость мультиклассового классификатора?

Я использую scikit learn и хочу построить кривые точности и запоминания. я использую классификатор RandomForestClassifier. Все ресурсы в документации scikit learn используют двоичную классификацию. Кроме того, могу ли я построить кривую ROC для мультикласса?

Кроме того, я нашел только SVM для мультиэкрана, и у него есть decision_function, у которого RandomForest нет


person John Sall    schedule 11.05.2019    source источник
comment
Здесь есть параграф с примером: scikit-learn.org/stable/auto_examples/ model_selection /. Разве это не то, что вы хотите?   -  person Yohst    schedule 11.05.2019
comment
scikit-learn.org/0.15/auto_examples/plot_precision_recall.html   -  person secretive    schedule 11.05.2019
comment
@Yohst в этом примере используется svm с функцией принятия решения, а в RandomForest нет функций принятия решения.   -  person John Sall    schedule 11.05.2019


Ответы (1)


Из документации scikit-learn:

Кривые точности-отзыва обычно используются в двоичной классификации для изучения результатов работы классификатора. Чтобы расширить кривую точности-отзыва и средней точности до классификации по нескольким классам или меткам, необходимо преобразовать выходные данные в двоичную форму. Для каждой метки можно нарисовать одну кривую, но можно также нарисовать кривую точного отзыва, рассматривая каждый элемент индикаторной матрицы метки как двоичный прогноз (микро-усреднение).

Кривые ROC обычно используются в двоичной классификации для изучения результатов работы классификатора. Чтобы расширить кривую ROC и область ROC до классификации с несколькими классами или метками, необходимо преобразовать выходные данные в двоичную форму. Для каждой метки можно нарисовать одну кривую ROC, но можно также нарисовать кривую ROC, рассматривая каждый элемент индикаторной матрицы метки как двоичный прогноз (микро-усреднение).

Следовательно, вы должны преобразовать выходные данные в двоичную форму и рассмотреть кривые точности-отзыва и roc для каждого класса. Кроме того, вы собираетесь использовать predict_proba, чтобы получить вероятности классов.

Я делю код на три части:

  1. общие настройки, обучение и предсказание
  2. кривая точности-отзыва
  3. Кривая ROC

1. общие настройки, обучение и прогнозирование

from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.preprocessing import label_binarize

import matplotlib.pyplot as plt
#%matplotlib inline

mnist = fetch_mldata("MNIST original")
n_classes = len(set(mnist.target))

Y = label_binarize(mnist.target, classes=[*range(n_classes)])

X_train, X_test, y_train, y_test = train_test_split(mnist.data,
                                                    Y,
                                                    random_state = 42)

clf = OneVsRestClassifier(RandomForestClassifier(n_estimators=50,
                             max_depth=3,
                             random_state=0))
clf.fit(X_train, y_train)

y_score = clf.predict_proba(X_test)

2. кривая точности-отзыва

# precision recall curve
precision = dict()
recall = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        y_score[:, i])
    plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(i))
    
plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs. recall curve")
plt.show()

введите описание изображения здесь

3. Кривая ROC

# roc curve
fpr = dict()
tpr = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i],
                                  y_score[:, i]))
    plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(i))

plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()

введите описание изображения здесь

person sentence    schedule 11.05.2019
comment
почему я использую OneVsRestClassifier? разве RandomForest уже не поддерживает мультикласс? - person John Sall; 12.05.2019
comment
У меня возникают эти ошибки, когда я запускаю первую часть: UserWarning: Label not 0 присутствует во всех обучающих примерах UserWarning: Label not 1 присутствует во всех обучающих примерах UserWarning: Label not 2 присутствует во всех обучающих примерах - person John Sall; 12.05.2019
comment
Обратите внимание, что предупреждение НЕ является ошибкой. Учитывая эту строку Y = label_binarize(mnist.target, classes=[*range(n_classes)]), вы должны указать классы в своем наборе данных. В моем примере это классы [0,1,2,...,9]. - person sentence; 12.05.2019