Простое руководство по дистилляции BERT

в 15:26, , рубрики: BERT, deep learning, natural language processing, neural networks, nlp (natural language processing), машинное обучение

Если вы интересуетесь машинным обучением, то наверняка слышали про BERT и трансформеры.

BERT — это языковая модель от Google, показавшая state-of-the-art результаты с большим отрывом на целом ряде задач. BERT, и вообще трансформеры, стали совершенно новым шагом развития алгоритмов обработки естественного языка (NLP). Статью о них и «турнирную таблицу» по разным бенчмаркам можно найти на сайте Papers With Code.

С BERT есть одна проблема: её проблематично использовать в промышленных системах. BERT-base содержит 110М параметров, BERT-large — 340М. Из-за такого большого числа параметров эту модель сложно загружать на устройства с ограниченными ресурсами, например, мобильные телефоны. К тому же, большое время инференса делает эту модель непригодной там, где скорость ответа критична. Поэтому поиск путей ускорения BERT является очень горячей темой.

Нам в Авито часто приходится решать задачи текстовой классификации. Это типичная задача прикладного машинного обучения, которая хорошо изучена. Но всегда есть соблазн попробовать что-то новое. Эта статья родилась из попытки применить BERT в повседневных задачах машинного обучения. В ней я покажу, как можно значительно улучшить качество существующей модели с помощью BERT, не добавляя новых данных и не усложняя модель.

Простое руководство по дистилляции BERT - 1

Knowledge distillation как метод ускорения нейронных сетей

Существует несколько способов ускорения/облегчения нейронных сетей. Самый подробный их обзор, который я встречал, опубликован в блоге Intento на Медиуме.

Способы можно грубо разделить на три группы:

  1. Изменение архитектуры сети.
  2. Сжатие модели (quantization, pruning).
  3. Knowledge distillation.

Если первые два способа сравнительно известны и понятны, то третий менее распространён. Впервые идею дистилляции предложил Рич Каруана в статье “Model Compression”. Её суть проста: можно обучить легковесную модель, которая будет имитировать поведение модели-учителя или даже ансамбля моделей. В нашем случае учителем будет BERT, учеником — любая легкая модель.

Задача

Давайте разберём дистилляцию на примере бинарной классификации. Возьмём открытый датасет SST-2 из стандартного набора задач, на которых тестируют модели для NLP.

Этот датасет представляет собой набор обзоров фильмов с IMDb с разбивкой на эмоциональный окрас — позитивный или негативный. В качестве метрики на этом датасете используют accuracy.

Обучение BERT-based модели или «учителя»

Прежде всего необходимо обучить «большую» BERT-based модель, которая станет учителем. Самый простой способ это сделать — взять эмбеддинги из BERT и обучить классификатор поверх них, добавив один слой в сеть.

Благодаря библиотеке tranformers сделать это довольно легко, потому что там есть готовый класс модели BertForSequenceClassification. На мой взгляд, самое подробное и понятное руководство по обучению этой модели опубликовали Towards Data Science.

Давайте представим, что мы получили обученную модель BertForSequenceClassification. В нашем случае num_labels=2, так как у нас бинарная классификация. Эту модель мы будем использовать в качестве «учителя».

Обучение «ученика»

В качестве ученика можно взять любую архитектуру: нейронную сеть, линейную модель, дерево решений. Давайте для большей наглядности попробуем обучить BiLSTM. Для начала обучим BiLSTM без BERT.

Чтобы подавать на вход нейронной сети текст, нужно представить его в виде вектора. Один из самых простых способов — это сопоставить каждому слову его индекс в словаре. Словарь будет состоять из топ-n самых популярных слов в нашем датасете плюс два служебных слова: “pad” — «слово-пустышка», чтобы все последовательности были одной длины, и “unk” — для слов за пределами словаря. Построим словарь с помощью стандартного набора инструментов из torchtext. Для простоты я не стал использовать предобученные эмбеддинги слов.
 

import torch
from torchtext import data

def get_vocab(X):
    X_split = [t.split() for t in X]
    text_field = data.Field()
    text_field.build_vocab(X_split, max_size=10000)
    return text_field

def pad(seq, max_len):
    if len(seq) < max_len:
        seq = seq + ['<pad>'] * (max_len - len(seq))
    return seq[0:max_len]

def to_indexes(vocab, words):
    return [vocab.stoi[w] for w in words]

def to_dataset(x, y, y_real):
    torch_x = torch.tensor(x, dtype=torch.long)
    torch_y = torch.tensor(y, dtype=torch.float)
    torch_real_y = torch.tensor(y_real, dtype=torch.long)
    return TensorDataset(torch_x, torch_y, torch_real_y)

Модель BiLSTM

Код для модели будет выглядеть так:

import torch
from torch import nn
from torch.autograd import Variable

class SimpleLSTM(nn.Module):

    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim, n_layers,
                 bidirectional, dropout, batch_size, device=None):
        super(SimpleLSTM, self).__init__()
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, embedding_dim)

        self.rnn = nn.LSTM(embedding_dim,
                           hidden_dim,
                           num_layers=n_layers,
                           bidirectional=bidirectional,
                           dropout=dropout)

        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.device = self.init_device(device)
        self.hidden = self.init_hidden()

    @staticmethod
    def init_device(device):
        if device is None:
            return torch.device('cuda')
        return device

    def init_hidden(self):
        return (Variable(torch.zeros(2 * self.n_layers, self.batch_size, self.hidden_dim).to(self.device)),
                Variable(torch.zeros(2 * self.n_layers, self.batch_size, self.hidden_dim).to(self.device)))

    def forward(self, text, text_lengths=None):
        self.hidden = self.init_hidden()
        x = self.embedding(text)
        x, self.hidden = self.rnn(x, self.hidden)
        hidden, cell = self.hidden
        hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        x = self.fc(hidden)

        return x

Обучение

Для этой модели размерность выходного вектора будет (batch_size, output_dim). При обучении будем использовать обычный logloss. В PyTorch есть класс BCEWithLogitsLoss, который комбинирует сигмоиду и кросс-энтропию. То, что надо.

def loss(self, output, bert_prob, real_label):
    criterion = torch.nn.BCEWithLogitsLoss()
    return criterion(output, real_label.float())

Код для одной эпохи обучения:

def get_optimizer(model):
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.9)
    return optimizer, scheduler

def epoch_train_func(model, dataset, loss_func, batch_size):
    train_loss = 0
    train_sampler = RandomSampler(dataset)
    data_loader = DataLoader(dataset, sampler=train_sampler,
                             batch_size=batch_size,
                             drop_last=True)
    model.train()
    optimizer, scheduler = get_optimizer(model)
    for i, (text, bert_prob, real_label) in enumerate(tqdm(data_loader, desc='Train')):
        text, bert_prob, real_label = to_device(text, bert_prob, real_label)
        model.zero_grad()
        output = model(text.t(), None).squeeze(1)
        loss = loss_func(output, bert_prob, real_label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    scheduler.step()
    return train_loss / len(data_loader)

Код для проверки после эпохи:

def epoch_evaluate_func(model, eval_dataset, loss_func, batch_size):
    eval_sampler = SequentialSampler(eval_dataset)
    data_loader = DataLoader(eval_dataset, sampler=eval_sampler,
                             batch_size=batch_size,
                             drop_last=True)

    eval_loss = 0.0
    model.eval()
    for i, (text, bert_prob, real_label) in enumerate(tqdm(data_loader, desc='Val')):
        text, bert_prob, real_label = to_device(text, bert_prob, real_label)
        output = model(text.t(), None).squeeze(1)
        loss = loss_func(output, bert_prob, real_label)
        eval_loss += loss.item()

    return eval_loss / len(data_loader)

Если это всё собрать воедино, то получится такой код для обучения модели:

import os
import torch
from torch.utils.data import (TensorDataset, random_split,
                              RandomSampler, DataLoader,
                              SequentialSampler)
from torchtext import data
from tqdm import tqdm

def device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def to_device(text, bert_prob, real_label):
    text = text.to(device())
    bert_prob = bert_prob.to(device())
    real_label = real_label.to(device())
    return text, bert_prob, real_label

class LSTMBaseline(object):
    vocab_name = 'text_vocab.pt'
    weights_name = 'simple_lstm.pt'

    def __init__(self, settings):
        self.settings = settings
        self.criterion = torch.nn.BCEWithLogitsLoss().to(device())

    def loss(self, output, bert_prob, real_label):
        return self.criterion(output, real_label.float())

    def model(self, text_field):
        model = SimpleLSTM(
            input_dim=len(text_field.vocab),
            embedding_dim=64,
            hidden_dim=128,
            output_dim=1,
            n_layers=1,
            bidirectional=True,
            dropout=0.5,
            batch_size=self.settings['train_batch_size'])
        return model

    def train(self, X, y, y_real, output_dir):
        max_len = self.settings['max_seq_length']
        text_field = get_vocab(X)

        X_split = [t.split() for t in X]
        X_pad = [pad(s, max_len) for s in tqdm(X_split, desc='pad')]
        X_index = [to_indexes(text_field.vocab, s) for s in tqdm(X_pad, desc='to index')]

        dataset = to_dataset(X_index, y, y_real)
        val_len = int(len(dataset) * 0.1)
        train_dataset, val_dataset = random_split(dataset, (len(dataset) - val_len, val_len))

        model = self.model(text_field)
        model.to(device())

        self.full_train(model, train_dataset, val_dataset, output_dir)
        torch.save(text_field, os.path.join(output_dir, self.vocab_name))

    def full_train(self, model, train_dataset, val_dataset, output_dir):
        train_settings = self.settings
        num_train_epochs = train_settings['num_train_epochs']
        best_eval_loss = 100000
        for epoch in range(num_train_epochs):
            train_loss = epoch_train_func(model, train_dataset, self.loss, self.settings['train_batch_size'])
            eval_loss = epoch_evaluate_func(model, val_dataset, self.loss, self.settings['eval_batch_size'])

            if eval_loss < best_eval_loss:
                best_eval_loss = eval_loss
                torch.save(model.state_dict(), os.path.join(output_dir, self.weights_name))

Дистилляция

Идея этого способа дистилляции взята из статьи исследователей из Университета Ватерлоо. Как я говорил выше, «ученик» должен научиться имитировать поведение «учителя». Что именно является поведением? В нашем случае это предсказания модели-учителя на обучающей выборке. Причём ключевая идея — использовать выход сети до применения функции активации. Предполагается, что так модель сможет лучше выучить внутреннее представление, чем в случае с финальными вероятностями.

В оригинальной статье предлагается в функцию потерь добавить слагаемое, которое будет отвечать за ошибку «подражания» — MSE между логитами моделей.

Простое руководство по дистилляции BERT - 2

Для этих целей сделаем два небольших изменения: изменим количество выходов сети с 1 до 2 и поправим функцию потерь.

def loss(self, output, bert_prob, real_label):
    a = 0.5
    criterion_mse = torch.nn.MSELoss()
    criterion_ce = torch.nn.CrossEntropyLoss()
    return a*criterion_ce(output, real_label) + (1-a)*criterion_mse(output, bert_prob)

Можно переиспользовать весь код, который мы написали, переопределив только модель и loss:


class LSTMDistilled(LSTMBaseline):
    vocab_name = 'distil_text_vocab.pt'
    weights_name = 'distil_lstm.pt'

    def __init__(self, settings):
        super(LSTMDistilled, self).__init__(settings)
        self.criterion_mse = torch.nn.MSELoss()
        self.criterion_ce = torch.nn.CrossEntropyLoss()
        self.a = 0.5

    def loss(self, output, bert_prob, real_label):
        return self.a * self.criterion_ce(output, real_label) + (1 - self.a) * self.criterion_mse(output, bert_prob)

    def model(self, text_field):
        model = SimpleLSTM(
            input_dim=len(text_field.vocab),
            embedding_dim=64,
            hidden_dim=128,
            output_dim=2,
            n_layers=1,
            bidirectional=True,
            dropout=0.5,
            batch_size=self.settings['train_batch_size'])
        return model

Вот и всё, теперь наша модель учится «подражать».

Сравнение моделей

В оригинальной статье наилучшие результаты классификации на SST-2 получаются при a=0, когда модель учится только подражать, не учитывая реальные лейблы. Accuracy всё ещё меньше, чем у BERT, но значительно лучше обычной BiLSTM.

Простое руководство по дистилляции BERT - 3

Я старался повторить результаты из статьи, но в моих экспериментах лучший результат получался при a=0,5.

Так выглядят графики loss и accuracy при обучении LSTM обычным способом. Судя по поведению loss, модель быстро обучилась, а где-то после шестой эпохи пошло переобучение.

Простое руководство по дистилляции BERT - 4

Графики при дистилляции:

Простое руководство по дистилляции BERT - 5

Дистиллированная BiLSTM стабильно лучше обычной. Важно, что по архитектуре они абсолютно идентичны, разница только в способе обучения. Полный код обучения я выложил на ГитХаб.

Заключение

В этом руководстве я постарался объяснить базовую идею подхода дистилляции. Конкретная архитектура ученика будет зависеть от решаемой задачи. Но в целом этот подход применим в любой практической задаче. За счёт усложнения на этапе обучения модели, можно получить значительный прирост её качества, сохранив изначальную простоту архитектуры.

Автор: pgladkov

Источник


* - обязательные к заполнению поля


https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js