Exécuter Gemma avec PyTorch

 Afficher sur ai.google.dev Exécuter dans Google Colab Afficher la source sur GitHub

Ce guide explique comment exécuter Gemma à l'aide du framework PyTorch, y compris comment utiliser des données d'image pour inciter les modèles Gemma version 3 et ultérieures. Pour en savoir plus sur l'implémentation Gemma PyTorch, consultez le fichier README du dépôt du projet.

Configuration

Les sections suivantes expliquent comment configurer votre environnement de développement, y compris comment accéder aux modèles Gemma à télécharger depuis Kaggle, définir des variables d'authentification, installer des dépendances et importer des packages.

Configuration requise

Cette bibliothèque Gemma PyTorch nécessite des processeurs GPU ou TPU pour exécuter le modèle Gemma. L'environnement d'exécution Python du processeur Colab standard et l'environnement d'exécution Python du GPU T4 sont suffisants pour exécuter les modèles Gemma de taille 1B, 2B et 4B. Pour les cas d'utilisation avancés d'autres GPU ou TPU, veuillez consulter le fichier README dans le dépôt Gemma PyTorch.

Accéder à Gemma sur Kaggle

Pour suivre ce tutoriel, vous devez d'abord suivre les instructions de configuration de Configuration de Gemma, qui vous expliquent comment effectuer les opérations suivantes :

  • Accédez à Gemma sur Kaggle.
  • Sélectionnez un environnement d'exécution Colab disposant de ressources suffisantes pour exécuter le modèle Gemma.
  • Générez et configurez un nom d'utilisateur et une clé API Kaggle.

Une fois la configuration de Gemma terminée, passez à la section suivante, où vous définirez les variables d'environnement pour votre environnement Colab.

Définir des variables d'environnement

Définissez les variables d'environnement pour KAGGLE_USERNAME et KAGGLE_KEY. Lorsque les messages "Accorder l'accès ?" s'affichent, acceptez d'accorder l'accès aux secrets.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Installer des dépendances

pip install -q -U torch immutabledict sentencepiece

Télécharger les pondérations du modèle

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

Définissez les chemins d'accès au tokenizer et au point de contrôle pour le modèle.

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Configurer l'environnement d'exécution

Les sections suivantes expliquent comment préparer un environnement PyTorch pour exécuter Gemma.

Préparer l'environnement d'exécution PyTorch

Préparez l'environnement d'exécution du modèle PyTorch en clonant le dépôt Gemma Pytorch.

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

Définir la configuration du modèle

Avant d'exécuter le modèle, vous devez définir certains paramètres de configuration, y compris la variante Gemma, le tokenizer et le niveau de quantification.

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

Configurer le contexte de l'appareil

Le code suivant configure le contexte de l'appareil pour exécuter le modèle :

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

Instancier et charger le modèle

Chargez le modèle avec ses pondérations pour préparer l'exécution des requêtes.

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

Exécuter une inférence

Vous trouverez ci-dessous des exemples de génération en mode chat et avec plusieurs requêtes.

Les modèles Gemma adaptés aux instructions ont été entraînés avec un formateur spécifique qui annote les exemples d'adaptation aux instructions avec des informations supplémentaires, à la fois pendant l'entraînement et l'inférence. Les annotations (1) indiquent les rôles dans une conversation et (2) délimitent les tours de parole.

Les jetons d'annotation pertinents sont les suivants :

  • user : tour de l'utilisateur
  • model : tour de modèle
  • <start_of_turn> : début du tour de dialogue
  • <start_of_image> : tag pour l'entrée de données d'image
  • <end_of_turn><eos> : fin du tour de dialogue

Pour en savoir plus, consultez Mise en forme des requêtes pour les modèles Gemma ajustés aux instructions.

Générer du texte avec du texte

Voici un exemple d'extrait de code montrant comment mettre en forme un prompt pour un modèle Gemma adapté aux instructions à l'aide de modèles de chat utilisateur et de modèle dans une conversation en plusieurs tours.

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

Générer du texte avec des images

Avec Gemma version 3 et ultérieures, vous pouvez utiliser des images avec votre requête. L'exemple suivant vous montre comment inclure des données visuelles dans votre requête.

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image = read_image(
    'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)

print(model.generate(
    [
        [
            '<start_of_turn>user\n',
            image,
            'What animal is in this image?<end_of_turn>\n',
            '<start_of_turn>model\n'
        ]
    ],
    device=device,
    output_len=256,
))

En savoir plus

Maintenant que vous avez appris à utiliser Gemma dans PyTorch, vous pouvez explorer les nombreuses autres choses que Gemma peut faire sur ai.google.dev/gemma.

Consultez également les ressources associées suivantes :