Web Analytics

flow-matching

⭐ 89 stars Italian by keishihara

🌐 Lingua

Flow Matching in PyTorch

Questo repository contiene una semplice implementazione in PyTorch dell’articolo Flow Matching for Generative Modeling.

Esempio di Flow Matching 2D

La gif qui sotto mostra la mappatura di una singola distribuzione Gaussiana su una distribuzione a scacchiera, con il campo vettoriale visualizzato.

Ecco un altro esempio con il dataset delle lune.

Introduzione

Clona il repository e configura l'ambiente python.

git clone https://github.com/keishihara/flow-matching.git
cd flow-matching

Assicurati di avere installato Python 3.12 o superiore. Installa uv:

curl -LsSf https://astral.sh/uv/install.sh | sh

Quindi, configura l'ambiente:

uv sync

Conditional Flow Matching [Lipman+ 2023]

Questa è l'implementazione originale dell'articolo CFM [1]. Alcuni componenti del codice sono adattati da [2] e [3].

Dataset 2D Toy

È possibile addestrare i modelli CFM su dataset sintetici 2D come checkerboard e moons. Specificare il nome del dataset utilizzando l'opzione --dataset. I parametri di addestramento sono predefiniti nello script, e le visualizzazioni dei risultati dell'addestramento sono memorizzate nella directory outputs/. I checkpoint del modello non sono inclusi in quanto facilmente riproducibili con le impostazioni predefinite.

uv run scripts/train_flow_matching_2d.py --dataset checkerboard

I campi vettoriali e i campioni generati, come quelli mostrati come GIF all'inizio di questo README, possono ora essere trovati nella directory outputs/cfm/.

Dataset di Immagini

Puoi anche addestrare modelli CFM condizionati sulla classe su popolari dataset di classificazione di immagini. Sia i campioni generati che i checkpoint dei modelli saranno salvati nella directory outputs/cfm. Per una lista dettagliata dei parametri di addestramento, esegui uv run scripts/train_flow_matching_on_image.py --help.

Per addestrare un modello CFM condizionato sulla classe sul dataset MNIST, esegui:

uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist

Dopo l'addestramento, ora puoi generare campioni con:

uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist

Ora dovresti essere in grado di vedere i campioni generati nella directory outputs/cfm/mnist/.

Rectified Flow [Liu+ 2023]

Questa è un'implementazione del modello Reflow (2-Rectified Flow per la precisione) dal paper Rectified Flow [2].

Dati Sintetici 2D

Abbiamo implementato Reflow su dataset sintetici 2d, così come il CFM. Per addestrare reflow, devi specificare checkpoint CFM pre-addestrati poiché reflow è un modello di distillazione.

Ad esempio, per addestrare sul dataset checkerboard con un checkpoint CFM pre-addestrato:

uv run scripts/train_reflow_2d.py --dataset checkerboard

I risultati dell'addestramento, incluse le visualizzazioni del campo vettoriale e i campioni generati, sono salvati nella cartella outputs/reflow/.

Confronto del processo di campionamento tra CFM e Reflow

Per confrontare CFM e Reflow su dataset 2d, eseguire:

uv run scripts/plot_comparison_2d.py --dataset checkerboard

Le GIF risultanti possono essere trovate nella cartella outputs/comparisons/. Di seguito è riportato un esempio di confronto tra i due metodi nel dataset checkerboard:

Riferimenti

--- Tranlated By Open Ai Tx | Last indexed: 2026-01-19 ---