Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gettin gradient of loss during inference #3371

Open
LalchandPandia opened this issue Jan 28, 2025 · 4 comments
Open

Gettin gradient of loss during inference #3371

LalchandPandia opened this issue Jan 28, 2025 · 4 comments

Comments

@LalchandPandia
Copy link

I am fine-tuning llama 2 using accelerate+deepseed zero3. During evaluation, which is run after every checkpoint step, I need to calculate gradient loss w.r.t certain input ids. As per my understanding the embedding matrix is sharded and when I try to get the gradient, I get an error saying that grad is set to None. Is there a cleaner way to do it using accelerate APIs?
My code:

def token_gradients(model, input_ids, targets):

    """
    Computes gradients of the loss with respect to the coordinates.
    
    Parameters
    ----------
    model : Transformer Model
        The transformer model to be used.
    input_ids : torch.Tensor
        The input sequence in the form of token ids.
    input_slice : slice
        The slice of the input sequence for which gradients need to be computed.
    targets :torch.Tensor
        The target sequence in the form of token ids .
    loss_slice : slice
        The slice of the logits to be used for computing the loss.

    Returns
    -------
    torch.Tensor
        The gradients of each token in the input_slice with respect to the loss.
    """
    valid_positions = (targets != -100).nonzero(as_tuple=True)[0]
    input_slice = slice(0, valid_positions[0].item())
    end_input_slice = valid_positions[-1].item()
    print('Luke input_slice valid_positions ',input_slice, ' end_input_slice ',end_input_slice)


    #embed_weights = get_embedding_matrix(model)
    embeddings = model.get_input_embeddings()
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        embedding_weights = embeddings.weight
        embedding_size = embedding_weights.shape[0]
        print('embed_weights ',embedding_weights.shape)
    one_hot = torch.zeros(
        input_ids[input_slice].shape[0],
        embedding_size,
        device=model.device,
        dtype=embeddings.weight.dtype
    )
    print('input_ids[input_slice].shape[0] ',input_ids[input_slice].shape[0])
    #print('embed_weights.shape  ',' embedding_size ',embedding_size)
    print('one_hot.shape ',one_hot.shape)
    one_hot.scatter_(
        1,
        input_ids[input_slice].unsqueeze(1),
        torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embeddings.weight.dtype)
    )
    one_hot.requires_grad_()
    print('one_hot.shape ',one_hot.shape)
    #print('embeddings.weight ',embeddings.weight)
    input_embeds = (one_hot @ embedding_weights).unsqueeze(0)
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        #this contains embeddings for all ids present in input_slice
        input_embeds = (one_hot @ embeddings.weight)
        print('input_embeds grad ',input_embeds.grad)
    input_ids = input_ids.cpu().tolist()
        embeds = embeddings.weight[input_ids[:end_input_slice+1],:].detach()

        #Now we stitch the input_embeddings for all input ids before the input slice starts
        #embedings for all ids in input slice
        #embeddings for all ids after input slice:
        full_embeds = torch.cat(
            [
                embeds[:input_slice.start,:],
                input_embeds,
                embeds[input_slice.stop:,:]
            ],
            dim=0)
        full_embeds = full_embeds.unsqueeze(0)
        print('Luke full_embeds ',full_embeds.shape)


        logits = model(inputs_embeds=full_embeds).logits
        print('Luke logits ',logits.shape)
        #calculate loss for logit position correspondind to every token in vocabulary 
        loss = torch.nn.CrossEntropyLoss()(logits[0,:,:], targets[:end_input_slice+1])
        print('Luke loss ',loss.shape)

        loss.backward()
        print(one_hot.grad.shape)
    return one_hot.grad.clone()
@LalchandPandia
Copy link
Author

@muellerzr @ArthurZucker

@BenjaminBossan
Copy link
Member

For my understanding, did you confirm that the gradient being None is due to the usage of DeepSpeed and accelerate? That is, if you run the same code on a single device, the gradient is present? The reason why I'm asking is that at evaluation time, we typically don't calculate the gradient, as it's not needed. So for instance, your call might be inside a torch.inference_mode() context. Since you don't show the whole code, it's not possible to tell.

@LalchandPandia
Copy link
Author

@BenjaminBossan
Below is a code for single gpu set up that works perfectly:

import torch
from transformers import ( AutoModelForCausalLM, AutoTokenizer)
def token_gradients(model, input_ids, targets):

    valid_positions = (targets != -100).nonzero(as_tuple=True)[0]
    input_slice = slice(0, valid_positions[0].item())
    end_input_slice = valid_positions[-1].item()

    embeddings = model.get_input_embeddings()
    embedding_weights = embeddings.weight
    embedding_size = embedding_weights.shape[0]
    print('embed_weights ',embedding_weights.shape)
    one_hot = torch.zeros(
        input_ids[input_slice].shape[0],
        embedding_size,
        device=model.device,
        dtype=embeddings.weight.dtype
    )
    one_hot.scatter_(
        1,
        input_ids[input_slice].unsqueeze(1),
        torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embeddings.weight.dtype)
    )
    one_hot.requires_grad_()

    input_embeds = (one_hot @ embeddings.weight)
    input_embeds.requires_grad_()
    input_embeds.retain_grad()
    input_ids = input_ids.cpu().tolist()
    embeds = embeddings.weight[input_ids[:end_input_slice+1],:].detach()
    print('embeds shape ',embeds.shape)

    full_embeds = torch.cat(
        [
            embeds[:input_slice.start,:],
            input_embeds,
            embeds[input_slice.stop:,:]
        ],
        dim=0)
    full_embeds = full_embeds.unsqueeze(0)
    logits = model(inputs_embeds=full_embeds).logits
    #calculate loss for logit position correspondind to every token in vocabulary 
    loss = torch.nn.CrossEntropyLoss()(logits[0,:,:], targets[:end_input_slice+1])

    loss.backward()
    print(one_hot.grad.shape)
    print('one hot grad ',one_hot.grad)
    print('input embeds grad ',input_embeds.grad)
    return one_hot.grad.clone()
device = "cuda"

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf").to(device)
print('model.device ',model.device)
model.eval()
input = torch.tensor([[    1,   894, 29901,  5122, 10753,   304, 14294,   670,  6567,  9098,491, 14051, 10549,   963, 29889,  8449, 19309,  7101,   674,  7738, 278,  1556, 12871, 29973,    13, 22550, 29901, 15589,  5112,  1516]])
target = torch.tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,-100,  -100,  -100,  -100,  -100, 22550, 29901, 15589,  5112,  1516])
token_gradients(model, input, target

But with changes for distrubuted setup fails:

def token_gradients(model, input_ids, targets):
    valid_positions = (targets != -100).nonzero(as_tuple=True)[0]
    input_slice = slice(0, valid_positions[0].item())
    end_input_slice = valid_positions[-1].item()
    embeddings = model.get_input_embeddings()
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        embedding_size = embeddings.weight.shape
    one_hot = torch.zeros(
        input_ids[input_slice].shape[0],
        embedding_size[0],
        device=model.device,
        dtype=embeddings.weight.dtype
    )
    one_hot.scatter_(
        1,
        input_ids[input_slice].unsqueeze(1),
        torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embeddings.weight.dtype)
    )
    one_hot.requires_grad_()
    with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
        input_embeds = (one_hot @ embeddings.weight)
        input_embeds.requires_grad_()
        input_embeds.retain_grad()
        input_ids = input_ids.cpu().tolist()
        embeds = embeddings.weight[input_ids[:end_input_slice+1],:].detach()
        full_embeds = torch.cat(
            [
                embeds[:input_slice.start,:],
                input_embeds,
                embeds[input_slice.stop:,:]
            ],
            dim=0)
        full_embeds = full_embeds.unsqueeze(0)
        logits = model(inputs_embeds=full_embeds).logits
        loss = torch.nn.CrossEntropyLoss()(logits[0,:,:], targets[:end_input_slice+1])
        print('Luke loss ',loss.shape)

        loss.backward()
    return one_hot.grad.clone()

The loss in this case an empty tensor

@LalchandPandia
Copy link
Author

Logits in single gpu setup (without accelearate+deepspeed) is a tensor with grad_fn=. But with multiple gpus there is no grad_fn
Single GPU: logits tensor([[[-12.9832, -7.4134, -0.4327, ..., -6.8297, -8.0879, -7.5863],
[ -6.5046, -3.2412, 4.3043, ..., -1.0471, -4.2429, -2.2271]]],
device='cuda:0', grad_fn=)
Multiple GPU with accelerate:
ogits tensor([[[-12.9832, -7.4134, -0.4327, ..., -6.8297, -8.0879, -7.5863],
[ -6.5046, -3.2412, 4.3043, ..., -1.0471, -4.2429, -2.2271]]],
device='cuda:0')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants