הפעלת Gemma באמצעות PyTorch

לצפייה באתר ai.google.dev הרצה ב-Google Colab צפייה במקור ב-GitHub

במדריך הזה מוסבר איך להריץ את Gemma באמצעות מסגרת PyTorch, כולל איך להשתמש בנתוני תמונות כדי ליצור הנחיות למודלים של Gemma מגרסה 3 ואילך. פרטים נוספים על ההטמעה של Gemma PyTorch זמינים במאמר README במאגר הפרויקט.

הגדרה

בקטעים הבאים מוסבר איך להגדיר את סביבת הפיתוח, כולל איך לקבל גישה למודלים של Gemma להורדה מ-Kaggle, איך להגדיר משתני אימות, איך להתקין תלות ואיך לייבא חבילות.

דרישות מערכת

כדי להריץ את מודל Gemma, ספריית Gemma Pytorch הזו דורשת מעבדי GPU או TPU. זמן הריצה הרגיל של Python ב-CPU ב-Colab וזמן הריצה של Python ב-GPU T4 מספיקים להרצת מודלים בגודל 1B,‏ 2B ו-4B של Gemma. לתרחישי שימוש מתקדמים יותר ב-GPU או ב-TPU אחרים, אפשר לעיין בקובץ ה-README במאגר Gemma PyTorch.

גישה ל-Gemma ב-Kaggle

כדי להשלים את המדריך הזה, צריך קודם לפעול לפי הוראות ההגדרה במאמר הגדרת Gemma. במאמר הזה מוסבר איך לבצע את הפעולות הבאות:

  • אפשר לקבל גישה ל-Gemma ב-Kaggle.
  • בוחרים זמן ריצה של Colab עם מספיק משאבים להרצת מודל Gemma.
  • יוצרים ומגדירים שם משתמש ומפתח API ב-Kaggle.

אחרי שתסיימו את ההגדרה של Gemma, תעברו לקטע הבא ותגדירו משתני סביבה לסביבת Colab.

הגדרה של משתני סביבה

מגדירים משתני סביבה ל-KAGGLE_USERNAME ול-KAGGLE_KEY. כשמופיעות ההודעות 'לתת גישה?', מאשרים את הגישה לסוד.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

התקנת יחסי תלות

pip install -q -U torch immutabledict sentencepiece

הורדת משקלים של מודלים

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

מגדירים את הנתיבים של הטוקנייזר ונקודת הבדיקה של המודל.

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

הגדרת סביבת ההרצה

בקטעים הבאים מוסבר איך להכין סביבת PyTorch להרצת Gemma.

הכנת סביבת ההפעלה של PyTorch

מכינים את סביבת ההפעלה של מודל PyTorch על ידי שיבוט של מאגר Gemma Pytorch.

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

הגדרת תצורת המודל

לפני שמריצים את המודל, צריך להגדיר כמה פרמטרים של הגדרה, כולל הווריאציה של Gemma, טוקנייזר ורמת קוונטיזציה.

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

הגדרת הקשר של המכשיר

הקוד הבא מגדיר את הקשר של המכשיר להרצת המודל:

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

יצירת מופע וטעינה של המודל

טוענים את המודל עם המשקלים שלו כדי להתכונן להרצת בקשות.

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

הרצת הסקת מסקנות

בהמשך מופיעות דוגמאות ליצירה במצב צ'אט וליצירה עם כמה בקשות.

מודלי Gemma שעברו כוונון לפי הוראות אומנו באמצעות מעצב פורמטים ספציפי שמבצע הערות על דוגמאות של כוונון לפי הוראות עם מידע נוסף, גם במהלך האימון וגם במהלך ההסקה. ההערות (1) מציינות את התפקידים בשיחה, ו-(2) מסמנות את התורות בשיחה.

הטוקנים הרלוונטיים של ההערות הם:

  • user: תור של משתמש
  • model: תור של המודל
  • <start_of_turn>: תחילת תור בדיאלוג
  • <start_of_image>: תג לקלט של נתוני תמונות
  • <end_of_turn><eos>: סוף תור בשיחה

מידע נוסף על פורמט הנחיות לשימוש במודלים של Gemma שעברו כוונון להוראות זמין כאן.

יצירת טקסט באמצעות טקסט

קטע הקוד הבא הוא דוגמה שממחישה איך לעצב הנחיה למודל Gemma שעבר כוונון להוראות, באמצעות תבניות צ'אט של משתמשים ומודלים בשיחה מרובת תפניות.

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

יצירת טקסט עם תמונות

בגרסה 3 של Gemma ואילך, אפשר להשתמש בתמונות בהנחיה. בדוגמה הבאה מוסבר איך לכלול נתונים חזותיים בהנחיה.

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image = read_image(
    'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)

print(model.generate(
    [
        [
            '<start_of_turn>user\n',
            image,
            'What animal is in this image?<end_of_turn>\n',
            '<start_of_turn>model\n'
        ]
    ],
    device=device,
    output_len=256,
))

מידע נוסף

אחרי שלמדתם איך להשתמש ב-Gemma ב-PyTorch, אתם יכולים לבדוק את הדברים הרבים האחרים ש-Gemma יכולה לעשות בכתובת ai.google.dev/gemma.

כדאי לעיין גם במקורות המידע הקשורים הבאים: