Ошибка mnlogit при вызове предсказать с новыми данными

Я хотел бы подогнать модель с использованием пакета mnlogit и использовать ее для прогнозирования вне выборки. Я создал игрушечный пример, используя данные о рыбалке, которые поставляются с mnlogit:

library(data.table)
library(mnlogit)

data(Fish, package="mnlogit")
fish_dt <- data.table(Fish)
rm(Fish)

unique_id <- unique(fish_dt[, chid])
set.seed(54321)
train_id <- sample(unique_id, size=0.5*length(unique_id))

setkey(fish_dt, chid, alt)
train <- fish_dt[J(train_id)]
test <- fish_dt[!J(train_id)]
setkey(train, chid, alt)
setkey(test, chid, alt)
stopifnot(nrow(train) + nrow(test) == nrow(fish_dt))  # Partition fish_dt

mnlogit_formula <- mode ~ catch | income
mnlogit_model <- mnlogit(mnlogit_formula, data=train, choiceVar="alt")

train_predictions <- predict(mnlogit_model, probability=F)
stopifnot(length(train_predictions) == length(unique(train[, chid])))  # One per choice
mean(subset(train, mode)[, alt] == train_predictions)  # Around 0.42 accuracy in sample

## Would like to do the same out of sample, i.e. with data table "test"
test_predictions <- predict(mnlogit_model, newdata=test, probability=F)  # Error
test_predictions <- predict(mnlogit_model, newdata=as.data.frame(test), probability=F)  # Same error

Я получаю следующее сообщение об ошибке:

Ошибка в _2 _ (_ 3_, value = list (chid = c (1L, 2L, 3L, 4L, 5L,: длина 'dimnames' [2] не равна экстенту массива

Я использую R версии 3.0.2 (25.09.2013) на Ubuntu 14.04.2 LTS.

Я неправильно использую пакет или это ошибка?

Изменить: см. Комментарии: я попытался удалить столбец «режим» из таблицы «тестовых» данных, но это дает мне ошибку «новые данные должны иметь те же столбцы, что и данные обучения»:

test[, mode := NULL]
mnlogit_predictions <- predict(mnlogit_model, newdata=test, probability=F)  # Error

Изменить: вот пример, в котором я использую пакет mlogit (который похож, но может быть значительно медленнее для больших проблем):

library(data.table)
library(mlogit)

data(Fish, package="mnlogit")
fish_dt <- data.table(Fish)
rm(Fish)

unique_id <- unique(fish_dt[, chid])
set.seed(54321)
train_id <- sample(unique_id, size=0.5*length(unique_id))

setkey(fish_dt, chid, alt)
train <- fish_dt[J(train_id)]
test <- fish_dt[!J(train_id)]
setkey(train, chid, alt)
setkey(test, chid, alt)
stopifnot(nrow(train) + nrow(test) == nrow(fish_dt))  # Partition fish_dt

train_mlogit <- mlogit.data(train, choice="mode", shape="long",
                            chid.var="chid", alt.var="alt")
test_mlogit <- mlogit.data(test, choice="mode", shape="long",
                           chid.var="chid", alt.var="alt")

model_formula <- mode ~ catch | income
mlogit_model <- mlogit(model_formula, data=train_mlogit)

## In-sample performance
train_predictions <- predict(mlogit_model, newdata=train_mlogit)
stopifnot(nrow(train_predictions) == length(unique(train[, chid])))  # One per choice
train_predictions <- colnames(train_predictions)[apply(train_predictions, 1, which.max)]
mean(subset(train, mode)[, alt] == train_predictions)  # Around 0.42 accuracy in sample

## Out-of-sample performance
test_predictions <- predict(mlogit_model, newdata=test_mlogit)
test_predictions <- colnames(test_predictions)[apply(test_predictions, 1, which.max)]
mean(subset(test, mode)[, alt] == test_predictions)  # Around 0.41 accuracy out of sample

Я бы хотел сделать именно это, но с mnlogit вместо mlogit.


person Adrian    schedule 08.04.2015    source источник
comment
Вы можете проверить пакет на странице github. Одна из возможных ошибок может заключаться в определении переменной ответа (для тестового набора данных) в вашем коде.   -  person S Das    schedule 09.04.2015
comment
@SubasishDas Вы имеете в виду столбец режима в тестовом фрейме данных? Если я удалю его, у меня должны быть те же столбцы, что и ошибка обучающих данных.   -  person Adrian    schedule 10.04.2015


Ответы (1)


Predict хорошо работает с объектами mnlogit. Однако ваш «тестовый» объект - это таблица данных, а не объект mlogit.data, поэтому вам также потребуется передать choiceVar в вызове predicter.

test_predictions <- predict(mnlogit_model, newdata=test,choiceVar="alt")  # works
test_predictions <- predict(mnlogit_model, newdata=as.data.frame(test), probability=F,choiceVar="alt")  # this also works

Спасибо

person user2778822    schedule 07.06.2017