Машинное обучение с помощью TMVA (ROOT)

в 12:29, , рубрики: c++, machine learning, машинное обучение

В последние пару лет только и слышно о том, что Python и scikit-learn являются неким золотым стандартом в data science. В то же время многие профессиональные разработчики жалуются, что в Python кривой способ обращения к базовым классам и т.п. И вообще им не нравится, что нельзя заниматься машинным обучением на родном C++.
Об одной из библиотек, написанных на С++, я и хотел бы рассказать.

TMVA (Toolkit for Multivariate Data Analysis with ROOT) — open-source библиотека алгоритмов машинного обучения, которая идёт в дополнение к пакету анализа больших данных ROOT, соответственно устанавливается вместе с ним. Про установку подробно написано в мануале, поэтому мы не будем рассматривать этот момент.
Основным сайтом проекта до недавнего времени считался TMVA, но, как мы видим, на нём уже давненько не было никаких обновлений. Это не повод для скепсиса и паники, т.к. теперь его развитием занимается новая команда разработчиков из CERN. С информацией о проекте можно ознакомиться здесь: New TMVA.
CERN (Европейская организация по ядерным исследованиям) была первопроходцем в создании ПО для анализа больших объёмов данных. Именно там была разработана объектно-ориентированная библиотека ROOT, которая нашла применение не только в мире физики.
В ROOT'е данные хранятся в очень экономичном формате *.root, но можно работать и с любым текстовым форматом. Для простоты используем при работе с TMVA обычный тектовый формат csv/txt.
К сожалению, на данный момент, в TMVA используются только алгоритмы обучения с учителем, т.е. вам понадобится иметь некоторые сигнальные данные (информация о плохих или хороших заёмщиках, показатели болезней и т.д.) на вход, дабы алгоритм распознал и отскорил необходимый массив (Background/шум), обучившись на примере.
Так выглядят корреляционные матрицы в TMVA
image

Итак, представим, что у нас уже установлен ROOT и есть 2 текстовых файла: с "хорошими" и теми, кого нужно классифицировать (либо построить регрессию для прогнозирования). Для того, чтобы подать как инпут эти 2 файла, необходимо привести заголовок файла к необходимому формату:
id/F:Param1/I:Param2/I:Param3/F

В TMVA 2 типа данных: Float и Integer (в Reader'e только float)
В качестве разделителя переменных по умолчанию идёт знак запятой.
Ознакомиться со списком алгоритмов можно в User Guide

Давайте перейдём к коду.

#include "TMVA/Types.h"
#include "TMVA/Factory.h"
#include "TMVA/Tools.h"
using std::cout;
std::string outputListFileName;

void Churn_model_BDT()
{
std::cout << std::endl;

std::cout << "===> Start TMVAClassification" << std::endl;
//Создаём выходное ROOT-дерево, в котором у нас будет содержаться информация о модели: корреляционные матрицы, корреляции параметров, RO-кривые)
TFile* outFputFile = new TFile("ChurnModel.root", "RECREATE"); 
//Создаём для записи файл построения модели (генерируемый методом MakeClass, он будет расположен в директории weights, вместе с xml файлом
TMVA::Factory *factory = new TMVA::Factory("TMVAClassification_Model",outFputFile,"V:!Silent:Color:Transformations=I:DrawProgressBar:AnalysisType=Classification");

//читаем файлы
TString sigFile="Signal.csv";
TString bkgFile ="Background.csv";

cout << ">>>> Adding variables phasen";
factory->AddVariable("Param1",'I');
factory->AddVariable("Param2",'I');
factory->AddVariable("Param3",'F');
//Id в моём случае будет просто проверочной переменной
factory->AddSpectator("id", 'F');
Double_t sigWeight = 1.0; // overall weight for all signal events
Double_t bkgWeight = 1.0; // overall weight for all background events
factory->SetInputTrees( sigFile, bkgFile, sigWeight, bkgWeight );

cout << ">>>> Cuttingn";
//Отбираем значения для параметра Param1 и Param3;может пригодиться если данные с каким-то шумом
TCut preselectionCut("Param1 > 0. && Param3<350.0");
TCut mycutS = "";
//Можем взять каждое n-ое событие в Background,если данных очень много, а ноутбук не тянет
TCut mycutB = "id%100==0";
//Задаём объём тренировочного и тестового дерева
factory->PrepareTrainingAndTestTree(mycutS, mycutB, "nTrain_Signal=16000:nTest_Signal=1451:nTrain_Background=800000:nTest_Background=118416:VerboseLevel=Debug");

//Выбираем модель Boosted Decision and Regression Trees, вводим параметры
factory->BookMethod(TMVA::Types::kBDT, "BDT", "MaxDepth=5:NTrees=2000:MinNodeSize=9%:PruneStrength=10:SeparationType=GiniIndex");
//Выводим help для метода
factory->PrintHelpMessage("BDT");
//тренируем,тестируем и оцениваем модель
cout << ">>>> doing TrainAllMethodsn";
factory->TrainAllMethods();
cout << ">>>> doing TestAllMethodsn";
factory->TestAllMethods();
cout << ">>>> doing EvaluateAllMethodsn";
factory->EvaluateAllMethods();

 // Save the output
   outFputFile->Close();

   std::cout << "===> Wrote root file: " << outFputFile->GetName() << std::endl;
   std::cout << "===> TMVAClassification is done!" << std::endl;

   delete factory;
}

Запустить макрос можно командой из терминала "root Churn_Model_BDT.C".
После того, как всё досчитается, в консоли можно открыть ROOT-браузер, командой "TBrowser b;" и полюбоваться множеством симпатичных графиков.
В следующей статье я хочу рассказать про то, как написать Reader модели, который позволяет применять полученную модель на любых других данных и выгрузить отскоренный массив с определённой отсечкой скор-балла.
Также отвечу на вопросы по текущему материалу в комментариях.

Автор: Fontanka135

Источник



https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js