Web Analytics

flow-matching

⭐ 89 stars Indonesian by keishihara

🌐 Bahasa

Flow Matching di PyTorch

Repositori ini berisi implementasi PyTorch sederhana dari makalah Flow Matching for Generative Modeling.

Contoh Flow Matching 2D

Gif di bawah ini mendemonstrasikan pemetaan distribusi Gaussian tunggal ke distribusi papan catur, dengan visualisasi medan vektor.

Dan, berikut adalah contoh lain dari dataset moons.

Memulai

Clone repositori dan atur lingkungan python.

git clone https://github.com/keishihara/flow-matching.git
cd flow-matching

Pastikan Anda telah menginstal Python 3.12+. Instal uv:

curl -LsSf https://astral.sh/uv/install.sh | sh

Kemudian, atur lingkungan:

uv sync

Conditional Flow Matching [Lipman+ 2023]

Ini adalah implementasi asli dari paper CFM [1]. Beberapa komponen kode diadaptasi dari [2] dan [3].

Dataset Mainan 2D

Anda dapat melatih model CFM pada dataset sintetis 2D seperti checkerboard dan moons. Tentukan nama dataset menggunakan opsi --dataset. Parameter pelatihan sudah ditetapkan di skrip, dan visualisasi hasil pelatihan disimpan di direktori outputs/. Checkpoint model tidak disertakan karena dapat direproduksi dengan mudah menggunakan pengaturan default.

uv run scripts/train_flow_matching_2d.py --dataset checkerboard

Medan vektor dan sampel yang dihasilkan, seperti yang ditampilkan dalam bentuk GIF di bagian atas README ini, sekarang dapat ditemukan di direktori outputs/cfm/.

Dataset Gambar

Anda juga dapat melatih model CFM bersyarat kelas pada dataset klasifikasi gambar yang populer. Baik sampel yang dihasilkan maupun checkpoint model akan disimpan di direktori outputs/cfm. Untuk daftar parameter pelatihan secara detail, jalankan uv run scripts/train_flow_matching_on_image.py --help.

Untuk melatih CFM bersyarat kelas pada dataset MNIST, jalankan:

uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist

Setelah pelatihan, Anda sekarang dapat menghasilkan sampel dengan:

uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist
Sekarang, Anda seharusnya dapat melihat sampel yang dihasilkan di direktori outputs/cfm/mnist/.

Rectified Flow [Liu+ 2023]

Ini adalah implementasi model Reflow (lebih spesifiknya 2-Rectified Flow) dari makalah Rectified Flow [2].

Data Sintetis 2D

Kami telah mengimplementasikan Reflow pada dataset sintetis 2d, sama seperti CFM. Untuk melatih reflow, Anda harus menentukan checkpoint CFM yang sudah dilatih sebelumnya karena reflow adalah model distilasi.

Sebagai contoh, untuk melatih pada dataset checkerboard dengan checkpoint CFM yang telah dilatih sebelumnya:

uv run scripts/train_reflow_2d.py --dataset checkerboard

Hasil pelatihan, termasuk visualisasi medan vektor dan sampel yang dihasilkan, disimpan di dalam folder outputs/reflow/.

Perbandingan proses sampling antara CFM dan Reflow

Untuk membandingkan CFM dan Reflow pada dataset 2d, jalankan:

uv run scripts/plot_comparison_2d.py --dataset checkerboard

GIF yang dihasilkan dapat ditemukan di bawah folder outputs/comparisons/. Berikut ini adalah contoh perbandingan dua metode pada dataset checkerboard:

Referensi

--- Tranlated By Open Ai Tx | Last indexed: 2026-01-19 ---