if001 commited on
Commit
da15cde
1 Parent(s): 9e3e2aa
Files changed (1) hide show
  1. sentencepiece_ja.py +27 -22
sentencepiece_ja.py CHANGED
@@ -4,15 +4,30 @@ from typing import Union, List, Optional, Tuple
4
  from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5
 
6
  class SentencePieceJA(PreTrainedTokenizer):
7
- def __init__(self, model_path = "./tokenizer.json", **kwargs):
8
- super().__init__(**kwargs)
 
 
 
 
 
 
9
  from tokenizers import Tokenizer
10
- self._tokenizer = Tokenizer.from_file(model_path)
11
- self.__pad_id = self._tokenize("<PAD>")[0]
12
- self.__bos_id = self._tokenize("<BOS>")[0]
13
- self.__eos_id = self._tokenize("<EOS>")[0]
14
- self.__unk_id = self._tokenize("<UNK>")[0]
15
- self.__mask_id = self._tokenize("<MASK>")[0]
 
 
 
 
 
 
 
 
 
16
 
17
  def get_vocab(self) -> int:
18
  return self._tokenizer.get_vocab()
@@ -20,24 +35,14 @@ class SentencePieceJA(PreTrainedTokenizer):
20
  def vocab_size(self) -> int:
21
  return self._tokenizer.get_vocab_size()
22
 
23
- def _tokenize(self, text, **kwargs):
24
- return self._tokenizer.encode(text).ids
25
 
26
  def _convert_token_to_id(self, token):
27
- return token
28
 
29
- def _convert_id_to_token(self, index: int) -> str:
30
  return self._tokenizer.decode(index)
31
- # return self._tokenizer.id_to_token(index)
32
-
33
- def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
34
- return tokens
35
-
36
- def convert_ids_to_tokens(
37
- self, ids: Union[int, List[int]], skip_special_tokens: bool = False
38
- ) -> Union[str, List[str]]:
39
- decoded = self._tokenizer.decode(ids)
40
- return decoded
41
 
42
  def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
43
  index = 0
 
4
  from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5
 
6
  class SentencePieceJA(PreTrainedTokenizer):
7
+ def __init__(self,
8
+ model_path = "./tokenizer.json",
9
+ pad = "<PAD>",
10
+ bos = "<BOS>",
11
+ eos = "<EOS>",
12
+ unk = "<UNK>",
13
+ mask = "<MASK>",
14
+ **kwargs):
15
  from tokenizers import Tokenizer
16
+ self._tokenizer = Tokenizer.from_file(model_path)
17
+ super().__init__(
18
+ pad_token=pad,
19
+ bos_token=bos,
20
+ eos_token=eos,
21
+ unk_token=unk,
22
+ mask_token=mask,
23
+ **kwargs)
24
+ self.add_special_tokens({
25
+ 'pad_token': pad,
26
+ 'bos_token': bos,
27
+ 'eos_token': eos,
28
+ 'unk_token': unk,
29
+ 'mask_token': mask
30
+ })
31
 
32
  def get_vocab(self) -> int:
33
  return self._tokenizer.get_vocab()
 
35
  def vocab_size(self) -> int:
36
  return self._tokenizer.get_vocab_size()
37
 
38
+ def _tokenize(self, text, **kwargs):
39
+ return self._tokenizer.encode(text).tokens
40
 
41
  def _convert_token_to_id(self, token):
42
+ return self._tokenizer.encode(token).ids[0]
43
 
44
+ def _convert_id_to_token(self, index: int) -> str:
45
  return self._tokenizer.decode(index)
 
 
 
 
 
 
 
 
 
 
46
 
47
  def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
48
  index = 0