Пошаговое руководство по поиску и пониманию проблем в вашей модели машинного обучения - и их устранению!

Модели машинного обучения можно использовать для создания классных приложений и демонстраций. Однако надежное развертывание систем машинного обучения в реальном мире по-прежнему является проблемой, поскольку они часто неожиданно выходят из строя, когда появляются неожиданные данные.

Например, классификация цифр в наборе данных рукописного ввода MNIST - одна из самых основных задач машинного обучения, часто используемая в качестве примера «решенной проблемы» во вводных курсах машинного обучения. Тем не менее, даже классификатор машинного обучения, который обеспечивает точность 99%, может делать неверные прогнозы, когда цифры меняются необычным образом:

Цель этого сообщения в блоге - помочь вам обнаружить эти точки данных (которые мы будем называть «достоверными данными»), которые вызывают точки сбоя машинного обучения в процессе обучения, чтобы вы не Поставляйте свои модели со всевозможными трещинами, которые затем могут превратиться в еще большие трещины, поскольку распределение данных продолжает меняться после развертывания вашей модели. Мы рассмотрим 3 важных шага, необходимых для устранения точек отказа вашей модели:

  1. Определение достоверных данных
  2. Визуализация достоверных данных и понимание того, почему они нарушают модель
  3. Исправление модели

Итак, приступим!

0. Настройка: загрузите вашу модель

Мы начнем с загрузки простой модели MNIST в TensorFlow, которую мы будем использовать в оставшейся части статьи. В нашем случае мы загрузим предварительно обученную модель из корзины S3, хотя, конечно, вы можете легко обучить модель MNIST самостоятельно или использовать совершенно другую модель.

1. Определение достоверных данных

После обучения модели нейронная сеть обычно достигает очень высокой точности в наборе обучающих данных, поэтому для определения потенциальных точек отказа необходимо использовать данные, которые модель не видела раньше. В нашем случае мы будем использовать набор тестов MNIST для идентификации достоверных данных.

Чтобы идентифицировать достоверные данные полезным способом, мы должны учитывать два фактора:

  • Сложность: по каким точкам данных модель была наименее точной? В качестве прокси-сервера мы рассмотрим уровни достоверности нейронной сети.
  • Разнообразие: как выбрать разнообразный набор достоверных данных, чтобы исследовать репрезентативный набор достоверных данных? Есть много способов использовать это, но мы рассмотрим, пожалуй, самый простой: убедитесь, что наши собранные образцы включают несколько примеров из каждого из 10 классов этикеток.

Приведенный ниже код выполняет эти шаги, сначала запуская модель на тестовых данных, а затем определяя 20 точных данных, то есть 2 самых сложных из каждой цифры. Вы заметите, что затем мы сохраняем каждый из этих примеров как изображения, что будет важно на следующем этапе.

2. Визуализация достоверных данных и понимание того, почему они нарушают модель

После того, как вы определили достоверные данные, теперь пора заняться грязными руками и по-настоящему понять, почему эти точки данных нарушают вашу модель. Для этого полезно использовать библиотеку визуализации, такую ​​как Gradio (https://github.com/gradio-app/gradio), которая может помочь вам увидеть ваши точные данные и их прогнозы по модели.

Чтобы использовать библиотеку Gradio, мы должны написать функцию, которая обтекает нашу модель. В приведенном ниже коде это recognize_digit(). Мы также определяем, какой пользовательский интерфейс мы хотели бы создать (sketchpadlabel), а затем мы можем передать список examples, который представляет собой список имен файлов из предыдущий шаг. Вот полный код:

Это создает следующий графический интерфейс, который загружает примеры и позволяет мне понять, почему разные образцы нарушают модель:

3. Исправление модели

А теперь самое важное! Для исправления модели необходимо досконально понять, почему различные точные данные нарушают ее. В зависимости от причины могут потребоваться различные шаги, чтобы сделать модель более надежной. Давайте применим это к достоверным данным, которые мы определили для модели MNIST, и обсудим, что нам нужно сделать, чтобы исправить эту модель.

3А. Добавить дополнение данных

Первая причина, по которой ваша модель может сломаться, заключается в том, что входные данные преобразуются необычным образом, как в примере с повернутой цифрой 2 ниже. Часто это можно легко исправить с помощью увеличения данных (синтетического преобразования ваших входных данных во время обучения). Вот руководство по увеличению объема данных.

Дополнительное примечание: вы можете использовать встроенные функции библиотеки Gradio, чтобы повернуть образец и наблюдать разницу в предсказаниях. Используя интерфейс image → label, у нас есть возможность попробовать множество различных преобразований в нашем образце (см. Colab notebook для более подробной информации). Вот результат, когда мы поворачиваем образец на 30 ° назад в правильную ориентацию, подтверждая, что, скорее всего, виновато вращение:

3Б. Соберите больше редких образцов

Вторая причина, по которой ваша модель может потерпеть неудачу, - это редкие данные, которые она не часто видит во время обучения. Рассмотрим цифру «5» ниже, у которой очень мало места между горизонтальной палкой, образующей вершину «5», и вершиной изогнутой части. Поскольку таких обучающих примеров не так много, модель MNIST не понимает, что это «5». Эту проблему трудно исправить с помощью увеличения данных - решение состоит в получении большего количества таких образцов.

3С. Классы баланса

Просматривая достоверные данные, я заметил одну вещь: модель редко предсказывает 5 вне зависимости от того, что это за цифра на самом деле. Например, цифра ниже неоднозначна - это может быть 5, 0 или 9, но модель уверена, что это 0, а не 5 (истинная метка). Это может отражать неравномерное распределение обучающих данных - возможно, во время обучения модель встречает 5 реже, чем другие цифры. И на самом деле, хотя в обучающих данных много цифр 5, они являются наименее распространенным классом в обучающих данных. Это снижает вероятность того, что модель предсказывает 5. Чтобы решить эту проблему, вы можете попробовать сбалансировать тренировочную выборку.

3D. Удалить неверные данные

Иногда достоверные данные классифицируются неправильно просто потому, что это неверная точка данных. Наборы данных обычно аннотируются людьми, которые могут ошибаться. Если вы уверены, что точка данных имеет неправильную маркировку (например, я полагаю, что точка данных ниже, которая якобы имеет значение «0» на основе присвоенной ей метки) и не репрезентативна для данных вашего развертывания, вам следует удалить ее из наборов тестов. чтобы ваши оценочные показатели были более реалистичными.

В этом посте я обрисовал в общих чертах процесс выявления точек сбоя моделей и их раннего исправления. Хотя особенности будут зависеть от того, какая модель у вас есть, этот общий подход выявления, понимания и исправления трудностей улучшит ваш набор данных и создаст надежную модель, которая не только имеет более высокую точность тестирования, но и, вероятно, будет более надежным при развертывании в реальном мире. Удачи!

Примечание: вы можете запустить весь приведенный выше код сразу в этой записной книжке Colab: https://colab.research.google.com/drive/1eOrsRGCcovZxCDxf8GL3xm8j2yzngr0Y?usp=sharing