lklimkiewicz commited on
Commit
0b117e9
1 Parent(s): f8ed495

Upload OPTForMusicGeneration

Browse files
Files changed (2) hide show
  1. config.json +3 -0
  2. model.py +16 -0
config.json CHANGED
@@ -7,6 +7,9 @@
7
  "OPTForMusicGeneration"
8
  ],
9
  "attention_dropout": 0.0,
 
 
 
10
  "bos_token_id": 1,
11
  "do_layer_norm_before": true,
12
  "do_sample": true,
 
7
  "OPTForMusicGeneration"
8
  ],
9
  "attention_dropout": 0.0,
10
+ "auto_map": {
11
+ "AutoModel": "model.OPTForMusicGeneration"
12
+ },
13
  "bos_token_id": 1,
14
  "do_layer_norm_before": true,
15
  "do_sample": true,
model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import OPTForCausalLM, OPTConfig, AutoModel
2
+ import torch
3
+
4
+ from miditok import TokSequence
5
+
6
+
7
+ class OPTForMusicGeneration(OPTForCausalLM):
8
+
9
+ def generate_music(self, **kwargs):
10
+ input = torch.tensor([[self.config.bos_token_id]], device=self.device)
11
+ midi = self.generate(input, **kwargs)
12
+ generated_ts = TokSequence(ids=midi.tolist()[0], ids_bpe_encoded=True)
13
+ return generated_ts
14
+
15
+
16
+ OPTForMusicGeneration.register_for_auto_class("AutoModel")