Сопоставление потоков в PyTorch
Этот репозиторий содержит простую реализацию на PyTorch статьи Flow Matching for Generative Modeling.
Пример сопоставления потока в 2D
Ниже показан gif, демонстрирующий отображение одномерного гауссовского распределения в шахматное распределение с визуализацией векторного поля.
А вот еще один пример с набором данных "луны".
Начало работы
Клонируйте репозиторий и настройте python-окружение.
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingУбедитесь, что у вас установлен Python версии 3.12 или выше.
Установите uv:
curl -LsSf https://astral.sh/uv/install.sh | shЗатем настройте среду:
uv syncConditional Flow Matching [Lipman+ 2023]
Это оригинальная реализация статьи CFM [1]. Некоторые компоненты кода адаптированы из [2] и [3].
2D игрушечные датасеты
Вы можете обучать модели CFM на 2D синтетических датасетах, таких как checkerboard и moons. Укажите название датасета с помощью опции --dataset. Параметры обучения заранее определены в скрипте, а визуализации результатов обучения сохраняются в директории outputs/. Контрольные точки модели не включены, так как они легко воспроизводимы с настройками по умолчанию.
uv run scripts/train_flow_matching_2d.py --dataset checkerboardВекторные поля и сгенерированные образцы, такие как те, что показаны в GIF-файлах в верхней части этого README, теперь можно найти в директории outputs/cfm/.
Наборы изображений
Вы также можете обучать классово-условные модели CFM на популярных наборах данных для классификации изображений. Как сгенерированные образцы, так и контрольные точки модели будут сохраняться в директории outputs/cfm. Для получения подробного списка параметров обучения выполните команду uv run scripts/train_flow_matching_on_image.py --help.
Чтобы обучить классово-условную CFM на наборе данных MNIST, выполните:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist
После обучения вы теперь можете генерировать образцы с помощью:uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist
Теперь вы должны увидеть сгенерированные образцы в директории outputs/cfm/mnist/.
Исправленный поток [Liu+ 2023]
Это реализация модели Reflow (а именно 2-Rectified Flow) из статьи Rectified Flow [2].
2D синтетические данные
Мы реализовали Reflow на 2d синтетических наборах данных, аналогично CFM. Для обучения reflow необходимо указать заранее обученные контрольные точки CFM, так как reflow является моделью дистилляции.
Например, чтобы обучить на наборе данных checkerboard с заранее обученной контрольной точкой CFM:
uv run scripts/train_reflow_2d.py --dataset checkerboardРезультаты обучения, включая визуализации векторных полей и сгенерированные образцы, сохраняются в папке outputs/reflow/.
Сравнение процесса семплирования между CFM и Reflow
Для сравнения CFM и Reflow на двумерных датасетах выполните:
uv run scripts/plot_comparison_2d.py --dataset checkerboardПолучившиеся GIF-файлы можно найти в папке outputs/comparisons/. Ниже приведен пример сравнения двух методов на датасете checkerboard:
Ссылки
- [1] Липман, Ярон и др. "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] Лю, Синчао и др. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching