PyTorchにおけるフローマッチング
このリポジトリは論文Flow Matching for Generative Modelingの簡単なPyTorch実装を含んでいます。
2Dフローマッチングの例
以下のgifは、単一のガウス分布をチェッカーボード分布にマッピングし、ベクトル場を可視化したものです。
こちらはムーンデータセットの別の例です。
はじめに
リポジトリをクローンし、Python環境をセットアップしてください。
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingPython 3.10以上がインストールされていることを確認してください。
uvを使用してPython環境をセットアップするには:
uv sync
source .venv/bin/activateあるいは、pipを使用して:
python -m venv .venv
source .venv/bin/activate
pip install -e .条件付きフローマッチング [Lipman+ 2023]
これは元のCFM論文実装[1]です。コードの一部は[2]および[3]から適応されています。
2D トイデータセット
checkerboardやmoonsなどの2D合成データセットでCFMモデルを訓練できます。--datasetオプションでデータセット名を指定してください。訓練パラメータはスクリプト内で事前定義されており、訓練結果の可視化はoutputs/ディレクトリに保存されます。モデルのチェックポイントは、デフォルト設定で簡単に再現可能なため含まれていません。
python train_flow_matching_2d.py --dataset checkerboardベクトル場と生成されたサンプルは、このREADMEの上部にGIFとして表示されているもののように、現在outputs/cfm/ディレクトリにあります。
画像データセット
また、人気のある画像分類データセットでクラス条件付きCFMモデルをトレーニングすることもできます。生成されたサンプルとモデルのチェックポイントは両方ともoutputs/cfmディレクトリに保存されます。トレーニングパラメータの詳細なリストはpython train_flow_matching_on_images.py --helpを実行してください。
MNISTデータセットでクラス条件付きCFMをトレーニングするには、以下を実行します。
python train_flow_matching_on_image.py --do_train --dataset mnistトレーニング後、以下でサンプルを生成できます:
python train_flow_matching_on_image.py --do_sample --dataset mnistNow, you should be able to see the generated samples in the outputs/cfm/mnist/ directory.
Rectified Flow [Liu+ 2023]
これはRectified Flow論文[2]のReflowモデル(具体的には2-Rectified Flow)の実装です。
2D Synthetic Data
CFMと同様に、2次元合成データセットでReflowを実装しました。Reflowは蒸留モデルであるため、トレーニングには事前学習済みCFMチェックポイントを指定する必要があります。
例えば、事前学習済みCFMチェックポイントを使ってcheckerboardデータセットでトレーニングする場合:
python train_reflow_2d.py --dataset checkerboard --pretrained-model outputs/cfm/checkerboard/ckpt.pthトレーニング結果は、ベクトル場の可視化や生成されたサンプルを含めて、outputs/reflow/ フォルダに保存されます。
CFMとReflowのサンプリングプロセスの比較
2次元データセットでCFMとReflowを比較するには、以下を実行してください:
python plot_comparison_2d.py --dataset checkerboard
生成されたGIFは outputs/comparisons/ フォルダにあります。以下は checkerboard データセットでの2つの手法の比較例です:
参考文献
- [1] Lipman, Yaron, et al. "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] Liu, Xingchao, et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching