summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--candle-examples/examples/marian-mt/convert_slow_tokenizer.py20
2 files changed, 17 insertions, 4 deletions
diff --git a/.gitignore b/.gitignore
index 58ff2515..50036de7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -40,3 +40,4 @@ candle-wasm-examples/*/package-lock.json
candle-wasm-examples/**/config*.json
.DS_Store
.idea/*
+__pycache__
diff --git a/candle-examples/examples/marian-mt/convert_slow_tokenizer.py b/candle-examples/examples/marian-mt/convert_slow_tokenizer.py
index 8ae32f79..33a887b6 100644
--- a/candle-examples/examples/marian-mt/convert_slow_tokenizer.py
+++ b/candle-examples/examples/marian-mt/convert_slow_tokenizer.py
@@ -43,6 +43,14 @@ def import_protobuf(error_message=""):
else:
raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
+def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
+ if add_prefix_space:
+ prepend_scheme = "always"
+ if hasattr(original_tokenizer, "legacy") and not original_tokenizer.legacy:
+ prepend_scheme = "first"
+ else:
+ prepend_scheme = "never"
+ return prepend_scheme
class SentencePieceExtractor:
"""
@@ -519,13 +527,15 @@ class SpmConverter(Converter):
)
def pre_tokenizer(self, replacement, add_prefix_space):
- return pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
+ return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
def post_processor(self):
return None
def decoder(self, replacement, add_prefix_space):
- return decoders.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space)
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
+ return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
def converted(self) -> Tokenizer:
tokenizer = self.tokenizer(self.proto)
@@ -636,7 +646,8 @@ class DebertaV2Converter(SpmConverter):
list_pretokenizers = []
if self.original_tokenizer.split_by_punct:
list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
- list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space))
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
+ list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
return pre_tokenizers.Sequence(list_pretokenizers)
def normalizer(self, proto):
@@ -929,10 +940,11 @@ class PegasusConverter(SpmConverter):
return proto.trainer_spec.unk_id + self.original_tokenizer.offset
def pre_tokenizer(self, replacement, add_prefix_space):
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
return pre_tokenizers.Sequence(
[
pre_tokenizers.WhitespaceSplit(),
- pre_tokenizers.Metaspace(replacement=replacement, add_prefix_space=add_prefix_space),
+ pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
]
)