Skip to content

Commit

Permalink
PR #15444: Fixed some issues around compiling on Windows.
Browse files Browse the repository at this point in the history
Imported from GitHub PR #15444

This PR fixes some issues I bumped into when trying to compile XLA on Windows. I still haven't gotten GPU support to work but I'm making progress. The CPU only version compiles fine after some of the changes in this PR. I'll point out some specific issues this PR fixes in comments.

There are also TSL-specific changes that are pulled in a separate PR (#15499).
Copybara import of the project:

--
eacee95 by eaplatanios <e.a.platanios@gmail.com>:

Fixed some issues around compiling on Windows.

--
b12e4cf by eaplatanios <e.a.platanios@gmail.com>:

.

--
e23ef17 by eaplatanios <e.a.platanios@gmail.com>:

.

--
bdae19b by eaplatanios <e.a.platanios@gmail.com>:

.

--
2f90e6b by eaplatanios <e.a.platanios@gmail.com>:

.

--
5700979 by eaplatanios <e.a.platanios@gmail.com>:

.

--
a978b1f by eaplatanios <e.a.platanios@gmail.com>:

.

--
d7fe81d by eaplatanios <e.a.platanios@gmail.com>:

.

--
fc40d91 by eaplatanios <e.a.platanios@gmail.com>:

.

--
326aec3 by eaplatanios <e.a.platanios@gmail.com>:

.

--
a7603b7 by eaplatanios <e.a.platanios@gmail.com>:

.

--
edcc97a by eaplatanios <e.a.platanios@gmail.com>:

.

--
cec2448 by eaplatanios <e.a.platanios@gmail.com>:

.

--
df3eb22 by eaplatanios <e.a.platanios@gmail.com>:

.

--
8997345 by eaplatanios <e.a.platanios@gmail.com>:

.

--
219a9f1 by eaplatanios <e.a.platanios@gmail.com>:

.

--
73f3cd7 by eaplatanios <e.a.platanios@gmail.com>:

.

Merging this change closes #15444

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7
PiperOrigin-RevId: 657913961
  • Loading branch information
eaplatanios authored and copybara-github committed Jul 31, 2024
1 parent 8bef836 commit cb18089
Show file tree
Hide file tree
Showing 22 changed files with 117 additions and 59 deletions.
24 changes: 12 additions & 12 deletions xla/backends/profiler/gpu/cupti_buffer_events.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,18 @@ void AddGraphTraceActivityEvent(CuptiEventCollectorDelegate &collector,
AnnotationMap::AnnotationInfo info = collector.annotation_map.LookUp(
graph_trace->deviceId, graph_trace->correlationId);
collector.receive(CuptiTracerEvent{
.type = CuptiTracerEventType::CudaGraph,
.source = CuptiTracerEventSource::Activity,
.name = absl::StrCat("CudaGraphExec:", graph_trace->graphId),
.annotation = info.annotation,
.nvtx_range = info.nvtx_range,
.start_time_ns = graph_trace->start,
.end_time_ns = graph_trace->end,
.device_id = graph_trace->deviceId,
.correlation_id = graph_trace->correlationId,
.context_id = graph_trace->contextId,
.stream_id = graph_trace->streamId,
.graph_id = graph_trace->graphId,
/* .type = */ CuptiTracerEventType::CudaGraph,
/* .source = */ CuptiTracerEventSource::Activity,
/* .name = */ absl::StrCat("CudaGraphExec:", graph_trace->graphId),
/* .annotation = */ info.annotation,
/* .nvtx_range = */ info.nvtx_range,
/* .start_time_ns = */ graph_trace->start,
/* .end_time_ns = */ graph_trace->end,
/* .device_id = */ graph_trace->deviceId,
/* .correlation_id = */ graph_trace->correlationId,
/* .context_id = */ graph_trace->contextId,
/* .stream_id = */ graph_trace->streamId,
/* .graph_id = */ graph_trace->graphId,
});
}

Expand Down
2 changes: 1 addition & 1 deletion xla/backends/profiler/gpu/cupti_buffer_events.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct MemcpyDetails {
int8_t dst_mem_kind;

// ID of the hardware channel on which this operation ran.
uint32_t channel_id = -1;
uint32_t channel_id = static_cast<uint32_t>(-1);
// CUpti_ChannelType of the channel above.
int8_t channel_type = 0; // CUPTI_CHANNEL_TYPE_INVALID
};
Expand Down
2 changes: 2 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,9 @@ std::optional<DynamicOrStaticInteger> EvaluateWhileLoopParamInitValue(

namespace internal {

#if !defined(_MSC_VER)
constexpr absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl";
#endif

std::optional<EvalErrorDetail> ParseEvalErrorDetail(const absl::Status& error) {
auto error_detail = error.GetPayload(kEvalErrorDetailUrl);
Expand Down
4 changes: 4 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,11 @@ enum class EvalErrorDetail : uint32_t {
kDynamicValueDependence = 0,
};

#if defined(_MSC_VER)
extern const absl::string_view kEvalErrorDetailUrl = "EvalErrorDetailUrl";
#else
extern const absl::string_view kEvalErrorDetailUrl;
#endif

std::optional<EvalErrorDetail> ParseEvalErrorDetail(const absl::Status& error);

Expand Down
2 changes: 1 addition & 1 deletion xla/pjrt/c/pjrt_c_api_wrapper_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2129,7 +2129,7 @@ PJRT_Error* PJRT_Layouts_MemoryLayout_Serialize(
PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE, args->struct_size));

PJRT_Layouts_SerializedLayout* s_layout = new PJRT_Layouts_SerializedLayout{
.serialized = args->layout->layout->Serialize()};
/* .serialized = */ args->layout->layout->Serialize()};
args->serialized_layout = s_layout;
args->serialized_bytes = s_layout->serialized.data();
args->serialized_bytes_size = s_layout->serialized.size();
Expand Down
14 changes: 8 additions & 6 deletions xla/pjrt/gpu/se_gpu_pjrt_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,15 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options,
#endif
}

STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(
#if TENSORFLOW_USE_ROCM
RocmName(),
STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(RocmName(),
std::make_unique<StreamExecutorGpuCompiler>());
});
#else
CudaName(),
#endif
std::make_unique<StreamExecutorGpuCompiler>());
STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, {
PjRtRegisterCompiler(CudaName(),
std::make_unique<StreamExecutorGpuCompiler>());
});
#endif
} // namespace xla
8 changes: 4 additions & 4 deletions xla/service/cpu/runtime/conv_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void EigenConv2DImpl(
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation,
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count,
std::optional<std::function<void()>> done_callback = std::nullopt) {
std::optional<std::function<void()>> done_callback) {
const Eigen::TensorMap<Eigen::Tensor<const ScalarType, 4, Eigen::RowMajor>,
Eigen::Aligned>
input(lhs, input_batch, input_x, input_y, input_channels);
Expand Down Expand Up @@ -129,7 +129,7 @@ void EigenConv3DImpl(
Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation,
Eigen::Index feature_group_count,
std::optional<std::function<void()>> done_callback = std::nullopt) {
std::optional<std::function<void()>> done_callback) {
using ConstTType =
Eigen::TensorMap<Eigen::Tensor<const ScalarType, 5, Eigen::RowMajor>,
Eigen::Aligned>;
Expand Down Expand Up @@ -223,7 +223,7 @@ void EigenConv3DImpl(
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, \
std::optional<std::function<void()>> done_callback = std::nullopt)
std::optional<std::function<void()>> done_callback)

CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float);
Expand All @@ -249,7 +249,7 @@ CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float);
Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \
Eigen::Index feature_group_count, \
std::optional<std::function<void()>> done_callback = std::nullopt)
std::optional<std::function<void()>> done_callback)

CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float);
Expand Down
6 changes: 4 additions & 2 deletions xla/service/cpu/runtime_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_conv2d.h"

#include <optional>

#define EIGEN_USE_THREADS

#include "absl/base/dynamic_annotations.h"
Expand All @@ -41,7 +43,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF32(
kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
col_stride, padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16(
Expand All @@ -63,5 +65,5 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16(
kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
col_stride, padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}
6 changes: 4 additions & 2 deletions xla/service/cpu/runtime_conv3d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_conv3d.h"

#include <optional>

#define EIGEN_USE_THREADS

#include "absl/base/dynamic_annotations.h"
Expand Down Expand Up @@ -44,7 +46,7 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF32(
y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before,
padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation,
lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation,
rhs_z_dilation, feature_group_count);
rhs_z_dilation, feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16(
Expand All @@ -69,5 +71,5 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16(
y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before,
padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation,
lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation,
rhs_z_dilation, feature_group_count);
rhs_z_dilation, feature_group_count, std::nullopt);
}
6 changes: 4 additions & 2 deletions xla/service/cpu/runtime_single_threaded_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_single_threaded_conv2d.h"

#include <optional>

#include "absl/base/dynamic_annotations.h"
#include "xla/service/cpu/runtime/conv_impl.h"

Expand All @@ -35,7 +37,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF16(
kernel_filters, output_rows, output_cols, row_stride, col_stride,
padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
Expand All @@ -55,5 +57,5 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF32(
kernel_filters, output_rows, output_cols, row_stride, col_stride,
padding_top, padding_bottom, padding_left, padding_right,
lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation,
feature_group_count);
feature_group_count, std::nullopt);
}
6 changes: 4 additions & 2 deletions xla/service/cpu/runtime_single_threaded_conv3d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "xla/service/cpu/runtime_single_threaded_conv3d.h"

#include <optional>

#include "absl/base/dynamic_annotations.h"
#include "xla/service/cpu/runtime/conv_impl.h"

Expand All @@ -38,7 +40,7 @@ __xla_cpu_runtime_EigenSingleThreadedConv3DF32(
z_stride, padding_x_before, padding_x_after, padding_y_before,
padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation,
lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation,
rhs_z_dilation, feature_group_count);
rhs_z_dilation, feature_group_count, std::nullopt);
}

ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void
Expand All @@ -61,5 +63,5 @@ __xla_cpu_runtime_EigenSingleThreadedConv3DF16(
z_stride, padding_x_before, padding_x_after, padding_y_before,
padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation,
lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation,
rhs_z_dilation, feature_group_count);
rhs_z_dilation, feature_group_count, std::nullopt);
}
12 changes: 6 additions & 6 deletions xla/service/gpu/fusions/mlir/computation_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,12 @@ PartitionedComputation::PartitionedComputation(
absl::StrJoin(roots, "_", [](std::string* out, const auto* root) {
absl::StrAppend(out, root->name());
})));
subgraphs_.push_back(
Subgraph{.name = std::move(name),
.instructions = {instructions.begin(), instructions.end()},
.roots = std::move(roots),
.index_ranges = std::move(ranges),
.root_indexing = std::move(root_indexing)});
subgraphs_.push_back(Subgraph{
/* .name = */ std::move(name),
/* .instructions = */ {instructions.begin(), instructions.end()},
/* .roots = */ std::move(roots),
/* .index_ranges = */ std::move(ranges),
/* .root_indexing = */ std::move(root_indexing)});
}

for (const auto& subgraph : subgraphs_) {
Expand Down
38 changes: 33 additions & 5 deletions xla/service/gpu/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ load("//xla:xla.bzl", "xla_cc_binary")
load("//xla/service/gpu:build_defs.bzl", "gpu_kernel_library")
load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured")
load("//xla/tests:build_defs.bzl", "xla_test")
load("//xla/tsl:tsl.bzl", "if_windows")

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
Expand Down Expand Up @@ -325,7 +326,10 @@ cc_library(
cuda_library(
name = "cutlass_gemm_adaptor",
hdrs = if_cuda_is_configured(["cutlass_gemm_adaptor.cu.h"]),
copts = ["-Wno-unknown-attributes"], # __grid_constant__ is not supported by clang
copts = if_windows(
[],
["-Wno-unknown-attributes"],
), # __grid_constant__ is not supported by clang
deps = if_cuda_is_configured([
":cutlass_gemm",
"@cutlass_archive//:cutlass",
Expand Down Expand Up @@ -367,7 +371,13 @@ cc_library(
cuda_library(
name = "cutlass_gemm_kernel_bf16xbf16_to_bf16",
srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc"]),
copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"],
copts = [
"-mllvm",
"-unroll-threshold=100000",
] + if_windows(
[],
["-Wno-unknown-attributes"],
),
deps = if_cuda_is_configured([
":cutlass_gemm_adaptor",
"@cutlass_archive//:cutlass",
Expand All @@ -378,7 +388,13 @@ cuda_library(
cuda_library(
name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80",
srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc"]),
copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"],
copts = [
"-mllvm",
"-unroll-threshold=100000",
] + if_windows(
[],
["-Wno-unknown-attributes"],
),
deps = if_cuda_is_configured([
":cutlass_gemm_adaptor",
"@cutlass_archive//:cutlass",
Expand All @@ -389,7 +405,16 @@ cuda_library(
cuda_library(
name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90",
srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc"]),
copts = ["-Wno-ctad-maybe-unsupported -Wno-unknown-attributes -mllvm -unroll-threshold=100000"],
copts = [
"-mllvm",
"-unroll-threshold=100000",
] + if_windows(
[],
[
"-Wno-ctad-maybe-unsupported",
"-Wno-unknown-attributes",
],
),
deps = if_cuda_is_configured([
":cutlass_gemm_adaptor",
":cutlass_gemm_epilogue",
Expand All @@ -401,7 +426,10 @@ cuda_library(
cuda_library(
name = "cutlass_gemm_kernel_f32xf32_to_f32",
srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xf32_to_f32.cu.cc"]),
copts = ["-Wno-unknown-attributes"],
copts = if_windows(
[],
["-Wno-unknown-attributes"],
),
deps = if_cuda_is_configured([
":cutlass_gemm_adaptor",
"@cutlass_archive//:cutlass",
Expand Down
5 changes: 3 additions & 2 deletions xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,9 @@ namespace adaptor_3x {
template <typename Tag>
static std::optional<Dim3> ClusterDim() {
typename Traits<Tag>::Kernel::DispatchPolicy::ClusterShape cluster;
return Dim3{cute::get<0>(cluster), cute::get<1>(cluster),
cute::get<2>(cluster)};
return Dim3{static_cast<uint32_t>(cute::get<0>(cluster)),
static_cast<uint32_t>(cute::get<1>(cluster)),
static_cast<uint32_t>(cute::get<2>(cluster))};
}

template <typename Tag>
Expand Down
4 changes: 4 additions & 0 deletions xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ KernelArgsPacking ArgsPacking(int32_t m, int32_t n, int32_t k,
// object constructed in the storage. For now we ignore it, and it's textbook
// definition of UB, but for CUTLASS kernels we use today it's perfectly safe.
struct Params {
#if defined(_MSC_VER)
alignas(64) std::byte storage[1024];
#else
alignas(128) std::byte storage[1024];
#endif
};

return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed {
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/model/gpu_collective_performance_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ float GpuPerformanceWithCollectiveModel::GetNvlinkBw(
}

/*static*/ bool GpuPerformanceWithCollectiveModel::InitNvml() {
#if GOOGLE_CUDA
#if GOOGLE_CUDA && defined(PLATFORM_POSIX)
void* libhandle = dlopen("libnvidia-ml.so.1", RTLD_NOW);
CHECK(libhandle != nullptr) << "Failed to open libnvidia-ml.so.1";

Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/model/gpu_collective_performance_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ limitations under the License.
#include "xla/stream_executor/device_description.h"

#if GOOGLE_CUDA
#if defined(PLATFORM_POSIX)
#include <dlfcn.h>
#endif

#include "third_party/gpus/cuda/nvml/include/nvml.h"
// Below is a list of function pointers to be used
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/stream_executor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ static void InitializeTypedBuffer(se::Stream* stream,

// Use a large prime number to fragment the accesses.
constexpr int host_buffer_size = 10069;
static std::vector<T>* host_buffer = [] {
static std::vector<T>* host_buffer = [&] {
auto* ret = new std::vector<T>(host_buffer_size);
// Default-seeded random numbers.
std::mt19937 gen;
Expand Down
Loading

0 comments on commit cb18089

Please sign in to comment.