diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | candle-examples/examples/marian-mt/convert_slow_tokenizer.py | 20 |
2 files changed, 17 insertions, 4 deletions
@@ -40,3 +40,4 @@ candle-wasm-examples/*/package-lock.json candle-wasm-examples/**/config*.json .DS_Store .idea/* +__pycache__ diff --git a/candle-examples/examples/marian-mt/convert_slow_tokenizer.py b/candle-examples/examples/marian-mt/convert_slow_tokenizer.py index 8ae32f79..33a887b6 100644 --- a/candle-examples/examples/marian-mt/convert_slow_tokenizer.py +++ b/candle-examples/examples/marian-mt/convert_slow_tokenizer.py @@ -43,6 +43,14 @@ def import_protobuf(error_message=""): else: raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message)) +def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str: + if add_prefix_space: + prepend_scheme = "always" + if hasattr(original_tokenizer, "legacy") and not original_tokenizer.legacy: + prepend_scheme = "first" + else: + prepend_scheme = "never" + return prepend_scheme class SentencePieceExtractor: """ @@ -519,13 +527,15 @@ class SpmConverter(Converter): ) def pre_tokenizer(self, replacement, add_prefix_space): - return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space) + prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer) + return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) def post_processor(self): return None def decoder(self, replacement, add_prefix_space): - return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space) + prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer) + return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme) def converted(self) -> Tokenizer: tokenizer = self.tokenizer(self.proto) @@ -636,7 +646,8 @@ class DebertaV2Converter(SpmConverter): list_pretokenizers = [] if self.original_tokenizer.split_by_punct: list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated")) - list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)) + prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer) + list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)) return pre_tokenizers.Sequence(list_pretokenizers) def normalizer(self, proto): @@ -929,10 +940,11 @@ class PegasusConverter(SpmConverter): return proto.trainer_spec.unk_id + self.original_tokenizer.offset def pre_tokenizer(self, replacement, add_prefix_space): + prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer) return pre_tokenizers.Sequence( [ pre_tokenizers.WhitespaceSplit(), - pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space), + pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme), ] ) |