การจับคู่โฟลว์ใน PyTorch
คลังนี้ประกอบด้วยการใช้งาน PyTorch อย่างง่ายของบทความ Flow Matching for Generative Modeling
ตัวอย่างการจับคู่โฟลว์แบบ 2D
ภาพ gif ด้านล่างแสดงการแมปการกระจายแบบเกาส์เซียนเดี่ยวไปยังการกระจายแบบกระดานหมากรุก พร้อมกับการแสดงภาพฟิลด์เวกเตอร์
และนี่คือตัวอย่างชุดข้อมูล moons อีกตัวอย่างหนึ่ง
เริ่มต้นใช้งาน
โคลน repository และตั้งค่าสภาพแวดล้อม python
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingตรวจสอบให้แน่ใจว่าคุณได้ติดตั้ง Python 3.12+ แล้ว
ติดตั้ง uv:
curl -LsSf https://astral.sh/uv/install.sh | sh
จากนั้น ตั้งค่าสภาพแวดล้อม:uv syncConditional Flow Matching [Lipman+ 2023]
นี่คือการนำเสนอการใช้งานต้นฉบับของ CFM [1] ส่วนประกอบบางส่วนของโค้ดถูกนำมาจาก [2] และ [3]
ชุดข้อมูลของเล่น 2D
คุณสามารถฝึกโมเดล CFM บนชุดข้อมูลสังเคราะห์ 2D เช่น checkerboard และ moons กำหนดชื่อชุดข้อมูลโดยใช้ตัวเลือก --dataset พารามิเตอร์การฝึกอบรมถูกกำหนดไว้ล่วงหน้าในสคริปต์ และผลการฝึกจะถูกจัดเก็บไว้ในไดเรกทอรี outputs/ จุดตรวจสอบของโมเดลจะไม่ถูกรวมไว้ เนื่องจากสามารถสร้างซ้ำได้ง่ายด้วยการตั้งค่าเริ่มต้น
uv run scripts/train_flow_matching_2d.py --dataset checkerboardฟิลด์เวกเตอร์และตัวอย่างที่สร้างขึ้น เช่นเดียวกับตัวอย่างที่แสดงเป็น GIF ที่ด้านบนของ README นี้ สามารถพบได้ในไดเรกทอรี outputs/cfm/
ชุดข้อมูลภาพ
คุณยังสามารถฝึกโมเดล CFM แบบกำหนดคลาสบนชุดข้อมูลการจำแนกภาพยอดนิยมได้ ตัวอย่างที่สร้างขึ้นและจุดตรวจโมเดลจะถูกจัดเก็บไว้ในไดเรกทอรี outputs/cfm สำหรับรายการพารามิเตอร์การฝึกที่ละเอียด ให้รันคำสั่ง uv run scripts/train_flow_matching_on_image.py --help
หากต้องการฝึกโมเดล CFM แบบกำหนดคลาสบนชุดข้อมูล MNIST ให้รันคำสั่ง:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist
หลังจากการฝึกอบรม คุณสามารถสร้างตัวอย่างได้โดยใช้:uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist
ตอนนี้ คุณควรจะสามารถเห็นตัวอย่างที่ถูกสร้างขึ้นในไดเรกทอรี outputs/cfm/mnist/ ได้แล้ว
Rectified Flow [Liu+ 2023]
นี่คือการนำเสนอโมเดล Reflow (โดยเฉพาะ 2-Rectified Flow) จากงานวิจัย Rectified Flow [2]
ข้อมูลสังเคราะห์ 2 มิติ
เราได้ทำการติดตั้ง Reflow บนชุดข้อมูลสังเคราะห์ 2 มิติ เช่นเดียวกับ CFM ในการฝึกสอน reflow คุณต้องระบุเช็คพอยต์ CFM ที่ผ่านการฝึกมาแล้ว เนื่องจาก reflow เป็นโมเดลการกลั่น
ตัวอย่างเช่น หากต้องการฝึกบนชุดข้อมูล checkerboard โดยใช้เช็คพอยต์ CFM ที่ผ่านการฝึกมาแล้ว:
uv run scripts/train_reflow_2d.py --dataset checkerboard
ผลการฝึกอบรม รวมถึงการแสดงภาพสนามเวกเตอร์และตัวอย่างที่สร้างขึ้น จะถูกบันทึกไว้ในโฟลเดอร์ outputs/reflow/การเปรียบเทียบกระบวนการสุ่มตัวอย่างระหว่าง CFM และ Reflow
ในการเปรียบเทียบ CFM และ Reflow บนชุดข้อมูล 2d ให้รัน:
uv run scripts/plot_comparison_2d.py --dataset checkerboard
ไฟล์ GIF ที่ได้สามารถพบได้ในโฟลเดอร์ outputs/comparisons/ ด้านล่างนี้เป็นตัวอย่างการเปรียบเทียบของสองวิธีในชุดข้อมูล checkerboard:
อ้างอิง
- [1] Lipman, Yaron, et al. "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] Liu, Xingchao, et al. "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching