Flow Matching trong PyTorch
Kho lưu trữ này chứa một hiện thực đơn giản bằng PyTorch của bài báo Flow Matching for Generative Modeling.
Ví dụ Flow Matching 2D
Ảnh động bên dưới minh họa việc ánh xạ một phân phối Gaussian đơn đến phân phối dạng bàn cờ, với trường véc-tơ được trực quan hóa.
Và đây là một ví dụ khác với tập dữ liệu moons.
Bắt đầu
Sao chép kho lưu trữ và thiết lập môi trường python.
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingĐảm bảo bạn đã cài đặt Python 3.12+.
Cài đặt uv:
curl -LsSf https://astral.sh/uv/install.sh | sh
Sau đó, thiết lập môi trường:uv syncConditional Flow Matching [Lipman+ 2023]
Đây là bản triển khai gốc của bài báo CFM [1]. Một số thành phần của mã nguồn được điều chỉnh từ [2] và [3].
Bộ dữ liệu 2D Toy
Bạn có thể huấn luyện các mô hình CFM trên các bộ dữ liệu tổng hợp 2D như checkerboard và moons. Chỉ định tên bộ dữ liệu bằng tùy chọn --dataset. Tham số huấn luyện đã được định nghĩa sẵn trong script, và hình ảnh minh họa kết quả huấn luyện sẽ được lưu trong thư mục outputs/. Các checkpoint của mô hình không được đính kèm vì có thể dễ dàng tái tạo với các thiết lập mặc định.
uv run scripts/train_flow_matching_2d.py --dataset checkerboardCác trường vectơ và các mẫu được tạo ra, giống như những mẫu được hiển thị dưới dạng GIF ở đầu README này, hiện có thể được tìm thấy trong thư mục outputs/cfm/.
Bộ dữ liệu hình ảnh
Bạn cũng có thể huấn luyện các mô hình CFM theo điều kiện lớp trên các bộ dữ liệu phân loại hình ảnh phổ biến. Cả các mẫu đã tạo và các checkpoint của mô hình đều sẽ được lưu trữ trong thư mục outputs/cfm. Để xem danh sách chi tiết các tham số huấn luyện, hãy chạy uv run scripts/train_flow_matching_on_image.py --help.
Để huấn luyện CFM theo điều kiện lớp trên bộ dữ liệu MNIST, hãy chạy:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnistSau khi huấn luyện, bạn có thể tạo các mẫu với:
uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnistBây giờ, bạn sẽ có thể nhìn thấy các mẫu đã được tạo ra trong thư mục outputs/cfm/mnist/.
Rectified Flow [Liu+ 2023]
Đây là một bản triển khai của mô hình Reflow (cụ thể là 2-Rectified Flow) từ bài báo Rectified Flow [2].
Dữ liệu tổng hợp 2D
Chúng tôi đã triển khai Reflow trên các bộ dữ liệu tổng hợp 2D, giống như CFM. Để huấn luyện reflow, bạn cần chỉ định checkpoint CFM đã được huấn luyện trước vì reflow là một mô hình chưng cất.
Ví dụ, để huấn luyện trên bộ dữ liệu checkerboard với checkpoint CFM đã được huấn luyện trước:
uv run scripts/train_reflow_2d.py --dataset checkerboardKết quả huấn luyện, bao gồm các hình ảnh trực quan hóa trường vector và các mẫu được tạo ra, được lưu trong thư mục outputs/reflow/.
So sánh quá trình lấy mẫu giữa CFM và Reflow
Để so sánh CFM và Reflow trên các bộ dữ liệu 2d, hãy chạy:
uv run scripts/plot_comparison_2d.py --dataset checkerboardCác ảnh GIF kết quả có thể được tìm thấy trong thư mục outputs/comparisons/. Dưới đây là ví dụ so sánh hai phương pháp trong tập dữ liệu checkerboard:
Tài liệu tham khảo
- [1] Lipman, Yaron, et al. "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] Liu, Xingchao, et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching