Web Analytics

flow-matching

⭐ 89 stars Portuguese by keishihara

🌐 Idioma

Flow Matching em PyTorch

Este repositório contém uma implementação simples em PyTorch do artigo Flow Matching for Generative Modeling.

Exemplo de Flow Matching 2D

O gif abaixo demonstra o mapeamento de uma única distribuição Gaussiana para uma distribuição em padrão tabuleiro, com o campo vetorial visualizado.

E aqui está outro exemplo com o conjunto de dados moons.

Primeiros Passos

Clone o repositório e configure o ambiente Python.

git clone https://github.com/keishihara/flow-matching.git
cd flow-matching

Certifique-se de ter o Python 3.12+ instalado. Instale o uv:

curl -LsSf https://astral.sh/uv/install.sh | sh

Então, configure o ambiente:

uv sync

Conditional Flow Matching [Lipman+ 2023]

Esta é a implementação original do artigo CFM [1]. Alguns componentes do código são adaptados de [2] e [3].

Conjuntos de Dados Sintéticos 2D

Você pode treinar os modelos CFM em conjuntos de dados sintéticos 2D, como checkerboard e moons. Especifique o nome do conjunto de dados usando a opção --dataset. Os parâmetros de treinamento estão predefinidos no script, e as visualizações dos resultados do treinamento são armazenadas no diretório outputs/. Os pontos de verificação dos modelos não estão incluídos, pois são facilmente reproduzíveis com as configurações padrão.

uv run scripts/train_flow_matching_2d.py --dataset checkerboard

Os campos vetoriais e amostras geradas, como os exibidos em GIFs no topo deste README, agora podem ser encontrados no diretório outputs/cfm/.

Conjuntos de Dados de Imagens

Você também pode treinar modelos CFM condicionais à classe em conjuntos de dados populares de classificação de imagens. Tanto as amostras geradas quanto os checkpoints do modelo serão armazenados no diretório outputs/cfm. Para uma lista detalhada dos parâmetros de treinamento, execute uv run scripts/train_flow_matching_on_image.py --help.

Para treinar um CFM condicional à classe no conjunto de dados MNIST, execute:

uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist

Após o treinamento, agora você pode gerar amostras com:

uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist

Agora, você deverá conseguir ver as amostras geradas no diretório outputs/cfm/mnist/.

Fluxo Retificado [Liu+ 2023]

Esta é uma implementação do modelo Reflow (2-Fluxo Retificado, especificamente) do artigo Rectified Flow [2].

Dados Sintéticos 2D

Implementamos o Reflow em conjuntos de dados sintéticos 2D, assim como o CFM. Para treinar o reflow, você deve especificar pontos de verificação (checkpoints) CFM pré-treinados, pois o reflow é um modelo de destilação.

Por exemplo, para treinar no conjunto de dados checkerboard com um checkpoint CFM pré-treinado:

uv run scripts/train_reflow_2d.py --dataset checkerboard

Os resultados do treinamento, incluindo visualizações do campo vetorial e amostras geradas, são salvos na pasta outputs/reflow/.

Comparação do processo de amostragem entre CFM e Reflow

Para comparar CFM e Reflow em conjuntos de dados 2d, execute:

uv run scripts/plot_comparison_2d.py --dataset checkerboard
Os GIFs resultantes podem ser encontrados na pasta outputs/comparisons/. Abaixo está um exemplo de comparação dos dois métodos no conjunto de dados checkerboard:

Referências

--- Tranlated By Open Ai Tx | Last indexed: 2026-01-19 ---