-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Phi] Add phi device context pool (#40635)
* add phi device context pool * change year * fix compile error * fix operator = error * refine init impl * polish details * refine init impl
- Loading branch information
Showing
8 changed files
with
194 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/phi/common/place.h" | ||
#include "paddle/phi/core/macros.h" | ||
#include "paddle/utils/flat_hash_map.h" | ||
|
||
namespace phi { | ||
class DeviceContext; | ||
class CPUContext; | ||
class GPUContext; | ||
} // namespace phi | ||
|
||
namespace paddle { | ||
namespace experimental { | ||
|
||
template <AllocationType T> | ||
struct DefaultDeviceContextType; | ||
|
||
template <> | ||
struct DefaultDeviceContextType<AllocationType::CPU> { | ||
using TYPE = phi::CPUContext; | ||
}; | ||
|
||
template <> | ||
struct DefaultDeviceContextType<AllocationType::GPU> { | ||
using TYPE = phi::GPUContext; | ||
}; | ||
|
||
/** | ||
* The DeviceContextPool here is just a mirror of the DeviceContextPool in | ||
* fluid, and does not manage the life cycle of the DeviceContext. | ||
* It is mainly used for external custom operator calls and high-performance | ||
* C++ APIs. | ||
* | ||
* Since DeviceContextPool in fluid is a global singleton, it always exists | ||
* in program running, so DeviceContextPool here can always access the correct | ||
* DeviceContext pointer. | ||
* | ||
* In order not to depend on the fluid's DeviceContextPool, | ||
* the DeviceContextPool here needs to be initialized in the fluid, and cannot | ||
* be initialized by itself. | ||
*/ | ||
class DeviceContextPool { | ||
public: | ||
static DeviceContextPool& Instance(); | ||
|
||
const phi::DeviceContext* Get(const Place& place) const; | ||
|
||
phi::DeviceContext* GetMutable(const Place& place); | ||
|
||
template <AllocationType T> | ||
const typename DefaultDeviceContextType<T>::TYPE* Get( | ||
const Place& place) const { | ||
return reinterpret_cast<const typename DefaultDeviceContextType<T>::TYPE*>( | ||
Get(place)); | ||
} | ||
|
||
private: | ||
DeviceContextPool(); | ||
paddle::flat_hash_map<Place, const phi::DeviceContext*, Place::Hash> | ||
context_map_; | ||
|
||
DISABLE_COPY_AND_ASSIGN(DeviceContextPool); | ||
}; | ||
|
||
} // namespace experimental | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/phi/api/include/context_pool.h" | ||
|
||
#include "paddle/phi/backends/all_context.h" | ||
#include "paddle/phi/core/enforce.h" | ||
|
||
namespace paddle { | ||
namespace experimental { | ||
|
||
DeviceContextPool& DeviceContextPool::Instance() { | ||
static DeviceContextPool g_device_context_pool; | ||
return g_device_context_pool; | ||
} | ||
|
||
const phi::DeviceContext* DeviceContextPool::Get(const Place& place) const { | ||
auto it = context_map_.find(place); | ||
PADDLE_ENFORCE_NE( | ||
it, | ||
context_map_.end(), | ||
phi::errors::NotFound("The DeviceContext of %s does not exists.", place)); | ||
return it->second; | ||
} | ||
|
||
phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) { | ||
return const_cast<phi::DeviceContext*>(Get(place)); | ||
} | ||
|
||
DeviceContextPool::DeviceContextPool() { | ||
// We need to make sure that the correct value exists | ||
// whenever we get the DeviceContext from DeviceContextPool | ||
const auto& device_contexts = | ||
paddle::platform::DeviceContextPool::Instance().device_contexts(); | ||
for (const auto& pair : device_contexts) { | ||
// only get CPU and GPU DeviceContext now, add other DeviceContext type | ||
// later if needed | ||
if (platform::is_cpu_place(pair.first) | ||
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) | ||
|| | ||
platform::is_gpu_place(pair.first)) { | ||
#else | ||
) { | ||
#endif | ||
const phi::DeviceContext* dev_ctx = pair.second.get().get(); | ||
VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", " | ||
<< dev_ctx << "}"; | ||
context_map_[pair.first] = dev_ctx; | ||
} | ||
} | ||
} | ||
|
||
} // namespace experimental | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters