Cristian283 commited on
Commit
53fb5c9
1 Parent(s): bfb1bd0
Files changed (1) hide show
  1. Jd +18 -0
Jd ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
+
3
+ class TextToImageGenerator(torch.nn.Module):
4
+ def __init__(self, model_name="gpt2"):
5
+ super(TextToImageGenerator, self).__init__()
6
+ self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
7
+ self.gpt2 = GPT2LMHeadModel.from_pretrained(model_name)
8
+
9
+ def forward(self, input_text):
10
+ input_ids = self.tokenizer(input_text, return_tensors="pt")["input_ids"]
11
+ output = self.gpt2(input_ids, return_dict=True)
12
+ return output.logits
13
+
14
+ # Instanciar el modelo
15
+ model = TextToImageGenerator()
16
+
17
+ # Imprimir la arquitectura del modelo
18
+ print(model)