Flow Matching en PyTorch
Este repositorio contiene una implementación sencilla en PyTorch del artículo Flow Matching for Generative Modeling.
Ejemplo de Flow Matching en 2D
El gif a continuación muestra el mapeo de una sola distribución Gaussiana a una distribución de tablero de ajedrez, con el campo vectorial visualizado.
Y aquí tienes otro ejemplo con el conjunto de datos moons.
Primeros Pasos
Clona el repositorio y configura el entorno de Python.
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingAsegúrate de tener instalado Python 3.12+.
Instala uv:
curl -LsSf https://astral.sh/uv/install.sh | shLuego, configura el entorno:
uv syncEmparejamiento de Flujo Condicional [Lipman+ 2023]
Esta es la implementación original del artículo CFM [1]. Algunos componentes del código están adaptados de [2] y [3].
Conjuntos de Datos Sintéticos 2D
Puede entrenar los modelos CFM en conjuntos de datos sintéticos 2D como checkerboard y moons. Especifique el nombre del conjunto de datos usando la opción --dataset. Los parámetros de entrenamiento están predefinidos en el script, y las visualizaciones de los resultados de entrenamiento se almacenan en el directorio outputs/. No se incluyen puntos de control del modelo ya que son fácilmente reproducibles con la configuración predeterminada.
uv run scripts/train_flow_matching_2d.py --dataset checkerboardLos campos vectoriales y las muestras generadas, como las que se muestran como GIFs en la parte superior de este README, ahora se pueden encontrar en el directorio outputs/cfm/.
Conjuntos de datos de imágenes
También puedes entrenar modelos CFM condicionales por clase en conjuntos de datos de clasificación de imágenes populares. Tanto las muestras generadas como los puntos de control del modelo se almacenarán en el directorio outputs/cfm. Para obtener una lista detallada de los parámetros de entrenamiento, ejecuta uv run scripts/train_flow_matching_on_image.py --help.
Para entrenar un CFM condicional por clase en el conjunto de datos MNIST, ejecuta:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnistDespués del entrenamiento, ahora puedes generar muestras con:
uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist
Ahora, deberías poder ver las muestras generadas en el directorio outputs/cfm/mnist/.
Flujo Rectificado [Liu+ 2023]
Esta es una implementación del modelo Reflow (Flujo Rectificado 2 para ser específicos) del artículo sobre Flujo Rectificado [2].
Datos Sintéticos 2D
Hemos implementado Reflow en conjuntos de datos sintéticos 2D, igual que CFM. Para entrenar el reflow, debes especificar puntos de control CFM preentrenados ya que reflow es un modelo de destilación.
Por ejemplo, para entrenar en el conjunto de datos checkerboard con un punto de control CFM preentrenado:
uv run scripts/train_reflow_2d.py --dataset checkerboardLos resultados del entrenamiento, incluyendo visualizaciones del campo vectorial y muestras generadas, se guardan en la carpeta outputs/reflow/.
Comparación del proceso de muestreo entre CFM y Reflow
Para comparar CFM y Reflow en conjuntos de datos 2d, ejecute:
uv run scripts/plot_comparison_2d.py --dataset checkerboardLos GIF resultantes se pueden encontrar en la carpeta outputs/comparisons/. A continuación, un ejemplo de comparación de los dos métodos en el conjunto de datos checkerboard:
Referencias
- [1] Lipman, Yaron, et al. "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] Liu, Xingchao, et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching