Extracting language model only
#17
by
mariboo
- opened
I'm trying to get the language part out of the 90B model. If I do:
model_id = "meta-llama/Llama-3.2-90B-Vision-Instruct"
model = MllamaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
sum(p.numel() for p in model.language_model.parameters())
I get 87,666,865,192
but I was expecting around 70B parameters as per the blog post Llama 3.1 was a starting point?
The language model has 20 cross-attention layers in addition of the 80 decoder-only layers (layers 3, 8, 13... 98). To extract the language decoder only, you have to discard these layers and you should get 70B parameters.
Ok, now it makes sense - thanks 👍