Skip to content

Commit

Permalink
Added the ability to build a project with PyTorch 2.0. (#2553)
Browse files Browse the repository at this point in the history
* Added the ability to build a project with PyTorch 2.0.
Namely, I added the flag -std=c++17 to extra_compile_args
 depending on the version of Torch.

* Lost the condition for the presence of nvcc

* Lost the condition for the presence of nvcc

* Add parse_version

* fix lint

---------

Co-authored-by: Xin Chen <irexyc@gmail.com>
  • Loading branch information
Danil328 and irexyc authored Nov 30, 2023
1 parent db73d55 commit 660af62
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os

from pkg_resources import parse_version
from setuptools import find_packages, setup

EXT_TYPE = ''
try:
import torch
from torch.utils.cpp_extension import BuildExtension
cmd_class = {'build_ext': BuildExtension}
EXT_TYPE = 'torch'
Expand Down Expand Up @@ -139,7 +141,10 @@ def get_extensions():
# to compile those cpp files, so there is no need to add the
# argument
if platform.system() != 'Windows':
extra_compile_args['cxx'] = ['-std=c++14']
if parse_version(torch.__version__) <= parse_version('1.12.1'):
extra_compile_args['cxx'] = ['-std=c++14']
else:
extra_compile_args['cxx'] = ['-std=c++17']

include_dirs = []

Expand All @@ -159,7 +164,10 @@ def get_extensions():
# to compile those cpp files, so there is no need to add the
# argument
if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
extra_compile_args['nvcc'] += ['-std=c++14']
if parse_version(torch.__version__) <= parse_version('1.12.1'):
extra_compile_args['nvcc'] += ['-std=c++14']
else:
extra_compile_args['nvcc'] += ['-std=c++17']

ext_ops = extension(
name=ext_name,
Expand Down

0 comments on commit 660af62

Please sign in to comment.