File size: 579 Bytes
e7d5680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

from opensora.registry import MODELS


@MODELS.register_module("classes")
class ClassEncoder:
    def __init__(self, num_classes, model_max_length=None, device="cuda", dtype=torch.float):
        self.num_classes = num_classes
        self.y_embedder = None

        self.model_max_length = model_max_length
        self.output_dim = None
        self.device = device

    def encode(self, text):
        return dict(y=torch.tensor([int(t) for t in text]).to(self.device))

    def null(self, n):
        return torch.tensor([self.num_classes] * n).to(self.device)