Web Analytics

flow-matching

⭐ 89 stars Spanish by keishihara

🌐 Idioma

Flow Matching en PyTorch

Este repositorio contiene una implementación sencilla en PyTorch del artículo Flow Matching for Generative Modeling.

Ejemplo de Flow Matching en 2D

El gif a continuación muestra el mapeo de una sola distribución Gaussiana a una distribución de tablero de ajedrez, con el campo vectorial visualizado.

Y aquí tienes otro ejemplo con el conjunto de datos moons.

Primeros Pasos

Clona el repositorio y configura el entorno de Python.

git clone https://github.com/keishihara/flow-matching.git
cd flow-matching

Asegúrate de tener instalado Python 3.12+. Instala uv:

curl -LsSf https://astral.sh/uv/install.sh | sh

Luego, configura el entorno:

uv sync

Emparejamiento de Flujo Condicional [Lipman+ 2023]

Esta es la implementación original del artículo CFM [1]. Algunos componentes del código están adaptados de [2] y [3].

Conjuntos de Datos Sintéticos 2D

Puede entrenar los modelos CFM en conjuntos de datos sintéticos 2D como checkerboard y moons. Especifique el nombre del conjunto de datos usando la opción --dataset. Los parámetros de entrenamiento están predefinidos en el script, y las visualizaciones de los resultados de entrenamiento se almacenan en el directorio outputs/. No se incluyen puntos de control del modelo ya que son fácilmente reproducibles con la configuración predeterminada.

uv run scripts/train_flow_matching_2d.py --dataset checkerboard

Los campos vectoriales y las muestras generadas, como las que se muestran como GIFs en la parte superior de este README, ahora se pueden encontrar en el directorio outputs/cfm/.

Conjuntos de datos de imágenes

También puedes entrenar modelos CFM condicionales por clase en conjuntos de datos de clasificación de imágenes populares. Tanto las muestras generadas como los puntos de control del modelo se almacenarán en el directorio outputs/cfm. Para obtener una lista detallada de los parámetros de entrenamiento, ejecuta uv run scripts/train_flow_matching_on_image.py --help.

Para entrenar un CFM condicional por clase en el conjunto de datos MNIST, ejecuta:

uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist

Después del entrenamiento, ahora puedes generar muestras con:

uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist
Ahora, deberías poder ver las muestras generadas en el directorio outputs/cfm/mnist/.

Flujo Rectificado [Liu+ 2023]

Esta es una implementación del modelo Reflow (Flujo Rectificado 2 para ser específicos) del artículo sobre Flujo Rectificado [2].

Datos Sintéticos 2D

Hemos implementado Reflow en conjuntos de datos sintéticos 2D, igual que CFM. Para entrenar el reflow, debes especificar puntos de control CFM preentrenados ya que reflow es un modelo de destilación.

Por ejemplo, para entrenar en el conjunto de datos checkerboard con un punto de control CFM preentrenado:

uv run scripts/train_reflow_2d.py --dataset checkerboard

Los resultados del entrenamiento, incluyendo visualizaciones del campo vectorial y muestras generadas, se guardan en la carpeta outputs/reflow/.

Comparación del proceso de muestreo entre CFM y Reflow

Para comparar CFM y Reflow en conjuntos de datos 2d, ejecute:

uv run scripts/plot_comparison_2d.py --dataset checkerboard

Los GIF resultantes se pueden encontrar en la carpeta outputs/comparisons/. A continuación, un ejemplo de comparación de los dos métodos en el conjunto de datos checkerboard:

Referencias

--- Tranlated By Open Ai Tx | Last indexed: 2026-01-19 ---