🌐 Sprache
Flow Matching in PyTorch
Dieses Repository enthält eine einfache PyTorch-Implementierung des Papers Flow Matching for Generative Modeling.
2D Flow Matching Beispiel
Das folgende GIF demonstriert die Abbildung einer einzelnen Gaußschen Verteilung auf eine Schachbrettverteilung, wobei das Vektorfeld visualisiert wird.
Und hier ist ein weiteres Beispiel mit dem Moons-Datensatz.
Erste Schritte
Klonen Sie das Repository und richten Sie die Python-Umgebung ein.
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingStellen Sie sicher, dass Python 3.12+ installiert ist.
Installieren Sie uv:
curl -LsSf https://astral.sh/uv/install.sh | shRichten Sie anschließend die Umgebung ein:
uv syncConditional Flow Matching [Lipman+ 2023]
Dies ist die Originalimplementierung des CFM-Papers [1]. Einige Komponenten des Codes wurden aus [2] und [3] übernommen.
2D Toy Datasets
Sie können die CFM-Modelle auf 2D-synthetischen Datensätzen wie checkerboard und moons trainieren. Geben Sie den Namen des Datensatzes mit der Option --dataset an. Trainingsparameter sind im Skript vordefiniert, und Visualisierungen der Trainingsergebnisse werden im Verzeichnis outputs/ gespeichert. Modell-Checkpoints sind nicht enthalten, da sie mit den Standardeinstellungen leicht reproduzierbar sind.
uv run scripts/train_flow_matching_2d.py --dataset checkerboardDie Vektorfelder und generierten Beispiele, wie sie als GIFs oben in diesem README angezeigt werden, sind jetzt im Verzeichnis outputs/cfm/ zu finden.
Bilddatensätze
Sie können auch klassenkonditionale CFM-Modelle auf bekannten Bildklassifikationsdatensätzen trainieren. Sowohl die generierten Beispiele als auch die Modell-Checkpoints werden im Verzeichnis outputs/cfm gespeichert. Eine detaillierte Liste der Trainingsparameter erhalten Sie mit uv run scripts/train_flow_matching_on_image.py --help.
Um ein klassenkonditionales CFM auf dem MNIST-Datensatz zu trainieren, führen Sie aus:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnistNach dem Training können Sie jetzt Proben erzeugen mit:
uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnistNun sollten Sie die generierten Beispiele im Verzeichnis outputs/cfm/mnist/ sehen können.
Rectified Flow [Liu+ 2023]
Dies ist eine Implementierung des Reflow-Modells (genauer gesagt 2-Rectified Flow) aus dem Rectified Flow Paper [2].
2D Synthetische Daten
Wir haben den Reflow auf 2D-synthetischen Datensätzen implementiert, genau wie beim CFM. Um den Reflow zu trainieren, müssen Sie vortrainierte CFM-Checkpoints angeben, da Reflow ein Distillationsmodell ist.
Zum Beispiel, um auf dem Datensatz checkerboard mit einem vortrainierten CFM-Checkpoint zu trainieren:
uv run scripts/train_reflow_2d.py --dataset checkerboardDie Trainingsergebnisse, einschließlich Vektorfeld-Visualisierungen und generierter Muster, werden im Ordner outputs/reflow/ gespeichert.
Vergleich des Sampling-Prozesses zwischen CFM und Reflow
Um CFM und Reflow auf 2D-Datensätzen zu vergleichen, führen Sie Folgendes aus:
uv run scripts/plot_comparison_2d.py --dataset checkerboardDie resultierenden GIFs sind im Ordner outputs/comparisons/ zu finden. Unten ist ein Beispielvergleich der beiden Methoden im checkerboard-Datensatz:
Referenzen
- [1] Lipman, Yaron, et al. „Flow Matching for Generative Modeling.“ arXiv:2210.02747
- [2] Liu, Xingchao, et al. „Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow.“ arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching