diff --git a/core/src/main/java/org/apache/seata/core/store/db/AbstractDataSourceProvider.java b/core/src/main/java/org/apache/seata/core/store/db/AbstractDataSourceProvider.java index c11a55f0979..bb8306b2b22 100644 --- a/core/src/main/java/org/apache/seata/core/store/db/AbstractDataSourceProvider.java +++ b/core/src/main/java/org/apache/seata/core/store/db/AbstractDataSourceProvider.java @@ -16,6 +16,15 @@ */ package org.apache.seata.core.store.db; +import java.io.File; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Stream; + import javax.sql.DataSource; import org.apache.seata.common.exception.StoreException; @@ -47,8 +56,20 @@ public abstract class AbstractDataSourceProvider implements DataSourceProvider, */ protected static final Configuration CONFIG = ConfigurationFactory.getInstance(); + private final static String MYSQL_DRIVER_CLASS_NAME = "com.mysql.jdbc.Driver"; + + private final static String MYSQL8_DRIVER_CLASS_NAME = "com.mysql.cj.jdbc.Driver"; + + private final static String MYSQL_DRIVER_FILE_PREFIX = "mysql-connector-java-"; + + private final static Map MYSQL_DRIVER_LOADERS; + private static final long DEFAULT_DB_MAX_WAIT = 5000; + static { + MYSQL_DRIVER_LOADERS = createMysqlDriverClassLoaders(); + } + @Override public void init() { this.dataSource = generate(); @@ -67,7 +88,7 @@ public DataSource generate() { public void validate() { //valid driver class name String driverClassName = getDriverClassName(); - ClassLoader loader = Thread.currentThread().getContextClassLoader(); + ClassLoader loader = getDriverClassLoader(); if (null == loader) { throw new StoreException("class loader set error, you should not use the Bootstrap classloader"); } @@ -124,7 +145,50 @@ protected Long getMaxWait() { } protected ClassLoader getDriverClassLoader() { - return ClassLoader.getSystemClassLoader(); + return MYSQL_DRIVER_LOADERS.getOrDefault(getDriverClassName(), ClassLoader.getSystemClassLoader()); + } + + private static Map createMysqlDriverClassLoaders() { + Map loaders = new HashMap<>(); + String cp = System.getProperty("java.class.path"); + if (cp == null || cp.isEmpty()) { + return loaders; + } + Stream.of(cp.split(File.pathSeparator)) + .map(File::new) + .filter(File::exists) + .map(file -> file.isFile() ? file.getParentFile() : file) + .filter(Objects::nonNull) + .filter(File::isDirectory) + .map(file -> new File(file, "jdbc")) + .filter(File::exists) + .filter(File::isDirectory) + .distinct() + .flatMap(file -> { + File[] files = file.listFiles((f, name) -> name.startsWith(MYSQL_DRIVER_FILE_PREFIX)); + if (files != null) { + return Stream.of(files); + } else { + return Stream.of(); + } + }) + .forEach(file -> { + if (loaders.containsKey(MYSQL8_DRIVER_CLASS_NAME) && loaders.containsKey(MYSQL_DRIVER_CLASS_NAME)) { + return; + } + try { + URL url = file.toURI().toURL(); + ClassLoader loader = new URLClassLoader(new URL[]{url}, ClassLoader.getSystemClassLoader()); + try { + loader.loadClass(MYSQL8_DRIVER_CLASS_NAME); + loaders.putIfAbsent(MYSQL8_DRIVER_CLASS_NAME, loader); + } catch (ClassNotFoundException e) { + loaders.putIfAbsent(MYSQL_DRIVER_CLASS_NAME, loader); + } + } catch (MalformedURLException ignore) { + } + }); + return loaders; } /**