Executar o Gemma usando o PyTorch

Ver em ai.google.dev Executar no Google Colab Ver código-fonte no GitHub

Este guia mostra como executar o Gemma usando o framework PyTorch, incluindo como usar dados de imagem para fazer solicitações aos modelos da versão 3 e mais recentes do Gemma. Para mais detalhes sobre a implementação do Gemma PyTorch, consulte o README (em inglês) do repositório do projeto.

Configuração

As seções a seguir explicam como configurar seu ambiente de desenvolvimento, incluindo como acessar os modelos da Gemma para fazer o download do Kaggle, definir variáveis de autenticação, instalar dependências e importar pacotes.

Requisitos do sistema

Essa biblioteca Gemma Pytorch exige processadores de GPU ou TPU para executar o modelo Gemma. O ambiente de execução padrão do Python na CPU do Colab e o ambiente de execução do Python na GPU T4 são suficientes para executar modelos de 1B, 2B e 4B do Gemma. Para casos de uso avançados de outras GPUs ou TPUs, consulte o LEIA-ME no repositório Gemma PyTorch.

Acessar o Gemma no Kaggle

Para concluir este tutorial, primeiro siga as instruções de configuração em Configuração da Gemma, que mostram como fazer o seguinte:

  • Acesse a Gemma no Kaggle.
  • Selecione um ambiente de execução do Colab com recursos suficientes para executar o modelo do Gemma.
  • Gere e configure um nome de usuário e uma chave de API do Kaggle.

Depois de concluir a configuração do Gemma, passe para a próxima seção, em que você vai definir variáveis de ambiente para seu ambiente do Colab.

Defina as variáveis de ambiente

Defina as variáveis de ambiente para KAGGLE_USERNAME e KAGGLE_KEY. Quando aparecerem as mensagens "Conceder acesso?", concorde em fornecer acesso ao Secret.

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Instalar dependências

pip install -q -U torch immutabledict sentencepiece

Baixar pesos do modelo

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

Defina os caminhos do tokenizador e do checkpoint para o modelo.

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

Configurar o ambiente de execução

As seções a seguir explicam como preparar um ambiente do PyTorch para executar a Gemma.

Preparar o ambiente de execução do PyTorch

Prepare o ambiente de execução do modelo PyTorch clonando o repositório Gemma Pytorch.

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

Definir a configuração do modelo

Antes de executar o modelo, é necessário definir alguns parâmetros de configuração, incluindo a variante do Gemma, o tokenizador e o nível de quantização.

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

Configurar o contexto do dispositivo

O código a seguir configura o contexto do dispositivo para executar o modelo:

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

Instanciar e carregar o modelo

Carregue o modelo com os pesos dele para se preparar para executar solicitações.

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

Executar inferência

Confira abaixo exemplos de geração no modo de chat e com várias solicitações.

Os modelos Gemma ajustados por instruções foram treinados com um formatador específico que anota exemplos de ajuste de instruções com informações extras, tanto durante o treinamento quanto na inferência. As anotações (1) indicam papéis em uma conversa e (2) delineiam turnos em uma conversa.

Os tokens de anotação relevantes são:

  • user: turno do usuário
  • model: turno do modelo
  • <start_of_turn>: início da vez de diálogo
  • <start_of_image>: tag para entrada de dados de imagem
  • <end_of_turn><eos>: fim da vez da conversa

Para mais informações, leia sobre a formatação de comandos para modelos Gemma ajustados por instrução aqui.

Gerar texto com texto

Confira a seguir um exemplo de snippet de código que demonstra como formatar um comando para um modelo da Gemma ajustado por instruções usando modelos de chat de usuário e modelo em uma conversa de várias rodadas.

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

Gerar texto com imagens

Com a versão 3 e mais recentes do Gemma, você pode usar imagens com seu comando. O exemplo a seguir mostra como incluir dados visuais no comando.

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image = read_image(
    'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)

print(model.generate(
    [
        [
            '<start_of_turn>user\n',
            image,
            'What animal is in this image?<end_of_turn>\n',
            '<start_of_turn>model\n'
        ]
    ],
    device=device,
    output_len=256,
))

Saiba mais

Agora que você aprendeu a usar o Gemma no Pytorch, confira as muitas outras coisas que o Gemma pode fazer em ai.google.dev/gemma.

Confira também estes outros recursos relacionados: