- PVSM.RU - https://www.pvsm.ru -
На днях ученые из MIT показали альтернативу многослойному перцептрону (MLP). MLP с самого момента изобретения глубокого обучения лежит в основе всех нейросетей, какими мы их знаем сегодня. На его идее в том числе построены большие языковые модели и системы компьютерного зрения.
Однако теперь все может измениться. В KAN (Kolmogorov-Arnold Networks) исследователи реализовали перемещение функций активации с нейронов на ребра нейросети, и такой подход показал блестящие результаты.

Идею KAN ученые подчерпнули из теоремы Колмогорова-Арнольда, именно в их честь и названа архитектура. Вообще говоря, исследование очень математичное, в статье 50 страниц с формулами, повсюду термины из мат.анализа, высшей алгебры, функана и прочего.
В общем, если хотите разобраться с тем, как эта сенсация работает, и при этом не сойти с ума, на нашем сайте мы, команда канала Data Secrets [1], написали для вас длинный и интересный разбор. Там мы на пальцах объяснили всю математику, рассказали про строение сети, привели примеры и ответили на вопрос "а почему до этого раньше никто не додумался".
Прочитайте, не пожалеете: https://datasecrets.ru/articles/9 [2].
А эта статья - для тех, кто хочет поиграть с новой архитектурой на практике. Мы рассмотрим несколько примеров кода на Python и понаблюдаем, как KAN справляется с привычными нам задачами машинного обучения. Поехали!
Чтобы участники сообщества могли сразу же потрогать все своими руками, добрые исследователи вместе со статьей представили библиотеку pykan [3], благодаря которой можно запускать KAN из коробки. Именно с ней мы сегодня и будем работать.
Итак, начнем с установки. Библиотеку можно поставить привычно через pip (pip install pykan) или с помощью клонирования репозитория [4]:
git clone https://github.com/KindXiaoming/pykan.git
cd pykan
pip install -e .
# pip install -r requirements.txt # install requirements
Далее импортируем библиотеку с помощью from kan import * и наконец-то переходим к написанию кода!
Ну куда же без задачи регрессии? Ведь именно с нее началось машинное обучение в 50-х годах прошлого века... Ладно, краткие исторические справки оставим на потом.
Давайте загадаем KAN такую загадку: возьмем функцию от двух переменных f(x,y) = exp(sin(pi*x)+y^2) и попросим KAN по входам и выходам функции найти ее формулу. Это так называемая символьная регрессия. Надо сказать, что задача хоть и кажется тривиальной, но обычно математически трудна для нейросетей.
from kan import *
# формируем KAN: 2D входы, 1D выходы, 5 скрытых нейронов,
# кубические сплайны и сетка на 5 точках.
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
# сгенерируем датасет
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
dataset['train_input'].shape, dataset['train_label'].shape
#(torch.Size([1000, 2]), torch.Size([1000, 1]))
Сплайны в KAN – это как раз те самые обучаемые функции на ребрах. В математике сплайн – это такая гладкая кривая, кусочно-полиномиальная функция, которая на разных отрезках задается различными полиномами. Каждый сплайн аппроксимируется с помощью заданного количества точек (сетки). Чем больше точек - тем точнее аппроксимация.
Обучающую и тестовую выборки получили, значит можно обучать. Тут ничего нового – привычный метод train:
# train the model
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.);
Можно визуализировать KAN, который у нас получился:

Давай посмотрим на эту картинку внимательнее. Наверху мы видим сплайн, похожий на экспоненту, а затем слева и справа наблюдаем соотвественно синус и параболу. Ничего не напоминает?
Все верно, если сложить все вместе, то получится формула которую мы загадывали: f(x,y) = exp(sin(pi*x)+y^2). Благодаря тому, что в KAN обучаются не параметры (числа), а функции, он почти идеально справляется с задачей регрессии на сложных функциях и, как показали исследователи, гораздо эффективнее генерализирует данные. В частности, в этой задаче мы получаем метрику r2 равной 0.99.
В статье исследователи также показали, как KAN помогает решать дифференциальные уравнения и (пере)открывает законы физики и математики.
Тут все еще интереснее. Но все по порядку. Снова сгенерируем игрушечный датасет (в сообществе его прозвали "две луны"):
from kan import KAN
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
import torch
import numpy as np
dataset = {}
train_input, train_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)
test_input, test_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)
dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label[:,None])
dataset['test_label'] = torch.from_numpy(test_label[:,None])
X = dataset['train_input']
y = dataset['train_label']
plt.scatter(X[:,0], X[:,1], c=y[:,0])

Для начала давайте немного развлечемся и решим задачу, как будто это регрессия: будем предсказывать некоторое число, округлять его и сравнивать с реальной меткой класса.
model = KAN(width=[2,1], grid=3, k=3)
def train_acc():
return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())
def test_acc():
return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())
results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc));
results['train_acc'][-1], results['test_acc'][-1]
# (1.0, 1.0)
По последней строке видно: KAN справился идеально. Если заглянуть глубже, то мы увидим, что (опять же с помощью обучения функций) сетка вывела для себя "формулу ответа", которая и помогает ей безупречно справится с задачей:

А теперь попробуем по-взрослому, с кросс-энтропией, логитами и argmax. Вот код, в котором мы немного подправляем размерности в датасете и обучаем KAN:
dataset = {}
train_input, train_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)
test_input, test_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)
dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label)
dataset['test_label'] = torch.from_numpy(test_label)
X = dataset['train_input']
y = dataset['train_label']
model = KAN(width=[2,2], grid=3, k=3)
def train_acc():
return torch.mean((torch.argmax(model(dataset['train_input']), dim=1) == dataset['train_label']).float())
def test_acc():
return torch.mean((torch.argmax(model(dataset['test_input']), dim=1) == dataset['test_label']).float())
results = model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), loss_fn=torch.nn.CrossEntropyLoss());
Точность в этом случае немного ниже, но все еще достаточно хороша: 0.9660. Кстати, вот так можно посмотреть на формулы KAN (для каждого класса формула своя):
lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
model.auto_symbolic(lib=lib)
formula1, formula2 = model.symbolic_formula()[0]
В данном случае они получаются такими:


В статье мы рассмотрели, как запустить KAN для привычных задач регрессии и классификации и немного заглянули "под капот" архитектуры. Если хотите больше примеров – загляните в документацию [5] или в репозитория проекта, там лежат очень красивые и понятные ноутбуки [6], в которых можно найти туториалы по библиотеке и кейсы использования KAN.
Авторы KAN доказали, что ему требуется во много раз меньше нейронов, чтобы достичь точности MLP. Также KAN гораздо лучше генерализует данные и лучше справляется с аппроксимацией сложных математических функций (мы увидели это на примерах), у него, можно сказать, "технический склад ума".
Однако у архитектуры есть бутылочное горлышко: KAN учится медленнее MLP примерно в 10 раз. Возможно, это станет серьезным камнем преткновения, а возможно инженеры быстро научатся оптимизировать эффективность таких сетей.
Больше новостей из мира машинного обучения можно найти в нашем телеграм-канале. Подписывайтесь, чтобы быть в курсе: @data_secrets [1].
Автор: DataSecrets
Источник [7]
Сайт-источник PVSM.RU: https://www.pvsm.ru
Путь до страницы источника: https://www.pvsm.ru/pertseptron/391333
Ссылки в тексте:
[1] Data Secrets: https://t.me/+Q9dkwn6NfzswYTYy
[2] https://datasecrets.ru/articles/9: https://datasecrets.ru/articles/9
[3] pykan: https://kindxiaoming.github.io/pykan/
[4] репозитория: https://github.com/KindXiaoming/pykan
[5] документацию: https://kindxiaoming.github.io/pykan/index.html
[6] ноутбуки: https://github.com/KindXiaoming/pykan/tree/master/tutorials
[7] Источник: https://habr.com/ru/articles/812147/?utm_source=habrahabr&utm_medium=rss&utm_campaign=812147
Нажмите здесь для печати.