Я пытаюсь обучить модель прогнозированию категории текстовых входных данных. Я сталкиваюсь с численной нестабильностью, используя классификатор pyspark.ml.classification.NaiveBayes
для набора слов, когда количество классов превышает определенное количество.
В моем реальном проекте у меня порядка ~1 миллиарда записей и ~50 классов. Я могу обучать свою модель и делать прогнозы, но получаю сообщение об ошибке, когда пытаюсь сохранить ее с помощью model.save()
. С точки зрения эксплуатации это раздражает, так как мне приходится каждый раз переобучать свою модель с нуля.
При попытке отладки я уменьшил свои данные примерно до 10 тыс. строк и столкнулся с той же проблемой при попытке сохранить. Однако сохранение работает нормально, если я уменьшу количество меток классов.
Это наводит меня на мысль, что существует ограничение на количество ярлыков. Я не могу воспроизвести свои точные проблемы, но приведенный ниже код связан. Если я установлю num_labels
на значение больше 31, model.fit()
выдаст ошибку.
Мои вопросы:
- Есть ли ограничение на количество классов в
mllib
реализацииNaiveBayes
? - По каким причинам я не могу сохранить свою модель, если я могу успешно использовать ее для прогнозирования?
- Если действительно существует ограничение, можно ли будет разделить мои данные на группы более мелких классов, обучить отдельные модели и объединить?
Полный рабочий пример
Создайте фиктивные данные.
Я собираюсь использовать nltk.corpus.comparitive_sentences
и nltk.corpus.sentence_polarity
. Имейте в виду, что это всего лишь иллюстративный пример с бессмысленными данными — меня не интересует производительность подобранной модели.
import pandas as pd
from pyspark.sql.types import StringType
# create some dummy data
from nltk.corpus import comparative_sentences, sentence_polarity
df = pd.DataFrame(
{
'sentence': [" ".join(s) for s in cs.sents() + sp.sents()]
}
)
# assign a 'category' to each row
num_labels = 31 # seems to be the upper limit
df['category'] = (df.index%num_labels).astype(str)
# make it into a spark dataframe
spark_df = sqlCtx.createDataFrame(df)
Конвейер подготовки данных
from pyspark.ml.feature import NGram, Tokenizer, StopWordsRemover
from pyspark.ml.feature import HashingTF, IDF, StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.linalg import Vector
indexer = StringIndexer(inputCol='category', outputCol='label')
tokenizer = Tokenizer(inputCol="sentence", outputCol="sentence_tokens")
remove_stop_words = StopWordsRemover(inputCol="sentence_tokens", outputCol="filtered")
unigrammer = NGram(n=1, inputCol="filtered", outputCol="tokens")
hashingTF = HashingTF(inputCol="tokens", outputCol="hashed_tokens")
idf = IDF(inputCol="hashed_tokens", outputCol="tf_idf_tokens")
clean_up = VectorAssembler(inputCols=['tf_idf_tokens'], outputCol='features')
data_prep_pipe = Pipeline(
stages=[indexer, tokenizer, remove_stop_words, unigrammer, hashingTF, idf, clean_up]
)
transformed = data_prep_pipe.fit(spark_df).transform(spark_df)
clean_data = transformed.select(['label','features'])
Обучение модели
from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()
(training,testing) = clean_data.randomSplit([0.7,0.3], seed=12345)
model = nb.fit(training)
test_results = model.transform(testing)
Оценить модель
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
acc_eval = MulticlassClassificationEvaluator()
acc = acc_eval.evaluate(test_results)
print("Accuracy of model at predicting label was: {}".format(acc))
На моей машине это печатает:
Accuracy of model at predicting label was: 0.0305764788269
Сообщение об ошибке
Если я изменяю num_labels
на 32 или выше, это ошибка, которую я получаю при вызове model.fit()
:
Py4JJavaError: Произошла ошибка при вызове o1336.fit. : org.apache.spark.SparkException: задание прервано из-за сбоя этапа: задача 0 на этапе 86.0 завершилась неудачно 4 раза, последний сбой: потеряна задача 0.3 на этапе 86.0 (TID 1984, someserver.somecompany.net, исполнитель 22): org .apache.spark.SparkException: сбой сериализации Kryo: переполнение буфера. Доступно: 7, требуется: 8 Трассировка сериализации: значения (org.apache.spark.ml.linalg.DenseVector). Чтобы этого избежать, увеличьте значение spark.kryoserializer.buffer.max. ... ... бла-бла-бла больше Java-вещей, которые продолжаются вечно
Примечания
- В этом примере, если я добавлю функцию для биграмм, ошибка произойдет, если
num_labels
> 15. Интересно, совпадение ли это, что это также на 1 меньше, чем степень числа 2. - В моем реальном проекте я также получаю сообщение об ошибке при попытке вызвать
model.theta
. (Я не думаю, что сами ошибки имеют смысл - это просто исключения, переданные из методов java/scala.)