Web Analytics

flow-matching

⭐ 89 stars Dutch by keishihara

🌐 Taal

Flow Matching in PyTorch

Deze repository bevat een eenvoudige PyTorch-implementatie van het artikel Flow Matching for Generative Modeling.

2D Flow Matching Voorbeeld

De onderstaande gif demonstreert het omzetten van een enkele Gaussische distributie naar een schaakbord-distributie, waarbij het vectorveld wordt gevisualiseerd.

En hier is een ander voorbeeld van de moons-dataset.

Aan de slag

Kloon de repository en stel de python-omgeving in.

git clone https://github.com/keishihara/flow-matching.git
cd flow-matching

Zorg ervoor dat je Python 3.12+ hebt geïnstalleerd. Installeer uv:

curl -LsSf https://astral.sh/uv/install.sh | sh
Stel vervolgens de omgeving in:

uv sync

Conditionele Flow Matching [Lipman+ 2023]

Dit is de originele CFM paperimplementatie [1]. Sommige onderdelen van de code zijn aangepast uit [2] en [3].

2D Speelgoeddatasets

U kunt de CFM-modellen trainen op 2D synthetische datasets zoals checkerboard en moons. Geef de naam van de dataset op met de optie --dataset. Trainingsparameters zijn vooraf ingesteld in het script en visualisaties van de trainingsresultaten worden opgeslagen in de map outputs/. Modelcheckpoints zijn niet inbegrepen omdat ze eenvoudig te reproduceren zijn met de standaardinstellingen.

uv run scripts/train_flow_matching_2d.py --dataset checkerboard

De vectorvelden en gegenereerde samples, zoals weergegeven in de GIFs bovenaan deze README, zijn nu te vinden in de map outputs/cfm/.

Afbeeldingendatasets

Je kunt ook klasse-geconditioneerde CFM-modellen trainen op populaire beeldclassificatiedatasets. Zowel de gegenereerde samples als de modelcheckpoints worden opgeslagen in de map outputs/cfm. Voor een gedetailleerde lijst van trainingsparameters, voer uv run scripts/train_flow_matching_on_image.py --help uit.

Om een klasse-geconditioneerde CFM op de MNIST-dataset te trainen, voer uit:

uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist

Na de training kun je nu voorbeelden genereren met:

uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist

Nu zou je de gegenereerde samples moeten kunnen zien in de map outputs/cfm/mnist/.

Gecorrigeerde Flow [Liu+ 2023]

Dit is een implementatie van het Reflow-model (specifiek 2-Gecorrigeerde Flow) uit het Rectified Flow-paper [2].

2D Synthetische Data

We hebben Reflow geïmplementeerd op 2d synthetische datasets, hetzelfde als de CFM. Om de reflow te trainen, moet je vooraf getrainde CFM-checkpoints opgeven omdat reflow een distillatiemodel is.

Bijvoorbeeld, om te trainen op de checkerboard dataset met een vooraf getraind CFM-checkpoint:

uv run scripts/train_reflow_2d.py --dataset checkerboard

De trainingsresultaten, inclusief visualisaties van vectorvelden en gegenereerde samples, worden opgeslagen in de map outputs/reflow/.

Vergelijking van het samplingproces tussen CFM en Reflow

Om CFM en Reflow op 2D-datasets te vergelijken, voer uit:

uv run scripts/plot_comparison_2d.py --dataset checkerboard
De resulterende GIF's zijn te vinden in de map outputs/comparisons/. Hieronder staat een voorbeeldvergelijking van de twee methoden in de checkerboard dataset:

Referenties

--- Tranlated By Open Ai Tx | Last indexed: 2026-01-19 ---