पायटॉर्च में फ्लो मैचिंग
यह रिपॉजिटरी पेपर Flow Matching for Generative Modeling का एक सरल पायटॉर्च इम्प्लीमेंटेशन प्रदान करती है।
2D फ्लो मैचिंग उदाहरण
नीचे दिया गया gif एक एकल गॉसियन वितरण को एक चेकरबोर्ड वितरण में मैप करने का प्रदर्शन करता है, जिसमें वेक्टर फील्ड को विज़ुअलाइज़ किया गया है।
और, यहाँ moons डेटासेट का एक और उदाहरण है।
शुरुआत कैसे करें
रिपॉजिटरी को क्लोन करें और python वातावरण सेटअप करें।
git clone https://github.com/keishihara/flow-matching.git
cd flow-matchingसुनिश्चित करें कि आपके पास Python 3.12+ स्थापित है।
uv इंस्टॉल करें:
curl -LsSf https://astral.sh/uv/install.sh | sh
फिर, परिवेश सेट करें:uv syncकंडीशनल फ्लो मैचिंग [लिपमैन+ 2023]
यह मूल CFM पेपर इम्प्लीमेंटेशन है [1]। कोड के कुछ घटक [2] और [3] से अनुकूलित किए गए हैं।
2डी टॉय डेटासेट्स
आप 2डी सिंथेटिक डेटासेट्स जैसे चेकरबोर्ड और मून्स पर CFM मॉडल्स को ट्रेन कर सकते हैं। डेटासेट का नाम --dataset विकल्प के द्वारा निर्दिष्ट करें। ट्रेनिंग पैरामीटर्स स्क्रिप्ट में पूर्वनिर्धारित हैं, और ट्रेनिंग परिणामों के विज़ुअलाइज़ेशन outputs/ डायरेक्टरी में स्टोर होते हैं। मॉडल चेकपॉइंट्स शामिल नहीं हैं क्योंकि वे डिफ़ॉल्ट सेटिंग्स के साथ आसानी से दोबारा बनाए जा सकते हैं।
uv run scripts/train_flow_matching_2d.py --dataset checkerboardवेक्टर फील्ड्स और जेनरेटेड सैम्पल्स, जैसे कि इस README के ऊपर GIFs के रूप में दिखाए गए हैं, अब outputs/cfm/ डायरेक्टरी में पाए जा सकते हैं।
इमेज डेटासेट्स
आप लोकप्रिय इमेज क्लासिफिकेशन डेटासेट्स पर क्लास-कंडीशनल CFM मॉडल्स भी ट्रेन कर सकते हैं। दोनों, जेनरेटेड सैम्पल्स और मॉडल चेकपॉइंट्स, outputs/cfm डायरेक्टरी में स्टोर किए जाएंगे। ट्रेनिंग पैरामीटर्स की विस्तृत सूची के लिए, uv run scripts/train_flow_matching_on_image.py --help चलाएं।
MNIST डेटासेट पर क्लास-कंडीशनल CFM ट्रेन करने के लिए, निम्नलिखित कमांड चलाएं:
uv run scripts/train_flow_matching_on_image.py --do_train --dataset mnist
प्रशिक्षण के बाद, अब आप नमूने उत्पन्न कर सकते हैं:uv run scripts/train_flow_matching_on_image.py --do_sample --dataset mnist
अब, आप outputs/cfm/mnist/ डायरेक्टरी में जेनरेट किए गए नमूनों को देख सकते हैं।
रेक्टिफाइड फ्लो [Liu+ 2023]
यह रेक्टिफाइड फ्लो पेपर [2] से Reflow मॉडल (विशेष रूप से 2-Rectified Flow) का एक इम्प्लीमेंटेशन है।
2D सिंथेटिक डेटा
हमने Reflow को 2D सिंथेटिक डेटासेट्स पर लागू किया है, ठीक उसी तरह जैसे CFM पर किया गया था। Reflow को ट्रेन करने के लिए आपको प्री-ट्रेंड CFM चेकपॉइंट्स को निर्दिष्ट करना होगा क्योंकि Reflow एक डिस्टिलेशन मॉडल है।
उदाहरण के लिए, checkerboard डेटासेट पर प्री-ट्रेंड CFM चेकपॉइंट के साथ ट्रेन करने के लिए:
uv run scripts/train_reflow_2d.py --dataset checkerboardप्रशिक्षण परिणाम, जिसमें वेक्टर क्षेत्र दृश्यावलोकन और उत्पन्न नमूने शामिल हैं, outputs/reflow/ फ़ोल्डर के अंतर्गत सहेजे जाते हैं।
CFM और Reflow के बीच सैंपलिंग प्रक्रिया की तुलना
CFM और Reflow की तुलना करने के लिए 2d डेटासेट्स पर, निम्नलिखित चलाएँ:
uv run scripts/plot_comparison_2d.py --dataset checkerboard
उत्पन्न GIFs को outputs/comparisons/ फोल्डर में पाया जा सकता है। नीचे checkerboard डेटासेट में दो विधियों की तुलना का एक उदाहरण दिया गया है:
संदर्भ
- [1] लिपमैन, यारोन, आदि। "Flow Matching for Generative Modeling." arXiv:2210.02747
- [2] लियू, शिंगचाओ, आदि। "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow." arXiv:2209.03003
- [3] facebookresearch/flow_matching
- [4] atong01/conditional-flow-matching