Web Analytics

flow-matching

⭐ 89 stars Polish by keishihara

🌐 Język

Dopasowywanie przepływu w PyTorch

To repozytorium zawiera prostą implementację w PyTorch artykułu Flow Matching for Generative Modeling.

Przykład dopasowywania przepływu w 2D

Poniższy gif demonstruje mapowanie pojedynczego rozkładu normalnego na rozkład szachownicowy, z wizualizacją pola wektorowego.

A tutaj znajduje się inny przykład z użyciem zbioru danych moons.

Pierwsze kroki

Sklonuj repozytorium i skonfiguruj środowisko Pythona.

git clone https://github.com/keishihara/flow-matching.git
cd flow-matching

Upewnij się, że masz zainstalowanego Pythona w wersji 3.12+. Zainstaluj uv:

curl -LsSf https://astral.sh/uv/install.sh | sh
Następnie skonfiguruj środowisko:

uv sync

Conditional Flow Matching [Lipman+ 2023]

To jest oryginalna implementacja artykułu CFM [1]. Niektóre komponenty kodu zostały zaadaptowane z [2] oraz [3].

Dwuwymiarowe zabawkowe zbiory danych

Modele CFM można trenować na dwuwymiarowych syntetycznych zbiorach danych takich jak checkerboard oraz moons. Nazwę zbioru danych należy określić za pomocą opcji --dataset. Parametry treningu są zdefiniowane w skrypcie, a wizualizacje wyników treningu są zapisywane w katalogu outputs/. Punkty kontrolne modeli nie są dołączone, ponieważ można je łatwo odtworzyć przy domyślnych ustawieniach.

uv run scripts/train_flow_matching_2d.py --dataset checkerboard

Pola wektorowe i wygenerowane próbki, takie jak te wyświetlane jako GIF-y na górze tego pliku README, można teraz znaleźć w katalogu outputs/cfm/.

Zbiory danych obrazów

Możesz także trenować modele CFM warunkowane klasami na popularnych zbiorach danych do klasyfikacji obrazów. Zarówno wygenerowane próbki, jak i punkty kontrolne modeli będą przechowywane w katalogu outputs/cfm. Aby uzyskać szczegółową listę parametrów treningowych, uruchom uv run scripts/train_flow_matching_on_image.py --help.

Aby wytrenować model CFM warunkowany klasami na zbiorze danych MNIST, uruchom:

uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist

Po zakończeniu treningu możesz teraz generować próbki za pomocą:

uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist

Teraz powinieneś zobaczyć wygenerowane próbki w katalogu outputs/cfm/mnist/.

Rectified Flow [Liu+ 2023]

Jest to implementacja modelu Reflow (dokładniej 2-Rectified Flow) z artykułu Rectified Flow [2].

Dane syntetyczne 2D

Zaimplementowaliśmy Reflow na syntetycznych zbiorach danych 2D, tak samo jak CFM. Aby wytrenować reflow, musisz podać wstępnie wytrenowane punkty kontrolne CFM, ponieważ reflow jest modelem destylacji.

Na przykład, aby trenować na zbiorze danych checkerboard z wstępnie wytrenowanym punktem kontrolnym CFM:

uv run scripts/train_reflow_2d.py --dataset checkerboard

Wyniki treningu, w tym wizualizacje pól wektorowych oraz wygenerowane próbki, są zapisywane w folderze outputs/reflow/.

Porównanie procesu próbkowania między CFM a Reflow

Aby porównać CFM i Reflow na zbiorach danych 2D, uruchom:

uv run scripts/plot_comparison_2d.py --dataset checkerboard
Wynikowe pliki GIF można znaleźć w folderze outputs/comparisons/. Poniżej znajduje się przykładowe porównanie dwóch metod w zbiorze danych checkerboard:

Odniesienia

--- Tranlated By Open Ai Tx | Last indexed: 2026-01-19 ---