diff --git a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java index f8c84c753ef..a253ce3d246 100644 --- a/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java +++ b/engines/ml/lightgbm/src/main/java/ai/djl/ml/lightgbm/LgbmEngineProvider.java @@ -18,6 +18,8 @@ /** {@code LgbmEngineProvider} is the LightGBM implementation of {@link EngineProvider}. */ public class LgbmEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = LgbmEngine.newInstance(); + if (engine == null) { + synchronized (LgbmEngineProvider.class) { + engine = LgbmEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java index 5859f3f344d..19cba32cc71 100644 --- a/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java +++ b/engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbEngineProvider.java @@ -18,6 +18,8 @@ /** {@code XgbEngineProvider} is the XGBoost implementation of {@link EngineProvider}. */ public class XgbEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = XgbEngine.newInstance(); + if (engine == null) { + synchronized (XgbEngineProvider.class) { + engine = XgbEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java index 5f45116f615..f30a6a89252 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxEngineProvider.java @@ -18,6 +18,8 @@ /** {@code MxEngineProvider} is the MXNet implementation of {@link EngineProvider}. */ public class MxEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = MxEngine.newInstance(); + if (engine == null) { + synchronized (MxEngineProvider.class) { + engine = MxEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java index 005c0fa25f1..c673b3dcbf1 100644 --- a/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java +++ b/engines/onnxruntime/onnxruntime-engine/src/main/java/ai/djl/onnxruntime/engine/OrtEngineProvider.java @@ -18,6 +18,8 @@ /** {@code OrtEngineProvider} is the ONNX Runtime implementation of {@link EngineProvider}. */ public class OrtEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = OrtEngine.newInstance(); + if (engine == null) { + synchronized (OrtEngineProvider.class) { + engine = OrtEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java index 59e5cd90724..e2b5bdd35a0 100644 --- a/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java +++ b/engines/paddlepaddle/paddlepaddle-engine/src/main/java/ai/djl/paddlepaddle/engine/PpEngineProvider.java @@ -18,6 +18,8 @@ /** {@code PpEngineProvider} is the PaddlePaddle implementation of {@link EngineProvider}. */ public class PpEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = PpEngine.newInstance(); + if (engine == null) { + synchronized (PpEngineProvider.class) { + engine = PpEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java index 1b9cdd0ab19..57ae6c09d34 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngineProvider.java @@ -18,6 +18,8 @@ /** {@code PtEngineProvider} is the PyTorch implementation of {@link EngineProvider}. */ public class PtEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = PtEngine.newInstance(); + if (engine == null) { + synchronized (PtEngineProvider.class) { + engine = PtEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java index f42f691d222..d964ea5c295 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfEngineProvider.java @@ -18,6 +18,8 @@ /** {@code TfEngineProvider} is the TensorFlow implementation of {@link EngineProvider}. */ public class TfEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = TfEngine.newInstance(); + if (engine == null) { + synchronized (TfEngineProvider.class) { + engine = TfEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java index d92ed9e449d..05a7eceeb41 100644 --- a/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java +++ b/engines/tensorrt/src/main/java/ai/djl/tensorrt/engine/TrtEngineProvider.java @@ -18,6 +18,8 @@ /** {@code TrtEngineProvider} is the TensorRT implementation of {@link EngineProvider}. */ public class TrtEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = TrtEngine.newInstance(); + if (engine == null) { + synchronized (TrtEngineProvider.class) { + engine = TrtEngine.newInstance(); + } + } + return engine; } } diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java index efd9d89e509..96066b380e1 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtEngineTest.java @@ -26,7 +26,7 @@ public void getVersion() { try { Engine engine = Engine.getEngine("TensorRT"); version = engine.getVersion(); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Assert.assertEquals(version, "8.4.1"); diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java index 09001f0e2da..24d734af54c 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/engine/TrtNDManagerTest.java @@ -28,7 +28,7 @@ public void testNDArray() { Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { diff --git a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java index 99cbc6f763e..105e057ba0a 100644 --- a/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java +++ b/engines/tensorrt/src/test/java/ai/djl/tensorrt/integration/TrtTest.java @@ -49,7 +49,7 @@ public void testTrtOnnx() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -75,7 +75,7 @@ public void testTrtUff() throws ModelException, IOException, TranslateException Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } if (!engine.defaultDevice().isGpu()) { @@ -112,7 +112,7 @@ public void testSerializedEngine() throws ModelException, IOException, Translate Engine engine; try { engine = Engine.getEngine("TensorRT"); - } catch (Throwable ignore) { + } catch (Exception ignore) { throw new SkipException("Your os configuration doesn't support TensorRT."); } Device device = engine.defaultDevice(); diff --git a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java index fb61551a3bf..aa0fdb73d21 100644 --- a/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java +++ b/engines/tflite/tflite-engine/src/main/java/ai/djl/tflite/engine/TfLiteEngineProvider.java @@ -18,6 +18,8 @@ /** {@code TfLiteEngineProvider} is the TFLite implementation of {@link EngineProvider}. */ public class TfLiteEngineProvider implements EngineProvider { + private static volatile Engine engine; // NOPMD + /** {@inheritDoc} */ @Override public String getEngineName() { @@ -33,10 +35,11 @@ public int getEngineRank() { /** {@inheritDoc} */ @Override public Engine getEngine() { - return InstanceHolder.INSTANCE; - } - - private static class InstanceHolder { - static final Engine INSTANCE = TfLiteEngine.newInstance(); + if (engine == null) { + synchronized (TfLiteEngineProvider.class) { + engine = TfLiteEngine.newInstance(); + } + } + return engine; } }