Extracting language model only

#17
by mariboo - opened

I'm trying to get the language part out of the 90B model. If I do:

model_id = "meta-llama/Llama-3.2-90B-Vision-Instruct"
model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
sum(p.numel() for p in model.language_model.parameters())

I get 87,666,865,192 but I was expecting around 70B parameters as per the blog post Llama 3.1 was a starting point?

The language model has 20 cross-attention layers in addition of the 80 decoder-only layers (layers 3, 8, 13... 98). To extract the language decoder only, you have to discard these layers and you should get 70B parameters.

Ok, now it makes sense - thanks 👍

Sign up or log in to comment