PyTorch中的流匹配
本仓库包含了论文 用于生成建模的流匹配 的一个简单PyTorch实现。
2D流匹配示例
下方的动图演示了将单一高斯分布映射到棋盘格分布的过程,并可视化了矢量场。
下面是月牙形(moons)数据集的另一个示例。
快速开始
克隆该仓库并配置 Python 环境。
git clone https://github.com/keishihara/flow-matching.git
cd flow-matching确保已安装 Python 3.12 及以上版本。
安装 uv:
curl -LsSf https://astral.sh/uv/install.sh | sh然后,设置环境:
uv sync条件流匹配 [Lipman+ 2023]
这是原始的CFM论文实现 [1]。代码的部分组件改编自 [2] 和 [3]。
2D 玩具数据集
您可以在二维合成数据集如 checkerboard 和 moons 上训练CFM模型。使用 --dataset 选项指定数据集名称。训练参数在脚本中预定义,训练结果的可视化存储在 outputs/ 目录下。模型检查点未包含,因为使用默认设置可以轻松重现。
uv run scripts/train_flow_matching_2d.py --dataset checkerboard现在,像本 README 顶部显示为 GIF 的那些向量场和生成样本,可以在 outputs/cfm/ 目录下找到。
图像数据集
您还可以在流行的图像分类数据集上训练类别条件的 CFM 模型。生成的样本和模型检查点都将存储在 outputs/cfm 目录中。要获取详细的训练参数列表,请运行 uv run scripts/train_flow_matching_on_image.py --help。
要在 MNIST 数据集上训练类别条件的 CFM,请运行:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist训练完成后,您现在可以生成样本,命令如下:
uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist
现在,您应该能够在 outputs/cfm/mnist/ 目录中看到生成的样本。
Rectified Flow [Liu+ 2023]
这是 Rectified Flow 论文 [2] 中 Reflow 模型(具体是 2-Rectified Flow)的实现。
2D 合成数据
我们已经在 2D 合成数据集上实现了 Reflow,和 CFM 一样。要训练 Reflow,您必须指定预训练的 CFM 检查点,因为 Reflow 是一个蒸馏模型。
例如,要在带有预训练 CFM 检查点的 checkerboard 数据集上训练:
uv run scripts/train_reflow_2d.py --dataset checkerboard
训练结果,包括向量场可视化和生成样本,保存在 outputs/reflow/ 文件夹下。CFM 与 Reflow 采样过程的比较
要在二维数据集上比较 CFM 和 Reflow,运行:
uv run scripts/plot_comparison_2d.py --dataset checkerboard
生成的 GIF 可以在 outputs/comparisons/ 文件夹中找到。下面是 checkerboard 数据集中两种方法的对比示例:
参考文献
- [1] Lipman, Yaron, 等人. "用于生成建模的流匹配." arXiv:2210.02747
- [2] Liu, Xingchao, 等人. "流直且快:学习使用校正流生成和传输数据." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching