Skip to content

Commit

Permalink
change deserialize method location to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
anencore94 committed Aug 28, 2021
1 parent c16ba72 commit ecf9e64
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 24 deletions.
6 changes: 3 additions & 3 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, config_file=None, context=None,
self.in_cluster = True

self.api_instance = client.CustomObjectsApi()
self.katib_api_client = api_client.ApiClient()
self.deserializer = utils.Deserializer()

def _is_ipython(self):
"""Returns whether we are running in notebook."""
Expand Down Expand Up @@ -270,7 +270,7 @@ def list_experiments(self, namespace=None):
try:
katibexp = thread.get(constants.APISERVER_TIMEOUT)
result = [
self.katib_api_client.deserialize_data(item, V1beta1Experiment)
self.deserializer.deserialize(item, V1beta1Experiment)
for item in katibexp.get("items")
]

Expand Down Expand Up @@ -341,7 +341,7 @@ def list_trials(self, name=None, namespace=None):
try:
katibtrial = thread.get(constants.APISERVER_TIMEOUT)
result = [
self.katib_api_client.deserialize_data(item, V1beta1Trial)
self.deserializer.deserialize(item, V1beta1Trial)
for item in katibtrial.get("items")
]
except multiprocessing.TimeoutError:
Expand Down
22 changes: 1 addition & 21 deletions sdk/python/v1beta1/kubeflow/katib/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

import atexit
import datetime

import kubernetes
from dateutil.parser import parse
import json
import mimetypes
Expand All @@ -24,7 +22,6 @@

# python 2 and python 3 compatibility library
import six
from kubeflow.katib.rest import RESTResponse
from six.moves.urllib.parse import quote

from kubeflow.katib.configuration import Configuration
Expand Down Expand Up @@ -269,7 +266,6 @@ def deserialize(self, response, response_type):
:return: deserialized object.
"""
assert isinstance(response, RESTResponse)
# handle file downloading
# save response body into a tmp file and return the instance
if response_type == "file":
Expand All @@ -283,18 +279,6 @@ def deserialize(self, response, response_type):

return self.__deserialize(data, response_type)

def deserialize_data(self, data, data_type):
"""Deserializes data into an object.
:param data: object to be deserialized.
:param data_type: class literal for
deserialized object, or string of class name.
:return: deserialized object.
"""
assert isinstance(data, (dict, list, str))
return self.__deserialize(data, data_type)

def __deserialize(self, data, klass):
"""Deserializes dict, list, str into an object.
Expand All @@ -320,12 +304,8 @@ def __deserialize(self, data, klass):
# convert str to class
if klass in self.NATIVE_TYPES_MAPPING:
klass = self.NATIVE_TYPES_MAPPING[klass]
elif klass in dir(kubeflow.katib.models):
klass = getattr(kubeflow.katib.models, klass)
elif klass in dir(kubernetes.client.models):
klass = getattr(kubernetes.client.models, klass)
else:
raise ValueError(f"type: {klass} is not supported to deserialized")
klass = getattr(kubeflow.katib.models, klass)

if klass in self.PRIMITIVE_TYPES:
return self.__deserialize_primitive(data, klass)
Expand Down
174 changes: 174 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import datetime
import os
import re

import kubeflow.katib.models
import kubernetes.client.models
# python 2 and python 3 compatibility library
import six
from dateutil.parser import parse
from kubeflow.katib import rest


def is_running_in_k8s():
Expand All @@ -34,3 +45,166 @@ def set_katib_namespace(katib):
katib_namespace = katib.metadata.namespace
namespace = katib_namespace or get_default_target_namespace()
return namespace


class Deserializer:
"""Deserializer for deserializing data into katib's custom objects.
"""
PRIMITIVE_TYPES = (float, bool, bytes, six.text_type) + six.integer_types
NATIVE_TYPES_MAPPING = {
'int': int,
'long': int if six.PY3 else long, # noqa: F821
'float': float,
'str': str,
'bool': bool,
'date': datetime.date,
'datetime': datetime.datetime,
'object': object,
}

def deserialize(self, data, data_type):
"""Deserializes data into an object.
:param data: object to be deserialized.
:param data_type: class literal for
deserialized object, or string of class name.
:return: deserialized object.
"""
assert isinstance(data, (dict, list, str))

return self.__deserialize(data, data_type)

def __deserialize(self, data, klass):
"""Deserializes dict, list, str into an object.
:param data: dict, list or str.
:param klass: class literal, or string of class name.
:return: object.
"""
if data is None:
return None

if type(klass) == str:
if klass.startswith('list['):
sub_kls = re.match(r'list\[(.*)\]', klass).group(1)
return [self.__deserialize(sub_data, sub_kls)
for sub_data in data]

if klass.startswith('dict('):
sub_kls = re.match(r'dict\(([^,]*), (.*)\)', klass).group(2)
return {k: self.__deserialize(v, sub_kls)
for k, v in six.iteritems(data)}

# convert str to class
if klass in self.NATIVE_TYPES_MAPPING:
klass = self.NATIVE_TYPES_MAPPING[klass]
elif klass in dir(kubeflow.katib.models):
klass = getattr(kubeflow.katib.models, klass)
elif klass in dir(kubernetes.client.models):
klass = getattr(kubernetes.client.models, klass)
else:
raise ValueError(f"type: {klass} is not supported to deserialized")

if klass in self.PRIMITIVE_TYPES:
return self.__deserialize_primitive(data, klass)
elif klass == object:
return self.__deserialize_object(data)
elif klass == datetime.date:
return self.__deserialize_date(data)
elif klass == datetime.datetime:
return self.__deserialize_datetime(data)
else:
return self.__deserialize_model(data, klass)

def __deserialize_primitive(self, data, klass):
"""Deserializes string to primitive type.
:param data: str.
:param klass: class literal.
:return: int, long, float, str, bool.
"""
try:
return klass(data)
except UnicodeEncodeError:
return six.text_type(data)
except TypeError:
return data

def __deserialize_object(self, value):
"""Return an original value.
:return: object.
"""
return value

def __deserialize_date(self, string):
"""Deserializes string to date.
:param string: str.
:return: date.
"""
try:
return parse(string).date()
except ImportError:
return string
except ValueError:
raise rest.ApiException(
status=0,
reason="Failed to parse `{0}` as date object".format(string)
)

def __deserialize_datetime(self, string):
"""Deserializes string to datetime.
The string should be in iso8601 datetime format.
:param string: str.
:return: datetime.
"""
try:
return parse(string)
except ImportError:
return string
except ValueError:
raise rest.ApiException(
status=0,
reason=(
"Failed to parse `{0}` as datetime object"
.format(string)
)
)

def __deserialize_model(self, data, klass):
"""Deserializes list or dict to model.
:param data: dict, list.
:param klass: class literal.
:return: model object.
"""
has_discriminator = False
if (hasattr(klass, 'get_real_child_model')
and klass.discriminator_value_class_map):
has_discriminator = True

if not klass.openapi_types and has_discriminator is False:
return data

kwargs = {}
if (data is not None and
klass.openapi_types is not None and
isinstance(data, (list, dict))):
for attr, attr_type in six.iteritems(klass.openapi_types):
if klass.attribute_map[attr] in data:
value = data[klass.attribute_map[attr]]
kwargs[attr] = self.__deserialize(value, attr_type)

instance = klass(**kwargs)

if has_discriminator:
klass_name = instance.get_real_child_model(data)
if klass_name:
instance = self.__deserialize(data, klass_name)
return instance

0 comments on commit ecf9e64

Please sign in to comment.