Как создать пользовательскую модель (используя трюк цикла/подмоделей в каретке)

Я борюсь с этой проблемой в течение смущающе долгого времени. Чувствую себя абсолютным кретином, так как ответ наверное до боли очевиден, но не могу найти ни одной ветки объясняющей как это сделать.

Часть документации о создании пользовательских моделей мне кажется этой. Я чувствую, что где-то во время моего обучения я пропустил какой-то очень специфический урок, который теперь помнят все, кроме меня, потому что все, что я нахожу, это «да, просто создайте пользовательскую модель и готово».

Актуальные вопросы здесь:

Я хочу получать прогнозы для каждой отдельной итерации gbm в caret. Например, в gbm я могу просто использовать n.trees в predict(..., n.trees = 1:100), и все готово.

В caret, по-видимому, мне нужно использовать что-то, называемое трюком с подмоделями, что означает, если я правильно понимаю, что я должен создать свою собственную модель.

Но я вижу в getModelInfo('gbm'), что есть какая-то функция цикла!

$gbm$loop
function (grid) 
{
    loop <- plyr::ddply(grid, c("shrinkage", "interaction.depth", 
        "n.minobsinnode"), function(x) c(n.trees = max(x$n.trees)))
    submodels <- vector(mode = "list", length = nrow(loop))
    for (i in seq(along = loop$n.trees)) {
        index <- which(grid$interaction.depth == loop$interaction.depth[i] & 
            grid$shrinkage == loop$shrinkage[i] & grid$n.minobsinnode == 
            loop$n.minobsinnode[i])
        trees <- grid[index, "n.trees"]
        submodels[[i]] <- data.frame(n.trees = trees[trees != 
            loop$n.trees[i]])
    }
    list(loop = loop, submodels = submodels)

Как мне это использовать? Почему он не работает по умолчанию? Действительно ли мне нужно создавать пользовательскую модель, а может и нет?

Отказ от ответственности 1: я не хочу использовать перекрестную проверку. Я просто хочу делать прогнозы для каждой итерации одного прогона gbm.

Отказ от ответственности 2: я не хочу использовать predict.gbm() на $finalModel, так как я хочу также протестировать некоторые другие алгоритмы, которые также используют этот прием с подмоделями. Я не хочу использовать все различные функции predict(), специфичные для алгоритма, потому что тогда зачем мне вообще возиться с кареткой.

Я даже не знаю, что мне поставить в качестве тиражируемого примера. С кодом проблем нет. Я просто понятия не имею, как эта штука должна работать.


person M. Ike    schedule 26.08.2018    source источник
comment
Итак, вы хотели бы получить прогнозы по обучающим данным для каждого дерева? Какой в ​​этом смысл? Я мог бы помочь, если вы хотите получить прогнозы с перекрестной проверкой/загрузкой для каждого дерева без создания пользовательской модели. Насколько мне известно, каретка не позволяет легко получить прогнозы поездов для любой модели, поскольку они очень мало значат.   -  person missuse    schedule 27.08.2018
comment
@missuse Почти так, но я хочу получить прогнозы для каждого дерева как для обучающих, так и для тестовых данных, чтобы позже создать кривые обучения для презентационных целей. Я хочу иметь полный контроль над параметрами и хочу максимально уменьшить количество элементов типа «черный ящик», поэтому на данный момент меня не интересуют варианты перекрестной проверки. Я также не понимаю, почему карет делает это настолько трудным для достижения. Этот функционал мне кажется достаточно важным для сравнения производительности разных алгоритмов за раз на презентациях.   -  person M. Ike    schedule 27.08.2018
comment
Производительность алгоритма не может быть оценена на данных поезда, это, скорее всего, причина, по которой автор карета решил не предоставлять легкий доступ к прогнозам данных поезда.   -  person missuse    schedule 27.08.2018
comment
@missuse Я указал, что хочу также (и в основном) получать прогнозы для тестовых данных.   -  person M. Ike    schedule 27.08.2018


Ответы (1)


Вот пример того, как получить желаемые прогнозы для тестовых данных для каждого дерева:

library(caret)
library(mlbench) #for the data set
data(Sonar) #some data set I always use on stack overflow

res <- train(Class~.,
             data = Sonar,
             method = "gbm",
             trControl = trainControl(method = "cv", #some evaluations scheme
                                      number = 5,
                                      savePredictions = "all"), #tell caret you would like to save all,
             tuneGrid = expand.grid(shrinkage = 0.01,
                                    interaction.depth = 2, 
                                    n.minobsinnode = 10,
                                    n.trees = 1:100)) #some random values and all the trees

res$pred #results are stored in here

По сути, код, который вы показываете в посте, говорит Caret не настраивать все модели n.tree, а просто настраивать одну с max(n.trees) для каждой комбинации гиперпараметров, а затем использовать ее для получения прогнозов для n.trees < max(n.trees)

какой-то сюжет

library(ggplot2)

ggplot(res$results)+
  geom_line(aes(x = n.trees, y = Accuracy))

введите здесь описание изображения

Вы также можете отказаться от savePredictions = "all", так как это создает объект поезда, требующий памяти. А лучше использовать res$results в котором бы вы считали все нужные метрики.

person missuse    schedule 27.08.2018
comment
Это очень похоже на то, что я хочу получить. Но можно ли сделать то же самое без перекрестной проверки? (trainControl(метод = «нет»)). Извините, я с мобильного, редактирование здесь довольно ограничено. Потому что, когда я это сделал, каретка говорила мне, что я не могу указать диапазон значений для параметра. - person M. Ike; 27.08.2018
comment
К сожалению, нет, так как с method = “none” прогнозы не предоставляются, и вы можете указать только одну комбинацию гиперпараметров. - person missuse; 27.08.2018
comment
Это действительно странно, что это не разрешено. Кажется, мой единственный вариант - создать собственную модель, я думаю? Спасибо за все что ты сделал для меня. - person M. Ike; 27.08.2018
comment
Вы можете извлечь как поезд, так и тестовый прогноз с помощью библиотеки mlr. Проверьте мой ответ здесь: stackoverflow.com/questions/48754886/ - person missuse; 27.08.2018