使用 PyTorch 运行 Gemma

在 ai.google.dev 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码

本指南介绍了如何使用 PyTorch 框架运行 Gemma,包括如何使用图片数据提示 Gemma 版本 3 及更高版本的模型。如需详细了解 Gemma PyTorch 实现,请参阅项目代码库的自述文件

设置

以下部分介绍了如何设置开发环境,包括如何获取对 Gemma 模型的访问权限以便从 Kaggle 下载、设置身份验证变量、安装依赖项和导入软件包。

系统要求

此 Gemma Pytorch 库需要 GPU 或 TPU 处理器才能运行 Gemma 模型。标准 Colab CPU Python 运行时和 T4 GPU Python 运行时足以运行 Gemma 1B、2B 和 4B 大小的模型。如需了解其他 GPU 或 TPU 的高级用例,请参阅 Gemma PyTorch 代码库中的 README

在 Kaggle 上获取 Gemma 访问权限

如需完成本教程,您首先需要按照 Gemma 设置中的设置说明操作,该页面会向您展示如何执行以下操作:

  • Kaggle 上获取 Gemma 访问权限。
  • 选择具有足够资源来运行 Gemma 模型的 Colab 运行时。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 Gemma 设置后,请继续前往下一部分,您将在其中为 Colab 环境设置环境变量。

设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。当系统提示“授予访问权限?”时,请同意提供密文访问权限。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

安装依赖项

pip install -q -U torch immutabledict sentencepiece

下载模型权重

# Choose variant and machine type
VARIANT = '4b-it' 
MACHINE_TYPE = 'cuda'
CONFIG = VARIANT.split('-')[0]
import kagglehub

# Load model weights
weights_dir = kagglehub.model_download(f'google/gemma-3/pyTorch/gemma-3-{VARIANT}')

为模型设置分词器和检查点路径。

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'model.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

配置运行环境

以下部分介绍了如何准备 PyTorch 环境以运行 Gemma。

准备 PyTorch 运行环境

通过克隆 Gemma Pytorch 代码库来准备 PyTorch 模型执行环境。

git clone https://github.com/google/gemma_pytorch.git
Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 239, done.
remote: Counting objects: 100% (123/123), done.
remote: Compressing objects: 100% (68/68), done.
remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116
Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.
Resolving deltas: 100% (135/135), done.
import sys

sys.path.append('gemma_pytorch/gemma')
from gemma_pytorch.gemma.config import get_model_config
from gemma_pytorch.gemma.gemma3_model import Gemma3ForMultimodalLM

import os
import torch

设置模型配置

在运行模型之前,您必须设置一些配置参数,包括 Gemma 变体、分词器和量化级别。

# Set up model config.
model_config = get_model_config(CONFIG)
model_config.dtype = "float32" if MACHINE_TYPE == "cpu" else "float16"
model_config.tokenizer = tokenizer_path

配置设备上下文

以下代码配置了用于运行模型的设备上下文:

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

实例化并加载模型

加载模型及其权重,以准备运行请求。

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
    model = Gemma3ForMultimodalLM(model_config)
    model.load_state_dict(torch.load(ckpt_path)['model_state_dict'])
    model = model.to(device).eval()
print("Model loading done.")

print('Generating requests in chat mode...')

运行推断

以下示例展示了如何在聊天模式下生成内容以及如何通过多个请求生成内容。

指令调优的 Gemma 模型在训练和推理期间都使用特定的格式化程序进行训练,该格式化程序会使用额外信息来注释指令调优示例。注释 (1) 表示对话中的角色,(2) 表示对话中的回合。

相关注释令牌包括:

  • user:用户对话轮次
  • model:模型回合数
  • <start_of_turn>:对话轮次的开始
  • <start_of_image>:用于图片数据输入的标记
  • <end_of_turn><eos>:对话轮次结束

如需了解详情,请点击此处,阅读有关指令调优 Gemma 模型的提示格式设置的文章。

使用文本生成文本

以下示例代码段演示了如何在多轮对话中使用用户和模型聊天模板,为指令调优的 Gemma 模型设置提示格式。

# Chat templates
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn><eos>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn><eos>\n"

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=256,
)
Chat prompt:
 <start_of_turn>user
What is a good place for travel in the US?<end_of_turn><eos>
<start_of_turn>model
California.<end_of_turn><eos>
<start_of_turn>user
What can I do in California?<end_of_turn><eos>
<start_of_turn>model
"California is a state brimming with diverse activities! To give you a great list, tell me: \n\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \n* **What's your budget like?** \n* **Who are you traveling with?** (family, friends, solo)  \n\nThe more you tell me, the better recommendations I can give! 😊  \n<end_of_turn>"
# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=100,
)
"\n\nA swirling cloud of data, raw and bold,\nIt hums and whispers, a story untold.\nAn LLM whispers, code into refrain,\nCrafting words of rhyme, a lyrical strain.\n\nA world of pixels, logic's vibrant hue,\nFlows through its veins, forever anew.\nThe human touch it seeks, a gentle hand,\nTo mold and shape, understand.\n\nEmotions it might learn, from snippets of prose,\nInspiration it seeks, a yearning"

生成带有图片的文本

在 Gemma 版本 3 及更高版本中,您可以在提示中使用图片。以下示例展示了如何在提示中添加视觉数据。

print('Chat with images...\n')

def read_image(url):
    import io
    import requests
    import PIL

    contents = io.BytesIO(requests.get(url).content)
    return PIL.Image.open(contents)

image = read_image(
    'https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png'
)

print(model.generate(
    [
        [
            '<start_of_turn>user\n',
            image,
            'What animal is in this image?<end_of_turn>\n',
            '<start_of_turn>model\n'
        ]
    ],
    device=device,
    output_len=256,
))

了解详情

现在,您已了解如何在 PyTorch 中使用 Gemma,接下来可以前往 ai.google.dev/gemma 探索 Gemma 的其他众多功能。

另请参阅以下相关资源: