Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.
/ alpa Public archive

Training and serving large-scale neural networks with auto parallelization.

License

Notifications You must be signed in to change notification settings

alpa-projects/alpa

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Dec 9, 2023
b8078a9 Â· Dec 9, 2023
Sep 7, 2022
Aug 21, 2023
Aug 21, 2023
Nov 22, 2022
Apr 3, 2023
Jun 18, 2023
Jun 13, 2023
Jul 4, 2022
May 19, 2023
Jan 7, 2023
Oct 29, 2022
May 18, 2022
Jun 22, 2022
Jun 7, 2022
Sep 9, 2022
Dec 9, 2023
Aug 30, 2022
Mar 13, 2023
Jul 23, 2022

Repository files navigation

Note: Alpa is not actively maintained currently. It is available as a research artifact. The core algorithm in Alpa has been merged into XLA, which is still being maintained. https://github.com/openxla/xla/tree/main/xla/hlo/experimental/auto_sharding

logo

CI Build Jaxlib

Documentation | Slack

Alpa is a system for training and serving large-scale neural networks.

Scaling neural networks to hundreds of billions of parameters has enabled dramatic breakthroughs such as GPT-3, but training and serving these large-scale neural networks require complicated distributed system techniques. Alpa aims to automate large-scale distributed training and serving with just a few lines of code.

The key features of Alpa include:

💻 Automatic Parallelization. Alpa automatically parallelizes users' single-device code on distributed clusters with data, operator, and pipeline parallelism.

🚀 Excellent Performance. Alpa achieves linear scaling on training models with billions of parameters on distributed clusters.

✨ Tight Integration with Machine Learning Ecosystem. Alpa is backed by open-source, high-performance, and production-ready libraries such as Jax, XLA, and Ray.

Serving

The code below shows how to use huggingface/transformers interface and Alpa distributed backend for large model inference. Detailed documentation is in Serving OPT-175B using Alpa.

from transformers import AutoTokenizer
from llm_serving.model.wrapper import get_model

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
tokenizer.add_bos_token = False

# Load the model. Alpa automatically downloads the weights to the specificed path
model = get_model(model_name="alpa/opt-2.7b", path="~/opt_weights/")

# Generate
prompt = "Paris is the capital city of"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(input_ids=input_ids, max_length=256, do_sample=True)
generated_string = tokenizer.batch_decode(output, skip_special_tokens=True)

print(generated_string)

Training

Use Alpa's decorator @parallelize to scale your single-device training code to distributed clusters. Check out the documentation site and examples folder for installation instructions, tutorials, examples, and more.

import alpa

# Parallelize the training step in Jax by simply using a decorator
@alpa.parallelize
def train_step(model_state, batch):
    def loss_func(params):
        out = model_state.forward(params, batch["x"])
        return jnp.mean((out - batch["y"]) ** 2)

    grads = grad(loss_func)(model_state.params)
    new_model_state = model_state.apply_gradient(grads)
    return new_model_state

# The training loop now automatically runs on your designated cluster
model_state = create_train_state()
for batch in data_loader:
    model_state = train_step(model_state, batch)

Learning more

Getting Involved

License

Alpa is licensed under the Apache-2.0 license.