Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Fixes loading BlockFactory bug #1547

Merged
merged 1 commit into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ protected Model createModel(
Model model = Model.newInstance(name, device, engine);
if (block == null) {
String className = (String) arguments.get("blockFactory");
BlockFactory factory = ClassLoaderUtils.findImplementation(modelPath, className);
BlockFactory factory =
ClassLoaderUtils.findImplementation(modelPath, BlockFactory.class, className);
if (factory != null) {
block = factory.newBlock(model, modelPath, arguments);
}
Expand Down Expand Up @@ -229,7 +230,7 @@ protected TranslatorFactory getTranslatorFactory(
String factoryClass = (String) arguments.get("translatorFactory");
if (factoryClass != null) {
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
factory = ClassLoaderUtils.initClass(cl, factoryClass);
factory = ClassLoaderUtils.initClass(cl, TranslatorFactory.class, factoryClass);
}
return factory;
}
Expand Down
24 changes: 13 additions & 11 deletions api/src/main/java/ai/djl/util/ClassLoaderUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ private ClassLoaderUtils() {}
* <p>For .class file, this function expects them in classes/your/package/ClassName.class
*
* @param path the path to scan from
* @param type the type of the class
* @param className the name of the classes, pass null if name is unknown
* @param <T> the Template T for the output Class
* @return the Class implementation
*/
public static <T> T findImplementation(Path path, String className) {
public static <T> T findImplementation(Path path, Class<T> type, String className) {
try {
Path classesDir = path.resolve("classes");
// we only consider .class files and skip .java files
Expand All @@ -75,16 +76,16 @@ public static <T> T findImplementation(Path path, String className) {
(PrivilegedAction<ClassLoader>)
() -> new URLClassLoader(urls, contextCl));
if (className != null && !className.isEmpty()) {
return initClass(cl, className);
return initClass(cl, type, className);
}

T implemented = scanDirectory(cl, classesDir);
T implemented = scanDirectory(cl, type, classesDir);
if (implemented != null) {
return implemented;
}

for (Path p : jarFiles) {
implemented = scanJarFile(cl, p);
implemented = scanJarFile(cl, type, p);
if (implemented != null) {
return implemented;
}
Expand All @@ -95,7 +96,7 @@ public static <T> T findImplementation(Path path, String className) {
return null;
}

private static <T> T scanDirectory(ClassLoader cl, Path dir) throws IOException {
private static <T> T scanDirectory(ClassLoader cl, Class<T> type, Path dir) throws IOException {
if (!Files.isDirectory(dir)) {
logger.trace("Directory not exists: {}", dir);
return null;
Expand All @@ -109,15 +110,15 @@ private static <T> T scanDirectory(ClassLoader cl, Path dir) throws IOException
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
T implemented = initClass(cl, className);
T implemented = initClass(cl, type, className);
if (implemented != null) {
return implemented;
}
}
return null;
}

private static <T> T scanJarFile(ClassLoader cl, Path path) throws IOException {
private static <T> T scanJarFile(ClassLoader cl, Class<T> type, Path path) throws IOException {
try (JarFile jarFile = new JarFile(path.toFile())) {
Enumeration<JarEntry> en = jarFile.entries();
while (en.hasMoreElements()) {
Expand All @@ -126,7 +127,7 @@ private static <T> T scanJarFile(ClassLoader cl, Path path) throws IOException {
if (fileName.endsWith(".class")) {
fileName = fileName.substring(0, fileName.lastIndexOf('.'));
fileName = fileName.replace('/', '.');
T implemented = initClass(cl, fileName);
T implemented = initClass(cl, type, fileName);
if (implemented != null) {
return implemented;
}
Expand All @@ -140,15 +141,16 @@ private static <T> T scanJarFile(ClassLoader cl, Path path) throws IOException {
* Loads the specified class and constructs an instance.
*
* @param cl the {@code ClassLoader} to use
* @param type the type of the class
* @param className the class to be loaded
* @param <T> the type of the class
* @return an instance of the class, null if the class not found
*/
@SuppressWarnings("unchecked")
public static <T> T initClass(ClassLoader cl, String className) {
public static <T> T initClass(ClassLoader cl, Class<T> type, String className) {
try {
Class<?> clazz = Class.forName(className, true, cl);
Constructor<T> constructor = (Constructor<T>) clazz.getConstructor();
Class<? extends T> sub = clazz.asSubclass(type);
Constructor<? extends T> constructor = sub.getConstructor();
return constructor.newInstance();
} catch (Throwable e) {
logger.trace("Not able to load Object", e);
Expand Down