diff --git a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java
index af6d460766b..413e73e167d 100644
--- a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java
+++ b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java
@@ -187,7 +187,8 @@ protected Model createModel(
Model model = Model.newInstance(name, device, engine);
if (block == null) {
String className = (String) arguments.get("blockFactory");
- BlockFactory factory = ClassLoaderUtils.findImplementation(modelPath, className);
+ BlockFactory factory =
+ ClassLoaderUtils.findImplementation(modelPath, BlockFactory.class, className);
if (factory != null) {
block = factory.newBlock(model, modelPath, arguments);
}
@@ -229,7 +230,7 @@ protected TranslatorFactory getTranslatorFactory(
String factoryClass = (String) arguments.get("translatorFactory");
if (factoryClass != null) {
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
- factory = ClassLoaderUtils.initClass(cl, factoryClass);
+ factory = ClassLoaderUtils.initClass(cl, TranslatorFactory.class, factoryClass);
}
return factory;
}
diff --git a/api/src/main/java/ai/djl/util/ClassLoaderUtils.java b/api/src/main/java/ai/djl/util/ClassLoaderUtils.java
index 5428915a73b..d9da664c159 100644
--- a/api/src/main/java/ai/djl/util/ClassLoaderUtils.java
+++ b/api/src/main/java/ai/djl/util/ClassLoaderUtils.java
@@ -45,11 +45,12 @@ private ClassLoaderUtils() {}
*
For .class file, this function expects them in classes/your/package/ClassName.class
*
* @param path the path to scan from
+ * @param type the type of the class
* @param className the name of the classes, pass null if name is unknown
* @param the Template T for the output Class
* @return the Class implementation
*/
- public static T findImplementation(Path path, String className) {
+ public static T findImplementation(Path path, Class type, String className) {
try {
Path classesDir = path.resolve("classes");
// we only consider .class files and skip .java files
@@ -75,16 +76,16 @@ public static T findImplementation(Path path, String className) {
(PrivilegedAction)
() -> new URLClassLoader(urls, contextCl));
if (className != null && !className.isEmpty()) {
- return initClass(cl, className);
+ return initClass(cl, type, className);
}
- T implemented = scanDirectory(cl, classesDir);
+ T implemented = scanDirectory(cl, type, classesDir);
if (implemented != null) {
return implemented;
}
for (Path p : jarFiles) {
- implemented = scanJarFile(cl, p);
+ implemented = scanJarFile(cl, type, p);
if (implemented != null) {
return implemented;
}
@@ -95,7 +96,7 @@ public static T findImplementation(Path path, String className) {
return null;
}
- private static T scanDirectory(ClassLoader cl, Path dir) throws IOException {
+ private static T scanDirectory(ClassLoader cl, Class type, Path dir) throws IOException {
if (!Files.isDirectory(dir)) {
logger.trace("Directory not exists: {}", dir);
return null;
@@ -109,7 +110,7 @@ private static T scanDirectory(ClassLoader cl, Path dir) throws IOException
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
- T implemented = initClass(cl, className);
+ T implemented = initClass(cl, type, className);
if (implemented != null) {
return implemented;
}
@@ -117,7 +118,7 @@ private static T scanDirectory(ClassLoader cl, Path dir) throws IOException
return null;
}
- private static T scanJarFile(ClassLoader cl, Path path) throws IOException {
+ private static T scanJarFile(ClassLoader cl, Class type, Path path) throws IOException {
try (JarFile jarFile = new JarFile(path.toFile())) {
Enumeration en = jarFile.entries();
while (en.hasMoreElements()) {
@@ -126,7 +127,7 @@ private static T scanJarFile(ClassLoader cl, Path path) throws IOException {
if (fileName.endsWith(".class")) {
fileName = fileName.substring(0, fileName.lastIndexOf('.'));
fileName = fileName.replace('/', '.');
- T implemented = initClass(cl, fileName);
+ T implemented = initClass(cl, type, fileName);
if (implemented != null) {
return implemented;
}
@@ -140,15 +141,16 @@ private static T scanJarFile(ClassLoader cl, Path path) throws IOException {
* Loads the specified class and constructs an instance.
*
* @param cl the {@code ClassLoader} to use
+ * @param type the type of the class
* @param className the class to be loaded
* @param the type of the class
* @return an instance of the class, null if the class not found
*/
- @SuppressWarnings("unchecked")
- public static T initClass(ClassLoader cl, String className) {
+ public static T initClass(ClassLoader cl, Class type, String className) {
try {
Class> clazz = Class.forName(className, true, cl);
- Constructor constructor = (Constructor) clazz.getConstructor();
+ Class extends T> sub = clazz.asSubclass(type);
+ Constructor extends T> constructor = sub.getConstructor();
return constructor.newInstance();
} catch (Throwable e) {
logger.trace("Not able to load Object", e);