From b87bf7188425723d95371f6024499575e313b60b Mon Sep 17 00:00:00 2001 From: Ben Howe Date: Wed, 31 Jan 2024 19:41:53 +0000 Subject: [PATCH] Don't throw fatal exception during 'import cudaq' If the user has GPUs but doesn't have all the right dependencies, we should not throw fatal exceptions in that case. --- python/utils/LinkedLibraryHolder.cpp | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/python/utils/LinkedLibraryHolder.cpp b/python/utils/LinkedLibraryHolder.cpp index c2462aba00..cbd098a18e 100644 --- a/python/utils/LinkedLibraryHolder.cpp +++ b/python/utils/LinkedLibraryHolder.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2022 - 2023 NVIDIA Corporation & Affiliates. * + * Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. * * All rights reserved. * * * * This source code and the accompanying materials are made available under * @@ -255,10 +255,27 @@ LinkedLibraryHolder::LinkedLibraryHolder() { // Set the default target // If environment variable set with a valid value, use it - // Otherwise, if GPU(s) available, set default to 'nvidia', else to 'qpp-cpu' + // Otherwise, if GPU(s) available and other dependencies are satisfied, set + // default to 'nvidia', else to 'qpp-cpu' defaultTarget = "qpp-cpu"; if (countGPUs() > 0) { - defaultTarget = "nvidia"; + // Before setting the defaultTarget to nvidia, make sure the simulator is + // available. + const std::string nvidiaTarget = "nvidia"; + auto iter = targets.find(nvidiaTarget); + if (iter != targets.end()) { + auto target = iter->second; + if (std::find(availableSimulators.begin(), availableSimulators.end(), + target.simulatorName) != availableSimulators.end()) + defaultTarget = nvidiaTarget; + else + cudaq::info( + "GPU(s) found but cannot select nvidia target because simulator " + "is not available. Are all dependencies installed?"); + } else { + cudaq::info("GPU(s) found but cannot select nvidia target because nvidia " + "target not found."); + } } auto env = std::getenv("CUDAQ_DEFAULT_SIMULATOR"); if (env) {