diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index da78849d40cde..a46244265ef20 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -137,7 +137,12 @@ int32_t MemorySparseTable::Load(const std::string& path, size_t feature_value_size = _value_accesor->GetAccessorInfo().size / sizeof(float); +#ifdef PADDLE_WITH_HETERPS + int thread_num = _real_local_shard_num; +#else int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; +#endif + omp_set_num_threads(thread_num); #pragma omp parallel for schedule(dynamic) for (int i = 0; i < _real_local_shard_num; ++i) { @@ -168,11 +173,6 @@ int32_t MemorySparseTable::Load(const std::string& path, int parse_size = _value_accesor->ParseFromString(++end, value.data()); value.resize(parse_size); - // for debug - for (int ii = 0; ii < parse_size; ++ii) { - VLOG(2) << "MemorySparseTable::load key: " << key << " value " << ii - << ": " << value.data()[ii] << " local_shard: " << i; - } } read_channel->close(); if (err_no == -1) { @@ -340,7 +340,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname, size_t file_start_idx = _avg_local_shard_num * _shard_idx; -#ifdef PADDLE_WITH_GPU_GRAPH +#ifdef PADDLE_WITH_HETERPS int thread_num = _real_local_shard_num; #else int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;