Web Analytics

flow-matching

⭐ 82 stars Simplified Chinese by keishihara

PyTorch中的流匹配

本仓库包含论文Flow Matching for Generative Modeling的一个简单PyTorch实现。

2D流匹配示例

下面的gif演示了将单个高斯分布映射到棋盘分布,并可视化了向量场。

这是另一个月亮数据集的示例。

快速开始

克隆仓库并设置Python环境。

git clone https://github.com/keishihara/flow-matching.git
cd flow-matching
确保已安装 Python 3.10 及以上版本。 使用 uv 设置 Python 环境:

uv sync
source .venv/bin/activate
或者,使用 pip

python -m venv .venv
source .venv/bin/activate
pip install -e .

条件流匹配 [Lipman+ 2023]

这是原始的CFM论文实现 [1]。代码的部分组件改编自 [2] 和 [3]。

2D 玩具数据集

您可以在二维合成数据集如 checkerboardmoons 上训练CFM模型。使用 --dataset 选项指定数据集名称。训练参数在脚本中预定义,训练结果的可视化存储在 outputs/ 目录下。模型检查点未包含,因为使用默认设置可以轻松重现。

python train_flow_matching_2d.py --dataset checkerboard

矢量场和生成的样本,如本自述文件顶部以GIF形式展示的内容,现在可以在 outputs/cfm/ 目录中找到。

图像数据集

您也可以在流行的图像分类数据集上训练类条件CFM模型。生成的样本和模型检查点都将存储在 outputs/cfm 目录中。有关训练参数的详细列表,请运行 python train_flow_matching_on_images.py --help

要在MNIST数据集上训练类条件CFM,请运行:

python train_flow_matching_on_image.py --do_train --dataset mnist

训练完成后,您现在可以生成样本,命令如下:

python train_flow_matching_on_image.py --do_sample --dataset mnist
现在,您应该能够在 outputs/cfm/mnist/ 目录中看到生成的样本。

Rectified Flow [Liu+ 2023]

这是 Rectified Flow 论文 [2] 中 Reflow 模型(具体是 2-Rectified Flow)的实现。

2D 合成数据

我们已经在 2D 合成数据集上实现了 Reflow,和 CFM 一样。要训练 Reflow,您必须指定预训练的 CFM 检查点,因为 Reflow 是一个蒸馏模型。

例如,要在带有预训练 CFM 检查点的 checkerboard 数据集上训练:

python train_reflow_2d.py --dataset checkerboard --pretrained-model outputs/cfm/checkerboard/ckpt.pth
训练结果,包括向量场可视化和生成样本,保存在 outputs/reflow/ 文件夹下。

CFM 与 Reflow 采样过程的比较

要在二维数据集上比较 CFM 和 Reflow,运行:

python plot_comparison_2d.py --dataset checkerboard
生成的 GIF 可以在 outputs/comparisons/ 文件夹中找到。下面是 checkerboard 数据集中两种方法的对比示例:

参考文献

--- Tranlated By Open Ai Tx | Last indexed: 2025-12-24 ---