Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ET-VK] Organize utils #2323

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
// @lint-ignore-every CLANGTIDY
// facebook-security-vulnerable-integer-sign-conversion

#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>

namespace at {
namespace native {
namespace vulkan {
Expand Down
4 changes: 1 addition & 3 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/api/Context.h>
#include <ATen/native/vulkan/api/Tensor.h>
#include <ATen/native/vulkan/api/Types.h>
#include <ATen/native/vulkan/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/GraphConfig.h>

Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/GraphConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/api/Context.h>
#include <ATen/native/vulkan/api/api.h>

namespace at {
namespace native {
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>

namespace at {
namespace native {
Expand Down
4 changes: 1 addition & 3 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/api/Context.h>
#include <ATen/native/vulkan/api/Tensor.h>
#include <ATen/native/vulkan/api/Types.h>
#include <ATen/native/vulkan/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>

Expand Down
2 changes: 0 additions & 2 deletions backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>

namespace at {
namespace native {
namespace vulkan {
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>

namespace at {
namespace native {
Expand Down
4 changes: 1 addition & 3 deletions backends/vulkan/runtime/graph/ops/PrepackNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/api/Context.h>
#include <ATen/native/vulkan/api/Tensor.h>
#include <ATen/native/vulkan/api/Types.h>
#include <ATen/native/vulkan/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>

Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

namespace at {
namespace native {
namespace vulkan {
Expand Down
2 changes: 0 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

namespace at {
namespace native {
namespace vulkan {
Expand Down
7 changes: 4 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Staging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

namespace at {
namespace native {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

#include <ATen/native/vulkan/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>

namespace at {
namespace native {
namespace vulkan {
Expand Down Expand Up @@ -75,27 +73,6 @@ uint32_t dim_at(const vTensor& v_in) {
return dim_at<N>(v_in.sizes());
}

/*
* For most global work group sizes, returns {4, 4, 4}, but adjusts the size for
* 2D global work group sizes. Always maintains a total of 64 invocations
*/
api::utils::uvec3 adaptive_work_group_size(
const api::utils::uvec3& global_work_group);

template <typename T>
T extract_scalar(const Value& value) {
if (value.isInt()) {
return static_cast<T>(value.toInt());
}
if (value.isDouble()) {
return static_cast<T>(value.toDouble());
}
if (value.isBool()) {
return static_cast<T>(value.toBool());
}
VK_THROW("Cannot extract scalar from Value with type ", value.type());
}

} // namespace vulkan
} // namespace native
} // namespace at
Expand Down
39 changes: 39 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/api/api.h>

#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>

namespace at {
namespace native {
namespace vulkan {

template <typename T>
T extract_scalar(const Value& value) {
if (value.isInt()) {
return static_cast<T>(value.toInt());
}
if (value.isDouble()) {
return static_cast<T>(value.toDouble());
}
if (value.isBool()) {
return static_cast<T>(value.toBool());
}
VK_THROW("Cannot extract scalar from Value with type ", value.type());
}

} // namespace vulkan
} // namespace native
} // namespace at

#endif /* USE_VULKAN_API */
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>

namespace at {
namespace native {
Expand All @@ -29,6 +31,14 @@ api::utils::uvec3 adaptive_work_group_size(
return local_group_size;
}

api::utils::ivec4 get_size_as_ivec4(const vTensor& t) {
return api::utils::make_ivec4(
{dim_at<Dim4D::Width>(t),
dim_at<Dim4D::Height>(t),
dim_at<Dim4D::Channel>(t),
dim_at<Dim4D::Batch>(t)});
}

} // namespace vulkan
} // namespace native
} // namespace at
28 changes: 28 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#ifdef USE_VULKAN_API

#include <ATen/native/vulkan/api/api.h>

namespace at {
namespace native {
namespace vulkan {

api::utils::uvec3 adaptive_work_group_size(
const api::utils::uvec3& global_work_group);

api::utils::ivec4 get_size_as_ivec4(const vTensor& t);

} // namespace vulkan
} // namespace native
} // namespace at

#endif /* USE_VULKAN_API */
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,12 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>

namespace at {
namespace native {
namespace vulkan {

api::utils::ivec4 get_size_as_ivec4(const vTensor& t) {
return api::utils::make_ivec4(
{dim_at<Dim4D::Width>(t),
dim_at<Dim4D::Height>(t),
dim_at<Dim4D::Channel>(t),
dim_at<Dim4D::Batch>(t)});
}

void bind_tensor_to_descriptor_set(
vTensor& tensor,
api::PipelineBarrier& pipeline_barrier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ namespace at {
namespace native {
namespace vulkan {

api::utils::ivec4 get_size_as_ivec4(const vTensor& t);

void bind_tensor_to_descriptor_set(
vTensor& tensor,
api::PipelineBarrier& pipeline_barrier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,36 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
// @lint-ignore-every CLANGTIDY facebook-security-vulnerable-memcpy

#include <executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/Utils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>

#include <cstring>

namespace at {
namespace native {
namespace vulkan {

template <typename T>
void memcpy_to_mapping_impl(
const void* src,
api::MemoryMap& dst_mapping,
const size_t nbytes) {
T* data_ptr = dst_mapping.template data<T>();
memcpy(data_ptr, reinterpret_cast<const T*>(src), nbytes);
}

template <typename T>
void memcpy_from_mapping_impl(
api::MemoryMap& src_mapping,
void* dst,
const size_t nbytes) {
T* data_ptr = src_mapping.template data<T>();
memcpy(reinterpret_cast<T*>(dst), data_ptr, nbytes);
}

void memcpy_to_mapping(
const void* src,
api::MemoryMap& dst_mapping,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,10 @@

#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>

#include <cstring>

namespace at {
namespace native {
namespace vulkan {

//
// Functions to memcpy data into staging buffer
//

void memcpy_to_mapping(
const void* src,
api::MemoryMap& dst_mapping,
const size_t nbytes,
const api::ScalarType dtype);
void memcpy_from_mapping(
const api::MemoryMap& src_mapping,
void* dst,
const size_t nbytes,
const api::ScalarType dtype);

//
// Utility functions for memcpy
//

template <typename T>
void memcpy_to_mapping_impl(
const void* src,
api::MemoryMap& dst_mapping,
const size_t nbytes) {
T* data_ptr = dst_mapping.template data<T>();
memcpy(data_ptr, reinterpret_cast<const T*>(src), nbytes);
}

template <typename T>
void memcpy_from_mapping_impl(
api::MemoryMap& src_mapping,
void* dst,
const size_t nbytes) {
T* data_ptr = src_mapping.template data<T>();
memcpy(reinterpret_cast<T*>(dst), data_ptr, nbytes);
}

//
// Functions to copy data into and out of a staging buffer
//
Expand Down
Loading
Loading