Из документации scikit-learn:
Кривые точности-отзыва обычно используются в двоичной классификации для изучения результатов работы классификатора. Чтобы расширить кривую точности-отзыва и средней точности до классификации по нескольким классам или меткам, необходимо преобразовать выходные данные в двоичную форму. Для каждой метки можно нарисовать одну кривую, но можно также нарисовать кривую точного отзыва, рассматривая каждый элемент индикаторной матрицы метки как двоичный прогноз (микро-усреднение).
Кривые ROC обычно используются в двоичной классификации для изучения результатов работы классификатора. Чтобы расширить кривую ROC и область ROC до классификации с несколькими классами или метками, необходимо преобразовать выходные данные в двоичную форму. Для каждой метки можно нарисовать одну кривую ROC, но можно также нарисовать кривую ROC, рассматривая каждый элемент индикаторной матрицы метки как двоичный прогноз (микро-усреднение).
Следовательно, вы должны преобразовать выходные данные в двоичную форму и рассмотреть кривые точности-отзыва и roc для каждого класса. Кроме того, вы собираетесь использовать predict_proba
, чтобы получить вероятности классов.
Я делю код на три части:
- общие настройки, обучение и предсказание
- кривая точности-отзыва
- Кривая ROC
1. общие настройки, обучение и прогнозирование
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
#%matplotlib inline
mnist = fetch_mldata("MNIST original")
n_classes = len(set(mnist.target))
Y = label_binarize(mnist.target, classes=[*range(n_classes)])
X_train, X_test, y_train, y_test = train_test_split(mnist.data,
Y,
random_state = 42)
clf = OneVsRestClassifier(RandomForestClassifier(n_estimators=50,
max_depth=3,
random_state=0))
clf.fit(X_train, y_train)
y_score = clf.predict_proba(X_test)
2. кривая точности-отзыва
# precision recall curve
precision = dict()
recall = dict()
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
y_score[:, i])
plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(i))
plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs. recall curve")
plt.show()
![введите описание изображения здесь](https://i.stack.imgur.com/6g2Ir.png)
3. Кривая ROC
# roc curve
fpr = dict()
tpr = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i],
y_score[:, i]))
plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(i))
plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()
![введите описание изображения здесь](https://i.stack.imgur.com/VHLHl.png)
person
sentence
schedule
11.05.2019