Как быстро написать и выкатить в продакшн алгоритм машинного обучения

в 10:06, , рубрики: big data, data mining, data science, kaggle, machine learning, python, машинное обучение

Сейчас анализ данных все шире используется в самых разных, зачастую далеких от ИТ, областях и задачи, стоящие перед специалистом на ранних этапах проекта радикально отличаются от тех, с которыми сталкиваются крупные компании с развитыми отделами аналитики. В этой статье я расскажу о том, как быстро сделать полезный прототип и подготовить простой API для его использования прикладным программистом.

Для примера рассмотрим задачу предсказания цены на трубы размещенную на платформе для соревнований Kaggle. Описание и данные можно найти здесь. На самом деле на практике очень часто встречаются задачи в которых надо быстро сделать прототип имея очень небольшое количество данных, а то и вообще не имея реальных данных до момента первого внедрения. В этих случаях приходится подходить к задаче творчески, начинать с несложных эвристик и ценить каждый запрос или размеченный объект. Но в нашей модельной ситуации таких проблем, к счастью, нет и поэтому мы можем сразу начать с обзора данных, определения задачи и попыток применения алгоритмов.

Итак, распаковав архив с данными, можно обнаружить, что в нем содержится приблизительно пара десятков csv-файлов. Согласно описанию на странице соревнования, основными из них являются train_set.csv и test_set.csv. Они содержат базовую информацию касающуюся цен. Остальные файлы содержат вспомогательные, но немногим менее важные данные. Давайте рассмотрим их поподробнее.

Поместив архив в поддиректорию data корневой директории проекта и разархивировав его командами

$ cd data/
$ unzip data.zip
$ cd ..

Мы можем посмотреть, что находится в интересующих нас файлах:

$ head data/competition_data/train_set.csv
tube_assembly_id,supplier,quote_date,annual_usage,min_order_quantity,bracket_pricing,quantity,cost
TA-00002,S-0066,2013-07-07,0,0,Yes,1,21.9059330191461
TA-00002,S-0066,2013-07-07,0,0,Yes,2,12.3412139792904
TA-00002,S-0066,2013-07-07,0,0,Yes,5,6.60182614356538
TA-00002,S-0066,2013-07-07,0,0,Yes,10,4.6877695119712
TA-00002,S-0066,2013-07-07,0,0,Yes,25,3.54156118026073
TA-00002,S-0066,2013-07-07,0,0,Yes,50,3.22440644770007
TA-00002,S-0066,2013-07-07,0,0,Yes,100,3.08252143576504
TA-00002,S-0066,2013-07-07,0,0,Yes,250,2.99905966403855
TA-00004,S-0066,2013-07-07,0,0,Yes,1,21.9727024365273

Видим столбцы с данными, соответствующими объекту (instance), цену которого мы будем предсказывать. А именно — идентефикатор сборки (пока совершенно непонятно что это такое, но волшебство машинного обучения, помимо прочего, состоит в том, что для эффективного использования данных иногда совершенно необязательно понимать, что они означают), номер поставщика, дата и так далее. Отдельно обратим внимание на предпоследний столбец с количеством единиц поставки. Наконец последний столбец — метка (label), его цена. Чем больше больше единиц товара поставляется, тем ниже его цена, что вполне согласуется с нашими представления о происходящем в рельном мире.

Посмотрим теперь на файл с тестовыми данными.

$ head data/competition_data/test_set.csv
id,tube_assembly_id,supplier,quote_date,annual_usage,min_order_quantity,bracket_pricing,quantity
1,TA-00001,S-0066,2013-06-23,0,0,Yes,1
2,TA-00001,S-0066,2013-06-23,0,0,Yes,2
3,TA-00001,S-0066,2013-06-23,0,0,Yes,5
4,TA-00001,S-0066,2013-06-23,0,0,Yes,10
5,TA-00001,S-0066,2013-06-23,0,0,Yes,25
6,TA-00001,S-0066,2013-06-23,0,0,Yes,50
7,TA-00001,S-0066,2013-06-23,0,0,Yes,100
8,TA-00001,S-0066,2013-06-23,0,0,Yes,250
9,TA-00003,S-0066,2013-07-07,0,0,Yes,1

Те же самые столбцы за исключением последнего — а именно предсказываемой цены. Что логично для соревнования — именно на этих данных надо сформировать ответ для отправки на сайт соревнования и получения промежуточного результата.

Как мы уже отметили выше, помимо двух основных файлов с данными в нашем распоряжении есть еще и немало дополнительных. Имеет смысл заглянуть хотя бы в парочку из них для того, чтобы уловить основную идею данных в них содержащихся.

$ head data/competition_data/tube.csv
tube_assembly_id,material_id,diameter,wall,length,num_bends,bend_radius,end_a_1x,end_a_2x,end_x_1x,end_x_2x,end_a,end_x,num_boss,num_bracket,other
TA-00001,SP-0035,12.7,1.65,164,5,38.1,N,N,N,N,EF-003,EF-003,0,0,0
TA-00002,SP-0019,6.35,0.71,137,8,19.05,N,N,N,N,EF-008,EF-008,0,0,0
TA-00003,SP-0019,6.35,0.71,127,7,19.05,N,N,N,N,EF-008,EF-008,0,0,0
TA-00004,SP-0019,6.35,0.71,137,9,19.05,N,N,N,N,EF-008,EF-008,0,0,0
TA-00005,SP-0029,19.05,1.24,109,4,50.8,N,N,N,N,EF-003,EF-003,0,0,0
TA-00006,SP-0029,19.05,1.24,79,4,50.8,N,N,N,N,EF-003,EF-003,0,0,0
TA-00007,SP-0035,12.7,1.65,202,5,38.1,N,N,N,N,EF-003,EF-003,0,0,0
TA-00008,SP-0039,6.35,0.71,174,6,19.05,N,N,N,N,EF-008,EF-008,0,0,0
TA-00009,SP-0029,25.4,1.65,135,4,63.5,N,N,N,N,EF-003,EF-003,0,0,0

Ага, это просто соответствие между неким «идентификатором сборки», упоминавшимся выше и материалом, диаметром и тому подобными сведениями. Информация, вероятно, вполне полезная для обучения алгоритма.

Наконец, посмотрим, что содержится в файле посвященном материалам из которых сделаны трубы.

$ head data/competition_data/bill_of_materials.csv

tube_assembly_id,component_id_1,quantity_1,component_id_2,quantity_2,component_id_3,quantity_3,component_id_4,quantity_4,component_id_5,quantity_5,component_id_6,quantity_6,component_id_7,quantity_7,component_id_8,quantity_8

TA-00001,C-1622,2,C-1629,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00002,C-1312,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00003,C-1312,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00004,C-1312,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00005,C-1624,1,C-1631,1,C-1641,1,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00006,C-1624,1,C-1631,1,C-1641,1,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00007,C-1622,2,C-1629,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00008,C-1312,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00009,C-1625,2,C-1632,2,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA

Опять видим просто список соответствий между «идентификатором сборки» и компонентами. Обратим внимание на обилие полей с значением «NA». Как правило они обозначают пропуски в данных, однако в данном примере они соответствуют случаям, когда компонент просто недостаточно много для того, чтобы заполнить все 8 имеющихся в шапке позиций.

Казалось бы данные в общих чертах изучены, пора переходить непосредственно к настройке алгоритма. А вот и нет — прежде чем взяться за настройку алгоритма стоит понять, чего мы от него собираемся добиться, как мы будем сравнивать между собой два алгоритма, какой из них хуже, а какой — лучше. Во-первых нам придется оставить часть размеченных данных для проверки качества обученного алгоритма, в машинном обучении эта часть данных называется валидационной выборкой. Во-вторых нам придется исключить эту часть данных из тех, на которых мы будем проводить обучение, иначе модель предсказывающая ровно ту метку (в нашем случае — цену), какая была в обучающем объекте и случайную — в любом другом случае, будет давать идеальные предсказания на валидационной выборке, но при этом совершенно ужасные — при реальном применении. Исходные данные после исключения из них валидационной выборки называются тренировочной выборкой. Процесс обучения модели на тренировочной выборке и проверки качества ее работы на валидационной называется валидационной процедурой. Но и это еще не все — помимо выбора валидационной процедуры (которая, кстати говоря, может меняться в процессе экспериментов) необходимо выбрать еще и способ оценки качества предсказания при имеющихся экспериментальных результатах. А именно — функцию, которая возьмет на вход соответствующие два массива и даст на выходе оценку того, насколько предсказания соответствуют эксперименту. Такая функция называется метрикой качества и от ее выбора в реальных ситуациях зачастую зависит намного больше, чем от того, какой алгоритм мы выберем и как будем его настраивать. Не вдаваясь пока в детали этого тонкого процесса воспользуемся предложенной организаторами соревнования Root Mean Squared Logarithic Error (RMSLE).

Итак, принципиальные вопросы решены и мы можем приступить к написанию кода, загружающего данные и обучающего наш алгоритм. Для того, чтобы протестировать базовый пайплайн разделим выборку на тренировочную и валидационную части в соотношении 70/30, используем две простые фичи и один из простейших алгоритмов — линейную регрессию.

Напишем базовую функцию для загрузки данных и проверим ее работоспособность:

def load_data():
  list_of_instances = []
  list_of_labels = []

  with open('./data/competition_data/train_set.csv') as input_stream:
    header_line = input_stream.readline()
    columns = header_line.strip().split(',')
    for line in input_stream:
      new_instance = dict(zip(columns[:-1], line.split(',')[:-1]))
      new_label = float(line.split(',')[-1])
      list_of_instances.append(new_instance)
      list_of_labels.append(new_label)
  return list_of_instances, list_of_labels

>>> list_of_instances, list_of_labels = load_data()
>>> print(len(list_of_instances), len(list_of_labels))
30213 30213
>>> print(list_of_instances[:3])
[{'annual_usage': '0', 'quote_date': '2013-07-07', 'tube_assembly_id': 'TA-00002', 'min_order_quantity': '0', 'bracket_pricing': 'Yes', 'quantity': '1', 'supplier': 'S-0066'}, {'annual_usage': '0', 'quote_date': '2013-07-07', 'tube_assembly_id': 'TA-00002', 'min_order_quantity': '0', 'bracket_pricing': 'Yes', 'quantity': '2', 'supplier': 'S-0066'}, {'annual_usage': '0', 'quote_date': '2013-07-07', 'tube_assembly_id': 'TA-00002', 'min_order_quantity': '0', 'bracket_pricing': 'Yes', 'quantity': '5', 'supplier': 'S-0066'}]
>>> print(list_of_labels[:3])
[21.9059330191461, 12.3412139792904, 6.60182614356538]

Результаты вполне соответствуют ожиданиям. Теперь напишем заготовку функции, переводящей объекты (instances) в фичевекторы (samples).

def is_bracket_pricing(instance):
  if instance['bracket_pricing'] == 'Yes':
    return [1]
  elif instance['bracket_pricing'] == 'No':
    return [0]
  else:
    raise ValueError

def get_quantity(instance):
  return [int(instance['quantity'])]

def to_sample(instance):
  return is_bracket_pricing(instance) + get_quantity(instance)

>>> print(list(map(to_sample, list_of_instances[:3])))
[[1, 1], [1, 2], [1, 5]]

Позже, когда разных фичей станет много, они все переедут в специально отведенный для этого файл features.py, где будут обрастать вариациями, вспомогательными функциями, а в отдельных, особо запущенных случаях — еще и юнит-тестами.

Теперь у нас уже в принципе есть все необходимое для того, чтобы обучить первую, самую простую модель машинного обучения. Маленький (но важный) момент — мы договорились, что будем оптимизироваться под предложенную в условии соревнования метрику

$RMSLE=sqrt{frac1nsum_{i=1}^n(log(p_i + 1) - log(a_i + 1))^2},$

где $p_i$ — предсказанные моделью значение, а $a_i$ — фактические. Для того, чтобы наши усилия по подбору фичей и настроек модели в большей степени соответствовали этой цели, применим к меткам (labels) функцию $f(x)=log(x + 1)$. Дело в том, что большинство регрессионных методов машинного обучения предназначены для минимизации квадратичной ошибки (MSE) — и именно к оптимизации такой метрики качества мы сводим нашу задачу применив к меткам вышеуказанную функцию.

import math

def to_interim_label(label):
  return math.log(label + 1)

def to_final_label(interim_label):
  return math.exp(interim_label) - 1

>>> print(to_final_label(to_interim_label(42)))
42.0

Похоже на то, что мы не ошиблись с функциями перехода от исходных меток к более удобным для оптимизации и обратно и эти функции в самом деле взаимообратны. Теперь инициализируем модель и обучим ее на полученных фичевекторах и промежуточных метках.

>>> model = LinearRegression()
>>> list_of_samples = list(map(to_sample, list_of_instances))
>>> TRAIN_SAMPLES_NUM = 20000
>>> train_samples = list_of_samples[:TRAIN_SAMPLES_NUM]
>>> train_labels = list_of_labels[:TRAIN_SAMPLES_NUM]
>>> model.fit(train_samples, train_labels)

Теперь модель обучена и мы можем проверить, насколько она хорошо работает, сравнив результат ее предсказания и действительные значения на валидационной выборке.

>>> validation_samples = list_of_samples[TRAIN_SAMPLES_NUM:]
>>> validation_labels = list(map(to_interim_label, list_of_labels[TRAIN_SAMPLES_NUM:]))

>>> squared_errors = []
>>> for sample, label in zip(validation_samples, validation_labels):
>>>   prediction = model.predict(numpy.array(sample).reshape(1, -1))[0]
>>>   squared_errors.append((prediction - label) ** 2)

>>> mean_squared_error = sum(squared_errors) / len(squared_errors)
>>> print('Mean Squared Error: {0}'.format(mean_squared_error))
Mean Squared Error: 0.8251727558694294

Велика или мала полученная ошибка? Априори это сказать невозможно, однако с учетом того, что мы используем всего две фичи и одну из самых простых моделей — скорее всего наше предсказание далеко от оптимального. Давайте попробуем поэкспериментировать, например добавив новых фичей. Например в базовом файле с данными помимо уже использованных двух полей (bracket_pricing и quantity) имеется (пусть и с не очень часто отличающимися от 0 значениями) поле min_order_quantity. Давайте попробуем использовать его значение в качестве новой фичи для обучения алгоритма.

def get_min_order_quantity(instance):
  return [int(instance['min_order_quantity'])]

def to_sample(instance):
  return (is_bracket_pricing(instance) + get_quantity(instance)
             + get_min_order_quantity(instance))

Mean Squared Error: 0.8234554779286141

Как видим, ошибка немного уменьшилась, а значит — наш алгоритм улучшился. Не будем останавливатся на достигнутом и добавим одну за одной еще несколько фичей.

def get_annual_usage(instance):
  return [int(instance['annual_usage'])]

def to_sample(instance):
  return (is_bracket_pricing(instance) + get_quantity(instance)
             + get_min_order_quantity(instance) + get_annual_usage(instance))

Mean Squared Error: 0.8227852260998361

Следующее неиспользованное пока поле — quote_date, не имеет простой интерпретации в виде числа или набора чисел фиксированной длины. Поэтому придется немного подумать о том, как использовать его значение в качестве числового входа для алгоритма. Конечно, и год и месяц и день могут помочь в качестве новой фичи, но самым логичным первым приближением выглядит количество дней начиная с определенной даты — например с 1 января нулевого года, как дня, наступившего заведомо раньше самой ранней из встретившихся в файле дат. В первом приближении можно пока посчитать, что в году всегда 365 дней, а в каждом из 12 месяцев — 30. И пусть нас не смутит кажущаяся математическая некорректность этого предположения — мы всегда сможем уточнить формулы позже и посмотреть, улучшит ли соответствующая фича качество предсказания на валидационной выборке.

def get_absolute_date(instance):
  return [365 * int(instance['quote_date'].split('-')[0])
          + 12 * int(instance['quote_date'].split('-')[1])
          + int(instance['quote_date'].split('-')[2])]

def to_sample(instance):
  return (is_bracket_pricing(instance) + get_quantity(instance)
          + get_min_order_quantity(instance) + get_annual_usage(instance)
          + get_absolute_date(instance))

Mean Squared Error: 0.8216646342919645

Как видим, даже не вполне математически и астрономически корректная фича тем не менее помогла улучшить качество предсказания нашей модели. Теперь перейдем к пока еще новому для нас типу признаков, содержащихся в полях tube_assembly_id и supplier. Каждое из этих полей содержит значения идентификаторов предприятия-изготовителя и вендора. Они имеют не бинарную и не количественную природу, а описывают тип объекта из фиксированного списка. В машинном обучении такие свойства объектов и соответствующие им фичи наывают категориальными. Впрочем, категория предприятия-изготовителя сама по себе нам наврядли поможет, так как они не повторяются в файле test_set.csv, и мы совершенно правильно разделили размеченную выборку таким образом чтобы между тренировочной и валидационной частями (практически) не было соответствующего пересечения. Попробуем, тем не менее, извлечь что-то полезное из значения поля supplier. Посмотрим для начала какие вообще соответствующие коды встречаются в файле с размеченными данными.

>>> with open('./data/competition_data/train_set.csv') as input_stream:
...   header_line = input_stream.readline()
...   suppliers = set()
...   for line in input_stream:
...     new_supplier = line.split(',')[1]
...     suppliers.add(new_supplier)
...
>>> print(len(suppliers))
57
>>> print(suppliers)
{'S-0058', 'S-0013', 'S-0050', 'S-0011', 'S-0070', 'S-0104', 'S-0012', 'S-0068', 'S-0041', 'S-0023', 'S-0092', 'S-0095', 'S-0029', 'S-0051', 'S-0111', 'S-0064', 'S-0005', 'S-0096', 'S-0062', 'S-0004', 'S-0059', 'S-0031', 'S-0078', 'S-0106', 'S-0060', 'S-0090', 'S-0072', 'S-0105', 'S-0087', 'S-0080', 'S-0061', 'S-0108', 'S-0042', 'S-0027', 'S-0074', 'S-0081', 'S-0025', 'S-0024', 'S-0030', 'S-0022', 'S-0014', 'S-0054', 'S-0015', 'S-0008', 'S-0007', 'S-0009', 'S-0056', 'S-0026', 'S-0107', 'S-0066', 'S-0018', 'S-0109', 'S-0043', 'S-0046', 'S-0003', 'S-0006', 'S-0097'}

Как видим — их не так уж и много. Для начала можно попробовать поступить самым простым стандартным для таких случаев образом, а именно сопоставить каждому значению поля массив с одной единицей, стоящей на месте соответствующем идентификатору и всеми остальными значениями равными 0.

SUPPLIERS_LIST = ['S-0058', 'S-0013', 'S-0050', 'S-0011', 'S-0070', 'S-0104', 'S-0012', 'S-0068', 'S-0041', 'S-0023', 'S-0092', 'S-0095', 'S-0029', 'S-0051', 'S-0111', 'S-0064', 'S-0005', 'S-0096', 'S-0062', 'S-0004', 'S-0059', 'S-0031', 'S-0078', 'S-0106', 'S-0060', 'S-0090', 'S-0072', 'S-0105', 'S-0087', 'S-0080', 'S-0061', 'S-0108', 'S-0042', 'S-0027', 'S-0074', 'S-0081', 'S-0025', 'S-0024', 'S-0030', 'S-0022', 'S-0014', 'S-0054', 'S-0015', 'S-0008', 'S-0007', 'S-0009', 'S-0056', 'S-0026', 'S-0107', 'S-0066', 'S-0018', 'S-0109', 'S-0043', 'S-0046', 'S-0003', 'S-0006', 'S-0097']

def get_supplier(instance):
  if instance['supplier'] in SUPPLIERS_LIST:
    supplier_index = SUPPLIERS_LIST.index(instance['supplier'])
    result = [0] * supplier_index + [1] + [0] * (len(SUPPLIERS_LIST) - supplier_index - 1)
  else:
    result = [0] * len(SUPPLIERS_LIST)
  return result

def to_sample(instance):
  return (is_bracket_pricing(instance) + get_quantity(instance)
          + get_min_order_quantity(instance) + get_annual_usage(instance)
          + get_absolute_date(instance) + get_supplier(instance))

Mean Squared Error: 0.7992338454746866

Как мы видим, произошло значительное снижение средней ошибки. Скорее всего это означает, что в значении поля, которое мы добавили, заложена ценная для алгоритма информация и стоит попробовать воспользоваться ей позже еще несколько раз, но каким-нибудь менее банальным способом. Мы уже обсудили, что поле tube_assembly вряд ли поможет нам напрямую, однако стоит все-таки попробовать проверить это экспериментально.

def get_assembly(instance):
  assembly_id = int(instance['tube_assembly_id'].split('-')[1])
  result = [0] * assembly_id + [1] + [0] * (25000 - assembly_id - 1)
  return result

def to_sample(instance):
  return (is_bracket_pricing(instance) + get_quantity(instance)
          + get_min_order_quantity(instance) + get_annual_usage(instance)
          + get_absolute_date(instance) + get_supplier(instance)
          + get_assembly(instance))

Разумеется, конкретный способ конвертации индентификатора производителя в фичевектор выглядит несколько неуклюже, особенно с учетом количества возможных значений, однако если у читателя есть более конструктивный вариант использования этого поля напрямую и без привлечения дополнительных данных, то можно обсудить его в комментариях и даже попробовать применить в рамках имеющегося на данный момент кода. А мы вспомним о том, что помимо основного файла с обучающей выборкой в нашем распоряжении имеется еще несколько файлов с вспомогательными данными и попробуем поэкспериментировать с их привлечением. Например, чтобы нам было не очень обидно за относительную (хотя и предсказуемую) неудачу с прямым использованием значения поля tube_assembly_id, можно попытаться взять реванш, воспользовавшись данными, содержащимися в файле specs.csv и в анонимизированной форме описывающими соответствующие этим значениям спецификаторы.

head data/competition_data/specs.csv -n 20
tube_assembly_id,spec1,spec2,spec3,spec4,spec5,spec6,spec7,spec8,spec9,spec10
TA-00001,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00002,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00003,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00004,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00005,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00006,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00007,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00008,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00009,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00010,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00011,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00012,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00013,SP-0004,SP-0069,SP-0080,NA,NA,NA,NA,NA,NA,NA
TA-00014,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00015,SP-0063,SP-0069,SP-0080,NA,NA,NA,NA,NA,NA,NA
TA-00016,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00017,NA,NA,NA,NA,NA,NA,NA,NA,NA,NA
TA-00018,SP-0007,SP-0058,SP-0070,SP-0080,NA,NA,NA,NA,NA,NA
TA-00019,SP-0080,NA,NA,NA,NA,NA,NA,NA,NA,NA

Похоже на то, что имеет место простое перечисление спецификаторов соответствующих каждому значению. Соответственно разумным выглядит вариант в качестве фичевектора выбрать массив из единиц и нулей, в котором единицы стоят на позициях, соответствующих имеющимся спецификаторам, а нули — всем остальным. Для того, чтобы использовать данные из дополнительного файла, нам придется сделать небольшое изменение в структуре кода.

def get_assembly_specs(instance, assembly_to_specs):
  result = [0] * 100
  for spec in assembly_to_specs[instance['tube_assembly_id']]:
    result[int(spec.split('-')[1])] = 1
  return result

def to_sample(instance, additional_data):
  return (is_bracket_pricing(instance) + get_quantity(instance)
          + get_min_order_quantity(instance) + get_annual_usage(instance)
          + get_absolute_date(instance) + get_supplier(instance)
          + get_assembly_specs(instance, additional_data['assembly_to_specs']))

def load_additional_data():
  result = dict()
  assembly_to_specs = dict()
  with open('data/competition_data/specs.csv') as input_stream:
    header_line = input_stream.readline()
    for line in input_stream:
      tube_assembly_id = line.split(',')[0]
      specs = []
      for spec in line.strip().split(',')[1:]:
        if spec != 'NA':
          specs.append(spec)
      assembly_to_specs[tube_assembly_id] = specs

  result['assembly_to_specs'] = assembly_to_specs
  return result

additional_data = load_additional_data()
list_of_samples = list(map(lambda x:to_sample(x, additional_data), list_of_instances))

Mean Squared Error: 0.7754770419953809

Наши труды не пропали зря и целевая метрика улучшилась на 0.024, что на текущем этапе уже совсем неплохо. На этом, пожалуй, можно пока что остановиться с оптимизацией алгоритма и обсудить вопрос простого предоставления удобного API к обученным алгоритмам для прикладного программиста.

Сперва сохраним обученную модель на диск.

with open('./data/model.mdl', 'wb') as output_stream:
  output_stream.write(pickle.dumps(model))

Теперь создадим скрипт generate_response.py, в котором воспользуемся полученными ранее наработками.

import pickle
import numpy
import research

class FinalModel(object):
  def __init__(self, model, to_sample, additional_data):
    self._model = model
    self._to_sample = to_sample
    self._additional_data = additional_data
  def process(self, instance):
    return self._model.predict(numpy.array(self._to_sample(
                                               instance, self._additional_data)).reshape(1, -1))[0]

if __name__ == '__main__':
  with open('./data/model.mdl', 'rb') as input_stream:
    model = pickle.loads(input_stream.read())
  additional_data = research.load_additional_data()
  final_model = FinalModel(model, research.to_sample, additional_data)
  print(final_model.process({'tube_assembly_id':'TA-00001', 'supplier':'S-0066',
                             'quote_date':'2013-06-23', 'annual_usage':'0',
                             'min_order_quantity':'0', 'bracket_pricing':'Yes', 'quantity':'1'}))

2.357692493326624

Теперь аналогичным образом прикладной программист может подгрузить модель и сопутствующие переменные в любом скрипте и использовать для предсказания значения на новых входящих данных.

В следующей серии — дальнейшая оптимизация модели, больше фичей, более сложные алгоритмы и настройка их гиперпараметров, при необходимости — продвинутые валидационные процедуры.

Автор: Андрей Гриненко

Источник


* - обязательные к заполнению поля


https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js