Skip to content

Commit

Permalink
revamp packArgs to use TypeSwitch
Browse files Browse the repository at this point in the history
Signed-off-by: Alex McCaskey <amccaskey@nvidia.com>
  • Loading branch information
amccaskey committed Mar 25, 2024
1 parent 5c9f69d commit 40d535b
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 242 deletions.
1 change: 0 additions & 1 deletion python/runtime/cudaq/algorithms/py_observe_async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ observe_result pyObservePar(const PyParType &type, py::object &kernel,
kernel.attr("compile")();

// Ensure the user input is correct.
// auto validatedArgs = validateInputArguments(kernel, args);
auto &platform = cudaq::get_platform();
if (!platform.supports_task_distribution())
throw std::runtime_error(
Expand Down
9 changes: 6 additions & 3 deletions python/runtime/cudaq/algorithms/py_sample_async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,17 @@ for more information on this programming pattern.)#")
[&](py::object &kernel, py::args args, std::size_t shots,
std::size_t qpu_id) {
auto &platform = cudaq::get_platform();
auto kernelName = kernel.attr("name").cast<std::string>();
if (py::hasattr(kernel, "compile"))
kernel.attr("compile")();

auto kernelName = kernel.attr("name").cast<std::string>();
auto kernelMod = kernel.attr("module").cast<MlirModule>();
auto kernelFunc = getKernelFuncOp(kernelMod, kernelName);

args = simplifiedValidateInputArguments(args);
auto *argData = new cudaq::OpaqueArguments();
cudaq::packArgs(*argData, args);
auto kernelMod = kernel.attr("module").cast<MlirModule>();
cudaq::packArgs(*argData, args, kernelFunc,
[](OpaqueArguments &, py::object &) { return false; });

// The function below will be executed multiple times
// if the kernel has conditional feedback. In that case,
Expand Down
62 changes: 40 additions & 22 deletions python/runtime/cudaq/platform/py_alt_launch_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,40 @@ jitAndCreateArgs(const std::string &name, MlirModule module,
// We need to append the return type to the OpaqueArguments here
// so that we get a spot in the `rawArgs` memory for the
// altLaunchKernel function to dump the result
if (!isa<NoneType>(returnType)) {
if (returnType.isInteger(64)) {
py::args returnVal = py::make_tuple(py::int_(0));
packArgs(runtimeArgs, returnVal);
} else if (returnType.isInteger(1)) {
py::args returnVal = py::make_tuple(py::bool_(0));
packArgs(runtimeArgs, returnVal);
} else if (isa<FloatType>(returnType)) {
py::args returnVal = py::make_tuple(py::float_(0.0));
packArgs(runtimeArgs, returnVal);
} else {
std::string msg;
{
llvm::raw_string_ostream os(msg);
returnType.print(os);
}
throw std::runtime_error(
"Unsupported CUDA Quantum kernel return type - " + msg + ".\n");
}
}
if (!isa<NoneType>(returnType))
TypeSwitch<Type, void>(returnType)
.Case([&](IntegerType type) {
if (type.getIntOrFloatBitWidth() == 1) {
bool *ourAllocatedArg = new bool();
*ourAllocatedArg = 0;
runtimeArgs.emplace_back(ourAllocatedArg, [](void *ptr) {
delete static_cast<bool *>(ptr);
});
return;
}

long *ourAllocatedArg = new long();
*ourAllocatedArg = 0;
runtimeArgs.emplace_back(ourAllocatedArg, [](void *ptr) {
delete static_cast<long *>(ptr);
});
})
.Case([&](Float64Type type) {
double *ourAllocatedArg = new double();
*ourAllocatedArg = 0.;
runtimeArgs.emplace_back(ourAllocatedArg, [](void *ptr) {
delete static_cast<double *>(ptr);
});
})
.Default([](Type ty) {
std::string msg;
{
llvm::raw_string_ostream os(msg);
ty.print(os);
}
throw std::runtime_error(
"Unsupported CUDA Quantum kernel return type - " + msg + ".\n");
});

void *rawArgs = nullptr;
std::size_t size = 0;
Expand Down Expand Up @@ -348,8 +362,10 @@ void bindAltLaunchKernel(py::module &mod) {
mod.def("synthesize", [](py::object kernel, py::args runtimeArgs) {
MlirModule module = kernel.attr("module").cast<MlirModule>();
auto name = kernel.attr("name").cast<std::string>();
auto kernelFuncOp = getKernelFuncOp(module, name);
cudaq::OpaqueArguments args;
cudaq::packArgs(args, runtimeArgs);
cudaq::packArgs(args, runtimeArgs, kernelFuncOp,
[](OpaqueArguments &, py::object &) { return false; });
return synthesizeKernel(name, module, args);
});

Expand All @@ -360,8 +376,10 @@ void bindAltLaunchKernel(py::module &mod) {
kernel.attr("compile")();
MlirModule module = kernel.attr("module").cast<MlirModule>();
auto name = kernel.attr("name").cast<std::string>();
auto kernelFuncOp = getKernelFuncOp(module, name);
cudaq::OpaqueArguments args;
cudaq::packArgs(args, runtimeArgs);
cudaq::packArgs(args, runtimeArgs, kernelFuncOp,
[](OpaqueArguments &, py::object &) { return false; });
return getQIRLL(name, module, args, profile);
},
py::arg("kernel"), py::kw_only(), py::arg("profile") = "");
Expand Down
Loading

0 comments on commit 40d535b

Please sign in to comment.