Skip to content

Commit

Permalink
[FLINK-36653] Fix OnlineLogisticRegressionModel updating logic
Browse files Browse the repository at this point in the history
  • Loading branch information
yunfengzhou-hub committed Nov 4, 2024
1 parent 44f71f2 commit 6b7f9a6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ public void processElement2(StreamRecord<LogisticRegressionModelData> streamReco
LogisticRegressionModelData modelData = streamRecord.getValue();
coefficient = modelData.coefficient;
modelDataVersion = modelData.modelVersion;
servable =
new LogisticRegressionModelServable(
new LogisticRegressionModelData(coefficient, modelDataVersion));
ParamUtils.updateExistingParams(servable, params);
for (Row dataPoint : bufferedPointsState.get()) {
processElement(new StreamRecord<>(dataPoint));
}
Expand All @@ -160,7 +164,7 @@ public void processElement(StreamRecord<Row> streamRecord) throws Exception {
if (servable == null) {
servable =
new LogisticRegressionModelServable(
new LogisticRegressionModelData(coefficient, 0L));
new LogisticRegressionModelData(coefficient, modelDataVersion));
ParamUtils.updateExistingParams(servable, params);
}
Vector features = (Vector) dataPoint.getField(servable.getFeaturesCol());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,14 @@
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
Expand All @@ -72,6 +77,7 @@
import static org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel.MODEL_DATA_VERSION_GAUGE_KEY;

/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
@RunWith(Parameterized.class)
public class OnlineLogisticRegressionTest extends TestLogger {
@Rule public final TemporaryFolder tempFolder = new TemporaryFolder();

Expand Down Expand Up @@ -142,10 +148,21 @@ public class OnlineLogisticRegressionTest extends TestLogger {
Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.)
};

private static final int defaultParallelism = 4;
@Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {{1}, {4}});
}

@Parameter public int defaultParallelism;
private static final int numTaskManagers = 2;
private static final int numSlotsPerTaskManager = 2;

private static final Configuration config =
new Configuration() {
{
set(RestOptions.BIND_PORT, "18081-19091");
set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
}
};
private long currentModelDataVersion;

private InMemorySourceFunction<Row> trainDenseSource;
Expand All @@ -170,9 +187,6 @@ public class OnlineLogisticRegressionTest extends TestLogger {

@BeforeClass
public static void beforeClass() throws Exception {
Configuration config = new Configuration();
config.set(RestOptions.BIND_PORT, "18081-19091");
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
reporter = InMemoryReporter.create();
reporter.addToConfiguration(config);

Expand All @@ -184,17 +198,17 @@ public static void beforeClass() throws Exception {
.setNumSlotsPerTaskManager(numSlotsPerTaskManager)
.build());
miniCluster.start();
}

@Before
public void before() throws Exception {
env = StreamExecutionEnvironment.getExecutionEnvironment(config);
env.getConfig().enableObjectReuse();
env.setParallelism(defaultParallelism);
env.enableCheckpointing(100);
env.setRestartStrategy(RestartStrategies.noRestart());
tEnv = StreamTableEnvironment.create(env);
}

@Before
public void before() throws Exception {
currentModelDataVersion = 0;

trainDenseSource = new InMemorySourceFunction<>();
Expand Down Expand Up @@ -562,6 +576,10 @@ public void testInitWithLogisticRegression() throws Exception {

@Test
public void testBatchSizeLessThanParallelism() {
if (defaultParallelism < 2) {
return;
}

try {
new OnlineLogisticRegression()
.setInitialModelData(initDenseModel)
Expand Down

0 comments on commit 6b7f9a6

Please sign in to comment.