Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add memset #300

Merged
merged 7 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ project adheres to [Semantic Versioning](http://semver.org/).
- Added `cu::Device::getArch()`
- Added `cu::DeviceMemory` constructor to create non-owning slice of another
`cu::DeviceMemory` object
- Added `cu::DeviceMemory::memset()`
- Added `cu::Stream::memsetAsync()`

### Changed

Expand Down
28 changes: 26 additions & 2 deletions include/cudawrappers/cu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,19 @@ class DeviceMemory : public Wrapper<CUdeviceptr> {
offset);
}

void zero(size_t size) { checkCudaCall(cuMemsetD8(_obj, 0, size)); }
void memset(unsigned char value, size_t size) {
csbnw marked this conversation as resolved.
Show resolved Hide resolved
checkCudaCall(cuMemsetD8(_obj, value, size));
}

void memset(unsigned short value, size_t size) {
checkCudaCall(cuMemsetD16(_obj, value, size));
}

void memset(unsigned int value, size_t size) {
checkCudaCall(cuMemsetD32(_obj, value, size));
}

void zero(size_t size) { memset(static_cast<unsigned char>(0), size); }

const void *parameter()
const // used to construct parameter list for launchKernel();
Expand Down Expand Up @@ -692,8 +704,20 @@ class Stream : public Wrapper<CUstream> {
checkCudaCall(cuMemPrefetchAsync(devPtr, size, dstDevice, _obj));
}

void memsetAsync(DeviceMemory &devPtr, unsigned char value, size_t size) {
checkCudaCall(cuMemsetD8Async(devPtr, value, size, _obj));
}

void memsetAsync(DeviceMemory &devPtr, unsigned short value, size_t size) {
checkCudaCall(cuMemsetD16Async(devPtr, value, size, _obj));
}

void memsetAsync(DeviceMemory &devPtr, unsigned int value, size_t size) {
checkCudaCall(cuMemsetD32Async(devPtr, value, size, _obj));
}

void zero(DeviceMemory &devPtr, size_t size) {
checkCudaCall(cuMemsetD8Async(devPtr, 0, size, _obj));
memsetAsync(devPtr, static_cast<unsigned char>(0), size);
}

void launchKernel(Function &function, unsigned gridX, unsigned gridY,
Expand Down
63 changes: 61 additions & 2 deletions tests/test_cu.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <array>
#include <catch2/catch_template_test_macros.hpp>
#include <catch2/catch_test_macros.hpp>
#include <cstring>
#include <iostream>
Expand Down Expand Up @@ -80,7 +81,7 @@ TEST_CASE("Test copying cu::DeviceMemory and cu::HostMemory using cu::Stream",
}
}

TEST_CASE("Test zeroing cu::DeviceMemory", "[zero]") {
TEST_CASE("Test cu::DeviceMemory", "[devicememory]") {
cu::init();
cu::Device device(0);
cu::Context context(CU_CTX_SCHED_BLOCKING_SYNC, device);
Expand Down Expand Up @@ -134,7 +135,7 @@ TEST_CASE("Test zeroing cu::DeviceMemory", "[zero]") {
CHECK(static_cast<bool>(memcmp(src, tgt, size)));
}

SECTION("Test cu::RegisteredMemory") {
SECTION("Test cu::DeviceMemory memcpy asynchronously") {
const size_t N = 3;
const size_t size = N * sizeof(int);

Expand Down Expand Up @@ -204,6 +205,64 @@ TEST_CASE("Test zeroing cu::DeviceMemory", "[zero]") {
}
}

using TestTypes = std::tuple<unsigned char, unsigned short, unsigned int>;
TEMPLATE_LIST_TEST_CASE("Test memset", "[memset]", TestTypes) {
cu::init();
cu::Device device(0);
cu::Context context(CU_CTX_SCHED_BLOCKING_SYNC, device);

SECTION("Test memset cu::DeviceMemory asynchronously") {
const size_t N = 3;
const size_t size = N * sizeof(TestType);
cu::HostMemory a(size);
cu::HostMemory b(size);
TestType value = 0xAA;

// Populate the memory with values
TestType* const a_ptr = static_cast<TestType*>(a);
TestType* const b_ptr = static_cast<TestType*>(b);
for (int i = 0; i < N; i++) {
a_ptr[i] = 0;
b_ptr[i] = value;
}
cu::DeviceMemory mem(size);

cu::Stream stream;
stream.memcpyHtoDAsync(mem, a, size);
stream.memsetAsync(mem, value, N);
stream.memcpyDtoHAsync(b, mem, size);
stream.synchronize();

CHECK(static_cast<bool>(memcmp(a, b, size)));
}

SECTION("Test zeroing cu::DeviceMemory synchronously") {
const size_t N = 3;
const size_t size = N * sizeof(TestType);
cu::HostMemory a(size);
cu::HostMemory b(size);
TestType value = 0xAA;

// Populate the memory with values
TestType* const a_ptr = static_cast<TestType*>(a);
TestType* const b_ptr = static_cast<TestType*>(b);
for (int i = 0; i < N; i++) {
a_ptr[i] = 0;
b_ptr[i] = value;
}
cu::DeviceMemory mem(size);

cu::Stream stream;
stream.memcpyHtoDAsync(mem, a, size);
stream.synchronize();
mem.memset(value, N);
stream.memcpyDtoHAsync(b, mem, size);
stream.synchronize();

CHECK(static_cast<bool>(memcmp(a, b, size)));
}
}

TEST_CASE("Test cu::Stream", "[stream]") {
cu::init();
cu::Device device(0);
Expand Down
Loading