Есть ли ограничение на количество классов в mlib NaiveBayes? Ошибка при вызове model.save()

Я пытаюсь обучить модель прогнозированию категории текстовых входных данных. Я сталкиваюсь с численной нестабильностью, используя классификатор pyspark.ml.classification.NaiveBayes для набора слов, когда количество классов превышает определенное количество.

В моем реальном проекте у меня порядка ~1 миллиарда записей и ~50 классов. Я могу обучать свою модель и делать прогнозы, но получаю сообщение об ошибке, когда пытаюсь сохранить ее с помощью model.save(). С точки зрения эксплуатации это раздражает, так как мне приходится каждый раз переобучать свою модель с нуля.

При попытке отладки я уменьшил свои данные примерно до 10 тыс. строк и столкнулся с той же проблемой при попытке сохранить. Однако сохранение работает нормально, если я уменьшу количество меток классов.

Это наводит меня на мысль, что существует ограничение на количество ярлыков. Я не могу воспроизвести свои точные проблемы, но приведенный ниже код связан. Если я установлю num_labels на значение больше 31, model.fit() выдаст ошибку.

Мои вопросы:

  1. Есть ли ограничение на количество классов в mllib реализации NaiveBayes?
  2. По каким причинам я не могу сохранить свою модель, если я могу успешно использовать ее для прогнозирования?
  3. Если действительно существует ограничение, можно ли будет разделить мои данные на группы более мелких классов, обучить отдельные модели и объединить?

Полный рабочий пример

Создайте фиктивные данные.

Я собираюсь использовать nltk.corpus.comparitive_sentences и nltk.corpus.sentence_polarity. Имейте в виду, что это всего лишь иллюстративный пример с бессмысленными данными — меня не интересует производительность подобранной модели.

import pandas as pd
from pyspark.sql.types import StringType

# create some dummy data
from nltk.corpus import comparative_sentences, sentence_polarity
df = pd.DataFrame(
    {
        'sentence': [" ".join(s) for s in cs.sents() + sp.sents()]
    }
)

# assign a 'category' to each row
num_labels = 31  # seems to be the upper limit
df['category'] = (df.index%num_labels).astype(str)

# make it into a spark dataframe
spark_df = sqlCtx.createDataFrame(df)

Конвейер подготовки данных

from pyspark.ml.feature import NGram, Tokenizer, StopWordsRemover
from pyspark.ml.feature import HashingTF, IDF, StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.linalg import Vector

indexer = StringIndexer(inputCol='category', outputCol='label')
tokenizer = Tokenizer(inputCol="sentence", outputCol="sentence_tokens")
remove_stop_words = StopWordsRemover(inputCol="sentence_tokens", outputCol="filtered")
unigrammer = NGram(n=1, inputCol="filtered", outputCol="tokens") 
hashingTF = HashingTF(inputCol="tokens", outputCol="hashed_tokens")
idf = IDF(inputCol="hashed_tokens", outputCol="tf_idf_tokens")

clean_up = VectorAssembler(inputCols=['tf_idf_tokens'], outputCol='features')

data_prep_pipe = Pipeline(
    stages=[indexer, tokenizer, remove_stop_words, unigrammer, hashingTF, idf, clean_up]
)
transformed = data_prep_pipe.fit(spark_df).transform(spark_df)
clean_data = transformed.select(['label','features'])

Обучение модели

from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()
(training,testing) = clean_data.randomSplit([0.7,0.3], seed=12345)
model = nb.fit(training)
test_results = model.transform(testing)

Оценить модель

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
acc_eval = MulticlassClassificationEvaluator()
acc = acc_eval.evaluate(test_results)
print("Accuracy of model at predicting label was: {}".format(acc))

На моей машине это печатает:

Accuracy of model at predicting label was: 0.0305764788269

Сообщение об ошибке

Если я изменяю num_labels на 32 или выше, это ошибка, которую я получаю при вызове model.fit():

Py4JJavaError: Произошла ошибка при вызове o1336.fit. : org.apache.spark.SparkException: задание прервано из-за сбоя этапа: задача 0 на этапе 86.0 завершилась неудачно 4 раза, последний сбой: потеряна задача 0.3 на этапе 86.0 (TID 1984, someserver.somecompany.net, исполнитель 22): org .apache.spark.SparkException: сбой сериализации Kryo: переполнение буфера. Доступно: 7, требуется: 8 Трассировка сериализации: значения (org.apache.spark.ml.linalg.DenseVector). Чтобы этого избежать, увеличьте значение spark.kryoserializer.buffer.max. ... ... бла-бла-бла больше Java-вещей, которые продолжаются вечно

Примечания

  • В этом примере, если я добавлю функцию для биграмм, ошибка произойдет, если num_labels > 15. Интересно, совпадение ли это, что это также на 1 меньше, чем степень числа 2.
  • В моем реальном проекте я также получаю сообщение об ошибке при попытке вызвать model.theta. (Я не думаю, что сами ошибки имеют смысл - это просто исключения, переданные из методов java/scala.)

person pault    schedule 12.01.2018    source источник
comment
бла-бла-бла, больше Java-вещей, которые продолжаются вечно — это то, на чем вы должны сосредоточиться в большинстве случаев, и обычно это то, что нам нужно :)   -  person zero323    schedule 13.01.2018


Ответы (1)


Жесткие ограничения:

Количество функций * Количество классов должно быть меньше Integer.MAX_VALUE (231 - 1). Вы далеки от этих значений.

Мягкие ограничения:

Матрица тета (условные вероятности) имеет размер Количество признаков * Количество классов. Тета хранится как локально в драйвере (как часть модели), так и сериализуется и отправляется воркерам. Это означает, что всем машинам требуется как минимум достаточно памяти для сериализации или десериализации и сохранения результата.

Поскольку вы используете настройки по умолчанию для HashingTF.numFeatures (220), каждый дополнительный класс добавляет 262144 — это не так много, но быстро увеличивается. Судя по опубликованной вами частичной трассировке, неисправным компонентом является сериализатор Kryo. Та же самая трассировка также предлагает решение, которое увеличивает spark.kryoserializer.buffer.max.

Вы также можете попробовать использовать стандартную сериализацию Java, установив:

 spark.serializer org.apache.spark.serializer.JavaSerializer 

Поскольку вы используете PySpark с pyspark.ml и pyspark.sql, это может быть приемлемо без значительной потери производительности.

Помимо конфигурации, я бы сосредоточился на компоненте разработки функций. Использование двоичного CountVetorizer (см. примечание о HashingTF ниже) с ChiSqSelector может предоставить один из способов как повысить интерпретируемость, так и эффективно уменьшить количество функций. Вы также можете рассмотреть более сложные подходы (определение важности функций и применение Наивного Байеса только к подмножеству данных, более продвинутая обработка текста, такая как лемматизация / стемминг, или использование некоторого варианта автоэнкодера для получения более компактного векторного представления).

Примечания:

  • Пожалуйста, имейте в виду, что многонациональные наивные байесовцы учитывают только бинарные функции. NaiveBayes справится с этим внутри, но я все же рекомендую использовать setBinary для ясности.
  • Возможно, HashingTF здесь довольно бесполезен. Если оставить в стороне коллизии хэшей, очень разреженные функции и по существу бессмысленные функции, то это плохой выбор в качестве шага предварительной обработки для NaiveBayes.
person zero323    schedule 12.01.2018
comment
Понятно, значит, проблема в памяти во время сериализации. Это объясняет, почему я могу делать прогнозы, но не могу сохранить модель. Является ли единственным решением увеличение размера буфера? Любое понимание обучения отдельных моделей на подмножествах данных и их объединения? - person pault; 13.01.2018
comment
Возможно, я смогу обучить несколько n моделей, а затем использовать каждую для прогнозирования и выбора класса с наибольшей вероятностью. - person pault; 13.01.2018
comment
Вам придется пройти через трассировку стека, чтобы выяснить, кто виноват. Обычно замыкания никогда не используют KryoSerializer, только данные. Так что вас поражает, когда тета сериализуется для записи (здесь мы используем DataFrames). Вы также можете отключить сериализацию Kryo (в любом случае она имеет ограниченное количество приложений в PySpark). Но я бы предпочел сосредоточиться на разработке функций - ИМХО › 200 000 функций - это много на практике, чтобы получить интерпретируемую модель. - person zero323; 13.01.2018
comment
Это все очень полезная информация. Последний вопрос: есть ли у вас какие-либо ссылки на документы по этому материалу? Я никогда не видел флаг setBinary — никогда не знал, что это вариант. - person pault; 13.01.2018