From eb7d80b502cd89349776c4f211bb5652b1fd46ad Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Wed, 30 Mar 2022 14:19:37 -0700 Subject: [PATCH] [api] Fixes loading BlockFactory bug Change-Id: I9fe68ae6af4be85bcd8a626cfa28534f8d141798 --- .../djl/repository/zoo/BaseModelLoader.java | 5 ++-- .../java/ai/djl/util/ClassLoaderUtils.java | 24 ++++++++++--------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java index af6d460766b..413e73e167d 100644 --- a/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java +++ b/api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java @@ -187,7 +187,8 @@ protected Model createModel( Model model = Model.newInstance(name, device, engine); if (block == null) { String className = (String) arguments.get("blockFactory"); - BlockFactory factory = ClassLoaderUtils.findImplementation(modelPath, className); + BlockFactory factory = + ClassLoaderUtils.findImplementation(modelPath, BlockFactory.class, className); if (factory != null) { block = factory.newBlock(model, modelPath, arguments); } @@ -229,7 +230,7 @@ protected TranslatorFactory getTranslatorFactory( String factoryClass = (String) arguments.get("translatorFactory"); if (factoryClass != null) { ClassLoader cl = ClassLoaderUtils.getContextClassLoader(); - factory = ClassLoaderUtils.initClass(cl, factoryClass); + factory = ClassLoaderUtils.initClass(cl, TranslatorFactory.class, factoryClass); } return factory; } diff --git a/api/src/main/java/ai/djl/util/ClassLoaderUtils.java b/api/src/main/java/ai/djl/util/ClassLoaderUtils.java index 5428915a73b..d9da664c159 100644 --- a/api/src/main/java/ai/djl/util/ClassLoaderUtils.java +++ b/api/src/main/java/ai/djl/util/ClassLoaderUtils.java @@ -45,11 +45,12 @@ private ClassLoaderUtils() {} *

For .class file, this function expects them in classes/your/package/ClassName.class * * @param path the path to scan from + * @param type the type of the class * @param className the name of the classes, pass null if name is unknown * @param the Template T for the output Class * @return the Class implementation */ - public static T findImplementation(Path path, String className) { + public static T findImplementation(Path path, Class type, String className) { try { Path classesDir = path.resolve("classes"); // we only consider .class files and skip .java files @@ -75,16 +76,16 @@ public static T findImplementation(Path path, String className) { (PrivilegedAction) () -> new URLClassLoader(urls, contextCl)); if (className != null && !className.isEmpty()) { - return initClass(cl, className); + return initClass(cl, type, className); } - T implemented = scanDirectory(cl, classesDir); + T implemented = scanDirectory(cl, type, classesDir); if (implemented != null) { return implemented; } for (Path p : jarFiles) { - implemented = scanJarFile(cl, p); + implemented = scanJarFile(cl, type, p); if (implemented != null) { return implemented; } @@ -95,7 +96,7 @@ public static T findImplementation(Path path, String className) { return null; } - private static T scanDirectory(ClassLoader cl, Path dir) throws IOException { + private static T scanDirectory(ClassLoader cl, Class type, Path dir) throws IOException { if (!Files.isDirectory(dir)) { logger.trace("Directory not exists: {}", dir); return null; @@ -109,7 +110,7 @@ private static T scanDirectory(ClassLoader cl, Path dir) throws IOException String className = p.toString(); className = className.substring(0, className.lastIndexOf('.')); className = className.replace(File.separatorChar, '.'); - T implemented = initClass(cl, className); + T implemented = initClass(cl, type, className); if (implemented != null) { return implemented; } @@ -117,7 +118,7 @@ private static T scanDirectory(ClassLoader cl, Path dir) throws IOException return null; } - private static T scanJarFile(ClassLoader cl, Path path) throws IOException { + private static T scanJarFile(ClassLoader cl, Class type, Path path) throws IOException { try (JarFile jarFile = new JarFile(path.toFile())) { Enumeration en = jarFile.entries(); while (en.hasMoreElements()) { @@ -126,7 +127,7 @@ private static T scanJarFile(ClassLoader cl, Path path) throws IOException { if (fileName.endsWith(".class")) { fileName = fileName.substring(0, fileName.lastIndexOf('.')); fileName = fileName.replace('/', '.'); - T implemented = initClass(cl, fileName); + T implemented = initClass(cl, type, fileName); if (implemented != null) { return implemented; } @@ -140,15 +141,16 @@ private static T scanJarFile(ClassLoader cl, Path path) throws IOException { * Loads the specified class and constructs an instance. * * @param cl the {@code ClassLoader} to use + * @param type the type of the class * @param className the class to be loaded * @param the type of the class * @return an instance of the class, null if the class not found */ - @SuppressWarnings("unchecked") - public static T initClass(ClassLoader cl, String className) { + public static T initClass(ClassLoader cl, Class type, String className) { try { Class clazz = Class.forName(className, true, cl); - Constructor constructor = (Constructor) clazz.getConstructor(); + Class sub = clazz.asSubclass(type); + Constructor constructor = sub.getConstructor(); return constructor.newInstance(); } catch (Throwable e) { logger.trace("Not able to load Object", e);