Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cupy_to_tensor to also infer uint8 and int8 dtypes #1621

Merged
merged 9 commits into from
Apr 23, 2024
Merged
1 change: 1 addition & 0 deletions morpheus/_lib/include/morpheus/objects/dtype.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ struct DType
}

private:
char byte_order_char() const;
char type_char() const;

TypeId m_type_id;
Expand Down
60 changes: 43 additions & 17 deletions morpheus/_lib/src/objects/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "morpheus/utilities/string_util.hpp" // for MORPHEUS_CONCAT_STR

#include <cudf/types.hpp>
#include <glog/logging.h> // for CHECK

#include <map>
#include <sstream> // Needed by MORPHEUS_CONCAT_STR
Expand All @@ -30,7 +29,7 @@

namespace {
const std::map<char, std::map<size_t, morpheus::TypeId>> StrToTypeId = {
{'?', {{1, morpheus::TypeId::BOOL8}}},
{'b', {{1, morpheus::TypeId::BOOL8}}},

{'i',
{{1, morpheus::TypeId::INT8},
Expand Down Expand Up @@ -100,14 +99,7 @@ std::string DType::name() const

std::string DType::type_str() const
{
if (m_type_id != TypeId::BOOL8 && m_type_id != TypeId::STRING)
{
return MORPHEUS_CONCAT_STR("<" << this->type_char() << this->item_size());
}
else
{
return std::string{this->type_char()};
}
return MORPHEUS_CONCAT_STR(this->byte_order_char() << this->type_char() << this->item_size());
}

// Cudf representation
Expand Down Expand Up @@ -214,19 +206,22 @@ DType DType::from_cudf(cudf::type_id tid)
case cudf::type_id::EMPTY:
case cudf::type_id::NUM_TYPE_IDS:
default:
throw std::runtime_error("Not supported");
throw std::invalid_argument("Not supported");
}
}

DType DType::from_numpy(const std::string& numpy_str)
{
CHECK(!numpy_str.empty()) << "Cannot create DataType from empty string";
if (numpy_str.empty())
{
throw std::invalid_argument("Cannot create DataType from empty string");
}

char type_char = numpy_str[0];
size_t size_start = 1;

// Can start with < or > or none
if (numpy_str[0] == '<' || numpy_str[0] == '>')
// Can start with <, >, | or none
if (numpy_str[0] == '<' || numpy_str[0] == '>' || numpy_str[0] == '|')
{
type_char = numpy_str[1];
size_start = 2;
Expand All @@ -241,11 +236,17 @@ DType DType::from_numpy(const std::string& numpy_str)
// Now lookup in the map
auto found_type = StrToTypeId.find(type_char);

CHECK(found_type != StrToTypeId.end()) << "Type char '" << type_char << "' not supported";
if (found_type == StrToTypeId.end())
{
throw std::invalid_argument(MORPHEUS_CONCAT_STR("Type char '" << type_char << "' not supported"));
}

auto found_enum = found_type->second.find(dtype_size);

CHECK(found_enum != found_type->second.end()) << "Type str '" << type_char << dtype_size << "' not supported";
if (found_enum == found_type->second.end())
{
throw std::invalid_argument(MORPHEUS_CONCAT_STR("Type str '" << type_char << dtype_size << "' not supported"));
}

return {found_enum->second};
}
Expand Down Expand Up @@ -299,6 +300,31 @@ DType DType::from_triton(const std::string& type_str)
}
else
{
throw std::invalid_argument("Not supported");
}
}

char DType::byte_order_char() const
{
switch (m_type_id)
{
case TypeId::BOOL8:
case TypeId::INT8:
case TypeId::UINT8:
return '|';
case TypeId::INT16:
case TypeId::UINT16:
case TypeId::INT32:
case TypeId::UINT32:
case TypeId::INT64:
case TypeId::UINT64:
case TypeId::FLOAT32:
case TypeId::FLOAT64:
return '<';
case TypeId::EMPTY:
case TypeId::NUM_TYPE_IDS:
case TypeId::STRING:
default:
throw std::runtime_error("Not supported");
}
}
Expand All @@ -318,7 +344,7 @@ char DType::type_char() const
case TypeId::UINT64:
return 'u';
case TypeId::BOOL8:
return '?';
return 'b';
case TypeId::FLOAT32:
case TypeId::FLOAT64:
return 'f';
Expand Down
6 changes: 6 additions & 0 deletions morpheus/_lib/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ add_morpheus_test(
modules/test_data_loader_module.cpp
)

add_morpheus_test(
NAME objects
FILES
objects/test_dtype.cpp
)

add_morpheus_test(
NAME deserializers
FILES
Expand Down
Loading
Loading