Skip to content

Commit

Permalink
Modify tfidf_transformer to enable custom vocabulary and approximate …
Browse files Browse the repository at this point in the history
…sublinear-tf scaling without sparse containers (#777)

* Extend one example in the documentation (WOE)
* Add a function to update initializers in a ONNX model
* Update text_vectoriser.py
* add unit test and format fix

Signed-off-by: adam444555 <a473489548@gmail.com>
  • Loading branch information
adam444555 authored Nov 15, 2021
1 parent b4c679b commit 09be7da
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 20 deletions.
7 changes: 5 additions & 2 deletions skl2onnx/operator_converters/text_vectoriser.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,11 @@ def convert_sklearn_text_vectorizer(scope: Scope, operator: Operator,
"You may raise an issue at "
"https://github.com/onnx/sklearn-onnx/issues.")

stop_words = op.stop_words_ | (
set(op.stop_words) if op.stop_words else set())
if hasattr(op, "stop_words_"):
stop_words = op.stop_words_ | (
set(op.stop_words) if op.stop_words else set())
else:
stop_words = set()

if op.lowercase or stop_words:
if len(operator.input_full_names) != 1:
Expand Down
36 changes: 18 additions & 18 deletions skl2onnx/operator_converters/tfidf_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..common._topology import Scope, Operator
from ..common._container import ModelComponentContainer
from ..common._apply_operation import (
apply_mul, apply_identity, apply_normalizer)
apply_add, apply_log, apply_mul, apply_identity, apply_normalizer)


def convert_sklearn_tfidf_transformer(scope: Scope, operator: Operator,
Expand All @@ -30,23 +30,23 @@ def convert_sklearn_tfidf_transformer(scope: Scope, operator: Operator,
# code scikit-learn
# np.log(X.data, X.data) --> does not apply on null coefficient
# X.data += 1
raise RuntimeError(
"ONNX does not support sparse tensors before opset < 11, "
"sublinear_tf must be False.")

# In case sparse is enabled.
# C = operator.inputs[0].type.shape[1]
# logged = scope.get_unique_variable_name('logged')
# apply_log(scope, data, logged, container)
# if not op.use_idf and op.norm is None:
# loggedplus1 = final
# else:
# loggedplus1 = scope.get_unique_variable_name('loggedplus1')
# ones = scope.get_unique_variable_name('ones')
# cst = np.ones((C,), dtype=float_type)
# container.add_initializer(ones, proto_dtype, [C], cst.flatten())
# apply_add(scope, [logged, ones], loggedplus1, container, broadcast=1)
# data = [loggedplus1]
# ONNX does not support sparse tensors before opset < 11
# approximated by X.data += 1 --> np.log(X.data, X.data)
if operator.target_opset < 11:
plus1 = scope.get_unique_variable_name("plus1")
C = operator.inputs[0].type.shape[1]
ones = scope.get_unique_variable_name("ones")
cst = np.ones((C,), dtype=float_type)
container.add_initializer(ones, proto_dtype, [C], cst.flatten())
apply_add(scope, data + [ones], plus1, container, broadcast=1)
plus1logged = scope.get_unique_variable_name("plus1logged")
apply_log(scope, plus1, plus1logged, container)
data = [plus1logged]
else:
# sparse containers have not yet been implemented.
raise RuntimeError(
"ONNX does not support sparse tensors before opset < 11, "
"sublinear_tf must be False.")

if op.use_idf:
cst = op.idf_.astype(float_type)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_sklearn_tfidf_vectorizer_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,34 @@ def test_model_tfidf_vectorizer_nan(self):
assert res.shape == (4, 9)
assert numpy.isnan(res[0, 0])

@unittest.skipIf(
StrictVersion(onnx.__version__) <= StrictVersion("1.4.1"),
reason="Requires opset 9.")
def test_model_tfidf_vectorizer11_custom_vocabulary(self):
corpus = numpy.array([
"This is the first document.",
"This document is the second document.",
"And this is the third one.",
"Is this the first document?",
]).reshape((4, 1))
vc = ["first", "second", "third", "document", "this"]
vect = TfidfVectorizer(ngram_range=(1, 1), norm=None, vocabulary=vc)
vect.fit(corpus.ravel())
self.assertFalse(hasattr(vect, "stop_words_"))
model_onnx = convert_sklearn(vect, "TfidfVectorizer",
[("input", StringTensorType())],
options=self.get_options(),
target_opset=TARGET_OPSET)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
corpus,
vect,
model_onnx,
basename="SklearnTfidfVectorizer11CustomVocab-OneOff-SklCol",
allow_failure="StrictVersion(onnxruntime.__version__)"
" <= StrictVersion('0.4.0')",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 09be7da

Please sign in to comment.