Skip to content

Commit

Permalink
Add option.tensor_parallel_degree in serving.properties
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Oct 15, 2022
1 parent 6e42d17 commit df9d12a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
12 changes: 8 additions & 4 deletions engines/python/src/main/java/ai/djl/python/engine/PyEnv.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public class PyEnv {
private int predictTimeout;
private int modelLoadingTimeout;
private int tensorParallelDegree;
private int mpiProcesses;
private Map<String, String> envs;
private Map<String, String> initParameters;
private boolean initialized;
Expand Down Expand Up @@ -219,6 +218,12 @@ public void setPythonExecutable(String pythonExecutable) {
* @return the tensor parallel degree
*/
public int getTensorParallelDegree() {
if (tensorParallelDegree == 0) {
String value = Utils.getenv("TENSOR_PARALLEL_DEGREE");
if (value != null) {
tensorParallelDegree = Integer.parseInt(value);
}
}
return tensorParallelDegree;
}

Expand All @@ -229,12 +234,11 @@ public int getTensorParallelDegree() {
*/
public void setTensorParallelDegree(int tensorParallelDegree) {
this.tensorParallelDegree = tensorParallelDegree;
int gpuCount = CudaUtils.getGpuCount();
mpiProcesses = gpuCount / tensorParallelDegree;
}

int getMpiWorkers() {
return mpiProcesses;
int gpuCount = CudaUtils.getGpuCount();
return gpuCount / getTensorParallelDegree();
}

/**
Expand Down
17 changes: 12 additions & 5 deletions engines/python/src/main/java/ai/djl/python/engine/PyModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
}
String entryPoint = null;
if (options != null) {
logger.debug("options in serving.properties for model: {}", modelName);
for (Map.Entry<String, ?> entry : options.entrySet()) {
String key = entry.getKey();
String value = (String) entry.getValue();
logger.debug("{}={}", key, value);
switch (key) {
case "pythonExecutable":
pyEnv.setPythonExecutable(value);
Expand Down Expand Up @@ -112,6 +114,9 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
case "parallel_loading":
parallelLoading = Boolean.parseBoolean(value);
break;
case "tensor_parallel_degree":
pyEnv.setTensorParallelDegree(Integer.parseInt(value));
break;
case "handler":
pyEnv.setHandler(value);
break;
Expand All @@ -134,10 +139,8 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
}
pyEnv.setEntryPoint(entryPoint);
if (pyEnv.isMpiMode()) {
int partitions;
if (System.getenv("TENSOR_PARALLEL_DEGREE") != null) {
partitions = Integer.parseInt(System.getenv("TENSOR_PARALLEL_DEGREE"));
} else {
int partitions = pyEnv.getTensorParallelDegree();
if (partitions == 0) {
// TODO: avoid use hardcoded "partitioned_model_" name
try (Stream<Path> stream = Files.list(modelPath)) {
partitions =
Expand All @@ -154,10 +157,10 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
throw new FileNotFoundException(
"partitioned_model_ file not found in: " + modelPath);
}
pyEnv.setTensorParallelDegree(partitions);
}
logger.info("Loading model in MPI model with TP: {}.", partitions);

pyEnv.setTensorParallelDegree(partitions);
int mpiWorkers = pyEnv.getMpiWorkers();
if (mpiWorkers <= 0) {
throw new EngineException(
Expand Down Expand Up @@ -235,6 +238,7 @@ private Path findModelFile(String prefix) {
}

private void createAllPyProcesses(int mpiWorkers) {
long begin = System.currentTimeMillis();
ExecutorService pool = null;
List<Future<?>> futures = new ArrayList<>();
if (parallelLoading) {
Expand All @@ -245,6 +249,7 @@ private void createAllPyProcesses(int mpiWorkers) {
PyProcess worker = new PyProcess(this, pyEnv, i);
workerQueue.offer(worker);
if (pool != null) {
logger.debug("Submitting to pool: {}", i);
futures.add(pool.submit(worker::startPythonProcess));
} else {
worker.startPythonProcess();
Expand All @@ -264,6 +269,8 @@ private void createAllPyProcesses(int mpiWorkers) {
}
}
}
long duration = System.currentTimeMillis() - begin;
logger.info("{} model loaded in {} ms.", modelName, duration);
}

private void shutdown() {
Expand Down

0 comments on commit df9d12a

Please sign in to comment.