"""GPU compute type definitions and utilities.
This module defines the available GPU types for serverless compute and provides
utilities for working with GPU configurations and resource allocation.
"""
from enum import Enum
from serverless_gpu.exceptions import GPUTypeError
[docs]class GPUType(Enum):
"""Enumeration of available GPU types for serverless compute.
This enum defines the GPU types that can be used for distributed computing
on the serverless GPU platform.
Attributes:
H100: NVIDIA H100 80GB GPU instances.
A10: NVIDIA A10 GPU instances.
"""
H100 = "h100_80gb"
A10 = "a10"
[docs] @classmethod
def values(cls):
"""Get all GPU type values.
Returns:
List[str]: List of all available GPU type values.
"""
return [gpu_type.value for gpu_type in cls]
@classmethod
def _missing_(cls, value: str):
"""Handle missing enum values with flexible string matching.
Provides case-insensitive matching and handles common aliases.
Args:
value (str): The string value to match against enum members.
Returns:
GPUType: The matching GPU type if found.
Raises:
GPUTypeError: If the value doesn't match any GPU type.
"""
if isinstance(value, str):
v = value.lower()
if v == "a10g":
return cls.A10
# first try matching against the enum values
for member in cls:
if member.value.lower() == v:
return member
# then try matching against the enum names
for member in cls:
if member.name.lower() == v:
return member
raise GPUTypeError(f"Invalid GPU type: '{value}'. Available GPU types: {cls.values()}")
[docs]class DisplayGpuType(Enum):
"""GPU type to Display mapping"""
a10 = "A10"
h100_80gb = "8xH100"
def get_gpus_per_node(gpu_type: GPUType) -> int:
"""Get the number of GPUs available per node for a given GPU type.
Args:
gpu_type (GPUType): The GPU type to query.
Returns:
int: The number of GPUs per node (8 for H100, 1 for A10).
Raises:
GPUTypeError: If the GPU type is not supported.
"""
# TODO: If possible, do not hardcode.
if gpu_type == GPUType.H100:
return 8
elif gpu_type == GPUType.A10:
return 1
raise GPUTypeError(f"Invalid GPU type: {gpu_type}. Available types: {GPUType.values()}")