Ищем пневмонию на рентгеновских снимках с Fast.ai

в 8:02, , рубрики: cnn, deep learning, Fast.AI, python, искусственный интеллект, машинное обучение

Наткнулся на статью в блоге компании Школа Данных и решил проверить, на что способна библиотека Fast.ai на том же датасете, который упоминается в статье. Здесь вы не найдете рассуждений о том, своевременно и правильно диагностировать пневмонию, будут ли нужны врачи-рентгенологи, можно ли считать предсказание нейронной сети медицинским диагнозом и т.д. Основная цель — показать, что машинное обучение в современных библиотеках может быть довольно простым (буквально требует немного строчек кода) и дает отличные результаты. Запомним пока результат из статьи (precision = 0.84, recall = 0.96) и посмотрим, что получится у нас.
Берем данные для обучения отсюда. Данные представляют собой 5856 рентгеновских снимков, распределенных по двум классам — с признаками пневмонии и без. Задача нейронной сети — дать нам качественный бинарный классификатор рентгеновских снимков для определения признаков пневмонии.
Начинаем с импортирования библиотек и некоторых стандартных настроек:

%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *
from fastai.metrics import error_rate
import os

Далее определяем batch size. При обучении на GPU важно его подобрать таким образом, чтобы у вас не переполнялась память. При необходимости его можно уменьшить в два раза.

bs = 64

Определяем пути к нашим данным

path = Path('storage/chest_xray')
path.ls()

и проверяем, что все папки на месте:

Out:
[PosixPath('storage/chest_xray/train'),
 PosixPath('storage/chest_xray/val'),
 PosixPath('storage/chest_xray/test')]

Готовим наши данные для «загрузки» в нейросеть. Важно отметить, что в Fast.ai есть несколько методов сопоставления изображения метке. Метод from_folder говорит нам о том, что метки нужно брать из имени папки, в которой находится изображение. В нашем случае мы будем игнорировать распределение изображений по папкам train, val и test. Вместо этого мы указываем, что хотим разделить весь датасет случайным образом на train и validation в соотношении 80/20. Параметр size означает, что мы ресайзим все изображения до размера 299х299 (наши алгоритмы работают с квадратными изображениями). Функция get_transforms дает нам аугментацию изображений для увеличения объема данных для обучения (мы оставляем здесь дефолтные настройки).

data = ImageDataBunch.from_folder(path, valid_pct=0.2, size=299, bs=bs, ds_tfms=get_transforms()).normalize(imagenet_stats)

Заглянем в данные:

data.show_batch(rows=3, figsize=(6,6))

Ищем пневмонию на рентгеновских снимках с Fast.ai - 1
Для проверки смотрим, какие классы у нас получились и какое количественное распределение изображений между train и validation:

data.classes, data.c, len(data.train_ds), len(data.valid_ds)

Out:
(['NORMAL', 'PNEUMONIA'], 2, 4685, 1171)

Определяем модель обучения на архитектуре Resnet50:
learn = cnn_learner(data, models.resnet50, metrics=error_rate)
и начинаем обучение на 8 эпох, основываясь на One Cycle Policy:

learn.fit_one_cycle(8)

Ищем пневмонию на рентгеновских снимках с Fast.ai - 2
Видим, что мы уже получили неплохую точность в 96,4%. Запишем пока веса нашей модели и попробуем улучшить результат.

learn.save('stage-1-50')

«Размораживаем» всю модель, т.к. до этого мы обучали модель только на последней группе слоев, а веса остальных были взяты из предобученной на Imagenet модели и «заморожены»:

learn.unfreeze()

Ищем оптимальный learning rate для продолжения обучения:

learn.lr_find()
learn.recorder.plot()

Ищем пневмонию на рентгеновских снимках с Fast.ai - 3
Запускаем обучение на 10 эпох с различными learning rate для каждой группы слоев.

learn.fit_one_cycle(10, max_lr=slice(5e-6, 1e-4))

Ищем пневмонию на рентгеновских снимках с Fast.ai - 4
Видим, что точность нашей модели повысилась до 97,8%.
Запишем веса и посмотрим на Confusion Matrix:

learn.save('stage-2-50')
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

Ищем пневмонию на рентгеновских снимках с Fast.ai - 5
Здесь мы вспоминаем, что сам по себе параметр точности (accuracy) недостаточен, особенно для несбалансированных классов. Например, если в реальной жизни пневмония будет встречаться только у 0,1% тех, кто проходит рентген исследование, система может просто выдавать отсутствие пневмонии во всех случаях и ее точность будет на уровне 99,9% с абсолютно нулевой полезностью.
Здесь и вступают в игру метрики Precision и Recall:

  • TP — истино-положительное предсказание;
  • TN — истино-отрицательное предсказание;
  • FP — ложно-положительное предсказание;
  • FN — ложно-отрицательное предсказание.

$Precision=TP / (TP + FP)=857 / 867=0,988$

$Recall=TP / (TP + FN)=857 / 873=0,982$

Видим, что полученный нами результат существенно выше, чем тот, который был упомянут в статье. При дальнейшей работе над задачей стоит помнить, что Recall крайне важный параметр в медицинских задачах, т.к. False Negative ошибки наиболее опасны с точки зрения диагностики (означает, что мы можем просто «проглядеть» опасный диагноз).

Автор: Олег Замощин

Источник


* - обязательные к заполнению поля


https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js