jykoh commited on
Commit
0b88c5c
1 Parent(s): a03fe94
Files changed (1) hide show
  1. fromage/models.py +1 -1
fromage/models.py CHANGED
@@ -644,7 +644,7 @@ def load_fromage(embeddings_dir: str, model_args_path: str, model_ckpt_path: str
644
  model = model.bfloat16()
645
  model = model.cuda()
646
 
647
- Load pretrained linear mappings and [RET] embeddings.
648
  checkpoint = torch.load(model_ckpt_path)
649
  model.load_state_dict(checkpoint['state_dict'], strict=False)
650
  with torch.no_grad():
 
644
  model = model.bfloat16()
645
  model = model.cuda()
646
 
647
+ # Load pretrained linear mappings and [RET] embeddings.
648
  checkpoint = torch.load(model_ckpt_path)
649
  model.load_state_dict(checkpoint['state_dict'], strict=False)
650
  with torch.no_grad():