Gemma mit PyTorch ausführen

Auf ai.google.dev ansehen In Google Colab ausführen Quelle auf GitHub ansehen

In diesem Leitfaden wird beschrieben, wie Sie Gemma mit dem PyTorch-Framework ausführen. Außerdem wird erläutert, wie Sie Bilddaten für das Prompting von Gemma-Modellen ab Version 3 verwenden. Weitere Informationen zur Gemma-PyTorch-Implementierung finden Sie in der README-Datei des Projekt-Repositorys.

Einrichtung

In den folgenden Abschnitten wird beschrieben, wie Sie Ihre Entwicklungsumgebung einrichten. Dazu gehört, wie Sie Zugriff auf Gemma-Modelle zum Herunterladen von Kaggle erhalten, Authentifizierungsvariablen festlegen, Abhängigkeiten installieren und Pakete importieren.

Systemanforderungen

Für diese Gemma-PyTorch-Bibliothek sind GPU- oder TPU-Prozessoren erforderlich, um das Gemma-Modell auszuführen. Die standardmäßige Colab-CPU-Python-Laufzeit und die T4-GPU-Python-Laufzeit reichen für die Ausführung von Gemma-Modellen mit 1B, 2B und 4B aus. Informationen zu erweiterten Anwendungsfällen für andere GPUs oder TPUs finden Sie in der README-Datei im Gemma PyTorch-Repository.

Zugriff auf Gemma auf Kaggle erhalten

Um diese Anleitung durchzuarbeiten, müssen Sie zuerst die Einrichtungsanleitung unter Gemma einrichten befolgen. Dort wird beschrieben, wie Sie Folgendes tun:

  • Kaggle
  • Wählen Sie eine Colab-Laufzeit mit ausreichend Ressourcen aus, um das Gemma-Modell auszuführen.
  • Generieren und konfigurieren Sie einen Kaggle-Nutzernamen und einen API-Schlüssel.

Nachdem Sie die Gemma-Einrichtung abgeschlossen haben, fahren Sie mit dem nächsten Abschnitt fort, in dem Sie Umgebungsvariablen für Ihre Colab-Umgebung festlegen.

Umgebungsvariablen festlegen

Legen Sie Umgebungsvariablen für KAGGLE_USERNAME und KAGGLE_KEY fest. Wenn Sie mit der Meldung „Zugriff gewähren?“ aufgefordert werden, stimmen Sie zu, den Zugriff auf das Secret zu gewähren.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Abhängigkeiten installieren

pip install -q -U torch immutabledict sentencepiece

Modellgewichte herunterladen

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

Legen Sie die Tokenizer- und Prüfpunktepfade für das Modell fest.

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Laufzeitumgebung konfigurieren

In den folgenden Abschnitten wird erläutert, wie Sie eine PyTorch-Umgebung für die Ausführung von Gemma vorbereiten.

PyTorch-Laufzeitumgebung vorbereiten

Bereiten Sie die Ausführungsumgebung für das PyTorch-Modell vor, indem Sie das Gemma-PyTorch-Repository klonen.

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

Modellkonfiguration festlegen

Bevor Sie das Modell ausführen, müssen Sie einige Konfigurationsparameter festlegen, darunter die Gemma-Variante, den Tokenizer und die Quantisierungsstufe.

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

Gerätekontext konfigurieren

Mit dem folgenden Code wird der Gerätekontext für die Ausführung des Modells konfiguriert:

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

Modell instanziieren und laden

Laden Sie das Modell mit seinen Gewichten, um Anfragen auszuführen.

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

Inferenz ausführen

Unten finden Sie Beispiele für die Generierung im Chatmodus und mit mehreren Anfragen.

Die anweisungsoptimierten Gemma-Modelle wurden mit einem bestimmten Formatierungsprogramm trainiert, das Beispiele für die Anweisungsoptimierung sowohl während des Trainings als auch bei der Inferenz mit zusätzlichen Informationen versieht. Die Anmerkungen (1) geben Rollen in einer Unterhaltung an und (2) grenzen die einzelnen Beiträge in einer Unterhaltung ab.

Die relevanten Annotationstokens sind:

  • user: Nutzerzug
  • model: Modellantwort
  • <start_of_turn>: Beginn des Dialogbeitrags
  • <start_of_image>: Tag für die Eingabe von Bilddaten
  • <end_of_turn><eos>: Ende des Dialogbeitrags

Weitere Informationen zum Formatieren von Prompts für Gemma-Modelle, die auf Anweisungen abgestimmt sind

Text mit Text generieren

Das folgende Code-Snippet zeigt, wie Sie einen Prompt für ein auf Anweisungen abgestimmtes Gemma-Modell mit Chatvorlagen für Nutzer und Modelle in einer Konversation mit mehreren Durchgängen formatieren.

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

Text mit Bildern generieren

Ab Gemma-Version 3 können Sie Bilder in Ihren Prompt einfügen. Das folgende Beispiel zeigt, wie Sie visuelle Daten in Ihren Prompt einfügen.

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image = read_image(
    'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)

print(model.generate(
    [
        [
            '<start_of_turn>user\n',
            image,
            'What animal is in this image?<end_of_turn>\n',
            '<start_of_turn>model\n'
        ]
    ],
    device=device,
    output_len=256,
))

Weitere Informationen

Nachdem Sie nun gelernt haben, wie Sie Gemma in PyTorch verwenden, können Sie die vielen anderen Möglichkeiten, die Gemma bietet, unter ai.google.dev/gemma erkunden.

Sehen Sie sich auch diese anderen verwandten Ressourcen an: