Spatial Transformer Networks — пространственная трансформация изображения в нейронной сети

Сети пространственных преобразований представляют собой обобщение дифференцированного внимания к любому пространственному преобразованию. Сети пространственных преобразований (сокращенно STN) позволяют нейронной сети научиться выполнять пространственные преобразования входного изображения, чтобы повысить геометрическую инвариантность модели. Например, он может обрезать интересующую область, масштабировать и корректировать ориентацию изображения. Это может быть полезный механизм, поскольку CNN не инвариантны к вращению, масштабированию и более общим аффинным преобразованиям.

Оригинальная работа представлена здесь: https://arxiv.org/pdf/1506.02025

Пример с цифрами рукописного текста:

Данная сеть может добавляться в другую сеть, тем самым улучшая ее работу:

2024-06-27_18-36-16

U - исходное изображение, V - результирующее.

Примеры задач, куда может быть интегрирована сеть:

  • классификация изображений: если CNN обучена выполнять многофакторную классификацию изображений в зависимости от того, содержат ли они определенную цифру – где положение и размер цифры могут значительно меняться в зависимости от каждой выборки (и не коррелируют с классом) , пространственный преобразователь, который вырезает и нормализует масштаб соответствующей области, может упростить последующую задачу классификации и привести к превосходной производительности классификации;
  • совместная локализация: при наличии набора изображения, содержащие разные экземпляры одного и того же (но неизвестного) класса, для их локализации в каждом изображении можно использовать пространственный преобразователь;
  • пространственное внимание: пространственный преобразователь можно использовать для задачи, требующие механизма внимания, но он более гибок и может обучаться исключительно с помощью обратного распространения ошибки без обучения с подкреплением.

Основным свойство STN является то, что для его внедрения не нужно как-то по особому изменять обучение. Т.е. в сеть добавляется вставка, которая автоматически подбирает оптимальные веса в зависимости от хода обучения. Например, в Keras при обучении классификатора изображений так:

2024-06-27_19-12-03

где функция stn описана в примере keras, приведенном в начале статьи.

Однако, применение данной модели (или ее разновидности) не обязательно даст положительный результат в вашей задаче. Описанные примеры связаны с распознавание рукописных цифр из датасета MNIST. Там всего 10 объектов.

К примеру, в моей задаче было около 40 объектов. И после добавления слоя validation loss немного улучшился (с 0.03 до 0.02). Однако при тестировании на реальных данных был получен результат более низкий по качеству. Причины? Большое количество объектов и их схожесть при определенных наклонах. При этом с локализацией до этого проблем не было и STN только поворачивала картинку.

Есть и другие примеры применения STN, например - распознавание автомобильных номеров.  Например, пример, описанный тут https://github.com/xuexingyu24/License_Plate_Detection_Pytorch

Там STN обучается в паре с сеткой для распознавания однострочных автомобильных номеров LPRNet, и результаты записываются в разные файлы моделей.

Если говорить про быстродействие, то STN не сильно замедляет обучение. Время инференса при этом совсем слабо увеличивается.