PyTorchによるフローマッチング
このリポジトリには、論文 Flow Matching for Generative Modeling のシンプルなPyTorch実装が含まれています。
2次元フローマッチング例
下のgifは、単一のガウス分布からチェッカーボード分布へのマッピングを示しており、ベクトル場が可視化されています。
そして、こちらはmoonsデータセットの別の例です。
はじめに
リポジトリをクローンし、Python環境をセットアップしてください。
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingPython 3.12以上がインストールされていることを確認してください。
uvをインストールしてください:
curl -LsSf https://astral.sh/uv/install.sh | sh次に、環境を設定します:
uv sync条件付きフローマッチング [Lipman+ 2023]
これは元のCFM論文実装[1]です。コードの一部は[2]および[3]から適応されています。
2D トイデータセット
checkerboardやmoonsなどの2D合成データセットでCFMモデルを訓練できます。--datasetオプションでデータセット名を指定してください。訓練パラメータはスクリプト内で事前定義されており、訓練結果の可視化はoutputs/ディレクトリに保存されます。モデルのチェックポイントは、デフォルト設定で簡単に再現可能なため含まれていません。
uv run scripts/train_flow_matching_2d.py --dataset checkerboardこのREADMEの冒頭にGIFとして表示されているようなベクトル場と生成サンプルは、現在 outputs/cfm/ ディレクトリにあります。
画像データセット
人気の画像分類データセットでクラス条件付きCFMモデルの学習も可能です。生成されたサンプルとモデルのチェックポイントはどちらも outputs/cfm ディレクトリに保存されます。学習パラメータの詳細な一覧は、uv run scripts/train_flow_matching_on_image.py --help を実行してください。
MNISTデータセットでクラス条件付きCFMを学習するには、次のコマンドを実行します:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnistトレーニング後、以下でサンプルを生成できます:
uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnistNow, you should be able to see the generated samples in the outputs/cfm/mnist/ directory.
Rectified Flow [Liu+ 2023]
これはRectified Flow論文[2]のReflowモデル(具体的には2-Rectified Flow)の実装です。
2D Synthetic Data
CFMと同様に、2次元合成データセットでReflowを実装しました。Reflowは蒸留モデルであるため、トレーニングには事前学習済みCFMチェックポイントを指定する必要があります。
例えば、事前学習済みCFMチェックポイントを使ってcheckerboardデータセットでトレーニングする場合:
uv run scripts/train_reflow_2d.py --dataset checkerboardトレーニング結果は、ベクトル場の可視化や生成されたサンプルを含めて、outputs/reflow/ フォルダに保存されます。
CFMとReflowのサンプリングプロセスの比較
2次元データセットでCFMとReflowを比較するには、以下を実行してください:
uv run scripts/plot_comparison_2d.py --dataset checkerboard
生成されたGIFは outputs/comparisons/ フォルダにあります。以下は checkerboard データセットでの2つの手法の比較例です:
参考文献
- [1] Lipman, Yaron, et al. "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] Liu, Xingchao, et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching