scikit изучает перекрестную проверку сетки, возвращая неверное среднее значение

Я использовал GridCV для перекрестной проверки по k сгибам, чтобы настроить мои гиперпараметры. Средние результаты, которые должны были быть средними по отдельным сгибам, неверны в моем атрибуте результатов «cv_results_». Ниже приведен мой код для того же:

gscv = GridSearchCV(n_jobs=n_jobs,cv=train_test_iterable, estimator=pipeline, param_grid=param_grid, 
                verbose=10, scoring=['accuracy', 'precision','recall','f1'], refit='f1', 
                    return_train_score=return_train_score, error_score=error_score,
                   )
gscv.fit(X,Y)
gscv.cv_results_

cv_results_ содержит следующий json (отображается в виде таблицы)

    mean_test_f1    split0_test_f1  split1_test_f1  Actual Mean
    0.934310796     0.935603198     0.933665455     0.934634326
    0.931279716     0.908430118     0.942689316     0.925559717
    0.927683609     0.912005672     0.935512149     0.923758911
    0.680908006     0.741198823     0.650802701     0.696000762
    0.680908006     0.741198823     0.650802701     0.696000762
    0.646005028     0.684483208     0.626791532     0.65563737
    0.840273248     0.847484083     0.836672627     0.842078355
    0.837160828     0.847484083     0.832006068     0.839745075
    0.833637        0.842109375     0.829406448     0.835757911

Вы можете видеть выше: «mean_test_f1» не является средним значением двух сгибов «split0_test_f1», «split1_test_f1». Фактическое среднее — последний столбец.

Примечание: F1 означает оценку f1.

Кто-нибудь сталкивался с подобными проблемами?


person mlengg    schedule 09.05.2018    source источник


Ответы (2)


Я думаю, что то, что вы видите, является взвешенным средним, а не прямым средним.

person tangy    schedule 09.05.2018

Попробуйте установить iid=False в GridSearchCV(...) и сравните.

Согласно документации:

iid : boolean, default=True

    If True, the data is assumed to be identically distributed across 
    the folds, and the loss minimized is the total loss per sample,
    and not the mean loss across the folds.

Поэтому, когда iid имеет значение True (по умолчанию), усреднение результатов теста включает вес, как указано здесь, в исходном коде:

    _store('test_%s' % scorer_name, test_scores[scorer_name],
                   splits=True, rank=True,
                   weights=test_sample_counts if iid else None)

Обратите внимание, что это не влияет на баллы поездов, так что также перепроверяйте средние баллы поездов.

person Vivek Kumar    schedule 10.05.2018