Flow Matching di PyTorch
Repositori ini berisi implementasi PyTorch sederhana dari makalah Flow Matching for Generative Modeling.
Contoh Flow Matching 2D
Gif di bawah ini mendemonstrasikan pemetaan distribusi Gaussian tunggal ke distribusi papan catur, dengan visualisasi medan vektor.
Dan, berikut adalah contoh lain dari dataset moons.
Memulai
Clone repositori dan atur lingkungan python.
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingPastikan Anda telah menginstal Python 3.12+.
Instal uv:
curl -LsSf https://astral.sh/uv/install.sh | shKemudian, atur lingkungan:
uv syncConditional Flow Matching [Lipman+ 2023]
Ini adalah implementasi asli dari paper CFM [1]. Beberapa komponen kode diadaptasi dari [2] dan [3].
Dataset Mainan 2D
Anda dapat melatih model CFM pada dataset sintetis 2D seperti checkerboard dan moons. Tentukan nama dataset menggunakan opsi --dataset. Parameter pelatihan sudah ditetapkan di skrip, dan visualisasi hasil pelatihan disimpan di direktori outputs/. Checkpoint model tidak disertakan karena dapat direproduksi dengan mudah menggunakan pengaturan default.
uv run scripts/train_flow_matching_2d.py --dataset checkerboardMedan vektor dan sampel yang dihasilkan, seperti yang ditampilkan dalam bentuk GIF di bagian atas README ini, sekarang dapat ditemukan di direktori outputs/cfm/.
Dataset Gambar
Anda juga dapat melatih model CFM bersyarat kelas pada dataset klasifikasi gambar yang populer. Baik sampel yang dihasilkan maupun checkpoint model akan disimpan di direktori outputs/cfm. Untuk daftar parameter pelatihan secara detail, jalankan uv run scripts/train_flow_matching_on_image.py --help.
Untuk melatih CFM bersyarat kelas pada dataset MNIST, jalankan:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnistSetelah pelatihan, Anda sekarang dapat menghasilkan sampel dengan:
uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist
Sekarang, Anda seharusnya dapat melihat sampel yang dihasilkan di direktori outputs/cfm/mnist/.
Rectified Flow [Liu+ 2023]
Ini adalah implementasi model Reflow (lebih spesifiknya 2-Rectified Flow) dari makalah Rectified Flow [2].
Data Sintetis 2D
Kami telah mengimplementasikan Reflow pada dataset sintetis 2d, sama seperti CFM. Untuk melatih reflow, Anda harus menentukan checkpoint CFM yang sudah dilatih sebelumnya karena reflow adalah model distilasi.
Sebagai contoh, untuk melatih pada dataset checkerboard dengan checkpoint CFM yang telah dilatih sebelumnya:
uv run scripts/train_reflow_2d.py --dataset checkerboardHasil pelatihan, termasuk visualisasi medan vektor dan sampel yang dihasilkan, disimpan di dalam folder outputs/reflow/.
Perbandingan proses sampling antara CFM dan Reflow
Untuk membandingkan CFM dan Reflow pada dataset 2d, jalankan:
uv run scripts/plot_comparison_2d.py --dataset checkerboardGIF yang dihasilkan dapat ditemukan di bawah folder outputs/comparisons/. Berikut ini adalah contoh perbandingan dua metode pada dataset checkerboard:
Referensi
- [1] Lipman, Yaron, dkk. "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] Liu, Xingchao, dkk. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching