Flow Matching em PyTorch
Este repositório contém uma implementação simples em PyTorch do artigo Flow Matching for Generative Modeling.
Exemplo de Flow Matching 2D
O gif abaixo demonstra o mapeamento de uma única distribuição Gaussiana para uma distribuição em padrão tabuleiro, com o campo vetorial visualizado.
E aqui está outro exemplo com o conjunto de dados moons.
Primeiros Passos
Clone o repositório e configure o ambiente Python.
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingCertifique-se de ter o Python 3.12+ instalado.
Instale o uv:
curl -LsSf https://astral.sh/uv/install.sh | shEntão, configure o ambiente:
uv syncConditional Flow Matching [Lipman+ 2023]
Esta é a implementação original do artigo CFM [1]. Alguns componentes do código são adaptados de [2] e [3].
Conjuntos de Dados Sintéticos 2D
Você pode treinar os modelos CFM em conjuntos de dados sintéticos 2D, como checkerboard e moons. Especifique o nome do conjunto de dados usando a opção --dataset. Os parâmetros de treinamento estão predefinidos no script, e as visualizações dos resultados do treinamento são armazenadas no diretório outputs/. Os pontos de verificação dos modelos não estão incluídos, pois são facilmente reproduzíveis com as configurações padrão.
uv run scripts/train_flow_matching_2d.py --dataset checkerboardOs campos vetoriais e amostras geradas, como os exibidos em GIFs no topo deste README, agora podem ser encontrados no diretório outputs/cfm/.
Conjuntos de Dados de Imagens
Você também pode treinar modelos CFM condicionais à classe em conjuntos de dados populares de classificação de imagens. Tanto as amostras geradas quanto os checkpoints do modelo serão armazenados no diretório outputs/cfm. Para uma lista detalhada dos parâmetros de treinamento, execute uv run scripts/train_flow_matching_on_image.py --help.
Para treinar um CFM condicional à classe no conjunto de dados MNIST, execute:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnistApós o treinamento, agora você pode gerar amostras com:
uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnistAgora, você deverá conseguir ver as amostras geradas no diretório outputs/cfm/mnist/.
Fluxo Retificado [Liu+ 2023]
Esta é uma implementação do modelo Reflow (2-Fluxo Retificado, especificamente) do artigo Rectified Flow [2].
Dados Sintéticos 2D
Implementamos o Reflow em conjuntos de dados sintéticos 2D, assim como o CFM. Para treinar o reflow, você deve especificar pontos de verificação (checkpoints) CFM pré-treinados, pois o reflow é um modelo de destilação.
Por exemplo, para treinar no conjunto de dados checkerboard com um checkpoint CFM pré-treinado:
uv run scripts/train_reflow_2d.py --dataset checkerboardOs resultados do treinamento, incluindo visualizações do campo vetorial e amostras geradas, são salvos na pasta outputs/reflow/.
Comparação do processo de amostragem entre CFM e Reflow
Para comparar CFM e Reflow em conjuntos de dados 2d, execute:
uv run scripts/plot_comparison_2d.py --dataset checkerboard
Os GIFs resultantes podem ser encontrados na pasta outputs/comparisons/. Abaixo está um exemplo de comparação dos dois métodos no conjunto de dados checkerboard:
Referências
- [1] Lipman, Yaron, et al. "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] Liu, Xingchao, et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching