Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to customize env on local backend #686

Merged
merged 1 commit into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dev/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,8 @@ curl -X 'POST' \
}'
```

To set custom environment variables when running lighter in local mode, add `conf` value
with `lighter.local.env.` prefix (ex: `lighter.local.env.FOO`) on your submitted json.

Lighter UI can be accessed on: [http://localhost:8080/lighter](http://localhost:8080/lighter).\
You can also explore Lighter API by visiting Swagger UI on [http://localhost:8080/swagger-ui/](http://localhost:8080/swagger-ui/).
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ public final class Constants {
public final static String DEPLOY_MODE_CLUSTER = "cluster";
public final static String DEPLOY_MODE_CLIENT = "client";
public final static String MASTER_YARN = "yarn";
public final static String LIGHTER_CONF_PREFIX = "lighter.";
public final static String LIGHTER_SESSION_ID_ENV_NAME = "LIGHTER_SESSION_ID";
public final static String PY_GATEWAY_HOST_ENV_NAME = "PY_GATEWAY_HOST";
public final static String PY_GATEWAY_PORT_ENV_NAME = "PY_GATEWAY_PORT";

private Constants() {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Map;
import java.util.function.Consumer;

import static com.exacaster.lighter.backend.Constants.LIGHTER_CONF_PREFIX;
import static org.apache.spark.launcher.SparkLauncher.DRIVER_MEMORY;
import static org.apache.spark.launcher.SparkLauncher.EXECUTOR_CORES;
import static org.apache.spark.launcher.SparkLauncher.EXECUTOR_MEMORY;
Expand Down Expand Up @@ -53,6 +54,18 @@ public Waitable launch() {
return EmptyWaitable.INSTANCE;
}

public Map<String, String> getConfigDefaults() {
return configDefaults;
}

public Map<String, String> getBackendConfiguration() {
return backendConfiguration;
}

public Map<String, String> getEnvVariables() {
return envVariables;
}

private SparkLauncher buildLauncher() {
var submitParams = application.getSubmitParams();
var launcher = new SparkLauncher(envVariables)
Expand All @@ -77,7 +90,11 @@ private SparkLauncher buildLauncher() {
submitParams.getPyFiles().forEach(launcher::addPyFile);

configDefaults.forEach(launcher::setConf);
submitParams.getConf().forEach(launcher::setConf);
submitParams.getConf().forEach((key, val) -> {
if (!key.startsWith(LIGHTER_CONF_PREFIX)) {
launcher.setConf(key, val);
}
});
backendConfiguration.forEach(launcher::setConf);

return launcher;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package com.exacaster.lighter.backend.kubernetes;

import static com.exacaster.lighter.backend.Constants.DEPLOY_MODE_CLUSTER;
import static com.exacaster.lighter.backend.Constants.LIGHTER_SESSION_ID_ENV_NAME;
import static com.exacaster.lighter.backend.Constants.PY_GATEWAY_HOST_ENV_NAME;
import static com.exacaster.lighter.backend.Constants.PY_GATEWAY_PORT_ENV_NAME;
import static java.util.Optional.ofNullable;
import static org.apache.spark.launcher.SparkLauncher.DEPLOY_MODE;
import static org.apache.spark.launcher.SparkLauncher.SPARK_MASTER;
Expand Down Expand Up @@ -66,9 +69,9 @@ Map<String, String> getBackendConfiguration(Application application) {
SPARK_MASTER, properties.getMaster(),
"spark.kubernetes.driver.label." + SPARK_APP_TAG_LABEL, application.getId(),
"spark.kubernetes.executor.label." + SPARK_APP_TAG_LABEL, application.getId(),
"spark.kubernetes.driverEnv.PY_GATEWAY_PORT", String.valueOf(conf.getPyGatewayPort()),
"spark.kubernetes.driverEnv.PY_GATEWAY_HOST", host,
"spark.kubernetes.driverEnv.LIGHTER_SESSION_ID", application.getId()
"spark.kubernetes.driverEnv." + PY_GATEWAY_PORT_ENV_NAME, String.valueOf(conf.getPyGatewayPort()),
"spark.kubernetes.driverEnv." + PY_GATEWAY_HOST_ENV_NAME, host,
"spark.kubernetes.driverEnv." + LIGHTER_SESSION_ID_ENV_NAME, application.getId()
));
props.putAll(STATIC_SUBMIT_PROPS);
return props;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@
import com.google.common.cache.CacheBuilder;

import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;

import static com.exacaster.lighter.backend.Constants.DEPLOY_MODE_CLIENT;
import static com.exacaster.lighter.backend.Constants.LIGHTER_CONF_PREFIX;
import static com.exacaster.lighter.backend.Constants.LIGHTER_SESSION_ID_ENV_NAME;
import static com.exacaster.lighter.backend.Constants.PY_GATEWAY_HOST_ENV_NAME;
import static com.exacaster.lighter.backend.Constants.PY_GATEWAY_PORT_ENV_NAME;
import static org.apache.spark.launcher.SparkLauncher.CHILD_PROCESS_LOGGER_NAME;
import static org.apache.spark.launcher.SparkLauncher.DEPLOY_MODE;
import static org.apache.spark.launcher.SparkLauncher.SPARK_MASTER;

public class LocalBackend implements Backend {

private final static String LOCAL_ENV_CONF_PREFIX = LIGHTER_CONF_PREFIX + "local.env.";

private final AppConfiguration conf;

private final Cache<String, LocalApp> activeApps = CacheBuilder.newBuilder()
Expand Down Expand Up @@ -58,10 +65,12 @@ public void kill(Application application) {
}

@Override
public SparkApp prepareSparkApplication(Application application, Map<String, String> configDefaults,
public SparkApp prepareSparkApplication(Application application,
Map<String, String> configDefaults,
Consumer<Throwable> errorHandler) {
var localApp = new LocalApp(application, errorHandler);
activeApps.put(application.getId(), localApp);

return new SparkApp(
application,
configDefaults,
Expand All @@ -70,14 +79,30 @@ public SparkApp prepareSparkApplication(Application application, Map<String, Str
SPARK_MASTER, "local[*]",
CHILD_PROCESS_LOGGER_NAME, localApp.getLoggerName()
),
Map.of("LIGHTER_SESSION_ID", application.getId(),
"PY_GATEWAY_PORT", conf.getPyGatewayPort().toString(),
"PY_GATEWAY_HOST", "localhost"
),
buildEnvironment(application),
localApp
);
}

private Map<String, String> buildEnvironment(Application application) {
var env = new HashMap<String, String>();
application.getSubmitParams()
.getConf()
.entrySet()
.stream()
.filter(entry -> entry.getKey().startsWith(LOCAL_ENV_CONF_PREFIX))
.forEach(entry -> {
var key = entry.getKey().replaceFirst(LOCAL_ENV_CONF_PREFIX, "");
env.put(key, entry.getValue());
});
env.putAll(Map.of(
LIGHTER_SESSION_ID_ENV_NAME, application.getId(),
PY_GATEWAY_PORT_ENV_NAME, conf.getPyGatewayPort().toString(),
PY_GATEWAY_HOST_ENV_NAME, "localhost"
));
return env;
}

Optional<LocalApp> handleForApp(Application application) {
return Optional.ofNullable(activeApps.getIfPresent(application.getId()));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package com.exacaster.lighter.backend.yarn;

import static com.exacaster.lighter.backend.Constants.DEPLOY_MODE_CLUSTER;
import static com.exacaster.lighter.backend.Constants.LIGHTER_SESSION_ID_ENV_NAME;
import static com.exacaster.lighter.backend.Constants.MASTER_YARN;
import static com.exacaster.lighter.backend.Constants.PY_GATEWAY_HOST_ENV_NAME;
import static com.exacaster.lighter.backend.Constants.PY_GATEWAY_PORT_ENV_NAME;
import static org.apache.hadoop.yarn.api.records.ApplicationId.fromString;
import static org.apache.spark.launcher.SparkLauncher.DEPLOY_MODE;
import static org.apache.spark.launcher.SparkLauncher.SPARK_MASTER;
Expand Down Expand Up @@ -119,9 +122,9 @@ Map<String, String> getBackendConfiguration(Application application) {
SPARK_MASTER, MASTER_YARN,
"spark.yarn.tags", "lighter," + application.getId(),
"spark.yarn.submit.waitAppCompletion", "false",
"spark.yarn.appMasterEnv.PY_GATEWAY_PORT", String.valueOf(conf.getPyGatewayPort()),
"spark.yarn.appMasterEnv.PY_GATEWAY_HOST", host,
"spark.yarn.appMasterEnv.LIGHTER_SESSION_ID", application.getId()
"spark.yarn.appMasterEnv." + PY_GATEWAY_PORT_ENV_NAME, String.valueOf(conf.getPyGatewayPort()),
"spark.yarn.appMasterEnv." + PY_GATEWAY_HOST_ENV_NAME, host,
"spark.yarn.appMasterEnv." + LIGHTER_SESSION_ID_ENV_NAME, application.getId()
));
if (!props.containsKey("spark.kerberos.keytab") && yarnProperties.getKerberos() != null) {
props.put("spark.kerberos.keytab", yarnProperties.getKerberos().getKeytab());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ class LocalBackendTest extends Specification {

then:
appHandle != null
appHandle.getEnvVariables() == [
FOO: "bar",
PY_GATEWAY_HOST: "localhost",
PY_GATEWAY_PORT: "5432",
LIGHTER_SESSION_ID: "1"
]

when:
def info = backend.getInfo(app)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class AppConfigurationTest extends Specification {
appConfiguration.sessionConfiguration.permanentSessions.size() == 1
appConfiguration.sessionConfiguration.permanentSessions.get(0).id == "permanentId1"
appConfiguration.sessionConfiguration.permanentSessions.get(0).submitParams.conf == [
"spark.kubernetes.namespace": "spark",
"spark.kubernetes.driver.secrets.spark-secret": "/etc/secret"
"spark.kubernetes.namespace": "spark",
"spark.kubernetes.driver.secrets.spark-secret": "/etc/secret"
]
appConfiguration.batchDefaultConf != null
appConfiguration.batchDefaultConf.get("spark.driver.cores") == "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class Factories {
"",
"",
0,
null, null, 1, null, null, null, null, null, null, null
null, null, 1, null, null, null, null, null, null,
["lighter.local.env.FOO": "bar"]
)
}

Expand All @@ -31,7 +32,7 @@ class Factories {
.setAppInfo("info")
.setCreatedAt(LocalDateTime.MAX)
.setId("1")
.setSubmitParams(null)
.setSubmitParams(submitParams())
.build()
}

Expand Down
Loading