Flow Matching in PyTorch
Questo repository contiene una semplice implementazione in PyTorch dell’articolo Flow Matching for Generative Modeling.
Esempio di Flow Matching 2D
La gif qui sotto mostra la mappatura di una singola distribuzione Gaussiana su una distribuzione a scacchiera, con il campo vettoriale visualizzato.
Ecco un altro esempio con il dataset delle lune.
Introduzione
Clona il repository e configura l'ambiente python.
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingAssicurati di avere installato Python 3.12 o superiore.
Installa uv:
curl -LsSf https://astral.sh/uv/install.sh | shQuindi, configura l'ambiente:
uv syncConditional Flow Matching [Lipman+ 2023]
Questa è l'implementazione originale dell'articolo CFM [1]. Alcuni componenti del codice sono adattati da [2] e [3].
Dataset 2D Toy
È possibile addestrare i modelli CFM su dataset sintetici 2D come checkerboard e moons. Specificare il nome del dataset utilizzando l'opzione --dataset. I parametri di addestramento sono predefiniti nello script, e le visualizzazioni dei risultati dell'addestramento sono memorizzate nella directory outputs/. I checkpoint del modello non sono inclusi in quanto facilmente riproducibili con le impostazioni predefinite.
uv run scripts/train_flow_matching_2d.py --dataset checkerboardI campi vettoriali e i campioni generati, come quelli mostrati come GIF all'inizio di questo README, possono ora essere trovati nella directory outputs/cfm/.
Dataset di Immagini
Puoi anche addestrare modelli CFM condizionati sulla classe su popolari dataset di classificazione di immagini. Sia i campioni generati che i checkpoint dei modelli saranno salvati nella directory outputs/cfm. Per una lista dettagliata dei parametri di addestramento, esegui uv run scripts/train_flow_matching_on_image.py --help.
Per addestrare un modello CFM condizionato sulla classe sul dataset MNIST, esegui:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnistDopo l'addestramento, ora puoi generare campioni con:
uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnistOra dovresti essere in grado di vedere i campioni generati nella directory outputs/cfm/mnist/.
Rectified Flow [Liu+ 2023]
Questa è un'implementazione del modello Reflow (2-Rectified Flow per la precisione) dal paper Rectified Flow [2].
Dati Sintetici 2D
Abbiamo implementato Reflow su dataset sintetici 2d, così come il CFM. Per addestrare reflow, devi specificare checkpoint CFM pre-addestrati poiché reflow è un modello di distillazione.
Ad esempio, per addestrare sul dataset checkerboard con un checkpoint CFM pre-addestrato:
uv run scripts/train_reflow_2d.py --dataset checkerboardI risultati dell'addestramento, incluse le visualizzazioni del campo vettoriale e i campioni generati, sono salvati nella cartella outputs/reflow/.
Confronto del processo di campionamento tra CFM e Reflow
Per confrontare CFM e Reflow su dataset 2d, eseguire:
uv run scripts/plot_comparison_2d.py --dataset checkerboardLe GIF risultanti possono essere trovate nella cartella outputs/comparisons/. Di seguito è riportato un esempio di confronto tra i due metodi nel dataset checkerboard:
Riferimenti
- [1] Lipman, Yaron, et al. "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] Liu, Xingchao, et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching