Поиск по сетке возвращает точно такой же результат для пользовательской модели

Я оборачиваю модель Scikit-Learn Random Forest в функцию следующим образом:

from sklearn.base import BaseEstimator, RegressorMixin

class Model(BaseEstimator, RegressorMixin):
    def __init__(self, model):
        self.model = model
    
    def fit(self, X, y):
        self.model.fit(X, y)
        
        return self
    
    def score(self, X, y):
           
        from sklearn.metrics import mean_squared_error
        
        return mean_squared_error(y_true=y, 
                                  y_pred=self.model.predict(X), 
                                  squared=False)
    
    def predict(self, X):
        return self.model.predict(X)
class RandomForest(Model):
    def __init__(self, n_estimators=100, 
                 max_depth=None, min_samples_split=2,
                 min_samples_leaf=1, max_features=None):
        
        self.n_estimators=n_estimators 
        self.max_depth=max_depth
        self.min_samples_split=min_samples_split
        self.min_samples_leaf=min_samples_leaf
        self.max_features=max_features
           
        from sklearn.ensemble import RandomForestRegressor
 
        self.model = RandomForestRegressor(n_estimators=self.n_estimators, 
                                           max_depth=self.max_depth, 
                                           min_samples_split=self.min_samples_split,
                                           min_samples_leaf=self.min_samples_leaf, 
                                           max_features=self.max_features,
                                           random_state = 777)
    
    
    def get_params(self, deep=True):
        return {"n_estimators": self.n_estimators,
                "max_depth": self.max_depth,
                "min_samples_split": self.min_samples_split,
                "min_samples_leaf": self.min_samples_leaf,
                "max_features": self.max_features}

    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

В основном я следую официальному руководству Scikit-Learn, которое можно найти по адресу https://scikit-learn.org/stable/developers/develop.html

Вот как выглядит мой поиск по сетке:

grid_search = GridSearchCV(estimator=RandomForest(), 
                            param_grid={'max_depth':[1, 3, 6], 'n_estimators':[10, 100, 300]},
                            n_jobs=-1, 
                            scoring='neg_root_mean_squared_error',
                            cv=5, verbose=True).fit(X, y)
    
print(pd.DataFrame(grid_search.cv_results_).sort_values(by='rank_test_score'))

Результат поиска по сетке и grid_search.cv_results_ напечатаны ниже

Fitting 5 folds for each of 9 candidates, totalling 45 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
   mean_fit_time  std_fit_time  mean_score_time  std_score_time  \
0       0.210918      0.002450         0.016754        0.000223   
1       0.207049      0.001675         0.016579        0.000147   
2       0.206495      0.002001         0.016598        0.000158   
3       0.206799      0.002417         0.016740        0.000144   
4       0.207534      0.001603         0.016668        0.000269   
5       0.206384      0.001396         0.016605        0.000136   
6       0.220052      0.024280         0.017247        0.001137   
7       0.226838      0.027507         0.017351        0.000979   
8       0.205738      0.003420         0.016246        0.000626   

  param_max_depth param_n_estimators                                 params  \
0               1                 10   {'max_depth': 1, 'n_estimators': 10}   
1               1                100  {'max_depth': 1, 'n_estimators': 100}   
2               1                300  {'max_depth': 1, 'n_estimators': 300}   
3               3                 10   {'max_depth': 3, 'n_estimators': 10}   
4               3                100  {'max_depth': 3, 'n_estimators': 100}   
5               3                300  {'max_depth': 3, 'n_estimators': 300}   
6               6                 10   {'max_depth': 6, 'n_estimators': 10}   
7               6                100  {'max_depth': 6, 'n_estimators': 100}   
8               6                300  {'max_depth': 6, 'n_estimators': 300}   

   split0_test_score  split1_test_score  split2_test_score  split3_test_score  \
0          -5.246725          -3.200585          -3.326962          -3.209387   
1          -5.246725          -3.200585          -3.326962          -3.209387   
2          -5.246725          -3.200585          -3.326962          -3.209387   
3          -5.246725          -3.200585          -3.326962          -3.209387   
4          -5.246725          -3.200585          -3.326962          -3.209387   
5          -5.246725          -3.200585          -3.326962          -3.209387   
6          -5.246725          -3.200585          -3.326962          -3.209387   
7          -5.246725          -3.200585          -3.326962          -3.209387   
8          -5.246725          -3.200585          -3.326962          -3.209387   

   split4_test_score  mean_test_score  std_test_score  rank_test_score  
0          -2.911422        -3.579016        0.845021                1  
1          -2.911422        -3.579016        0.845021                1  
2          -2.911422        -3.579016        0.845021                1  
3          -2.911422        -3.579016        0.845021                1  
4          -2.911422        -3.579016        0.845021                1  
5          -2.911422        -3.579016        0.845021                1  
6          -2.911422        -3.579016        0.845021                1  
7          -2.911422        -3.579016        0.845021                1  
8          -2.911422        -3.579016        0.845021                1  
[Parallel(n_jobs=-1)]: Done  45 out of  45 | elapsed:    3.2s finished

Мой вопрос: почему поиск по сетке возвращает точно такой же результат для всех разбиений данных?

Я предполагаю, что поиск по сетке выполняет только 1 сетку параметров (например, {'max_depth': 1, 'n_estimators': 10}) для всех разделений данных. Если это так, то почему это происходит?

Наконец, как сделать так, чтобы поиск по сетке возвращал правильный результат для всех разбиений данных?


person glorian    schedule 05.08.2020    source источник
comment
Ваше предположение неверно; из cv_results_ видно, что перепробованы все комбинации гиперпараметров (поэтому и у вас 9 записей) - см. столбцы param_max_depth и param_n_estimators. Без ваших данных невозможно сказать что-то еще, но первым шагом отладки будет запуск этого без вашего класса-оболочки (т.е. с нативным РФ scikit-learn).   -  person desertnaut    schedule 05.08.2020
comment
Из того, что вы показали, я не понимаю, почему вы просто не используете RandomForestRegressor напрямую; зачем вам этот класс-оболочка?   -  person Ben Reiniger    schedule 05.08.2020
comment
@desertnaut, если я использую RandomForestRegressor() из scikit-learn, он работает совершенно нормально, то есть возвращает правильный результат для всех разделений данных   -  person glorian    schedule 06.08.2020


Ответы (1)


Ваш метод set_params на самом деле не изменяет гиперпараметры экземпляра RandomForestRegressor в атрибуте self.model. Вместо этого он напрямую устанавливает атрибуты для вашего экземпляра RandomForest (которых раньше не было и они не влияют на реальную модель!). Таким образом, поиск по сетке постоянно устанавливает эти новые параметры, которые не имеют значения, и фактическая подгонка модели каждый раз остается одной и той же. (Аналогичным образом метод get_params получает атрибуты RandomForest, которые не совпадают с атрибутами RandomForestRegressor.)

Вы должны быть в состоянии исправить большую часть этого, если set_params просто вызовет self.model.set_params (и пусть get_params будет использовать self.model.<parameter_name> вместо self.<parameter_name>.

Я думаю, есть еще одна проблема, но я не знаю, как вообще работает ваш пример из-за этого: вы создаете экземпляр атрибута model, используя self.<parameter_name>, но он никогда не определяется в __init__.

person Ben Reiniger    schedule 05.08.2020
comment
Большое спасибо за решение! Да, вы правы, я думаю, я неправильно присвоил атрибуты обертке. Что касается другой проблемы с атрибутами, которые никогда не создаются в __init__(), я не вставил ее в этот пост. Я отредактировал исходный вопрос. Спасибо за ваше замечание! - person glorian; 06.08.2020