Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Android example version compatibility #1603

Merged
merged 12 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

- **Add support for ** `workload_id` **and** `group_id` **in Driver API** ([#1595](https://github.com/adap/flower/pull/1595))

- **Make Android example compatible with** `flwr >= 1.0.0` **and the latest versions of Android** ([#1603](https://github.com/adap/flower/pull/1603))

### Incompatible changes

## v1.2.0 (2023-01-13)
Expand Down
2 changes: 1 addition & 1 deletion examples/android/apk/download_apk.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
wget https://www.dropbox.com/s/e14t3e9py3mr73v/flwr_android_client.apk?dl=1
wget https://www.dropbox.com/s/ii0vwrjrpupifiv/flower-client.apk?dl=0
4 changes: 4 additions & 0 deletions examples/android/client/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.gradle
.idea
local.properties
app/src/main/assets
42 changes: 22 additions & 20 deletions examples/android/client/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ apply plugin: 'com.google.protobuf'
apply plugin: 'com.android.application'

android {
compileSdkVersion 29
buildToolsVersion "29.0.0"
compileSdkVersion 33

defaultConfig {
applicationId "flwr.android_client"
// API level 14+ is required for TLS since Google Play Services v10.2
minSdkVersion 24
targetSdkVersion 29
targetSdkVersion 33
versionCode 1
versionName "1.0"
}
Expand All @@ -30,21 +29,24 @@ android {
enabled = true
}

aaptOptions {
noCompress "tflite"
}

lintOptions {
namespace 'flwr.android_client'
androidResources {
noCompress 'tflite'
}
lint {
disable 'GoogleAppIndexingWarning', 'HardcodedText', 'InvalidPackage'
textOutput file('stdout')
textReport true
textOutput "stdout"
}
}

def grpc_version = '1.43.0'

protobuf {
protoc { artifact = 'com.google.protobuf:protoc:3.11.0' }
plugins {
grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.27.2' // CURRENT_GRPC_VERSION
grpc { artifact = "io.grpc:protoc-gen-grpc-java:$grpc_version"
}
}
generateProtoTasks {
Expand All @@ -61,28 +63,28 @@ protobuf {
}

dependencies {
implementation 'androidx.appcompat:appcompat:1.0.0'
implementation 'androidx.appcompat:appcompat:1.6.0'
implementation fileTree(dir: 'libs', include: ['*.jar'])
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
testImplementation 'junit:junit:4.12'
androidTestImplementation 'androidx.test:runner:1.2.0'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.2.0'
implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
testImplementation 'junit:junit:4.13.2'
androidTestImplementation 'androidx.test:runner:1.5.2'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'
implementation project(path: ':transfer_api')

implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT'
// This dependency adds the necessary TF op support.
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:0.0.0-nightly-SNAPSHOT'

implementation 'io.grpc:grpc-okhttp:1.27.2' // CURRENT_GRPC_VERSION
implementation 'io.grpc:grpc-protobuf-lite:1.27.2' // CURRENT_GRPC_VERSION
implementation 'io.grpc:grpc-stub:1.27.2' // CURRENT_GRPC_VERSION
implementation 'javax.annotation:javax.annotation-api:1.2'
implementation "io.grpc:grpc-okhttp:$grpc_version"
implementation "io.grpc:grpc-protobuf-lite:$grpc_version"
implementation "io.grpc:grpc-stub:$grpc_version"
implementation 'javax.annotation:javax.annotation-api:1.3.2'

def lifecycle_version = '2.1.0-rc01'
def lifecycle_version = '2.2.0'
implementation "androidx.lifecycle:lifecycle-extensions:$lifecycle_version"
implementation "androidx.lifecycle:lifecycle-common-java8:$lifecycle_version"

implementation 'com.google.android.material:material:1.0.0'
implementation 'com.google.android.material:material:1.7.0'

}

Expand Down
11 changes: 7 additions & 4 deletions examples/android/client/app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="flwr.android_client">
<manifest xmlns:android="http://schemas.android.com/apk/res/android">

<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
Expand All @@ -12,13 +11,17 @@
android:usesCleartextTraffic="true"
android:supportsRtl="true"
android:theme="@style/Theme.AppCompat.Light">
<activity android:name="flwr.android_client.MainActivity">
<activity android:name="flwr.android_client.MainActivity"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />

<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
<meta-data
android:name="preloaded_fonts"
android:resource="@array/preloaded_fonts" />
</application>

</manifest>
</manifest>
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import android.app.Activity;
import android.icu.text.SimpleDateFormat;
import android.os.AsyncTask;
import android.os.Build;
import android.os.Bundle;

import androidx.annotation.RequiresApi;
import androidx.appcompat.app.AppCompatActivity;

import android.os.Handler;
import android.os.Looper;
import android.text.TextUtils;
import android.text.method.ScrollingMovementMethod;
import android.util.Log;
Expand All @@ -25,23 +23,23 @@
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;

import flwr.android_client.FlowerServiceGrpc.FlowerServiceBlockingStub;
import flwr.android_client.FlowerServiceGrpc.FlowerServiceStub;
import com.google.protobuf.ByteString;

import io.grpc.stub.StreamObserver;

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.net.URL;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import javax.net.ssl.HttpsURLConnection;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;


public class MainActivity extends AppCompatActivity {
private EditText ip;
Expand All @@ -53,7 +51,7 @@ public class MainActivity extends AppCompatActivity {
private EditText device_id;
private ManagedChannel channel;
public FlowerClient fc;
private static String TAG = "Flower";
private static final String TAG = "Flower";

@Override
protected void onCreate(Bundle savedInstanceState) {
Expand Down Expand Up @@ -82,7 +80,7 @@ public static void hideKeyboard(Activity activity) {


public void setResultText(String text) {
SimpleDateFormat dateFormat = new SimpleDateFormat("HH:mm:ss");
SimpleDateFormat dateFormat = new SimpleDateFormat("HH:mm:ss", Locale.GERMANY);
String time = dateFormat.format(new Date());
resultText.append("\n" + time + " " + text);
}
Expand All @@ -99,15 +97,30 @@ else if (Integer.parseInt(device_id.getText().toString()) > 10 || Integer.parse
hideKeyboard(this);
setResultText("Loading the local training dataset in memory. It will take several seconds.");
loadDataButton.setEnabled(false);
final Handler handler = new Handler();
handler.postDelayed(new Runnable() {

ExecutorService executor = Executors.newSingleThreadExecutor();
Handler handler = new Handler(Looper.getMainLooper());

executor.execute(new Runnable() {
private String result;
@Override
public void run() {
fc.loadData(Integer.parseInt(device_id.getText().toString()));
setResultText("Training dataset is loaded in memory.");
connectButton.setEnabled(true);
try {
fc.loadData(Integer.parseInt(device_id.getText().toString()));
result = "Training dataset is loaded in memory.";
} catch (Exception e) {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
e.printStackTrace(pw);
pw.flush();
result = "Training dataset is loaded in memory.";
}
handler.post(() -> {
setResultText(result);
connectButton.setEnabled(true);
});
}
}, 1000);
});
}
}

Expand All @@ -118,7 +131,7 @@ public void connect(View view) {
Toast.makeText(this, "Please enter the correct IP and port of the FL server", Toast.LENGTH_LONG).show();
}
else {
int port = TextUtils.isEmpty(portStr) ? 0 : Integer.valueOf(portStr);
int port = TextUtils.isEmpty(portStr) ? 0 : Integer.parseInt(portStr);
channel = ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(10 * 1024 * 1024).usePlaintext().build();
hideKeyboard(this);
trainButton.setEnabled(true);
Expand All @@ -127,61 +140,44 @@ public void connect(View view) {
}
}

public void runGRCP(View view){
new GrpcTask(new FlowerServiceRunnable(), channel, this).execute();
}

private static class GrpcTask extends AsyncTask<Void, Void, String> {
private final GrpcRunnable grpcRunnable;
private final ManagedChannel channel;
private final MainActivity activityReference;

GrpcTask(GrpcRunnable grpcRunnable, ManagedChannel channel, MainActivity activity) {
this.grpcRunnable = grpcRunnable;
this.channel = channel;
this.activityReference = activity;
}

@Override
protected String doInBackground(Void... nothing) {
try {
grpcRunnable.run(FlowerServiceGrpc.newBlockingStub(channel), FlowerServiceGrpc.newStub(channel), this.activityReference);
return "Connection to the FL server successful \n";
} catch (Exception e) {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
e.printStackTrace(pw);
pw.flush();
return "Failed to connect to the FL server \n" + sw;
}
}

@Override
protected void onPostExecute(String result) {
MainActivity activity = activityReference;
if (activity == null) {
return;
public void runGrpc(View view){
MainActivity activity = this;
ExecutorService executor = Executors.newSingleThreadExecutor();
Handler handler = new Handler(Looper.getMainLooper());

executor.execute(new Runnable() {
private String result;
@Override
public void run() {
try {
(new FlowerServiceRunnable()).run(FlowerServiceGrpc.newStub(channel), activity);
result = "Connection to the FL server successful \n";
} catch (Exception e) {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
e.printStackTrace(pw);
pw.flush();
result = "Failed to connect to the FL server \n" + sw;
}
handler.post(() -> {
setResultText(result);
trainButton.setEnabled(false);
});
}
activity.setResultText(result);
activity.trainButton.setEnabled(false);
}
});
}

private interface GrpcRunnable {
void run(FlowerServiceBlockingStub blockingStub, FlowerServiceStub asyncStub, MainActivity activity) throws Exception;
}

private static class FlowerServiceRunnable implements GrpcRunnable {
private Throwable failed;
private static class FlowerServiceRunnable{
protected Throwable failed;
private StreamObserver<ClientMessage> requestObserver;
@Override
public void run(FlowerServiceBlockingStub blockingStub, FlowerServiceStub asyncStub, MainActivity activity)
throws Exception {

public void run(FlowerServiceStub asyncStub, MainActivity activity) {
join(asyncStub, activity);
}

private void join(FlowerServiceStub asyncStub, MainActivity activity)
throws InterruptedException, RuntimeException {
throws RuntimeException {

final CountDownLatch finishLatch = new CountDownLatch(1);
requestObserver = asyncStub.join(
Expand All @@ -193,6 +189,7 @@ public void onNext(ServerMessage msg) {

@Override
public void onError(Throwable t) {
t.printStackTrace();
failed = t;
finishLatch.countDown();
Log.e(TAG, t.getMessage());
Expand All @@ -212,7 +209,7 @@ private void handleMessage(ServerMessage message, MainActivity activity) {
ByteBuffer[] weights;
ClientMessage c = null;

if (message.hasGetParameters()) {
if (message.hasGetParametersIns()) {
Log.e(TAG, "Handling GetParameters");
activity.setResultText("Handling GetParameters message from the server.");

Expand All @@ -226,6 +223,7 @@ private void handleMessage(ServerMessage message, MainActivity activity) {

Scalar epoch_config = message.getFitIns().getConfigMap().getOrDefault("local_epochs", Scalar.newBuilder().setSint64(1).build());

assert epoch_config != null;
int local_epochs = (int) epoch_config.getSint64();

// Our model has 10 layers
Expand Down Expand Up @@ -257,7 +255,6 @@ private void handleMessage(ServerMessage message, MainActivity activity) {
}
requestObserver.onNext(c);
activity.setResultText("Response sent to the server");
c = null;
}
catch (Exception e){
Log.e(TAG, e.getMessage());
Expand All @@ -266,19 +263,19 @@ private void handleMessage(ServerMessage message, MainActivity activity) {
}

private static ClientMessage weightsAsProto(ByteBuffer[] weights){
List<ByteString> layers = new ArrayList<ByteString>();
for (int i=0; i < weights.length; i++) {
layers.add(ByteString.copyFrom(weights[i]));
List<ByteString> layers = new ArrayList<>();
for (ByteBuffer weight : weights) {
layers.add(ByteString.copyFrom(weight));
}
Parameters p = Parameters.newBuilder().addAllTensors(layers).setTensorType("ND").build();
ClientMessage.ParametersRes res = ClientMessage.ParametersRes.newBuilder().setParameters(p).build();
return ClientMessage.newBuilder().setParametersRes(res).build();
ClientMessage.GetParametersRes res = ClientMessage.GetParametersRes.newBuilder().setParameters(p).build();
return ClientMessage.newBuilder().setGetParametersRes(res).build();
}

private static ClientMessage fitResAsProto(ByteBuffer[] weights, int training_size){
List<ByteString> layers = new ArrayList<ByteString>();
for (int i=0; i < weights.length; i++) {
layers.add(ByteString.copyFrom(weights[i]));
List<ByteString> layers = new ArrayList<>();
for (ByteBuffer weight : weights) {
layers.add(ByteString.copyFrom(weight));
}
Parameters p = Parameters.newBuilder().addAllTensors(layers).setTensorType("ND").build();
ClientMessage.FitRes res = ClientMessage.FitRes.newBuilder().setParameters(p).setNumExamples(training_size).build();
Expand Down
Loading