Implement A Vision On A LLM
Implement A Vision On A LLM
Back to Articles
Upvote 12 +6
AviSoori1x
Avinash Sooriyarachchi
Motivation
Vision language models have become a topic of great interest in the machine
learning community due to the capabilities displayed by GPT-4, Grok 1.5, Claude 3
and Google Gemini. In addition to these proprietary multimodal (primarily vision-
language) models, there have been a number of highly performant open models
such as LLaVa, Kosmos from Microsoft and most recently, Idefics2 from Hugging
Face.
Although the term vision language model could mean a number of things, the
current wave of this class of models tend to demonstrate instruction following
capabilities over both image and text inputs. In essence, you can expect a vision
language model to write you a poem about how great sushi is and at the same time
be able to count the number of sushi rolls on a given plate, given an image. I want to
make this clear as there’s a rich collection of other types of vision language models
such as CLIP and more recent variations such as SigLIP that are very important but
quite different in how they are used. As a matter of fact, we will look at how
components from these architectures are used in the current crop of vision language
models.
For the purpose of this blog, I will focus on this type of vision language models that
can be instruction tuned to perform useful tasks. More specifically, here I will
specify a common architectural pattern that seems to be taking shape and proving to
be highly versatile.
1. Image Encoder to extract visual features from images. In this case I use a from
scratch implementation of the original vision transformer used in CLIP. This is
actually a popular choice in many modern VLMs. The one notable exception is
the Fuyu series of models from Adept, that passes the patchified images directly
to the projection layer.
So in summary, an image encoder extracts features from a given image, passes these
image embeddings to a vision-language projector which projects these image
embeddings to the text embedding space, that is then concatenated with the text
embeddings from the text inputs and used to autoregressively generate text by a
decoder only language model.
When you zoom out, it’s not all that complicated and honestly, quite clever. It’s also
kind of amazing that this works. Just like everything else in deep learning.
source: https://openai.com/research/clip
There has been a trend of vision language models getting much better performance
using a vision transformer from an improved version of CLIP known as SigLIP that
uses a sigmoid loss instead of the cross entropy loss used in the contrastive learning
task of CLIP. A great example of a tiny vision language model using the vision
transformer from SigLIP punching way above its weight (literally. it’s only 1.6B
parameters in total) is moondream 2 by vikhyat:
https://github.com/vikhyat/moondream.
However, for the sake of simplicity, we assume that the CLIP version is used here but
the implementation would be identical. In seemore, I use the embedding
corresponding to the ‘[CLS]’ token as the feature vector that represents the entire
image. This is done for the sake of simplicity. However, it is possible, and likely
better, to choose all the feature vectors from the last layer of the vision transformer.
My assumption is that this will help with tasks such as counting and OCR, where
spatial information is available in a less compressed manner for the decoder.
source: https://arxiv.org/pdf/2010.11929.pdf
return X
In the above code, the input image is broken down to (img_size // patch_size) ** 2
patches using the convolution layer and projected into vectors with a channel
dimension (the C, in [B, T, C] shape commonly encountered in pytorch
implementations for 3D tensors) of 512.
Attention Mechanism across both the vision encoder and
language decoder
Things get interesting when building the components seen in the transformer
blocks. i.e. The attention head implementation, multi head attention, the multilayer
perceptron seen in each transformer block and the transformer block itself. These
components are mostly identical across the vision transformer we are implementing
for the ‘visual token’ generation and the decoder language model for the actual text
output generation.
The only key difference is the masking applied in each attention head in the decoder
language model. This is done to ensure the integrity of the autoregressive language
generation process, particularly in a decoder-only model, the code implements
masking. This masking technique is crucial as it obscures any information following
the current token's position, thereby directing the model's attention to only the
preceding parts of the sequence. Such an attention mechanism is known as causal
self-attention.
In the above image, the lower triangular mask is only applied in the case of a
decoder model. Consider the bright blue triangle in matrix W absent in the case of
visualizing the process in each attention head in the vision encoder.
So here I implement these components in such a manner that they can be shared for
both the vision encoder and language decoder by passing in an is_decoder boolean
argument to the class constructor.
The code for causal self attention and multi-head causal self attention can be
organized as follows. Multi-head self attention applies multiple attention heads in
parallel, each focusing on a separate section of the channel (the embedding
dimension). Multi-head self attention essentially improves the learning process and
improves efficiency of model training due to the inherently parallel implementation.
Notice I have used dropout throughout this implementation for regularization i.e.
preventing overfitting.
class Head(nn.Module):
def __init__(self, n_embd, head_size, dropout=0.1, is_decoder=False
super().__init__()
if self.is_decoder:
# If this head is used in the decoder, apply a causal mask
# to prevent attending to future positions
tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device
wei = wei.masked_fill(tril == 0, float('-inf'))
return out
class MultiHeadAttention(nn.Module):
def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False
super().__init__()
# Concatenate the outputs from all heads along the last dimensi
out = torch.cat(head_outputs, dim=-1)
return out
The multilayer perceptron that follows each multihead attention module is quite
straightforward. Please note that I’ve noticed GELU being used quite often in Vision
Transformers and ReLU used in text transformers, so I have this conditional logic to
switch between the two based on where this MLP will be inserted. However, it seems
that GELU is being used for both due to its resultant model performance, regardless
of the fact that it’s more computationally expensive that RELU.
class MLP(nn.Module):
def __init__(self, n_embd, dropout=0.1, is_decoder=True):
super().__init__()
class Block(nn.Module):
def __init__(self, n_embd, num_heads, dropout=0.1, is_decoder=False
super().__init__()
return x
Now the patchification logic and attention blocks can be combined to create the
vision transformer (ViT)
class ViT(nn.Module):
def __init__(self, img_size, patch_size, num_hiddens, num_heads, nu
super().__init__()
return x
Overall, the ViT class encapsulates the architecture and forward pass of a Vision
Transformer model. It takes an input image, converts it into patch embeddings, adds
positional information, and processes the embeddings through a series of
transformer blocks to generate a meaningful representation of the image. The final
representation returned is the embedding corresponding to the CLS token, which is
then used to condition the text generation in the language decoder.
Here’s the implementation of this projection module. It’s not too different from the
MLP used in the transformer blocks.
class MultiModalProjector(nn.Module):
def __init__(self, n_embd, image_embed_dim, dropout=0.1):
super().__init__()
The final component we need to look at is the decoder language model. Here I’ve
remained within the confines of the modern VLM architecture but deviated a bit in
the implementation. I have integrated the projection module into the decoder
model class implementation. This is because I built everything from scratch and
wanted to retain the causal language model architecture from Andrej Karpathy’s
makemore. There’s no easy way to directly feed in reshaped embeddings in this
implementation, so I’ve had to improvise a little. Please keep in mind that in using
pretrained models with the Hugging Face API or any other modern library that
allows you to use pretrained large language models, you can directly feed
embeddings as input to the model (e.g. using inputs_embeds parameter:
https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2
That being said, what I’ve done here is an interesting exercise in that it allows you to
see in pretty simple code:
How the image embeddings are reshaped using the vision language projector to
match that of text embeddings.
Essentially the text generation is conditioned on the initial image input. This can be
modified in a number of ways to work with interleaved text and images, which will
be useful for multi-turn conversation i.e. chat scenarios using the finetuned VLM. A
number of useful tips can be found in this paper by Apple:
https://arxiv.org/pdf/2403.09611.pdf.
The crucial parts of this decoder implementation is given below. Note how the
is_decoder flag is passed as ‘True’ to use the masked version of the self attention
blocks, resulting in causal scaled dot product self attention in the language decoder.
Please refer to the GitHub repo linked above for the full implementation.
class DecoderLanguageModel(nn.Module):
def __init__(self, n_embd, image_embed_dim, vocab_size, num_heads,
super().__init__()
self.use_images = use_images
if use_images:
# Image projection layer to align image embeddings with tex
self.image_projection = MultiModalProjector(n_embd, image_e
return logits
Now that we have our three key components, we can put it all together into a Vision
Language Model. The full implementation is given below. If you were to remove the
assert statements for error handling, this looks very simple. Coming back full circle
to the outline I’ve given at the beginning of the blog, all that’s happening here is:
1. Get image features from the vision encoder (Here it’s a vision transformer, but
it could be any model that could generate features from an image input such as
a ResNet or a traditional convolutional neural network (needless to say
performance may suffer))
2. A projection module for projecting image tokens to the same embedding space
as text embeddings for the decoder (this projector is integrated with the
decoder in this implementation).
class VisionLanguageModel(nn.Module):
def __init__(self, n_embd, image_embed_dim, vocab_size, n_layer, im
super().__init__()
1. Get pretrained vision encoder from SigLIP or CLIP (both come in difference
sizes). Freeze weights (i.e. don’t update during backward pass in training).
2. Get pretrained decoder only language model e.g. all the way from TinyLLaMA,
Phi-2 etc. to Llama 3 (or even much bigger in the case of GPT-4 and Grok 1.5
etc.). Freeze weights.
3. Implement a projection module and train a VLM module much like what we
have here, but only updating the weights of this projection module. This would
effectively be the pretraining phase.
4. Then during the instruction finetuning keep both the projection module and
the decoder language model unfrozen and update weights of both in the
backward pass.
I developed this on Databricks using a single T4 GPU and MLFlow for tracking loss
(during the training process). I wanted to set things up this way so that I can scale
up to a GPU cluster of any size I want quite easily on Databricks, should I decide to
adapt this to a more performance oriented implementation. However, you can run
this anywhere, with or without a GPU. Please note that even the toy training loop
with 90 samples will be painfully slow on a CPU.
Company
TOS
Privacy
About
Jobs
Website
Models
Datasets
Spaces
Pricing
Docs
© Hugging Face