PyTorch ile Akış Eşleştirme
Bu depo, Flow Matching for Generative Modeling makalesinin basit bir PyTorch uygulamasını içermektedir.
2D Akış Eşleştirme Örneği
Aşağıdaki gif, tek bir Gauss dağılımının dama tahtası dağılımına eşlenmesini ve vektör alanının görselleştirilmesini göstermektedir.
Ve işte, moons veri setinin başka bir örneği.
Başlarken
Depoyu klonlayın ve python ortamını kurun.
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingPython 3.12+ kurulu olduğundan emin olun.
uv'yi yükleyin:
curl -LsSf https://astral.sh/uv/install.sh | shDaha sonra, ortamı kurun:
uv syncKoşullu Akış Eşleştirme [Lipman+ 2023]
Bu orijinal CFM makalesinin uygulamasıdır [1]. Kodun bazı bileşenleri [2] ve [3] kaynaklarından uyarlanmıştır.
2D Oyuncak Veri Setleri
CFM modellerini checkerboard ve moons gibi 2D sentetik veri setleri üzerinde eğitebilirsiniz. Veri seti adını --dataset seçeneğiyle belirtebilirsiniz. Eğitim parametreleri betikte önceden tanımlanmıştır ve eğitim sonuçlarının görselleştirmeleri outputs/ dizininde saklanır. Model kontrol noktaları dahil edilmemiştir çünkü varsayılan ayarlarla kolayca yeniden üretilebilirler.
uv run scripts/train_flow_matching_2d.py --dataset checkerboardREADME'nin en üstünde GIF olarak gösterilen vektör alanları ve üretilen örnekler artık outputs/cfm/ dizininde bulunabilir.
Görüntü Veri Setleri
Popüler görüntü sınıflandırma veri setlerinde de sınıf koşullu CFM modelleri eğitebilirsiniz. Hem üretilen örnekler hem de model kontrol noktaları outputs/cfm dizininde saklanacaktır. Eğitim parametrelerinin ayrıntılı listesi için uv run scripts/train_flow_matching_on_image.py --help komutunu çalıştırın.
MNIST veri setinde sınıf koşullu CFM eğitmek için şunu çalıştırın:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnistEğitimden sonra, artık örnekler üretebilirsiniz:
uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnistArtık oluşturulan örnekleri outputs/cfm/mnist/ dizininde görebilmelisiniz.
Düzeltmeli Akış [Liu+ 2023]
Bu, Düzeltmeli Akış makalesinden (özellikle 2-Düzeltmeli Akış) Reflow modelinin bir uygulamasıdır [2].
2D Sentetik Veri
Reflow'u, tıpkı CFM'de olduğu gibi, 2d sentetik veri kümeleri üzerinde uyguladık. Reflow'u eğitmek için, reflow bir distilasyon modeli olduğu için önceden eğitilmiş CFM kontrol noktalarını belirtmeniz gerekir.
Örneğin, önceden eğitilmiş bir CFM kontrol noktası ile checkerboard veri kümesinde eğitim yapmak için:
uv run scripts/train_reflow_2d.py --dataset checkerboardEğitim sonuçları, vektör alanı görselleştirmeleri ve üretilen örnekler dahil olmak üzere outputs/reflow/ klasörü altında kaydedilir.
CFM ve Reflow arasında örnekleme sürecinin karşılaştırılması
CFM ve Reflow'u 2 boyutlu veri kümelerinde karşılaştırmak için şunu çalıştırın:
uv run scripts/plot_comparison_2d.py --dataset checkerboard
Ortaya çıkan GIF'ler outputs/comparisons/ klasöründe bulunabilir. Aşağıda, checkerboard veri setinde iki yöntemin örnek bir karşılaştırması gösterilmektedir:
Kaynaklar
- [1] Lipman, Yaron, ve diğerleri. "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] Liu, Xingchao, ve diğerleri. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching