Web Analytics

flow-matching

⭐ 89 stars German by keishihara

🌐 Sprache

Flow Matching in PyTorch

Dieses Repository enthält eine einfache PyTorch-Implementierung des Papers Flow Matching for Generative Modeling.

2D Flow Matching Beispiel

Das folgende GIF demonstriert die Abbildung einer einzelnen Gaußschen Verteilung auf eine Schachbrettverteilung, wobei das Vektorfeld visualisiert wird.

Und hier ist ein weiteres Beispiel mit dem Moons-Datensatz.

Erste Schritte

Klonen Sie das Repository und richten Sie die Python-Umgebung ein.

git clone https://github.com/keishihara/flow-matching.git
cd flow-matching

Stellen Sie sicher, dass Python 3.12+ installiert ist. Installieren Sie uv:

curl -LsSf https://astral.sh/uv/install.sh | sh

Richten Sie anschließend die Umgebung ein:

uv sync

Conditional Flow Matching [Lipman+ 2023]

Dies ist die Originalimplementierung des CFM-Papers [1]. Einige Komponenten des Codes wurden aus [2] und [3] übernommen.

2D Toy Datasets

Sie können die CFM-Modelle auf 2D-synthetischen Datensätzen wie checkerboard und moons trainieren. Geben Sie den Namen des Datensatzes mit der Option --dataset an. Trainingsparameter sind im Skript vordefiniert, und Visualisierungen der Trainingsergebnisse werden im Verzeichnis outputs/ gespeichert. Modell-Checkpoints sind nicht enthalten, da sie mit den Standardeinstellungen leicht reproduzierbar sind.

uv run scripts/train_flow_matching_2d.py --dataset checkerboard

Die Vektorfelder und generierten Beispiele, wie sie als GIFs oben in diesem README angezeigt werden, sind jetzt im Verzeichnis outputs/cfm/ zu finden.

Bilddatensätze

Sie können auch klassenkonditionale CFM-Modelle auf bekannten Bildklassifikationsdatensätzen trainieren. Sowohl die generierten Beispiele als auch die Modell-Checkpoints werden im Verzeichnis outputs/cfm gespeichert. Eine detaillierte Liste der Trainingsparameter erhalten Sie mit uv run scripts/train_flow_matching_on_image.py --help.

Um ein klassenkonditionales CFM auf dem MNIST-Datensatz zu trainieren, führen Sie aus:

uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist

Nach dem Training können Sie jetzt Proben erzeugen mit:

uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist

Nun sollten Sie die generierten Beispiele im Verzeichnis outputs/cfm/mnist/ sehen können.

Rectified Flow [Liu+ 2023]

Dies ist eine Implementierung des Reflow-Modells (genauer gesagt 2-Rectified Flow) aus dem Rectified Flow Paper [2].

2D Synthetische Daten

Wir haben den Reflow auf 2D-synthetischen Datensätzen implementiert, genau wie beim CFM. Um den Reflow zu trainieren, müssen Sie vortrainierte CFM-Checkpoints angeben, da Reflow ein Distillationsmodell ist.

Zum Beispiel, um auf dem Datensatz checkerboard mit einem vortrainierten CFM-Checkpoint zu trainieren:

uv run scripts/train_reflow_2d.py --dataset checkerboard

Die Trainingsergebnisse, einschließlich Vektorfeld-Visualisierungen und generierter Muster, werden im Ordner outputs/reflow/ gespeichert.

Vergleich des Sampling-Prozesses zwischen CFM und Reflow

Um CFM und Reflow auf 2D-Datensätzen zu vergleichen, führen Sie Folgendes aus:

uv run scripts/plot_comparison_2d.py --dataset checkerboard

Die resultierenden GIFs sind im Ordner outputs/comparisons/ zu finden. Unten ist ein Beispielvergleich der beiden Methoden im checkerboard-Datensatz:

Referenzen

--- Tranlated By Open Ai Tx | Last indexed: 2026-01-19 ---