Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows MT layer bug fixes #3364

Merged
merged 4 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions build/meson/tests/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,6 @@ test('test-zstream-1',
test('test-zstream-3',
zstreamtest,
args: ['--newapi', '-t1', ZSTREAM_TESTTIME] + FUZZER_FLAGS,
# --newapi dies on Windows with "exit status 3221225477 or signal 3221225349 SIGinvalid"
should_fail: host_machine_os == os_windows,
timeout: 120)
test('test-longmatch', longmatch, timeout: 36)
test('test-invalidDictionaries', invalidDictionaries) # should be fast
Expand Down
2 changes: 1 addition & 1 deletion lib/common/pool.c
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ static void POOL_join(POOL_ctx* ctx) {
/* Join all of the threads */
{ size_t i;
for (i = 0; i < ctx->threadCapacity; ++i) {
ZSTD_pthread_join(ctx->threads[i], NULL); /* note : could fail */
ZSTD_pthread_join(ctx->threads[i]); /* note : could fail */
} }
}

Expand Down
64 changes: 50 additions & 14 deletions lib/common/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,39 +34,75 @@ int g_ZSTD_threading_useless_symbol;

/* === Implementation === */

typedef struct {
void* (*start_routine)(void*);
void* arg;
int initialized;
ZSTD_pthread_cond_t initialized_cond;
ZSTD_pthread_mutex_t initialized_mutex;
} ZSTD_thread_params_t;

static unsigned __stdcall worker(void *arg)
{
ZSTD_pthread_t* const thread = (ZSTD_pthread_t*) arg;
thread->arg = thread->start_routine(thread->arg);
ZSTD_thread_params_t* const thread_param = (ZSTD_thread_params_t*)arg;
void* (*start_routine)(void*) = thread_param->start_routine;
void* thread_arg = thread_param->arg;

/* Signal main thread that we are running and do not depend on its memory anymore */
ZSTD_pthread_mutex_lock(&thread_param->initialized_mutex);
thread_param->initialized = 1;
ZSTD_pthread_mutex_unlock(&thread_param->initialized_mutex);
ZSTD_pthread_cond_signal(&thread_param->initialized_cond);

terrelln marked this conversation as resolved.
Show resolved Hide resolved
start_routine(thread_arg);

return 0;
}

int ZSTD_pthread_create(ZSTD_pthread_t* thread, const void* unused,
void* (*start_routine) (void*), void* arg)
{
ZSTD_thread_params_t thread_param;
int error = 0;
(void)unused;
thread->arg = arg;
thread->start_routine = start_routine;
thread->handle = (HANDLE) _beginthreadex(NULL, 0, worker, thread, 0, NULL);

if (!thread->handle)
thread_param.start_routine = start_routine;
thread_param.arg = arg;
thread_param.initialized = 0;

terrelln marked this conversation as resolved.
Show resolved Hide resolved
/* Setup thread initialization synchronization */
error |= ZSTD_pthread_cond_init(&thread_param.initialized_cond, NULL);
error |= ZSTD_pthread_mutex_init(&thread_param.initialized_mutex, NULL);
if(error)
return -1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend against this pattern here. We need to destroy whichever one we initialized correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two functions can't really fail, so we will never hit this error condition, it's just there because we have a return value (to align with pthreads) and we have to handle the return value to not get compilation errors.
Still, I'll add cleanup for for compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could an assert() be used if only to document that this part of the code path, or this condition, can never be reached ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a comment and handled the errors in any case. If you think an assert would be better let me know.

ZSTD_pthread_mutex_lock(&thread_param.initialized_mutex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why lock this mutex here? I don't see any reason to hold it while creating the thread. So lets minimize the scope and move it to line 84.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct, don't know why it's there, may have moved stuff and missed it.


/* Spawn thread */
*thread = (HANDLE)_beginthreadex(NULL, 0, worker, &thread_param, 0, NULL);
if (!thread)
return errno;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to unlock (until it is moved below the thread creation), and destroy the mutex/cond in this case.

else
return 0;

/* Wait for thread to be initialized */
while(!thread_param.initialized) {
ZSTD_pthread_cond_wait(&thread_param.initialized_cond, &thread_param.initialized_mutex);
}
ZSTD_pthread_mutex_unlock(&thread_param.initialized_mutex);
ZSTD_pthread_mutex_destroy(&thread_param.initialized_mutex);
ZSTD_pthread_cond_destroy(&thread_param.initialized_cond);

return 0;
}

int ZSTD_pthread_join(ZSTD_pthread_t thread, void **value_ptr)
int ZSTD_pthread_join(ZSTD_pthread_t thread)
{
DWORD result;

if (!thread.handle) return 0;
if (!thread) return 0;

result = WaitForSingleObject(thread.handle, INFINITE);
CloseHandle(thread.handle);
result = WaitForSingleObject(thread, INFINITE);
CloseHandle(thread);

switch (result) {
case WAIT_OBJECT_0:
if (value_ptr) *value_ptr = thread.arg;
return 0;
case WAIT_ABANDONED:
return EINVAL;
Expand Down
12 changes: 4 additions & 8 deletions lib/common/threading.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,12 @@ extern "C" {
#define ZSTD_pthread_cond_broadcast(a) WakeAllConditionVariable((a))

/* ZSTD_pthread_create() and ZSTD_pthread_join() */
typedef struct {
HANDLE handle;
void* (*start_routine)(void*);
void* arg;
} ZSTD_pthread_t;
typedef HANDLE ZSTD_pthread_t;

int ZSTD_pthread_create(ZSTD_pthread_t* thread, const void* unused,
void* (*start_routine) (void*), void* arg);

int ZSTD_pthread_join(ZSTD_pthread_t thread, void** value_ptr);
int ZSTD_pthread_join(ZSTD_pthread_t thread);

/**
* add here more wrappers as required
Expand Down Expand Up @@ -98,7 +94,7 @@ int ZSTD_pthread_join(ZSTD_pthread_t thread, void** value_ptr);

#define ZSTD_pthread_t pthread_t
#define ZSTD_pthread_create(a, b, c, d) pthread_create((a), (b), (c), (d))
#define ZSTD_pthread_join(a, b) pthread_join((a),(b))
#define ZSTD_pthread_join(a) pthread_join((a),NULL)

#else /* DEBUGLEVEL >= 1 */

Expand All @@ -123,7 +119,7 @@ int ZSTD_pthread_cond_destroy(ZSTD_pthread_cond_t* cond);

#define ZSTD_pthread_t pthread_t
#define ZSTD_pthread_create(a, b, c, d) pthread_create((a), (b), (c), (d))
#define ZSTD_pthread_join(a, b) pthread_join((a),(b))
#define ZSTD_pthread_join(a) pthread_join((a),NULL)

#endif

Expand Down
4 changes: 2 additions & 2 deletions tests/fuzzer.c
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ static int threadPoolTests(void) {

ZSTD_pthread_create(&t1, NULL, threadPoolTests_compressionJob, &p1);
ZSTD_pthread_create(&t2, NULL, threadPoolTests_compressionJob, &p2);
ZSTD_pthread_join(t1, NULL);
ZSTD_pthread_join(t2, NULL);
ZSTD_pthread_join(t1);
ZSTD_pthread_join(t2);

assert(!memcmp(decodedBuffer, decodedBuffer2, CNBuffSize));
free(decodedBuffer2);
Expand Down