Web Analytics

flow-matching

⭐ 89 stars Simplified Chinese by keishihara

🌐 语言

PyTorch中的流匹配

本仓库包含了论文 用于生成建模的流匹配 的一个简单PyTorch实现。

2D流匹配示例

下方的动图演示了将单一高斯分布映射到棋盘格分布的过程,并可视化了矢量场。

下面是月牙形(moons)数据集的另一个示例。

快速开始

克隆该仓库并配置 Python 环境。

git clone https://github.com/keishihara/flow-matching.git
cd flow-matching

确保已安装 Python 3.12 及以上版本。 安装 uv

curl -LsSf https://astral.sh/uv/install.sh | sh

然后,设置环境:

uv sync

条件流匹配 [Lipman+ 2023]

这是原始的CFM论文实现 [1]。代码的部分组件改编自 [2] 和 [3]。

2D 玩具数据集

您可以在二维合成数据集如 checkerboardmoons 上训练CFM模型。使用 --dataset 选项指定数据集名称。训练参数在脚本中预定义,训练结果的可视化存储在 outputs/ 目录下。模型检查点未包含,因为使用默认设置可以轻松重现。

uv run scripts/train_flow_matching_2d.py --dataset checkerboard

现在,像本 README 顶部显示为 GIF 的那些向量场和生成样本,可以在 outputs/cfm/ 目录下找到。

图像数据集

您还可以在流行的图像分类数据集上训练类别条件的 CFM 模型。生成的样本和模型检查点都将存储在 outputs/cfm 目录中。要获取详细的训练参数列表,请运行 uv run scripts/train_flow_matching_on_image.py --help

要在 MNIST 数据集上训练类别条件的 CFM,请运行:

uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist

训练完成后,您现在可以生成样本,命令如下:

uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist
现在,您应该能够在 outputs/cfm/mnist/ 目录中看到生成的样本。

Rectified Flow [Liu+ 2023]

这是 Rectified Flow 论文 [2] 中 Reflow 模型(具体是 2-Rectified Flow)的实现。

2D 合成数据

我们已经在 2D 合成数据集上实现了 Reflow,和 CFM 一样。要训练 Reflow,您必须指定预训练的 CFM 检查点,因为 Reflow 是一个蒸馏模型。

例如,要在带有预训练 CFM 检查点的 checkerboard 数据集上训练:

uv run scripts/train_reflow_2d.py --dataset checkerboard
训练结果,包括向量场可视化和生成样本,保存在 outputs/reflow/ 文件夹下。

CFM 与 Reflow 采样过程的比较

要在二维数据集上比较 CFM 和 Reflow,运行:

uv run scripts/plot_comparison_2d.py --dataset checkerboard
生成的 GIF 可以在 outputs/comparisons/ 文件夹中找到。下面是 checkerboard 数据集中两种方法的对比示例:

参考文献

--- Tranlated By Open Ai Tx | Last indexed: 2026-01-19 ---