diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index d099290a316..2f8fe2d3be8 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -119,6 +119,19 @@ jobs: Copy-Item -Path "$env:GITHUB_WORKSPACE\swiftshader-install\vulkan-1.dll" -Destination 'build-x64\tests' cd build-x64; ctest -C Release --output-on-failure -j 4 + - name: x64-simpleomp + run: | + mkdir build-x64-simpleomp; cd build-x64-simpleomp + cmake -T ${{ matrix.toolset-version }},host=x64 -A x64 -Dprotobuf_DIR="$env:GITHUB_WORKSPACE\protobuf-install\cmake" -DNCNN_SHARED_LIB=OFF -DNCNN_VULKAN=ON -DNCNN_SIMPLEOMP=ON -DNCNN_BUILD_TESTS=ON .. + cmake --build . --config Release -j 4 + - name: x64-simpleomp-test + if: matrix.vs-version != 'vs2015' && matrix.vs-version != 'vs2017' + run: | + echo "[Processor]`nThreadCount=1`n" > build-x64-simpleomp/tests/Release/SwiftShader.ini + Copy-Item -Path "$env:GITHUB_WORKSPACE\swiftshader-install\vulkan-1.dll" -Destination 'build-x64-simpleomp\tests' + cd build-x64-simpleomp; ctest -C Release --output-on-failure -j 4 +# Copy-Item -Path "build-x64-simpleomp\src\Release\ncnn.dll" -Destination 'build-x64-simpleomp\tests' + - name: x64-sse2 run: | mkdir build-x64-sse2; cd build-x64-sse2 @@ -144,3 +157,7 @@ jobs: run: | Copy-Item -Path "build-x86\src\Release\ncnn.dll" -Destination 'build-x86\tests' cd build-x86; ctest -C Release --output-on-failure -j 4 + + + + diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f32a80c86e..9734b97ac19 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -687,7 +687,7 @@ if(NCNN_VULKAN) endif() endif() - if (TARGET glslang AND TARGET SPIRV) + if(TARGET glslang AND TARGET SPIRV) get_property(glslang_location TARGET glslang PROPERTY LOCATION) get_property(SPIRV_location TARGET SPIRV PROPERTY LOCATION) message(STATUS "Found glslang: ${glslang_location} (found version \"${glslang_VERSION}\")") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 803c34a780d..5aa7206ac8b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -10,8 +10,8 @@ function(ncnn_src_group ncnn_src_string folder) string(REGEX REPLACE "/" "\\\\" _target_folder "${folder}") foreach(_file IN LISTS ${_ncnn_src_list}) - source_group ("${_target_folder}" FILES "${_file}") - endforeach () + source_group("${_target_folder}" FILES "${_file}") + endforeach() endfunction() set(ncnn_SRCS @@ -215,11 +215,11 @@ endif() target_include_directories(ncnn PUBLIC - $ - $ - $ + $ + $ + $ PRIVATE - $) + $) if(NCNN_OPENMP) if(NOT NCNN_SIMPLEOMP) @@ -237,6 +237,8 @@ if(NCNN_OPENMP) if(NCNN_SIMPLEOMP) if(IOS OR APPLE) target_compile_options(ncnn PRIVATE -Xpreprocessor -fopenmp) + elseif(MSVC OR WIN32) + target_compile_options(ncnn PRIVATE /openmp) else() target_compile_options(ncnn PRIVATE -fopenmp) endif() @@ -262,7 +264,11 @@ if(NCNN_THREADS) target_link_libraries(ncnn PUBLIC Threads::Threads) endif() if(NCNN_SIMPLEOMP OR NCNN_SIMPLESTL) - target_link_libraries(ncnn PUBLIC pthread) + if(MSVC OR WIN32) + message(STATUS "Using Native WIN32 THREAD MODEL") + else() + target_link_libraries(ncnn PUBLIC pthread) + endif() endif() endif() @@ -398,7 +404,7 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") else() if(NOT CMAKE_SYSTEM_NAME MATCHES "WASI") target_compile_options(ncnn PRIVATE -msse2 -msse) - endif () + endif() if(CMAKE_SYSTEM_NAME MATCHES "Emscripten|WASI") target_compile_options(ncnn PRIVATE -msimd128) endif() diff --git a/src/simpleomp.cpp b/src/simpleomp.cpp index 6de069a9b6f..73719efac79 100644 --- a/src/simpleomp.cpp +++ b/src/simpleomp.cpp @@ -61,11 +61,52 @@ extern "C" typedef void (*kmpc_micro_30)(int32_t* gtid, int32_t* tid, void*, voi extern "C" typedef void (*kmpc_micro_31)(int32_t* gtid, int32_t* tid, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); #endif // __clang__ +#if _WIN32 +extern "C" typedef void (*win_kmpc_micro)(int32_t* gtid, int32_t* tid, ...); +extern "C" typedef void (*win_kmpc_micro_0)(); +extern "C" typedef void (*win_kmpc_micro_1)(void*); +extern "C" typedef void (*win_kmpc_micro_2)(void*, void*); +extern "C" typedef void (*win_kmpc_micro_3)(void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_4)(void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_5)(void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_6)(void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_7)(void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_8)(void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_9)(void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_10)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_11)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_12)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_13)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_14)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_15)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_16)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_17)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_18)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_19)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_20)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_21)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_22)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_23)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_24)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_25)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_26)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_27)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_28)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_29)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_30)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +extern "C" typedef void (*win_kmpc_micro_31)(void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*, void*); +#endif // _WIN32 + #ifdef __cplusplus extern "C" { #endif +#ifdef _WIN32 +static BOOL CALLBACK init_g_kmp_global(PINIT_ONCE InitOnce, PVOID Parameter, PVOID* Context); +#else static void init_g_kmp_global(); +#endif // _WIN32 + static void* kmp_threadfunc(void* args); #ifdef __cplusplus @@ -83,6 +124,11 @@ class KMPTask kmpc_micro fn; int argc; void** argv; +#elif _WIN32 + // vcomp abi + void (*fn)(void*); + int argc; + void** argv; #else // libgomp abi void (*fn)(void*); @@ -210,11 +256,19 @@ class KMPGlobal void try_init() { +#ifdef _WIN32 + InitOnceExecuteOnce(&is_initialized, init_g_kmp_global, NULL, NULL); +#else pthread_once(&is_initialized, init_g_kmp_global); +#endif // _WIN32 } public: +#ifdef _WIN32 + static INIT_ONCE is_initialized; +#else static pthread_once_t is_initialized; +#endif // _WIN32 void init() { @@ -248,10 +302,14 @@ class KMPGlobal tasks[i].fn = 0; tasks[i].argc = 0; tasks[i].argv = (void**)0; +#elif _WIN32 + tasks[i].fn = 0; + tasks[i].argc = 0; + tasks[i].argv = (void**)0; #else tasks[i].fn = 0; tasks[i].data = 0; -#endif +#endif // __clang__ tasks[i].num_threads = kmp_max_threads; tasks[i].thread_num = i + 1; tasks[i].num_threads_to_wait = 0; @@ -288,17 +346,29 @@ class KMPGlobal } // namespace ncnn +#ifdef _WIN32 +INIT_ONCE ncnn::KMPGlobal::is_initialized = INIT_ONCE_STATIC_INIT; +#else pthread_once_t ncnn::KMPGlobal::is_initialized = PTHREAD_ONCE_INIT; +#endif // _WIN32 static ncnn::KMPGlobal g_kmp_global; static ncnn::ThreadLocalStorage tls_num_threads; static ncnn::ThreadLocalStorage tls_thread_num; +#ifdef _WIN32 +static BOOL CALLBACK init_g_kmp_global(PINIT_ONCE InitOnce, PVOID Parameter, PVOID* Context) +{ + g_kmp_global.init(); + return TRUE; +} +#else static void init_g_kmp_global() { g_kmp_global.init(); } +#endif // _WIN32 #ifdef __cplusplus extern "C" { @@ -456,6 +526,10 @@ static int kmp_invoke_microtask(kmpc_micro fn, int gtid, int tid, int argc, void } #endif // __clang__ +#if _WIN32 +void CDECL _vcomp_fork_call_wrapper(void* wrapper, int nargs, void** args); +#endif // _WIN32 + static void* kmp_threadfunc(void* args) { #if __clang__ @@ -479,6 +553,8 @@ static void* kmp_threadfunc(void* args) #if __clang__ kmp_invoke_microtask(task->fn, task->thread_num, tid, task->argc, task->argv); +#elif _WIN32 + _vcomp_fork_call_wrapper(task->fn, task->argc, task->argv); #else task->fn(task->data); #endif @@ -650,8 +726,305 @@ void __kmpc_for_static_fini(void* /*loc*/, int32_t gtid) // NCNN_LOGE("__kmpc_for_static_fini"); (void)gtid; } -#else // __clang__ +#elif _WIN32 // __clang__ +int CDECL omp_in_parallel(void) +{ + // NCNN_LOGE("omp_in_parallel() is called!"); + return TRUE; +} +void CDECL _vcomp_set_num_threads(int num_threads) +{ + // NCNN_LOGE("_vcomp_set_num_threads(%d)\n", num_threads); + if (num_threads >= 1) + omp_set_num_threads(num_threads); +} +void CDECL _vcomp_fork_call_wrapper(void* wrapper, int nargs, void** args) +{ + switch (nargs) + { + case 0: + (*(win_kmpc_micro_0)wrapper)(); + break; + case 1: + (*(win_kmpc_micro_1)wrapper)(args[0]); + break; + case 2: + (*(win_kmpc_micro_2)wrapper)(args[0], args[1]); + break; + case 3: + (*(win_kmpc_micro_3)wrapper)(args[0], args[1], args[2]); + break; + case 4: + (*(win_kmpc_micro_4)wrapper)(args[0], args[1], args[2], args[3]); + break; + case 5: + (*(win_kmpc_micro_5)wrapper)(args[0], args[1], args[2], args[3], args[4]); + break; + case 6: + (*(win_kmpc_micro_6)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5]); + break; + case 7: + (*(win_kmpc_micro_7)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); + break; + case 8: + (*(win_kmpc_micro_8)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); + break; + case 9: + (*(win_kmpc_micro_9)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8]); + break; + case 10: + (*(win_kmpc_micro_10)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9]); + break; + case 11: + (*(win_kmpc_micro_11)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10]); + break; + case 12: + (*(win_kmpc_micro_12)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11]); + break; + case 13: + (*(win_kmpc_micro_13)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12]); + break; + case 14: + (*(win_kmpc_micro_14)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13]); + break; + case 15: + (*(win_kmpc_micro_15)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14]); + break; + case 16: + (*(win_kmpc_micro_16)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15]); + break; + case 17: + (*(win_kmpc_micro_17)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16]); + break; + case 18: + (*(win_kmpc_micro_18)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17]); + break; + case 19: + (*(win_kmpc_micro_19)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18]); + break; + case 20: + (*(win_kmpc_micro_20)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19]); + break; + case 21: + (*(win_kmpc_micro_21)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20]); + break; + case 22: + (*(win_kmpc_micro_22)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20], args[21]); + break; + case 23: + (*(win_kmpc_micro_23)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20], args[21], args[22]); + break; + case 24: + (*(win_kmpc_micro_24)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20], args[21], args[22], args[23]); + break; + case 25: + (*(win_kmpc_micro_25)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20], args[21], args[22], args[23], args[24]); + break; + case 26: + (*(win_kmpc_micro_26)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20], args[21], args[22], args[23], args[24], args[25]); + break; + case 27: + (*(win_kmpc_micro_27)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20], args[21], args[22], args[23], args[24], args[25], args[26]); + break; + case 28: + (*(win_kmpc_micro_28)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20], args[21], args[22], args[23], args[24], args[25], args[26], args[27]); + break; + case 29: + (*(win_kmpc_micro_29)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20], args[21], args[22], args[23], args[24], args[25], args[26], args[27], args[28]); + break; + case 30: + (*(win_kmpc_micro_30)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20], args[21], args[22], args[23], args[24], args[25], args[26], args[27], args[28], args[29]); + break; + case 31: + (*(win_kmpc_micro_31)wrapper)(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], args[9], args[10], args[11], args[12], args[13], args[14], args[15], args[16], args[17], args[18], args[19], args[20], args[21], args[22], args[23], args[24], args[25], args[26], args[27], args[28], args[29], args[30]); + break; + default: + // this line should not be touched + break; + } +} + +void CDECL _vcomp_for_static_simple_init(unsigned int first, unsigned int last, int step, + BOOL increment, unsigned int* begin, unsigned int* end) +{ + unsigned int iterations, per_thread, remaining; + int num_threads = omp_get_num_threads(); + int thread_num = omp_get_thread_num(); + // NCNN_LOGE("inside _vcomp_for_static_simple_init(), the thread_num is %d", thread_num); + + if (num_threads == 1) + { + *begin = first; + *end = last; + return; + } + + if (step <= 0) + { + *begin = 0; + *end = increment ? -1 : 1; + return; + } + + if (increment) + iterations = 1 + (last - first) / step; + else + { + iterations = 1 + (first - last) / step; + step *= -1; + } + + per_thread = iterations / num_threads; + remaining = iterations - per_thread * num_threads; + + if (thread_num < remaining) + per_thread++; + else if (per_thread) + first += remaining * step; + else + { + *begin = first; + *end = first - step; + return; + } + + *begin = first + per_thread * thread_num * step; + *end = *begin + (per_thread - 1) * step; +} +void CDECL _vcomp_for_static_simple_init_i8(ULONG64 first, ULONG64 last, LONG64 step, + BOOL increment, ULONG64* begin, ULONG64* end) +{ + ULONG64 iterations, per_thread, remaining; + int num_threads = omp_get_num_threads(); + int thread_num = omp_get_thread_num(); + + if (num_threads == 1) + { + *begin = first; + *end = last; + return; + } + + if (step <= 0) + { + *begin = 0; + *end = increment ? -1 : 1; + return; + } + + if (increment) + iterations = 1 + (last - first) / step; + else + { + iterations = 1 + (first - last) / step; + step *= -1; + } + + per_thread = iterations / num_threads; + remaining = iterations - per_thread * num_threads; + + if (thread_num < remaining) + per_thread++; + else if (per_thread) + first += remaining * step; + else + { + *begin = first; + *end = first - step; + return; + } + + *begin = first + per_thread * thread_num * step; + *end = *begin + (per_thread - 1) * step; +} + +void CDECL _vcomp_for_static_end(void) +{ + // NCNN_LOGE("MSVC _vcomp_for_static_end() is called!"); + /* nothing to do here */ +} + +// this func will be called when cl.exe encounters a parallel region +void WINAPIV _vcomp_fork(BOOL ifval, int nargs, void* wrapper, ...) +{ + g_kmp_global.try_init(); + + int num_threads = omp_get_num_threads(); + + // for nested parallel region feature, not supported here + if (!ifval) + num_threads = 1; + + // build argv + void* argv[32]; + { + va_list ap; + va_start(ap, wrapper); + for (int i = 0; i < nargs; i++) + argv[i] = va_arg(ap, void*); + va_end(ap); + } + + if (num_threads == 0) + { + num_threads = omp_get_max_threads(); + } + + if (g_kmp_global.kmp_max_threads == 1 || num_threads == 1) + { + for (unsigned i = 0; i < num_threads; i++) + { + tls_num_threads.set(reinterpret_cast((size_t)num_threads)); + tls_thread_num.set(reinterpret_cast((size_t)i)); + + _vcomp_fork_call_wrapper(wrapper, nargs, argv); + } + + return; + } + + int num_threads_to_wait = num_threads - 1; + ncnn::Mutex finish_lock; + ncnn::ConditionVariable finish_condition; + + // TODO portable stack allocation + ncnn::KMPTask* tasks = (ncnn::KMPTask*)alloca((num_threads - 1) * sizeof(ncnn::KMPTask)); + for (int i = 0; i < num_threads - 1; i++) + { + tasks[i].fn = (void (*)(void*))wrapper; + tasks[i].argc = nargs; + tasks[i].argv = (void**)argv; + tasks[i].num_threads = num_threads; + tasks[i].thread_num = i + 1; + tasks[i].num_threads_to_wait = &num_threads_to_wait; + tasks[i].finish_lock = &finish_lock; + tasks[i].finish_condition = &finish_condition; + } + + // dispatch 1 ~ num_threads + g_kmp_global.kmp_task_queue->dispatch(tasks, num_threads - 1); + + // dispatch 0 + { + tls_num_threads.set(reinterpret_cast((size_t)num_threads)); + tls_thread_num.set(reinterpret_cast((size_t)0)); + + _vcomp_fork_call_wrapper(wrapper, nargs, argv); + } + + // wait for finished + { + finish_lock.lock(); + if (num_threads_to_wait != 0) + { + finish_condition.wait(finish_lock); + } + finish_lock.unlock(); + } +} + +#else static ncnn::ThreadLocalStorage tls_parallel_context; struct parallel_context @@ -661,11 +1034,9 @@ struct parallel_context ncnn::ConditionVariable finish_condition; ncnn::KMPTask* tasks; }; - void GOMP_parallel_start(void (*fn)(void*), void* data, unsigned num_threads) { g_kmp_global.try_init(); - // NCNN_LOGE("GOMP_parallel_start %p %p %u", fn, data, num_threads); if (num_threads == 0) { @@ -794,6 +1165,7 @@ void GOMP_parallel(void (*fn)(void*), void* data, unsigned num_threads, unsigned finish_lock.unlock(); } } + #endif // __clang__ #ifdef __cplusplus