Грокаем PyTorch

@
Привет, Хабр!

У нас в предзаказе появилась долгожданная книга о библиотеке PyTorch.



Поскольку весь необходимый базовый материал о PyTorch вы узнаете из этой книги, мы напоминаем о пользепроцесса под названием «grokking» или «углубленное постижение» той темы, которую вы хотите усвоить. В сегодняшней публикации мы расскажем, как Кай Арулкумаран (Kai Arulkumaran) грокнул PyTorch (без картинок). Добро пожаловать под кат.

PyTorch – это гибкий фреймворк для глубокого обучения, обеспечивающий автоматическое различение объектов при помощи динамических нейронных сетей (то есть, сетей, использующих динамическое управление потоком, например, инструкции if и циклы while). PyTorch поддерживает GPU-ускорение, распределенное обучение, различные виды оптимизации и еще множество других приятных возможностей. Здесь я изложил некоторые мысли о том, как, на мой взгляд, следует использовать PyTorch; здесь не охвачены все аспекты библиотеки и рекомендуемые практики, но, надеюсь, этот текст окажется вам полезен.

Нейронные сети – это подкласс вычислительных графов. Вычислительные графы получают на вход данные, далее эти данные маршрутизируются (и могут преобразовываться) на узлах, где и происходит их обработка. В глубоком обучении нейроны (узлы) обычно преобразуют данные, применяя к ним параметры и дифференцируемые функции, так, чтобы параметры можно было оптимизировать для минимизации потерь методом градиентного спуска. В более широком смысле отмечу, что функции могут быть стохастическими, а граф – динамическим. Таким образом, тогда как нейронные сети хорошо вписываются в парадигму программирования потоков данных (dataflow programming), API PyTorch ориентирован на парадигму императивного программирования, а такой способ трактовки создаваемых программ гораздо более привычен. Именно поэтому код PyTorch проще читается, по нему проще судить об устройстве сложных программ, что, однако, не требует серьезно поступаться производительностью: на самом деле, PyTorch достаточно быстр и предусматривает множество оптимизаций, о которых вы, как конечный пользователь, можете совершенно не волноваться (однако, если они вам действительно интересны, можете копнуть поглубже и познакомиться с ними).

Остальная часть этой статьи является разбором официального примера на датасете MNIST. Здесь мы грокаем PyTorch, поэтому разбираться в статье рекомендую только после знакомства с официальными руководствами для начинающих. Для удобства код представлен в виде небольших фрагментов, снабженных комментариями, то есть, не распределен на отдельные функции/файлы, которые вы привыкли видеть в чистом модульном коде.

Импорты


import argparse import os import torch from torch import nn, optim from torch.nn import functional as F from torch.utils.data import DataLoader from torchvision import datasets, transforms

Все это вполне стандартные импорты, за исключением модулей torchvision, особенно активно используемых для решения задач, связанных с компьютерным зрением.

Настройка


parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--momentum', type=float, default=0.5, metavar='M', help='SGD momentum (default: 0.5)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--save-interval', type=int, default=10, metavar='N', help='how many batches to wait before checkpointing') parser.add_argument('--resume', action='store_true', default=False, help='resume training from checkpoint') args = parser.parse_args() use_cuda = torch.cuda.is_available() and not args.no_cuda device = torch.device('cuda' if use_cuda else 'cpu') torch.manual_seed(args.seed) if use_cuda: torch.cuda.manual_seed(args.seed)

argparse – это стандартный способ обращения с аргументами командной строки в Python.

Если нужно писать код, рассчитанный на работу на разных устройствах (пользуясь GPU-ускорением, когда оно доступно, но при его отсутствии откатываясь обратно к вычислениям на CPU), то выберите и сохраните подходящий torch.device, при помощи которого можно определить, где должны храниться тензоры. Подробнее о создании такого кода см. в официальной документации. Подход PyTorch – отдавать подбор устройств под контроль пользователя, что может показаться нежелательным в простых примерах. Однако, такой подход значительно упрощает работу, когда приходится иметь дело с тензорами, что а) удобно при отладке b) позволяет эффективно использовать устройства вручную.

Для воспроизводимости экспериментов необходимо установить случайные начальные значения для всех компонентов, использующих случайную генерацию чисел (в том числе, random или numpy, если и они у вас используются). Обратите внимание: cuDNN использует недетерминированные алгоритмы и по желанию отключается при помощи torch.backends.cudnn.enabled = False.

Данные


data_path = os.path.join(os.path.expanduser('~'), '.torch', 'datasets', 'mnist') train_data = datasets.MNIST(data_path, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) test_data = datasets.MNIST(data_path, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(test_data, batch_size=args.batch_size, num_workers=4, pin_memory=True)


Поскольку модели torchvision сохраняются под ~/.torch/models/, я предпочитаю хранить датасеты torchvision под ~/.torch/datasets. Это мое авторское соглашение, но им очень удобно пользоваться в проектах, разрабатываемых на базе MNIST, CIFAR-10, т.д. В целом, датасеты следует хранить отдельно от кода, если вы собираетесь переиспользовать несколько датасетов.

torchvision.transforms содержит множество удобных вариантов преобразований для отдельных изображений, например, обрезку и нормализацию.

В DataLoader есть множество опций, но, кроме batch_size и shuffle, также следует иметь в виду num_workers и pin_memory, они помогают повысить эффективность. num_workers > 0 использует субпроцессы для асинхронной загрузки данных, а не блокирует под это главный процесс. Типичный пример использования – загрузка данных (например, изображений) с диска и, возможно, их преобразования; все это может делаться параллельно, вместе с сетевой обработкой данных. Степень обработки, возможно, потребуется настроить, чтобы a) минимизировать количество работников и, следовательно, объем использования CPU и RAM (каждый работник загружает отдельную порцию, а не отдельные образцы, входящие в порцию) b) минимизировать длительность ожидания данных в сети. pin_memory использует закрепленную память (pinned memory) (в противовес подкачиваемой) для ускорения любых операций переноса данных из RAM в GPU (и ничего не делает с кодом, относящимся только к CPU).

Модель


class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = self.fc2(x) return F.log_softmax(x, dim=1) model = Net().to(device) optimiser = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) if args.resume: model.load_state_dict(torch.load('model.pth')) optimiser.load_state_dict(torch.load('optimiser.pth'))

Сетевая инициализация обычно распространяется на переменные членов, слои, в которых содержатся обучаемые параметры и, может быть, на отдельные обучаемые параметры и необучаемые буферы. Затем при прямом проходе они используются в сочетании с функциями из F, чисто функциональными, не содержащими параметров. Некоторым нравится работать с чисто функциональными сетями (напр., держать параметры и использовать F.conv2d вместо nn.Conv2d) или сети, целиком состоящие из слоев (напр., nn.ReLU вместо F.relu).

.to(device) – удобный способ отправлять параметры устройства (и буферы) на GPU, если в качестве device задан GPU, так как в противном случае (если в качестве device задан CPU) ничего делаться не будет. Важно перенести параметры устройства на соответствующее устройство, прежде, чем передавать их оптимизатору; в противном случае оптимизатор не сможет правильно отслеживать параметры!

И нейронные сети (nn.Module), и оптимизаторы (optim.Optimizer) умеют сохранять и загружать свое внутреннее состояние, и делать это рекомендуется с помощью .load_state_dict(state_dict) – перезагрузить состояние обоих бывает нужно, чтобы возобновить обучение на основе ранее сохраненных словарей состояний. Сохранение всего объекта целиком может быть чревато ошибками. Если вы сохранили тензоры на GPU и хотите загрузить их на CPU или другой GPU, то проще всего загружать их непосредственно на CPU при помощи опцииmap_location, напр., torch.load('model.pth', map_location='cpu').

Вот еще некоторые моменты, не показанные здесь, но заслуживающие упоминания, связаны с тем, что при прямом проходе можно использовать поток управления (напр., выполнение инструкции if может зависеть от переменной члена или от самих данных. Кроме того, совершенно допустимо посреди процесса выводить (print) тензоры, что значительно упрощает отладку. Наконец, при прямом проходе может использоваться множество аргументов. Проиллюстрирую этот момент коротким листингом, не привязанным ни к какой конкретной идее:

def forward(self, x, hx, drop=False): hx2 = self.rnn(x, hx) print(hx.mean().item(), hx.var().item()) if hx.max.item() > 10 or self.can_drop and drop: return hx else: return hx2

Обучение


model.train() train_losses = [] for i, (data, target) in enumerate(train_loader): data = data.to(device=device, non_blocking=True) target = target.to(device=device, non_blocking=True) optimiser.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() train_losses.append(loss.item()) optimiser.step() if i % 10 == 0: print(i, loss.item()) torch.save(model.state_dict(), 'model.pth') torch.save(optimiser.state_dict(), 'optimiser.pth') torch.save(train_losses, 'train_losses.pth')

Сетевые модули по умолчанию ставятся в режим обучения – что в определенной степени отражается на работе модулей, больше всего – на прореживании и пакетной нормализации. Так или иначе, лучше задавать такие вещи вручную при помощи .train(), который просачивает флаг «training» до всех дочерних модулей.

Здесь метод .to() не только принимает устройство, но и устанавливает non_blocking=True, обеспечивая таким образом асинхронное копирование данных на GPU из закрепленной памяти, позволяя CPU сохранять работоспособность при переносе данных; в противном случае non_blocking=True попросту не вариант.

Прежде чем собрать новый набор градиентов при помощи loss.backward() и выполнить обратное распространение при помощи optimiser.step(), необходимо вручную обнулить градиенты оптимизируемых параметров при помощи optimiser.zero_grad(). По умолчанию PyTorch накапливает градиенты, что очень удобно, если у вас не хватает ресурсов, чтобы вычислить все нужные вам градиенты за один проход.

PyTorch использует «магнитофонную» систему автоматических градиентов – собирает информацию о том, какие операции и в каком порядке производились над тензорами, а затем воспроизводит их в обратном направлении, чтобы выполнить дифференциацию в обратном порядке (reverse-mode differentiation). Вот почему он такой супер-гибкий и допускает произвольные вычислительные графы. Если ни один из этих тензоров не требует градиентов (приходится установить requires_grad=True, создавая тензор для этой цели), то никакой граф не сохраняется! Однако, у сетей обычно есть параметры, требующие градиентов, поэтому любые вычисления, выполняемые на основе вывода сети, будут сохраняться в графе. Итак, если вы хотите сохранять данные, результирующие после этого шага, то понадобится вручную отключить градиенты или (более распространенный подход), сохранить эту информацию как число Python (при помощи .item() в скаляре PyTorch) или массив numpy. Подробнее об autograd рассказано в официальной документации.

Один из способов сократить вычислительный граф — пользоваться .detach(), когда проходится скрытое состояние при обучении RNN с усеченной версией backpropagation-through-time. Это также удобно при дифференциации потерь, когда один из компонентов является выводом другой сети, но эта другая сеть не должна оптимизироваться относительно потерь. В качестве примера приведу обучение дискриминативной части на материале вывода генерирующей при работе с GAN, либо обучение политики в алгоритме актор-критик с использованием целевой функции в качестве базовой (напр. A2C). Еще один прием, предотвращающий вычисление градиентов, эффективный при обучении GAN (обучение генерирующей части на материале дискриминативной) и типичный при тонкой настройке – циклический перебор параметров сети, при котором задано param.requires_grad = False.

Важно не только регистрировать результаты в консоли/файле логов, но и ставить контрольные точки в параметрах модели (и состоянии оптимизатора) просто на всякий случай. Также можно пользоваться torch.save() для сохранения обычных Python-объектов, либо воспользоваться другим стандартным решением – встроенным pickle.

Тестирование


model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for data, target in test_loader: data = data.to(device=device, non_blocking=True) target = target.to(device=device, non_blocking=True) output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() pred = output.argmax(1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_data) acc = correct / len(test_data) print(acc, test_loss)

В ответ на .train() сети нужно явно переводить в режим оценки (evaluation mode) при помощи .eval().

Как упоминалось выше, при использовании сети обычно составляется вычислительный граф. Чтобы этого не происходило, пользуйтесь менеджером контекста no_grad при помощи with torch.no_grad().

Еще немного


Это дополнительный раздел, в который я вынес еще несколько полезных отступлений.
Вот официальная документация, поясняющая работу с памятью.

Ошибки CUDA? Исправлять их тяжко, и обычно они связаны с логическими неувязками, по которым на CPU выводятся более вразумительные сообщения об ошибках, чем на GPU. Лучше всего, если, планируя работать с GPU, вы сможете быстро переключаться между CPU и GPU. Более общий совет по разработке – организовать код так, чтобы его можно было быстро проверить перед запуском полноценного задания. Например, подготовьте небольшой или синтетический датасет, прогоните одну эпоху train + test, т.д. Если дело в ошибке CUDA, либо вы совсем никак не можете переключиться на CPU, установите CUDA_LAUNCH_BLOCKING=1. Так запуски ядра CUDA станут синхронными, и вы станете получать более точные сообщения об ошибках.

Замечание о torch.multiprocessing или просто об одновременном запуске множества сценариев PyTorch. Поскольку PyTorch использует многопоточные библиотеки BLAS для ускорения вычислений линейной алгебры на CPU, обычно при этом задействовано несколько ядер. Если вы хотите делать несколько вещей одновременно, с использованием многопоточной обработки или нескольких сценариев, может быть целесообразно вручную сократить их количество, установив для переменной окружения OMP_NUM_THREADS значение 1 или другое невысокое значение. Таким образом снижается вероятность пробуксовки процессора. В официальной документации есть и другие замечания по поводу многопоточной обработки.
Попробуйте дерево связей для поиска пересечений между медийными объектами. Узнайте, пересекаются ли Искусственные нейронные сети и , и что именно их связывает
Смотреть

Хотите больше?

Получите полный доступ к новостям и аналитике бесплатно и без рекламы.

Анализ статьи

×
Организации
Упоминаются
АО "ДСЕТ"
Организации
Технологии
Упоминаются