Web Analytics

flow-matching

⭐ 89 stars French by keishihara

🌐 Langue

Flow Matching avec PyTorch

Ce dépôt contient une implémentation simple sous PyTorch de l’article Flow Matching for Generative Modeling.

Exemple de Flow Matching 2D

Le gif ci-dessous montre le passage d’une distribution gaussienne simple à une distribution en damier, avec visualisation du champ de vecteurs.

Et voici un autre exemple avec le jeu de données moons.

Pour commencer

Clonez le dépôt et configurez l'environnement python.

git clone https://github.com/keishihara/flow-matching.git
cd flow-matching

Assurez-vous d'avoir installé Python 3.12+. Installez uv :

curl -LsSf https://astral.sh/uv/install.sh | sh
Ensuite, configurez l'environnement :

uv sync

Appariement de flux conditionnel [Lipman+ 2023]

Voici l’implémentation originale de l’article CFM [1]. Certains composants du code sont adaptés de [2] et [3].

Jeux de données 2D factices

Vous pouvez entraîner les modèles CFM sur des jeux de données synthétiques 2D tels que checkerboard et moons. Spécifiez le nom du jeu de données via l’option --dataset. Les paramètres d’entraînement sont prédéfinis dans le script, et les visualisations des résultats d’entraînement sont stockées dans le répertoire outputs/. Les points de contrôle du modèle ne sont pas inclus car ils sont facilement reproductibles avec les paramètres par défaut.

uv run scripts/train_flow_matching_2d.py --dataset checkerboard

Les champs de vecteurs et les échantillons générés, comme ceux affichés sous forme de GIFs en haut de ce README, se trouvent désormais dans le répertoire outputs/cfm/.

Jeux de données d’images

Vous pouvez également entraîner des modèles CFM conditionnels sur la classe sur des jeux de données populaires de classification d’images. Les échantillons générés ainsi que les points de contrôle du modèle seront stockés dans le répertoire outputs/cfm. Pour obtenir une liste détaillée des paramètres d’entraînement, exécutez uv run scripts/train_flow_matching_on_image.py --help.

Pour entraîner un CFM conditionnel sur la classe sur le jeu de données MNIST, exécutez :

uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist
Après l'entraînement, vous pouvez désormais générer des échantillons avec :

uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist

Now, you should be able to see the generated samples in the outputs/cfm/mnist/ directory.

Flux Rectifié [Liu+ 2023]

Ceci est une implémentation du modèle Reflow (2-Flux Rectifié pour être précis) issu de l’article Rectified Flow [2].

Données Synthétiques 2D

Nous avons implémenté le Reflow sur des ensembles de données synthétiques 2D, de la même manière que le CFM. Pour entraîner le reflow, vous devez spécifier des points de contrôle CFM pré-entraînés, car le reflow est un modèle de distillation.

Par exemple, pour entraîner sur l’ensemble de données checkerboard avec un point de contrôle CFM pré-entraîné :

uv run scripts/train_reflow_2d.py --dataset checkerboard

Les résultats de l'entraînement, y compris les visualisations du champ vectoriel et les échantillons générés, sont enregistrés dans le dossier outputs/reflow/.

Comparaison du processus d'échantillonnage entre CFM et Reflow

Pour comparer CFM et Reflow sur des jeux de données 2D, exécutez :

uv run scripts/plot_comparison_2d.py --dataset checkerboard

Les GIFs résultants peuvent être trouvés dans le dossier outputs/comparisons/. Voici un exemple de comparaison des deux méthodes sur le jeu de données checkerboard :

Références

--- Tranlated By Open Ai Tx | Last indexed: 2026-01-19 ---