Skip to content

Commit

Permalink
MSVC fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Dec 29, 2023
1 parent 4c55dc5 commit 2716adf
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 17 deletions.
76 changes: 60 additions & 16 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def get_path(basename: str, extension: str, specified_dir: str = ""):
subdir = os.path.join(cache_dir(), specified_dir)
else:
subdir = os.path.join(cache_dir(), basename[1:3])
path = os.path.join(subdir, f"{basename}.{extension}")
path = os.path.join(subdir, f"{basename}.{extension}").replace(os.sep, "/")
return basename, subdir, path


Expand Down Expand Up @@ -431,7 +431,10 @@ def cpp_compiler_search(search):
)
with lock:
cxx = install_gcc_via_conda()
subprocess.check_output([cxx, "--version"])
if cxx == "cl":
subprocess.check_output([cxx])
else:
subprocess.check_output([cxx, "--version"])
return cxx
except (subprocess.SubprocessError, FileNotFoundError, ImportError):
continue
Expand Down Expand Up @@ -504,7 +507,12 @@ class VecISA:
__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0};
extern "C" void __avx_chk_kernel() {
#ifdef _MSC_VER
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif
extern "C" DLLEXPORT void __avx_chk_kernel() {
auto tmp0 = at::vec::Vectorized<float>(1);
auto tmp1 = tmp0.exp();
tmp1.store(in_out_ptr0);
Expand Down Expand Up @@ -543,7 +551,7 @@ def __bool__(self):
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
output_path = input_path[:-3] + "so"
output_path = input_path[:-3] + ("so" if sys.platform != "win32" else "dll")
build_cmd = shlex.split(
cpp_compile_command(
input_path, output_path, warning_all=False, vec_isa=self
Expand Down Expand Up @@ -647,6 +655,10 @@ def pick_vec_isa():


def get_shared(shared=True):
if sys.platform == "win32":
if cpp_compiler() in ["cl", "clang", "clang-cl"]:
return ""
return "-shared" if shared else ""
return "-shared -fPIC" if shared else ""


Expand All @@ -655,6 +667,8 @@ def get_warning_all_flag(warning_all=True):


def cpp_flags():
if cpp_compiler() in ["cl", "clang-cl"]:
return "/std:c++17"
return "-std=c++17 -Wno-unused-variable"


Expand All @@ -664,6 +678,8 @@ def cpp_wrapper_flags():

def optimization_flags():
base_flags = "-O3 -ffast-math -fno-finite-math-only"
if cpp_compiler() in ["cl", "clang-cl"]:
base_flags = "/nologo /O2 /fp:fast"
if config.is_fbcode():
# FIXME: passing `-fopenmp` adds libgomp.so to the generated shared library's dependencies.
# This causes `ldopen` to fail in fbcode, because libgomp does not exist in the default paths.
Expand All @@ -674,6 +690,8 @@ def optimization_flags():
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
# Also, `-march=native` is unrecognized option on M1
base_flags += " -Xclang"
elif sys.platform == "win32":
pass
else:
if platform.machine() == "ppc64le":
base_flags += " -mcpu=native"
Expand All @@ -682,12 +700,15 @@ def optimization_flags():

# Internal cannot find libgomp.so
if not config.is_fbcode():
base_flags += " -fopenmp"
if cpp_compiler() in ["cl", "clang-cl"]:
base_flags += " /openmp"
else:
base_flags += " -fopenmp"
return base_flags


def use_custom_generated_macros():
return "-D C10_USING_CUSTOM_GENERATED_MACROS"
return "-DC10_USING_CUSTOM_GENERATED_MACROS"


def use_fb_internal_macros():
Expand Down Expand Up @@ -844,6 +865,9 @@ def get_include_and_linking_paths(
else:
libs = ["omp"] if config.is_fbcode() else ["gomp"]

if sys.platform == "win32" and "gomp" in libs:
libs.pop(libs.index("gomp"))

# third party libs
if config.is_fbcode():
ipaths.append(build_paths.sleef())
Expand All @@ -859,9 +883,13 @@ def get_include_and_linking_paths(
# (later on, we copy the include paths from cpp_extensions into our remote dir)
ipaths.append("include")

ipaths = " ".join(["-I" + p for p in ipaths])
lpaths = " ".join(["-L" + p for p in lpaths])
libs = " ".join(["-l" + p for p in libs])
ipaths = " ".join([f'-I"{p}"' for p in ipaths])
lpaths = " ".join([f'-L"{p}"' for p in lpaths])
libs = " ".join([f'-l"{p}"' for p in libs])
if sys.platform == "win32":
ipaths = ipaths.replace(os.sep, "/")
lpaths = lpaths.replace(os.sep, "/")
libs = libs.replace(os.sep, "/")
return ipaths, lpaths, libs, macros


Expand Down Expand Up @@ -892,6 +920,10 @@ def cpp_compile_command(
inp_name = input
out_name = output
linker_paths = "" # let the compiler pick

out_dir = ""
if cpp_compiler() in ["cl", "clang-cl"]:
out_dir = "/Fe:" + os.path.dirname(out_name) + "/"
return re.sub(
r"[ \n]+",
" ",
Expand All @@ -903,7 +935,8 @@ def cpp_compile_command(
{use_custom_generated_macros()}
{use_fb_internal_macros()}
{use_standard_sys_dir_headers()}
-o {out_name}
{out_dir}
{"-o " if "cl" not in cpp_compiler() else "/LDd /OUT:"}"{out_name}"
""",
).strip()

Expand Down Expand Up @@ -953,7 +986,7 @@ def compile(cls, graph, source_code, cuda):
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
output_so = os.path.splitext(input_path)[0] + ".so"
output_so = os.path.splitext(input_path)[0] + (".so" if sys.platform != "win32" else ".dll")

if not os.path.exists(output_so):
cmd = shlex.split(
Expand Down Expand Up @@ -1011,9 +1044,16 @@ def cpp_prefix():
# everything that we compile into a folder for remote compilation.
return f'#include "{os.path.basename(filename)}"'
else:
filename = filename.replace(os.sep, "/")
return f'#include "{filename}"'


@functools.lru_cache(None)
def output_encoding():
import locale
return locale.getpreferredencoding()


# Given a path to an input cpp file and an output path,
# Attempts to compile the file, storing the output in "output_path"
def compile_file(input_path, output_path, cmd) -> None:
Expand Down Expand Up @@ -1045,7 +1085,7 @@ def compile_file(input_path, output_path, cmd) -> None:
else:
subprocess.check_output(cmd, stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
output = e.output.decode("utf-8")
output = e.output.decode(output_encoding())
openmp_problem = "'omp.h' file not found" in output or "libomp" in output
if openmp_problem and sys.platform == "darwin":
instruction = (
Expand Down Expand Up @@ -1095,15 +1135,19 @@ def load(cls, source_code):
lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
output_path = input_path[:-3] + "so"
output_path = input_path[:-3] + ("so" if sys.platform != "win32" else "dll")
if not os.path.exists(output_path):
cmd = shlex.split(
cpp_compile_command(
input=input_path, output=output_path, vec_isa=picked_vec_isa
)
)
compile_file(input_path, output_path, cmd)
cls.cache[key] = cls._load_library(output_path)
if sys.platform == "win32":
#cls.cache[key] = cls._load_library(os.path.join(".", os.path.basename(output_path)))
cls.cache[key] = cls._load_library(output_path)
else:
cls.cache[key] = cls._load_library(output_path)
cls.cache[key].key = key

return cls.cache[key]
Expand All @@ -1128,7 +1172,7 @@ def load_by_key_path(cls, key, path, linemap=()):
if key not in cls.cache:
with open(path) as f:
try:
code = compile(f.read(), path, "exec")
code = compile(f.read(), path.replace(os.sep, "/"), "exec")
except Exception as e:
raise RuntimeError(
f"Failed to import {path}\n{type(e).__name__}: {e}"
Expand Down Expand Up @@ -1183,7 +1227,7 @@ def load(cls, source_code, func_name, key, cuda):
if not os.path.exists(cpp_wrapper_dir):
os.makedirs(cpp_wrapper_dir)

ext = "so"
ext = "so" if sys.platform != "win32" else "dll"
filepath = os.path.join(cpp_wrapper_dir, f"{name}.{ext}")
log.debug("Cpp wrapper code path %s", filepath)

Expand Down
7 changes: 6 additions & 1 deletion torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2858,7 +2858,12 @@ def codegen_define_and_call(self, wrapper):
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
code.writeline(codecache.cpp_prefix())

code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})')
code.writeline('#ifdef _MSC_VER')
code.writeline(' #define DLLEXPORT __declspec(dllexport)')
code.writeline('#else')
code.writeline(' #define DLLEXPORT')
code.writeline('#endif')
code.writeline(f'extern "C" DLLEXPORT void {kernel_decl_name}({arg_defs})')
with code.indent():
if enable_kernel_profile:
graph_id = V.graph.graph_id
Expand Down

0 comments on commit 2716adf

Please sign in to comment.