![]() |
![]() |
![]() |
In diesem Leitfaden wird beschrieben, wie Sie Gemma mit dem PyTorch-Framework ausführen. Außerdem wird erläutert, wie Sie Bilddaten für das Prompting von Gemma-Modellen ab Version 3 verwenden. Weitere Informationen zur Gemma-PyTorch-Implementierung finden Sie in der README-Datei des Projekt-Repositorys.
Einrichtung
In den folgenden Abschnitten wird beschrieben, wie Sie Ihre Entwicklungsumgebung einrichten. Dazu gehört, wie Sie Zugriff auf Gemma-Modelle zum Herunterladen von Kaggle erhalten, Authentifizierungsvariablen festlegen, Abhängigkeiten installieren und Pakete importieren.
Systemanforderungen
Für diese Gemma-PyTorch-Bibliothek sind GPU- oder TPU-Prozessoren erforderlich, um das Gemma-Modell auszuführen. Die standardmäßige Colab-CPU-Python-Laufzeit und die T4-GPU-Python-Laufzeit reichen für die Ausführung von Gemma-Modellen mit 1B, 2B und 4B aus. Informationen zu erweiterten Anwendungsfällen für andere GPUs oder TPUs finden Sie in der README-Datei im Gemma PyTorch-Repository.
Zugriff auf Gemma auf Kaggle erhalten
Um diese Anleitung durchzuarbeiten, müssen Sie zuerst die Einrichtungsanleitung unter Gemma einrichten befolgen. Dort wird beschrieben, wie Sie Folgendes tun:
- Kaggle
- Wählen Sie eine Colab-Laufzeit mit ausreichend Ressourcen aus, um das Gemma-Modell auszuführen.
- Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.
Nachdem Sie die Gemma-Einrichtung abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort, in dem Sie Umgebungsvariablen für Ihre Colab-Umgebung festlegen.
Umgebungsvariablen festlegen
Legen Sie Umgebungsvariablen für KAGGLE_USERNAME
und KAGGLE_KEY
fest. Wenn Sie mit der Meldung „Zugriff gewähren?“ aufgefordert werden, stimmen Sie zu, den Zugriff auf das Secret zu gewähren.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Abhängigkeiten installieren
pip install -q -U torch immutabledict sentencepiece
Modellgewichte herunterladen
# Choose variant and machine type
VARIANT = '4b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')
Legen Sie die Tokenizer- und Prüfpunktepfade für das Modell fest.
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
Laufzeitumgebung konfigurieren
In den folgenden Abschnitten wird erläutert, wie Sie eine PyTorch-Umgebung für die Ausführung von Gemma vorbereiten.
PyTorch-Laufzeitumgebung vorbereiten
Bereiten Sie die Ausführungsumgebung für das PyTorch-Modell vor, indem Sie das Gemma-PyTorch-Repository klonen.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM
import os
import torch
Modellkonfiguration festlegen
Bevor Sie das Modell ausführen, müssen Sie einige Konfigurationsparameter festlegen, darunter die Gemma-Variante, den Tokenizer und die Quantisierungsstufe.
# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path
Gerätekontext konfigurieren
Mit dem folgenden Code wird der Gerätekontext für die Ausführung des Modells konfiguriert:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
Modell instanziieren und laden
Laden Sie das Modell mit seinen Gewichten, um Anfragen auszuführen.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
model = model.to(device).eval()
print("Model loading done.")
print('Generating requests in chat mode...')
Inferenz ausführen
Unten finden Sie Beispiele für die Generierung im Chatmodus und mit mehreren Anfragen.
Die anweisungsoptimierten Gemma-Modelle wurden mit einem bestimmten Formatierungsprogramm trainiert, das Beispiele für die Anweisungsoptimierung sowohl während des Trainings als auch bei der Inferenz mit zusätzlichen Informationen versieht. Die Anmerkungen (1) geben Rollen in einer Unterhaltung an und (2) grenzen die einzelnen Beiträge in einer Unterhaltung ab.
Die relevanten Annotationstokens sind:
user
: Nutzerzugmodel
: Modellantwort<start_of_turn>
: Beginn des Dialogbeitrags<start_of_image>
: Tag für die Eingabe von Bilddaten<end_of_turn><eos>
: Ende des Dialogbeitrags
Text mit Text generieren
Das folgende Code-Snippet zeigt, wie Sie einen Prompt für ein auf Anweisungen abgestimmtes Gemma-Modell mit Chatvorlagen für Nutzer und Modelle in einer Konversation mit mehreren Durchgängen formatieren.
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=256,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
Text mit Bildern generieren
Ab Gemma-Version 3 können Sie Bilder in Ihren Prompt einfügen. Das folgende Beispiel zeigt, wie Sie visuelle Daten in Ihren Prompt einfügen.
print('Chat with images...\n')
def read_image(url):
import io
import requests
import PIL
contents = io.BytesIO(requests.get(url).content)
return PIL.Image.open(contents)
image = read_image(
'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)
print(model.generate(
[
[
'<start_of_turn>user\n',
image,
'What animal is in this image?<end_of_turn>\n',
'<start_of_turn>model\n'
]
],
device=device,
output_len=256,
))
Weitere Informationen
Nachdem Sie nun gelernt haben, wie Sie Gemma in PyTorch verwenden, können Sie die vielen anderen Möglichkeiten, die Gemma bietet, unter ai.google.dev/gemma erkunden.
Sehen Sie sich auch diese anderen verwandten Ressourcen an: