Skip to content

Commit

Permalink
Allow model versions to be strings (#197)
Browse files Browse the repository at this point in the history
* Versions can be strings

* Reformatted

* Fixed some tests

* Minor changes in management library, now check strings to be grouped for invalid characters

* VersionedModelId as a class

* Temporary commit to show hash problem

* Partial fix

* Functional

* Formatted

* Fixup

* Fix failing tests

* Addressed comments
  • Loading branch information
nishadsingh1 authored and dcrankshaw committed Jun 20, 2017
1 parent 5a7d172 commit e71d594
Show file tree
Hide file tree
Showing 25 changed files with 497 additions and 296 deletions.
25 changes: 16 additions & 9 deletions clipper_admin/clipper_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def deploy_model(self,
----------
name : str
The name to assign this model.
version : int
version : Any object with a string representation (with __str__ implementation)
The version to assign this model.
model_data : str or BaseEstimator
The trained model to add to Clipper. This can either be a
Expand Down Expand Up @@ -470,6 +470,7 @@ def deploy_model(self,
warn("%s is invalid model format" % str(type(model_data)))
return False

version = str(version)
vol = "{model_repo}/{name}/{version}".format(
model_repo=MODEL_REPO, name=name, version=version)
# publish model to Clipper and verify success before copying model
Expand Down Expand Up @@ -509,13 +510,14 @@ def register_external_model(self,
----------
name : str
The name to assign this model.
version : int
version : Any object with a string representation (with __str__ implementation)
The version to assign this model.
input_type : str
One of "integers", "floats", "doubles", "bytes", or "strings".
labels : list of str, optional
A list of strings annotating the model.
"""
version = str(version)
return self._publish_new_model(name, version, labels, input_type,
EXTERNALLY_MANAGED_MODEL,
EXTERNALLY_MANAGED_MODEL)
Expand Down Expand Up @@ -586,7 +588,7 @@ def deploy_pyspark_model(self,
----------
name : str
The name to assign this model.
version : int
version : Any object with a string representation (with __str__ implementation)
The version to assign this model.
predict_function : function
A function that takes three arguments, a SparkContext, the ``model`` parameter and
Expand Down Expand Up @@ -679,7 +681,7 @@ def deploy_predict_function(self,
----------
name : str
The name to assign this model.
version : int
version : Any object with a string representation (with __str__ implementation)
The version to assign this model.
predict_function : function
The prediction function. Any state associated with the function should be
Expand Down Expand Up @@ -766,7 +768,7 @@ def get_model_info(self, model_name, model_version):
----------
model_name : str
The name of the model to look up
model_version : int
model_version : Any object with a string representation (with __str__ implementation)
The version of the model to look up
Returns
Expand All @@ -776,6 +778,7 @@ def get_model_info(self, model_name, model_version):
If no model with name `model_name@model_version` is
registered with Clipper, None is returned.
"""
model_version = str(model_version)
url = "http://%s:1338/admin/get_model" % self.host
req_json = json.dumps({
"model_name": model_name,
Expand Down Expand Up @@ -826,7 +829,7 @@ def get_container_info(self, model_name, model_version, replica_id):
----------
model_name : str
The name of the container to look up
model_version : int
model_version : Any object with a string representation (with __str__ implementation)
The version of the container to look up
replica_id : int
The container replica to look up
Expand All @@ -837,6 +840,7 @@ def get_container_info(self, model_name, model_version, replica_id):
A dictionary with the specified container's info.
If no corresponding container is registered with Clipper, None is returned.
"""
model_version = str(model_version)
url = "http://%s:1338/admin/get_container" % self.host
req_json = json.dumps({
"model_name": model_name,
Expand Down Expand Up @@ -970,7 +974,7 @@ def add_container(self, model_name, model_version):
----------
model_name : str
The name of the model
model_version : int
model_version : Any object with a string representation (with __str__ implementation)
The version of the model
Returns
Expand All @@ -979,6 +983,7 @@ def add_container(self, model_name, model_version):
True if the container was added successfully and False
if the container could not be added.
"""
model_version = str(model_version)
with hide("warnings", "output", "running"):
# Look up model info in Redis
if self.redis_ip == DEFAULT_REDIS_IP:
Expand Down Expand Up @@ -1024,7 +1029,7 @@ def add_container(self, model_name, model_version):
mv=model_version,
mip=model_input_type,
clipper_label=CLIPPER_DOCKER_LABEL,
mv_label="%s=%s:%d" % (CLIPPER_MODEL_CONTAINER_LABEL,
mv_label="%s=%s:%s" % (CLIPPER_MODEL_CONTAINER_LABEL,
model_name, model_version),
restart_policy=restart_policy))
result = self._execute_root(add_container_cmd)
Expand Down Expand Up @@ -1101,14 +1106,16 @@ def set_model_version(self, model_name, model_version, num_containers=0):
----------
model_name : str
The name of the model
model_version : int
model_version : Any object with a string representation (with __str__ implementation)
The version of the model. Note that `model_version`
must be a model version that has already been deployed.
num_containers : int
The number of new containers to start with the newly
selected model version.
"""
model_version = str(model_version)

url = "http://%s:%d/admin/set_model_version" % (
self.host, CLIPPER_MANAGEMENT_PORT)
req_json = json.dumps({
Expand Down
2 changes: 1 addition & 1 deletion integration-tests/clipper_manager_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_model_version_sets_correctly(self):
models_list_contains_correct_version = False
for model_info in all_models:
version = model_info["model_version"]
if version == self.model_version_1:
if version == str(self.model_version_1):
models_list_contains_correct_version = True
self.assertTrue(model_info["is_current_version"])

Expand Down
2 changes: 1 addition & 1 deletion src/benchmarks/src/end_to_end_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void send_predictions(
cifar_input,
100000,
clipper::DefaultOutputSelectionPolicy::get_name(),
{std::make_pair(SKLEARN_MODEL_NAME, 1)}});
{VersionedModelId(SKLEARN_MODEL_NAME, "1")}});
futures.push_back(std::move(future));
}

Expand Down
27 changes: 13 additions & 14 deletions src/frontends/src/query_frontend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,16 @@ class RequestHandler {
event_type);
if (event_type == "set") {
std::string model_name = key;
int new_version = clipper::redis::get_current_model_version(
redis_connection_, key);
if (new_version >= 0) {
boost::optional<std::string> new_version =
clipper::redis::get_current_model_version(redis_connection_,
key);
if (new_version) {
std::unique_lock<std::mutex> l(current_model_versions_mutex_);
current_model_versions_[key] = new_version;
current_model_versions_[key] = *new_version;
} else {
clipper::log_error_formatted(
LOGGING_TAG_QUERY_FRONTEND,
"Model version change for model {} was invalid (-1).", key);
"Model version change for model {} was invalid.", key);
}
}
});
Expand Down Expand Up @@ -223,18 +224,16 @@ class RequestHandler {
for (std::string model_name : model_names) {
auto model_version = clipper::redis::get_current_model_version(
redis_connection_, model_name);
if (model_version >= 0) {
if (model_version) {
std::unique_lock<std::mutex> l(current_model_versions_mutex_);
current_model_versions_[model_name] = model_version;
current_model_versions_[model_name] = *model_version;
model_names_with_version.push_back(model_name + "@" + *model_version);
} else {
clipper::log_error_formatted(
LOGGING_TAG_QUERY_FRONTEND,
"Found model {} with invalid version number {}.", model_name,
model_version);
throw std::runtime_error("Invalid model version number");
"Found model {} with missing current version.", model_name);
throw std::runtime_error("Invalid model version");
}
model_names_with_version.push_back(model_name + "@v" +
std::to_string(model_version));
}
if (model_names.size() > 0) {
clipper::log_info_formatted(LOGGING_TAG_QUERY_FRONTEND,
Expand Down Expand Up @@ -496,7 +495,7 @@ class RequestHandler {
/**
* Returns a copy of the map containing current model names and versions.
*/
std::unordered_map<std::string, int> get_current_model_versions() {
std::unordered_map<std::string, std::string> get_current_model_versions() {
return current_model_versions_;
}

Expand All @@ -506,7 +505,7 @@ class RequestHandler {
redox::Redox redis_connection_;
redox::Subscriber redis_subscriber_;
std::mutex current_model_versions_mutex_;
std::unordered_map<std::string, int> current_model_versions_;
std::unordered_map<std::string, std::string> current_model_versions_;
};

} // namespace query_frontend
17 changes: 9 additions & 8 deletions src/frontends/src/query_frontend_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MockQueryProcessor {
public:
MockQueryProcessor() = default;
boost::future<Response> predict(Query query) {
Response response(query, 3, 5, Output("-1.0", {std::make_pair("m", 1)}),
Response response(query, 3, 5, Output("-1.0", {VersionedModelId("m", "1")}),
false, boost::optional<std::string>{});
return boost::make_ready_future(response);
}
Expand Down Expand Up @@ -280,24 +280,25 @@ TEST_F(QueryFrontendTest, TestReadModelsAtStartup) {
// Add multiple models (some with multiple versions)
std::vector<std::string> labels{"ads", "images", "experimental", "other",
"labels"};
VersionedModelId model1 = std::make_pair("m", 1);
VersionedModelId model1 = VersionedModelId("m", "1");
std::string container_name = "clipper/test_container";
std::string model_path = "/tmp/models/m/1";
ASSERT_TRUE(add_model(*redis_, model1, InputType::Ints, labels,
container_name, model_path));
VersionedModelId model2 = std::make_pair("m", 2);
VersionedModelId model2 = VersionedModelId("m", "2");
std::string model_path2 = "/tmp/models/m/2";
ASSERT_TRUE(add_model(*redis_, model2, InputType::Ints, labels,
container_name, model_path2));
VersionedModelId model3 = std::make_pair("n", 3);
VersionedModelId model3 = VersionedModelId("n", "3");
std::string model_path3 = "/tmp/models/n/3";
ASSERT_TRUE(add_model(*redis_, model3, InputType::Ints, labels,
container_name, model_path3));

// Set m@v2 and n@v3 as current model versions
set_current_model_version(*redis_, "m", 2);
set_current_model_version(*redis_, "n", 3);
std::unordered_map<std::string, int> expected_models = {{"m", 2}, {"n", 3}};
set_current_model_version(*redis_, "m", "2");
set_current_model_version(*redis_, "n", "3");
std::unordered_map<std::string, std::string> expected_models = {{"m", "2"},
{"n", "3"}};

RequestHandler<MockQueryProcessor> rh2_("127.0.0.1", 1337, 8);
EXPECT_EQ(rh2_.get_current_model_versions(), expected_models);
Expand All @@ -306,7 +307,7 @@ TEST_F(QueryFrontendTest, TestReadModelsAtStartup) {
TEST_F(QueryFrontendTest, TestReadInvalidModelVersionAtStartup) {
std::vector<std::string> labels{"ads", "images", "experimental", "other",
"labels"};
VersionedModelId model1 = std::make_pair("m", 1);
VersionedModelId model1 = VersionedModelId("m", "1");
std::string container_name = "clipper/test_container";
std::string model_path = "/tmp/models/m/1";
ASSERT_TRUE(add_model(*redis_, model1, InputType::Ints, labels,
Expand Down
3 changes: 1 addition & 2 deletions src/libclipper/include/clipper/containers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ class ActiveContainers {
// A mapping of models to their replicas. The replicas
// for each model are represented as a map keyed on replica id.
std::unordered_map<VersionedModelId,
std::map<int, std::shared_ptr<ModelContainer>>,
decltype(&versioned_model_hash)>
std::map<int, std::shared_ptr<ModelContainer>>>
containers_;
};
}
Expand Down
41 changes: 37 additions & 4 deletions src/libclipper/include/clipper/datatypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
#include <string>
#include <vector>

#include <boost/functional/hash.hpp>
#include <boost/optional.hpp>
#include <boost/thread.hpp>

namespace clipper {

using ByteBuffer = std::vector<uint8_t>;
using VersionedModelId = std::pair<std::string, int>;
using QueryId = long;
using FeedbackAck = bool;

Expand All @@ -28,11 +29,32 @@ enum class RequestType {
FeedbackRequest = 1,
};

size_t versioned_model_hash(const VersionedModelId &key);
std::string versioned_model_to_str(const VersionedModelId &model);
std::string get_readable_input_type(InputType type);
InputType parse_input_type(std::string type_string);

class VersionedModelId {
public:
VersionedModelId(const std::string name, const std::string id);

std::string get_name() const;
std::string get_id() const;
std::string serialize() const;
static VersionedModelId deserialize(std::string);

VersionedModelId(const VersionedModelId &) = default;
VersionedModelId &operator=(const VersionedModelId &) = default;

VersionedModelId(VersionedModelId &&) = default;
VersionedModelId &operator=(VersionedModelId &&) = default;

bool operator==(const VersionedModelId &rhs) const;
bool operator!=(const VersionedModelId &rhs) const;

private:
std::string name_;
std::string id_;
};

class Output {
public:
Output(const std::string y_hat,
Expand Down Expand Up @@ -384,5 +406,16 @@ class PredictionResponse {
} // namespace rpc

} // namespace clipper

namespace std {
template <>
struct hash<clipper::VersionedModelId> {
typedef std::size_t result_type;
std::size_t operator()(const clipper::VersionedModelId &vm) const {
std::size_t seed = 0;
boost::hash_combine(seed, vm.get_name());
boost::hash_combine(seed, vm.get_id());
return seed;
}
};
}
#endif // CLIPPER_LIB_DATATYPES_H
28 changes: 23 additions & 5 deletions src/libclipper/include/clipper/redis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ namespace redis {

const std::string LOGGING_TAG_REDIS = "REDIS";

/**
* Elements of this vector should not appear as substrings of any input Clipper
* object value that will be grouped into one entry.
* This list should be updated to reflect all delimiters and characters added in
* `labels_to_str` or `models_to_str`.
*/
const std::vector<std::string> prohibited_group_strings = {
ITEM_DELIMITER, ITEM_PART_CONCATENATOR};

/**
* Use this function to validate inputs that will be grouped before submitting
* them to functions in this library.
* @return Whether or not `value` contains any elements of `probhited_strings`
* as substrings.
*/
bool contains_prohibited_chars_for_group(std::string value);

/**
* Issues a command to Redis and checks return code.
* \return Returns true if the command was successful.
Expand Down Expand Up @@ -91,10 +108,11 @@ std::string models_to_str(const std::vector<VersionedModelId>& models);
std::vector<VersionedModelId> str_to_models(const std::string& model_str);

bool set_current_model_version(redox::Redox& redis,
const std::string& model_name, int version);
const std::string& model_name,
const std::string& version);

int get_current_model_version(redox::Redox& redis,
const std::string& model_name);
boost::optional<std::string> get_current_model_version(
redox::Redox& redis, const std::string& model_name);

/**
* Adds a model into the model table. This will
Expand Down Expand Up @@ -142,8 +160,8 @@ std::unordered_map<std::string, std::string> get_model(
* \return Returns a list of model versions. If the
* model was not found, an empty list will be returned.
*/
std::vector<int> get_model_versions(redox::Redox& redis,
const std::string& model_name);
std::vector<std::string> get_model_versions(redox::Redox& redis,
const std::string& model_name);

/**
* Looks up model names listed in the model table. Since a call to KEYS may
Expand Down
3 changes: 1 addition & 2 deletions src/libclipper/include/clipper/rpc_service.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ class RPCService {
std::atomic_bool active_;
// The next available message id
int message_id_ = 0;
std::unordered_map<VersionedModelId, int, decltype(&versioned_model_hash)>
replica_ids_;
std::unordered_map<VersionedModelId, int> replica_ids_;
std::shared_ptr<metrics::Histogram> msg_queueing_hist_;

std::function<void(VersionedModelId, int)> container_ready_callback_;
Expand Down
Loading

0 comments on commit e71d594

Please sign in to comment.