Skip to content

Commit

Permalink
Improved interaction between 'DiscreteDomain.data_values' and 'Discre…
Browse files Browse the repository at this point in the history
…teDomain.dtype' attributes. Fixes #411
  • Loading branch information
vruusmann committed Feb 13, 2024
1 parent 35f8671 commit 4729274
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 4 deletions.
33 changes: 30 additions & 3 deletions sklearn2pmml/decoration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pandas import DataFrame
from pandas import CategoricalDtype, DataFrame
from pandas.api.types import is_object_dtype
from sklearn.base import clone, BaseEstimator, TransformerMixin
try:
Expand All @@ -10,6 +10,7 @@
from sklearn2pmml.util import cast, common_dtype, is_1d, to_numpy

import copy
import itertools
import numpy
import numbers
import pandas
Expand Down Expand Up @@ -201,12 +202,21 @@ def __init__(self, missing_values = None, missing_value_treatment = "as_is", mis
if data_values:
if not with_data:
raise ValueError("Valid values require with_data attribute")
if isinstance(dtype, CategoricalDtype) and data_values != (dtype.categories).tolist():
raise ValueError("Valid values are invalid")
self.data_values = data_values

def _is_ordered(self):
raise NotImplementedError()

def _valid_value_mask(self, X, where):
if hasattr(self, "data_values_"):
data_values = self.data_values_

elif _is_pandas_categorical(self.dtype_):
data_values = self.dtype_.categories
else:
data_values = None
if data_values is not None:
def _isin_mask(x, values):
if hasattr(x, "isin"):
return x.isin(values)
Expand All @@ -231,7 +241,14 @@ def _isin_mask(x, values):

def fit(self, X, y = None):
if self.dtype is not None:
X = cast(X, self.dtype)
if isinstance(self.dtype, str) and self.dtype == "category":
if self.data_values is not None:
dtype = CategoricalDtype(list(itertools.chain.from_iterable(self.data_values)), ordered = self._is_ordered())
else:
dtype = self.dtype
X = cast(X, dtype)
else:
X = cast(X, self.dtype)
self.dtype_ = common_dtype(X)
if self._empty_fit():
return self
Expand Down Expand Up @@ -284,11 +301,21 @@ class CategoricalDomain(DiscreteDomain):

def __init__(self, missing_values = None, missing_value_treatment = "as_is", missing_value_replacement = None, invalid_value_treatment = "return_invalid", invalid_value_replacement = None, with_data = True, with_statistics = False, dtype = None, display_name = None, data_values = None):
super(CategoricalDomain, self).__init__(missing_values = missing_values, missing_value_treatment = missing_value_treatment, missing_value_replacement = missing_value_replacement, invalid_value_treatment = invalid_value_treatment, invalid_value_replacement = invalid_value_replacement, with_data = with_data, with_statistics = with_statistics, dtype = dtype, display_name = display_name, data_values = data_values)
if isinstance(dtype, CategoricalDtype) and dtype.ordered:
raise ValueError()

def _is_ordered(self):
return False

class OrdinalDomain(DiscreteDomain):

def __init__(self, missing_values = None, missing_value_treatment = "as_is", missing_value_replacement = None, invalid_value_treatment = "return_invalid", invalid_value_replacement = None, with_data = True, with_statistics = False, dtype = None, display_name = None, data_values = None):
super(OrdinalDomain, self).__init__(missing_values = missing_values, missing_value_treatment = missing_value_treatment, missing_value_replacement = missing_value_replacement, invalid_value_treatment = invalid_value_treatment, invalid_value_replacement = invalid_value_replacement, with_data = with_data, with_statistics = with_statistics, dtype = dtype, display_name = display_name, data_values = data_values)
if isinstance(dtype, CategoricalDtype) and not dtype.ordered:
raise ValueError()

def _is_ordered(self):
return True

def _interquartile_range(X, axis):
quartiles = numpy.nanpercentile(X, [25, 75], axis = axis)
Expand Down
14 changes: 13 additions & 1 deletion sklearn2pmml/decoration/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,25 @@ def test_fit_int_missing(self):
self.assertEqual([1, 2, 0], Xt[:, 1].tolist())

def test_fit_int_categorical(self):
domain = clone(CategoricalDomain(dtype = "category"))
self.assertFalse(hasattr(domain, "dtype_"))
X = Series([-1, 0, 1, 0, -1])
Xt = domain.fit_transform(X)
self.assertTrue(isinstance(domain.dtype_, CategoricalDtype))
self.assertEqual([-1, 0, 1], domain.dtype_.categories.tolist())
domain = clone(CategoricalDomain(dtype = CategoricalDtype()))
self.assertIsNone(domain.dtype.categories)
self.assertFalse(hasattr(domain, "dtype_"))
X = Series([-1, 0, 1, 0, -1])
Xt = domain.fit_transform(X)
self.assertIsNone(domain.dtype.categories)
self.assertEqual([-1, 0, 1], domain.dtype_.categories.tolist())
self.assertEqual([-1, 0, 1, 0, -1], Xt.values.tolist())
domain = clone(CategoricalDomain(invalid_value_treatment = "as_value", invalid_value_replacement = 0, data_values = [[0, 1]], dtype = "category"))
self.assertFalse(hasattr(domain, "dtype_"))
Xt = domain.fit_transform(X)
self.assertTrue(isinstance(domain.dtype_, CategoricalDtype))
self.assertEqual([0, 1], domain.dtype_.categories.tolist())
self.assertEqual([0, 0, 1, 0, 0], Xt.values.tolist())

def test_fit_int64(self):
domain = clone(CategoricalDomain())
Expand Down

0 comments on commit 4729274

Please sign in to comment.