Skip to content

Commit

Permalink
fix EV incremental restore.
Browse files Browse the repository at this point in the history
  • Loading branch information
candyzone committed Jul 2, 2024
1 parent 2325297 commit 261ccfb
Showing 1 changed file with 23 additions and 20 deletions.
43 changes: 23 additions & 20 deletions tensorflow/core/kernels/kv_variable_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,7 @@ Status EVRestoreDynamically(EmbeddingVar<K, V>* ev,
}
}

reader->LookupHeader(tensor_key, sizeof(K) * key_shape.dim_size(0));
st = reader->LookupHeader(tensor_key, sizeof(K) * key_shape.dim_size(0));
if (!st.ok()) {
break;
}
Expand Down Expand Up @@ -1272,41 +1272,24 @@ Status EVRestoreDynamically(EmbeddingVar<K, V>* ev,
return st;
}

TensorShape part_offset_shape, part_filter_offset_shape;
DataType part_offset_type, part_filter_offset_type;
TensorShape part_offset_shape;
DataType part_offset_type;
string offset_tensor_name = tensor_name + part_offset_tensor_suffix;
string offset_filter_tensor_name =
tensor_name + "-partition_filter_offset";
st = reader->LookupDtypeAndShape(offset_tensor_name,
&part_offset_type, &part_offset_shape);
if (!st.ok()) {
LOG(FATAL) << "EV restoring fail:" << st.ToString();
}
st = reader->LookupDtypeAndShape(offset_filter_tensor_name,
&part_filter_offset_type, &part_filter_offset_shape);
if (!st.ok()) {
LOG(FATAL) << "EV restoring fail:" << st.ToString();
}
Tensor part_offset_tensor(cpu_allocator(),
part_offset_type, part_offset_shape);
if (!st.ok()) {
LOG(FATAL) << "EV restoring fail:" << st.ToString();
}
Tensor part_filter_offset_tensor(cpu_allocator(),
part_offset_type, part_offset_shape);
if (!st.ok()) {
LOG(FATAL) << "EV restoring fail:" << st.ToString();
}
st = reader->Lookup(offset_tensor_name, &part_offset_tensor);
if (!st.ok()) {
LOG(FATAL) << "EV restoring fail:" << st.ToString();
}
auto part_offset_flat = part_offset_tensor.flat<int32>();
st = reader->Lookup(offset_filter_tensor_name, &part_filter_offset_tensor);
if (!st.ok()) {
LOG(FATAL) << "EV restoring fail:" << st.ToString();
}
auto part_filter_offset_flat = part_filter_offset_tensor.flat<int32>();

for (size_t i = 0; i < loaded_parts.size(); i++) {
int subpart_id = loaded_parts[i];
Expand Down Expand Up @@ -1429,6 +1412,26 @@ Status EVRestoreDynamically(EmbeddingVar<K, V>* ev,
}

if (restore_filter_flag) {
TensorShape part_filter_offset_shape;
DataType part_filter_offset_type;
string offset_filter_tensor_name =
tensor_name + "-partition_filter_offset";
st = reader->LookupDtypeAndShape(offset_filter_tensor_name,
&part_filter_offset_type, &part_filter_offset_shape);
if (!st.ok()) {
LOG(FATAL) << "EV restoring fail:" << st.ToString();
}
Tensor part_filter_offset_tensor(cpu_allocator(),
part_filter_offset_type, part_filter_offset_shape);
if (!st.ok()) {
LOG(FATAL) << "EV restoring fail:" << st.ToString();
}
st = reader->Lookup(offset_filter_tensor_name, &part_filter_offset_tensor);
if (!st.ok()) {
LOG(FATAL) << "EV restoring fail:" << st.ToString();
}
auto part_filter_offset_flat = part_filter_offset_tensor.flat<int32>();

int subpart_filter_offset = part_filter_offset_flat(subpart_id);
int64 key_filter_part_offset = subpart_filter_offset * sizeof(K);
int64 version_filter_part_offset = subpart_filter_offset * sizeof(int64);
Expand Down

0 comments on commit 261ccfb

Please sign in to comment.