Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed some issues around compiling on Windows. #15444

Closed

Conversation

eaplatanios
Copy link
Contributor

@eaplatanios eaplatanios commented Jul 29, 2024

This PR fixes some issues I bumped into when trying to compile XLA on Windows. I still haven't gotten GPU support to work but I'm making progress. The CPU only version compiles fine after some of the changes in this PR. I'll point out some specific issues this PR fixes in comments.

There are also TSL-specific changes that are pulled in a separate PR (#15499).

configured_version = "%{cuda_version}"
configured_major = int(configured_version.split('.')[0])
configured_minor = int(configured_version.split('.')[1])
# Strip "64_" which appears in the CUDA version on Windows.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly self-explanatory. In Windows, the CUDA version looks something like 64_121. There are other parts of the build that already handle such version numbers but it was not being handled properly here.

.context_id = graph_trace->contextId,
.stream_id = graph_trace->streamId,
.graph_id = graph_trace->graphId,
/* .type = */ CuptiTracerEventType::CudaGraph,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA is configured to build using C++ 17. However, this is a C++ 20 feature, resulting in the following error when trying to compile on Windows:

error C7555: use of designated initializers requires at least '/std:c++20'

@@ -56,7 +56,7 @@ struct MemcpyDetails {
int8_t dst_mem_kind;

// ID of the hardware channel on which this operation ran.
uint32_t channel_id = -1;
uint32_t channel_id = static_cast<uint32_t>(-1);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resulted in an implicit type narrowing error (I believe it was C2397). The explicit static cast fixes it.

@@ -41,7 +41,7 @@ void EigenConv2DImpl(
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation,
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation,
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count,
std::optional<std::function<void()>> done_callback = std::nullopt) {
Copy link
Contributor Author

@eaplatanios eaplatanios Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resulted in this error:

error C2765: 'tensorflow::xla::EigenConv2DImpl': an explicit specialization or instantiation of a function template cannot have any default arguments

I just removed a couple default arguments that were causing this error and propagated them at call sites where they were missing.

@@ -199,13 +199,17 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options,
#endif
}

#if TENSORFLOW_USE_ROCM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested macro seems to not be supported by MSVC. Pushing the inner ifdef outside the other macro seems to work and doesn't change the behavior/functionality of the code here.

alignas(128) std::byte storage[1024];
#endif
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in the following error on Windows:

error C2719: '<args_0>': formal parameter with requested alignment of 128 won't be aligned

cc @dimvar who previously made the change from 64 to 128.

@NaiyerRizz NaiyerRizz requested a review from ezhulenev July 30, 2024 06:59
@NaiyerRizz NaiyerRizz self-assigned this Jul 30, 2024
@eaplatanios
Copy link
Contributor Author

I also have some changes for third_party/tsl. I assume I should make those directly to the openxla/tsl repository? If I do that, how do I get the submodule here to update since it appears to be a copy of the code?

@ezhulenev ezhulenev requested a review from ddunl July 30, 2024 15:55
@ezhulenev
Copy link
Member

@ddunl

@ddunl
Copy link
Member

ddunl commented Jul 30, 2024

Thanks!! LGTM, for TSL changes I think it's easiest to open a separate PR on this repo which does the TSL edits in the copy. We are in limbo with TSL right now, some of it has been moved here to xla/tsl and I'm still moving the rest. It'd be good if the third_party/gpus/cuda change could be a part of that TSL PR also.

@eaplatanios
Copy link
Contributor Author

Thanks!! LGTM, for TSL changes I think it's easiest to open a separate PR on this repo which does the TSL edits in the copy. We are in limbo with TSL right now, some of it has been moved here to xla/tsl and I'm still moving the rest. It'd be good if the third_party/gpus/cuda change could be a part of that TSL PR also.

Thanks that's super helpful! I'll go ahead and open that PR as well in this repo.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
This saves one register and a few instructions in the hot loop.

name                                     old time/op          new time/op          delta
BM_SelectAndScatterF32/128/process_time   377µs ± 4%           371µs ± 2%  -1.73%
BM_SelectAndScatterF32/256/process_time  1.55ms ± 4%          1.52ms ± 2%  -1.98%
BM_SelectAndScatterF32/512/process_time  6.64ms ± 4%          6.58ms ± 4%  -0.93%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657602607
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
name                                     old cpu/op   new cpu/op   delta
BM_SelectAndScatterF32/128/process_time   373µs ± 2%   337µs ± 2%   -9.74%
BM_SelectAndScatterF32/256/process_time  1.54ms ± 3%  1.39ms ± 4%  -10.04%
BM_SelectAndScatterF32/512/process_time  7.08ms ± 7%  6.42ms ± 6%   -9.29%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657676415
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657802573
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
This saves one register and a few instructions in the hot loop.

name                                     old time/op          new time/op          delta
BM_SelectAndScatterF32/128/process_time   377µs ± 4%           371µs ± 2%  -1.73%
BM_SelectAndScatterF32/256/process_time  1.55ms ± 4%          1.52ms ± 2%  -1.98%
BM_SelectAndScatterF32/512/process_time  6.64ms ± 4%          6.58ms ± 4%  -0.93%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657602607
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
name                                     old cpu/op   new cpu/op   delta
BM_SelectAndScatterF32/128/process_time   889µs ± 1%   740µs ± 3%  -16.70%
BM_SelectAndScatterF32/256/process_time  3.64ms ± 2%  3.00ms ± 1%  -17.64%
BM_SelectAndScatterF32/512/process_time  15.3ms ± 1%  13.1ms ± 3%  -14.61%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657693426
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
name                                     old cpu/op   new cpu/op   delta
BM_SelectAndScatterF32/128/process_time   385µs ± 2%   378µs ± 4%  -1.82%
BM_SelectAndScatterF32/256/process_time  1.58ms ± 2%  1.56ms ± 2%  -1.77%
BM_SelectAndScatterF32/512/process_time  7.24ms ± 4%  7.07ms ± 6%  -2.39%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657437409
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
This can be disabled with the flag --xla_gpu_mlir_emitter_level, setting it to
any value < 4.
Change some tests to still use the old emitters. We have separate IR tests for
the new emitters, and keeping the old tests running with the old emitters ensures
we still have coverage for the old emitters, in case we need to rollback.
One notable change with enabling emitter level 4 is that the heuristic to avoid
code duplication due to cache invalidation is disabled. This was always a
a workaround, and the new emitters fixed the problem. This is the most common
source of why the tests behave differently between the old and the new emitters.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 653901032
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 655567725
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
…eline.

Since Shardy is inside the middle of the XLA pipeline, after converting down to HLO, we need to run the Shardy export pipeline to preserve the SDY ops and sharding attributes for when we come back from HLO to MLIR when Shardy propagation is run.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 649052479
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
This saves one register and a few instructions in the hot loop.

name                                     old time/op          new time/op          delta
BM_SelectAndScatterF32/128/process_time   377µs ± 4%           371µs ± 2%  -1.73%
BM_SelectAndScatterF32/256/process_time  1.55ms ± 4%          1.52ms ± 2%  -1.98%
BM_SelectAndScatterF32/512/process_time  6.64ms ± 4%          6.58ms ± 4%  -0.93%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657602607
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
name                                     old cpu/op   new cpu/op   delta
BM_SelectAndScatterF32/128/process_time   889µs ± 1%   740µs ± 3%  -16.70%
BM_SelectAndScatterF32/256/process_time  3.64ms ± 2%  3.00ms ± 1%  -17.64%
BM_SelectAndScatterF32/512/process_time  15.3ms ± 1%  13.1ms ± 3%  -14.61%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657693426
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
name                                     old cpu/op   new cpu/op   delta
BM_SelectAndScatterF32/128/process_time   385µs ± 2%   378µs ± 4%  -1.82%
BM_SelectAndScatterF32/256/process_time  1.58ms ± 2%  1.56ms ± 2%  -1.77%
BM_SelectAndScatterF32/512/process_time  7.24ms ± 4%  7.07ms ± 6%  -2.39%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657437409
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
Imported from GitHub PR openxla/xla#15444

This PR fixes some issues I bumped into when trying to compile XLA on Windows. I still haven't gotten GPU support to work but I'm making progress. The CPU only version compiles fine after some of the changes in this PR. I'll point out some specific issues this PR fixes in comments.

There are also TSL-specific changes that are pulled in a separate PR (#15499).
Copybara import of the project:

--
eacee95f41abc49a21516ee389861d84a40eca85 by eaplatanios <e.a.platanios@gmail.com>:

Fixed some issues around compiling on Windows.

--
b12e4cf0d23c2690111125a651e486ec6a112e54 by eaplatanios <e.a.platanios@gmail.com>:

.

--
e23ef176de72cf04555242174a19a407884f3f0e by eaplatanios <e.a.platanios@gmail.com>:

.

--
bdae19b9e15c396985703bb7e88a4db6fcddc7f6 by eaplatanios <e.a.platanios@gmail.com>:

.

--
2f90e6ba564f92fafa564b104ed0ce82b7642563 by eaplatanios <e.a.platanios@gmail.com>:

.

--
57009793b74c4d7d51fb39547a70a3ec142dadab by eaplatanios <e.a.platanios@gmail.com>:

.

--
a978b1f7f70d49f1426fe46b107fdcc3618e3085 by eaplatanios <e.a.platanios@gmail.com>:

.

--
d7fe81dc9cf909a6a8d70e2be8cfffca4063493e by eaplatanios <e.a.platanios@gmail.com>:

.

--
fc40d919619330bce596555613e425cb6267eea4 by eaplatanios <e.a.platanios@gmail.com>:

.

--
326aec3fd73a67ca3c667cfeb5c88a8ffa52eb3d by eaplatanios <e.a.platanios@gmail.com>:

.

--
a7603b7e1be990ff012440c74bd2c2ecbc2b1e2f by eaplatanios <e.a.platanios@gmail.com>:

.

--
edcc97a67016584c285d84ac732952c572283119 by eaplatanios <e.a.platanios@gmail.com>:

.

--
cec244808a8df163f9a803db450ca2bebdda9315 by eaplatanios <e.a.platanios@gmail.com>:

.

--
df3eb2215eea9076cb352378c5745e113df7cc7d by eaplatanios <e.a.platanios@gmail.com>:

.

--
8997345fd1e1aa6f55e445615460124c6e14417c by eaplatanios <e.a.platanios@gmail.com>:

.

--
219a9f1bff7fb12c3407ab2e47512560001900fe by eaplatanios <e.a.platanios@gmail.com>:

.

--
73f3cd7e0135ec05c97595f795ec318fb635bd32 by eaplatanios <e.a.platanios@gmail.com>:

.

Merging this change closes #15444

PiperOrigin-RevId: 657937707
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
name                                     old cpu/op   new cpu/op   delta
BM_SelectAndScatterF32/128/process_time   373µs ± 2%   337µs ± 2%   -9.74%
BM_SelectAndScatterF32/256/process_time  1.54ms ± 3%  1.39ms ± 4%  -10.04%
BM_SelectAndScatterF32/512/process_time  7.08ms ± 7%  6.42ms ± 6%   -9.29%

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657676415
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657960174
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Jul 31, 2024
Updates LLVM usage to match
[42d641ef5cc4](llvm/llvm-project@42d641ef5cc4)

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15444 from eaplatanios:u/eaplatanios/cpp-17-fixes 73f3cd7e0135ec05c97595f795ec318fb635bd32
PiperOrigin-RevId: 657972192
copybara-service bot pushed a commit that referenced this pull request Aug 12, 2024
This is following up on #15444. There's still one issue blocking Windows support that I haven't resolved. I've described it here. Any suggestions/advice on how to proceed for that one would be helpful. cc @hawkinsp @ddunl and also fyi @metab0t.

This closes #15499.

PiperOrigin-RevId: 662267938
copybara-service bot pushed a commit that referenced this pull request Aug 13, 2024
This is following up on #15444. There's still one issue blocking Windows support that I haven't resolved. I've described it here. Any suggestions/advice on how to proceed for that one would be helpful. cc @hawkinsp @ddunl and also fyi @metab0t.

This closes #15499.

PiperOrigin-RevId: 662267938
copybara-service bot pushed a commit that referenced this pull request Aug 14, 2024
This is following up on #15444. There's still one issue blocking Windows support that I haven't resolved. I've described it here. Any suggestions/advice on how to proceed for that one would be helpful. cc @hawkinsp @ddunl and also fyi @metab0t.

This closes #15499.

PiperOrigin-RevId: 662267938
copybara-service bot pushed a commit that referenced this pull request Aug 14, 2024
This is following up on #15444. There's still one issue blocking Windows support that I haven't resolved. I've described it here. Any suggestions/advice on how to proceed for that one would be helpful. cc @hawkinsp @ddunl and also fyi @metab0t.

This closes #15499.

PiperOrigin-RevId: 662267938
copybara-service bot pushed a commit that referenced this pull request Aug 14, 2024
This is following up on #15444. There's still one issue blocking Windows support that I haven't resolved. I've described it here. Any suggestions/advice on how to proceed for that one would be helpful. cc @hawkinsp @ddunl and also fyi @metab0t.

This closes #15499.

PiperOrigin-RevId: 663079395
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants