![]() |
![]() |
![]() |
En esta guía, se muestra cómo ejecutar Gemma con el framework de PyTorch, incluido cómo usar datos de imágenes para generar instrucciones para los modelos de la versión 3 de Gemma y versiones posteriores. Para obtener más detalles sobre la implementación de Gemma en PyTorch, consulta el archivo README del repositorio del proyecto.
Configuración
En las siguientes secciones, se explica cómo configurar tu entorno de desarrollo, incluido cómo obtener acceso a los modelos de Gemma para descargarlos desde Kaggle, establecer variables de autenticación, instalar dependencias y, luego, importar paquetes.
Requisitos del sistema
Esta biblioteca de Gemma Pytorch requiere procesadores de GPU o TPU para ejecutar el modelo de Gemma. El entorno de ejecución de Python de CPU estándar de Colab y el entorno de ejecución de Python de GPU T4 son suficientes para ejecutar modelos de tamaño Gemma 1B, 2B y 4B. Para casos de uso avanzados de otras GPUs o TPUs, consulta el archivo README en el repo de Gemma PyTorch.
Obtén acceso a Gemma en Kaggle
Para completar este instructivo, primero debes seguir las instrucciones de configuración en Configuración de Gemma, en las que se explica cómo realizar las siguientes acciones:
- Obtén acceso a Gemma en Kaggle.
- Selecciona un entorno de ejecución de Colab con recursos suficientes para ejecutar el modelo de Gemma.
- Genera y configura un nombre de usuario y una clave de API de Kaggle.
Después de completar la configuración de Gemma, pasa a la siguiente sección, en la que establecerás variables de entorno para tu entorno de Colab.
Configure las variables de entorno
Configura las variables de entorno para KAGGLE_USERNAME
y KAGGLE_KEY
. Cuando aparezcan los mensajes "¿Quieres otorgar acceso?", acepta proporcionar acceso a los secretos.
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
Instala dependencias
pip install -q -U torch immutabledict sentencepiece
Descarga los pesos del modelo
# Choose variant and machine type
VARIANT = '4b-it'
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub
# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')
Establece las rutas del tokenizador y del punto de control para el modelo.
# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'
# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'
Configura el entorno de ejecución
En las siguientes secciones, se explica cómo preparar un entorno de PyTorch para ejecutar Gemma.
Prepara el entorno de ejecución de PyTorch
Clona el repositorio de Gemma PyTorch para preparar el entorno de ejecución del modelo de PyTorch.
git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'... remote: Enumerating objects: 239, done. remote: Counting objects: 100% (123/123), done. remote: Compressing objects: 100% (68/68), done. remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116 Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done. Resolving deltas: 100% (135/135), done.
import sys
sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM
import os
import torch
Establece la configuración del modelo
Antes de ejecutar el modelo, debes establecer algunos parámetros de configuración, como la variante de Gemma, el tokenizador y el nivel de cuantificación.
# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path
Configura el contexto del dispositivo
El siguiente código configura el contexto del dispositivo para ejecutar el modelo:
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
Crea una instancia del modelo y cárgalo
Carga el modelo con sus pesos para prepararte para ejecutar solicitudes.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
model = model.to(device).eval()
print("Model loading done.")
print('Generating requests in chat mode...')
Ejecuta la inferencia
A continuación, se muestran ejemplos de generación en modo de chat y con varias solicitudes.
Los modelos de Gemma ajustados con instrucciones se entrenaron con un formateador específico que anota los ejemplos de ajuste con instrucciones con información adicional, tanto durante el entrenamiento como en la inferencia. Las anotaciones (1) indican los roles en una conversación y (2) delimitan los turnos en una conversación.
Los tokens de anotación relevantes son los siguientes:
user
: Turno del usuariomodel
: Giro del modelo<start_of_turn>
: Comienzo del turno de diálogo<start_of_image>
: Es la etiqueta para la entrada de datos de imágenes.<end_of_turn><eos>
: Fin del turno de diálogo
Para obtener más información, consulta el formato de instrucciones para los modelos de Gemma ajustados con instrucciones aquí.
Genera texto con texto
A continuación, se muestra un fragmento de código de ejemplo que demuestra cómo dar formato a una instrucción para un modelo de Gemma ajustado a instrucciones con plantillas de chat de usuario y modelo en una conversación de varios turnos.
# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"
# Sample formatted prompt
prompt = (
USER_CHAT_TEMPLATE.format(
prompt='What is a good place for travel in the US?'
)
+ MODEL_CHAT_TEMPLATE.format(prompt='California.')
+ USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
+ '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)
model.generate(
USER_CHAT_TEMPLATE.format(prompt=prompt),
device=device,
output_len=256,
)
Chat prompt: <start_of_turn>user What is a good place for travel in the US?<end_of_turn><eos> <start_of_turn>model California.<end_of_turn><eos> <start_of_turn>user What can I do in California?<end_of_turn><eos> <start_of_turn>model "California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo) \n\nThe more you tell me, the better recommendations I can give! 😊 \n<end_of_turn>"
# Generate sample
model.generate(
'Write a poem about an llm writing a poem.',
device=device,
output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"
Genera texto con imágenes
Con la versión 3 y posteriores de Gemma, puedes usar imágenes en tus instrucciones. En el siguiente ejemplo, se muestra cómo incluir datos visuales en tu instrucción.
print('Chat with images...\n')
def read_image(url):
import io
import requests
import PIL
contents = io.BytesIO(requests.get(url).content)
return PIL.Image.open(contents)
image = read_image(
'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)
print(model.generate(
[
[
'<start_of_turn>user\n',
image,
'What animal is in this image?<end_of_turn>\n',
'<start_of_turn>model\n'
]
],
device=device,
output_len=256,
))
Más información
Ahora que aprendiste a usar Gemma en PyTorch, puedes explorar las muchas otras cosas que Gemma puede hacer en ai.google.dev/gemma.
Consulta también estos otros recursos relacionados: