From 39ba7d0dd531f1c0fcfaadec035175f7c9e6def0 Mon Sep 17 00:00:00 2001 From: Shaz-hash <73599731+Shaz-hash@users.noreply.github.com> Date: Wed, 27 Sep 2023 20:59:29 +0500 Subject: [PATCH 01/12] Upgrade version of flower Added WorkManager library to allow the FL process to run in the background. --- examples/android/client/app/build.gradle | 4 +- .../client/app/src/main/AndroidManifest.xml | 38 +- .../flwr/android_client/FlowerClient.java | 39 +- .../flwr/android_client/MainActivity.java | 439 ++++++++-------- .../flwr/android_client/MessageAdapter.java | 58 +++ .../java/flwr/android_client/MyWorker.java | 480 ++++++++++++++++++ .../app/src/main/res/layout/activity_main.xml | 231 ++++++--- .../app/src/main/res/layout/item_message.xml | 17 + .../app/src/main/res/values/strings.xml | 4 + examples/android/client/build.gradle | 2 +- 10 files changed, 999 insertions(+), 313 deletions(-) create mode 100644 examples/android/client/app/src/main/java/flwr/android_client/MessageAdapter.java create mode 100644 examples/android/client/app/src/main/java/flwr/android_client/MyWorker.java create mode 100644 examples/android/client/app/src/main/res/layout/item_message.xml diff --git a/examples/android/client/app/build.gradle b/examples/android/client/app/build.gradle index abb2f5109d0..7b9e844105c 100644 --- a/examples/android/client/app/build.gradle +++ b/examples/android/client/app/build.gradle @@ -44,7 +44,7 @@ android { def grpc_version = '1.43.0' protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.11.0' } + protoc { artifact = 'com.google.protobuf:protoc:3.17.3' } plugins { grpc { artifact = "io.grpc:protoc-gen-grpc-java:$grpc_version" } @@ -66,6 +66,7 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.6.0' implementation fileTree(dir: 'libs', include: ['*.jar']) implementation 'androidx.constraintlayout:constraintlayout:2.1.4' + implementation 'androidx.work:work-runtime:2.8.1' testImplementation 'junit:junit:4.13.2' androidTestImplementation 'androidx.test:runner:1.5.2' androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1' @@ -85,6 +86,7 @@ dependencies { implementation "androidx.lifecycle:lifecycle-common-java8:$lifecycle_version" implementation 'com.google.android.material:material:1.7.0' + implementation 'com.google.protobuf:protobuf-javalite:3.17.3' } diff --git a/examples/android/client/app/src/main/AndroidManifest.xml b/examples/android/client/app/src/main/AndroidManifest.xml index 18eb6bad1fe..c318e56c70f 100644 --- a/examples/android/client/app/src/main/AndroidManifest.xml +++ b/examples/android/client/app/src/main/AndroidManifest.xml @@ -1,8 +1,16 @@ - - - + + + + + + + + - + + + + + + + + + + + diff --git a/examples/android/client/app/src/main/java/flwr/android_client/FlowerClient.java b/examples/android/client/app/src/main/java/flwr/android_client/FlowerClient.java index c453a1d106e..e789e8f15cb 100644 --- a/examples/android/client/app/src/main/java/flwr/android_client/FlowerClient.java +++ b/examples/android/client/app/src/main/java/flwr/android_client/FlowerClient.java @@ -11,6 +11,8 @@ import androidx.lifecycle.MutableLiveData; import java.io.BufferedReader; +import java.io.File; +import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.nio.ByteBuffer; @@ -64,6 +66,7 @@ public void setLastLoss(int epoch, float newLoss) { public void loadData(int device_id) { try { + Log.d("FLOWERCLIENT_LOAD", "loadData: "); BufferedReader reader = new BufferedReader(new InputStreamReader(this.context.getAssets().open("data/partition_" + (device_id - 1) + "_train.txt"))); String line; int i = 0; @@ -137,4 +140,38 @@ private static float[] prepareImage(Bitmap bitmap) { return normalizedRgb; } -} \ No newline at end of file + + // function to write to a file : + + public void writeStringToFile( Context context , String fileName, String content) { + try { + // Get the app-specific external storage directory + File directory = context.getExternalFilesDir(null); + + if (directory != null) { + File file = new File(directory, fileName); + + // Check if the file exists + boolean fileExists = file.exists(); + + // Open a FileWriter in append mode + FileWriter writer = new FileWriter(file, true); + + // If the file exists and is not empty, add a new line + if (fileExists && file.length() > 0) { + writer.append("\n"); + } + + // Write the string to the file + writer.append(content); + + // Close the FileWriter + writer.close(); + } + } catch (IOException e) { + e.printStackTrace(); // Handle the exception as needed + } + } + + +} diff --git a/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java b/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java index 911d5043dfe..c4458391675 100644 --- a/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java +++ b/examples/android/client/app/src/main/java/flwr/android_client/MainActivity.java @@ -1,289 +1,278 @@ package flwr.android_client; -import android.app.Activity; -import android.icu.text.SimpleDateFormat; +import android.content.Context; +import android.content.Intent; +import android.net.Uri; +import android.os.Build; import android.os.Bundle; - import androidx.appcompat.app.AppCompatActivity; - -import android.os.Handler; -import android.os.Looper; +import androidx.lifecycle.LifecycleOwner; +import androidx.recyclerview.widget.LinearLayoutManager; +import androidx.recyclerview.widget.RecyclerView; +import androidx.work.Constraints; +import androidx.work.Data; +import androidx.work.ExistingPeriodicWorkPolicy; +import androidx.work.PeriodicWorkRequest; +import androidx.work.WorkInfo; +import androidx.work.WorkManager; +import android.os.PowerManager; +import android.provider.Settings; import android.text.TextUtils; -import android.text.method.ScrollingMovementMethod; -import android.util.Log; -import android.util.Pair; -import android.util.Patterns; import android.view.View; -import android.view.inputmethod.InputMethodManager; import android.widget.Button; import android.widget.EditText; -import android.widget.TextView; import android.widget.Toast; - -import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; - -import flwr.android_client.FlowerServiceGrpc.FlowerServiceStub; -import com.google.protobuf.ByteString; - -import io.grpc.stub.StreamObserver; - -import java.io.PrintWriter; -import java.io.StringWriter; -import java.nio.ByteBuffer; +import androidx.lifecycle.Observer; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; import java.util.ArrayList; -import java.util.Date; import java.util.List; -import java.util.Locale; -import java.util.Objects; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + public class MainActivity extends AppCompatActivity { - private EditText ip; - private EditText port; - private Button loadDataButton; - private Button connectButton; - private Button trainButton; - private TextView resultText; - private EditText device_id; - private ManagedChannel channel; - public FlowerClient fc; - private static final String TAG = "Flower"; + private static final String TAG = "Flower"; + private static final int REQUEST_WRITE_PERMISSION = 786; + private Button batteryOptimisationButton; + MessageAdapter messageAdapter; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); - resultText = (TextView) findViewById(R.id.grpc_response_text); - resultText.setMovementMethod(new ScrollingMovementMethod()); - device_id = (EditText) findViewById(R.id.device_id_edit_text); - ip = (EditText) findViewById(R.id.serverIP); - port = (EditText) findViewById(R.id.serverPort); - loadDataButton = (Button) findViewById(R.id.load_data) ; - connectButton = (Button) findViewById(R.id.connect); - trainButton = (Button) findViewById(R.id.trainFederated); - - fc = new FlowerClient(this); - } + RecyclerView recyclerView = findViewById(R.id.recyclerView); + recyclerView.setLayoutManager(new LinearLayoutManager(this)); - public static void hideKeyboard(Activity activity) { - InputMethodManager imm = (InputMethodManager) activity.getSystemService(Activity.INPUT_METHOD_SERVICE); - View view = activity.getCurrentFocus(); - if (view == null) { - view = new View(activity); - } - imm.hideSoftInputFromWindow(view.getWindowToken(), 0); - } + messageAdapter = new MessageAdapter(readStringFromFile( getApplicationContext() , "FlowerResults.txt")); // Create your custom adapter + recyclerView.setLayoutManager(new LinearLayoutManager(this)); + recyclerView.setAdapter(messageAdapter); + requestPermission(); - public void setResultText(String text) { - SimpleDateFormat dateFormat = new SimpleDateFormat("HH:mm:ss", Locale.GERMANY); - String time = dateFormat.format(new Date()); - resultText.append("\n" + time + " " + text); + LifecycleOwner lifecycleOwner = this ; + WorkManager.getInstance(getApplicationContext()).getWorkInfosForUniqueWorkLiveData("my_unique_periodic_work").observe(lifecycleOwner, new Observer>() { + @Override + public void onChanged(List workInfos) { + if (workInfos.size() > 0) { + WorkInfo info = workInfos.get(0); + int progress = info.getProgress().getInt("progress", -1); + // You can recieve any message from the Worker Thread + refreshRecyclerView(); + } + } + }); + // code for functionality of permission buttons : + batteryOptimisationButton = findViewById(R.id.battery_optimisation); + batteryOptimisationButton.setOnClickListener(new View.OnClickListener() { + @Override + public void onClick(View v) { + toggleBatteryOptimization(); + } + }); } - public void loadData(View view){ - if (TextUtils.isEmpty(device_id.getText().toString())) { - Toast.makeText(this, "Please enter a client partition ID between 1 and 10 (inclusive)", Toast.LENGTH_LONG).show(); + private void requestPermission() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + requestPermissions(new String[]{android.Manifest.permission.WRITE_EXTERNAL_STORAGE}, REQUEST_WRITE_PERMISSION); + createEmptyFile("FlowerResults.txt"); } - else if (Integer.parseInt(device_id.getText().toString()) > 10 || Integer.parseInt(device_id.getText().toString()) < 1) + else { - Toast.makeText(this, "Please enter a client partition ID between 1 and 10 (inclusive)", Toast.LENGTH_LONG).show(); - } - else{ - hideKeyboard(this); - setResultText("Loading the local training dataset in memory. It will take several seconds."); - loadDataButton.setEnabled(false); - - ExecutorService executor = Executors.newSingleThreadExecutor(); - Handler handler = new Handler(Looper.getMainLooper()); - - executor.execute(new Runnable() { - private String result; - @Override - public void run() { - try { - fc.loadData(Integer.parseInt(device_id.getText().toString())); - result = "Training dataset is loaded in memory."; - } catch (Exception e) { - StringWriter sw = new StringWriter(); - PrintWriter pw = new PrintWriter(sw); - e.printStackTrace(pw); - pw.flush(); - result = "Training dataset is loaded in memory."; - } - handler.post(() -> { - setResultText(result); - connectButton.setEnabled(true); - }); - } - }); + createEmptyFile("FlowerResults.txt"); } } + private List readStringFromFile(Context context, String fileName) { + List lines = new ArrayList<>(); - public void connect(View view) { - String host = ip.getText().toString(); - String portStr = port.getText().toString(); - if (TextUtils.isEmpty(host) || TextUtils.isEmpty(portStr) || !Patterns.IP_ADDRESS.matcher(host).matches()) { - Toast.makeText(this, "Please enter the correct IP and port of the FL server", Toast.LENGTH_LONG).show(); - } - else { - int port = TextUtils.isEmpty(portStr) ? 0 : Integer.parseInt(portStr); - channel = ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(10 * 1024 * 1024).usePlaintext().build(); - hideKeyboard(this); - trainButton.setEnabled(true); - connectButton.setEnabled(false); - setResultText("Channel object created. Ready to train!"); + try { + File directory = context.getExternalFilesDir(null); + + if (directory != null) { + File file = new File(directory, fileName); + + // Checking if the file exists + if (!file.exists()) { + return lines; // File doesn't exist then return an empty list + } + // Opening a FileReader to read the file + FileReader reader = new FileReader(file); + BufferedReader bufferedReader = new BufferedReader(reader); + String line; + while ((line = bufferedReader.readLine()) != null) { + lines.add(line); + } + // Closing the readers + bufferedReader.close(); + reader.close(); + } + } catch (IOException e) { + e.printStackTrace(); // Handle the exception as needed } + + return lines; } - public void runGrpc(View view){ - MainActivity activity = this; - ExecutorService executor = Executors.newSingleThreadExecutor(); - Handler handler = new Handler(Looper.getMainLooper()); - executor.execute(new Runnable() { - private String result; - @Override - public void run() { - try { - (new FlowerServiceRunnable()).run(FlowerServiceGrpc.newStub(channel), activity); - result = "Connection to the FL server successful \n"; - } catch (Exception e) { - StringWriter sw = new StringWriter(); - PrintWriter pw = new PrintWriter(sw); - e.printStackTrace(pw); - pw.flush(); - result = "Failed to connect to the FL server \n" + sw; + private void clearFileContents(Context context, String fileName) { + try { + File directory = context.getExternalFilesDir(null); + + if (directory != null) { + File file = new File(directory, fileName); + + // Checking if the file exists + if (!file.exists()) { + // File doesn't exist, so there's nothing to clear + return; } - handler.post(() -> { - setResultText(result); - trainButton.setEnabled(false); - }); + + // Opens a FileWriter with append mode set to false (this will clear the file) + FileWriter writer = new FileWriter(file, false); + writer.write(""); // Write an empty string to clear the file + writer.close(); + + refreshRecyclerView(); } - }); + } catch (IOException e) { + e.printStackTrace(); // Handle the exception as needed + } } - private static class FlowerServiceRunnable{ - protected Throwable failed; - private StreamObserver requestObserver; - public void run(FlowerServiceStub asyncStub, MainActivity activity) { - join(asyncStub, activity); - } + public void startWorker(View view) { + + // ensuring all inputs are entered : + + EditText deviceIdEditText = findViewById(R.id.device_id_edit_text); + EditText serverIPEditText = findViewById(R.id.serverIP); + EditText serverPortEditText = findViewById(R.id.serverPort); + + // Get the text from the EditText widgets + String dataSlice = deviceIdEditText.getText().toString(); + String serverIP = serverIPEditText.getText().toString(); + String serverPort = serverPortEditText.getText().toString(); - private void join(FlowerServiceStub asyncStub, MainActivity activity) - throws RuntimeException { - - final CountDownLatch finishLatch = new CountDownLatch(1); - requestObserver = asyncStub.join( - new StreamObserver() { - @Override - public void onNext(ServerMessage msg) { - handleMessage(msg, activity); - } - - @Override - public void onError(Throwable t) { - t.printStackTrace(); - failed = t; - finishLatch.countDown(); - Log.e(TAG, t.getMessage()); - } - - @Override - public void onCompleted() { - finishLatch.countDown(); - Log.e(TAG, "Done"); - } - }); + if (TextUtils.isEmpty(dataSlice) || TextUtils.isEmpty(serverIP) || TextUtils.isEmpty(serverPort)) { + // Display a toast message indicating that fields are omitted + Toast.makeText(this, "Please fill in all fields", Toast.LENGTH_SHORT).show(); + } else { + + // Launching the Worker : + Constraints constraints = new Constraints.Builder() + // Add constraints if needed (e.g., network connectivity) + .build(); + + PeriodicWorkRequest workRequest = new PeriodicWorkRequest.Builder( + MyWorker.class, 15, TimeUnit.MINUTES) + .setInitialDelay(0, TimeUnit.MILLISECONDS) + .setInputData(new Data.Builder() + .putString( "dataslice", deviceIdEditText.getText().toString() ) + .putString( "server", serverIPEditText.getText().toString()) + .putString( "port" , serverPortEditText.getText().toString()) + .build()) + .setConstraints(constraints) + .build(); + + String uniqueWorkName = "my_unique_periodic_work"; + + WorkManager.getInstance(getApplicationContext()) + .enqueueUniquePeriodicWork(uniqueWorkName, ExistingPeriodicWorkPolicy.KEEP, workRequest); + + // Providing user feedback, e.g., a toast message + Toast.makeText(this, "Worker started!", Toast.LENGTH_SHORT).show(); } + } - private void handleMessage(ServerMessage message, MainActivity activity) { + // Listener function for the "Stop" button + public void stopWorker(View view) { + // Cancel the worker + WorkManager.getInstance(getApplicationContext()).cancelAllWork(); + // Providing user feedback again, e.g., a toast message + Toast.makeText(this, "Worker stopped!", Toast.LENGTH_SHORT).show(); + } - try { - ByteBuffer[] weights; - ClientMessage c = null; - if (message.hasGetParametersIns()) { - Log.e(TAG, "Handling GetParameters"); - activity.setResultText("Handling GetParameters message from the server."); + // Another Listener function to refresh the updates : - weights = activity.fc.getWeights(); - c = weightsAsProto(weights); - } else if (message.hasFitIns()) { - Log.e(TAG, "Handling FitIns"); - activity.setResultText("Handling Fit request from the server."); + public void refresh(View view) + { + refreshRecyclerView(); + } - List layers = message.getFitIns().getParameters().getTensorsList(); + // Another Listener to clear the contents of the File : - Scalar epoch_config = message.getFitIns().getConfigMap().getOrDefault("local_epochs", Scalar.newBuilder().setSint64(1).build()); + public void clear(View view) + { + clearFileContents(getApplicationContext() , "FlowerResults.txt"); + } - assert epoch_config != null; - int local_epochs = (int) epoch_config.getSint64(); - // Our model has 10 layers - ByteBuffer[] newWeights = new ByteBuffer[10] ; - for (int i = 0; i < 10; i++) { - newWeights[i] = ByteBuffer.wrap(layers.get(i).toByteArray()); - } + private void refreshRecyclerView() { + // Get messages from MessageRepository using the getMessagesArray method + List messages = readStringFromFile( getApplicationContext() ,"FlowerResults.txt"); - Pair outputs = activity.fc.fit(newWeights, local_epochs); - c = fitResAsProto(outputs.first, outputs.second); - } else if (message.hasEvaluateIns()) { - Log.e(TAG, "Handling EvaluateIns"); - activity.setResultText("Handling Evaluate request from the server"); + // Update the data source of your adapter with the new messages + messageAdapter.setData(messages); - List layers = message.getEvaluateIns().getParameters().getTensorsList(); + // Notify the adapter that the data has changed + messageAdapter.notifyDataSetChanged(); - // Our model has 10 layers - ByteBuffer[] newWeights = new ByteBuffer[10] ; - for (int i = 0; i < 10; i++) { - newWeights[i] = ByteBuffer.wrap(layers.get(i).toByteArray()); - } - Pair, Integer> inference = activity.fc.evaluate(newWeights); + } - float loss = inference.first.first; - float accuracy = inference.first.second; - activity.setResultText("Test Accuracy after this round = " + accuracy); - int test_size = inference.second; - c = evaluateResAsProto(loss, test_size); - } - requestObserver.onNext(c); - activity.setResultText("Response sent to the server"); - } - catch (Exception e){ - Log.e(TAG, e.getMessage()); - } + + + // following code is for just the permissions : + private void toggleBatteryOptimization() { + if (isBatteryOptimizationEnabled()) { + disableBatteryOptimization(); + } else { + requestBatteryOptimization(); } } - private static ClientMessage weightsAsProto(ByteBuffer[] weights){ - List layers = new ArrayList<>(); - for (ByteBuffer weight : weights) { - layers.add(ByteString.copyFrom(weight)); + private boolean isBatteryOptimizationEnabled() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + String packageName = getPackageName(); + PowerManager powerManager = (PowerManager) getSystemService(Context.POWER_SERVICE); + return powerManager.isIgnoringBatteryOptimizations(packageName); } - Parameters p = Parameters.newBuilder().addAllTensors(layers).setTensorType("ND").build(); - ClientMessage.GetParametersRes res = ClientMessage.GetParametersRes.newBuilder().setParameters(p).build(); - return ClientMessage.newBuilder().setGetParametersRes(res).build(); + // Battery optimization is not available on versions prior to M, so return false. + return false; } - private static ClientMessage fitResAsProto(ByteBuffer[] weights, int training_size){ - List layers = new ArrayList<>(); - for (ByteBuffer weight : weights) { - layers.add(ByteString.copyFrom(weight)); + private void disableBatteryOptimization() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + Intent intent = new Intent(Settings.ACTION_IGNORE_BATTERY_OPTIMIZATION_SETTINGS); + startActivity(intent); } - Parameters p = Parameters.newBuilder().addAllTensors(layers).setTensorType("ND").build(); - ClientMessage.FitRes res = ClientMessage.FitRes.newBuilder().setParameters(p).setNumExamples(training_size).build(); - return ClientMessage.newBuilder().setFitRes(res).build(); } - private static ClientMessage evaluateResAsProto(float accuracy, int testing_size){ - ClientMessage.EvaluateRes res = ClientMessage.EvaluateRes.newBuilder().setLoss(accuracy).setNumExamples(testing_size).build(); - return ClientMessage.newBuilder().setEvaluateRes(res).build(); + private void requestBatteryOptimization() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { + Intent intent = new Intent(Settings.ACTION_REQUEST_IGNORE_BATTERY_OPTIMIZATIONS); + intent.setData(Uri.parse("package:" + getPackageName())); + startActivity(intent); +// startActivityForResult(intent, BATTERY_OPTIMIZATION_REQUEST_CODE); + } + } + + public void createEmptyFile(String fileName) { + try { + File file = new File(fileName); + + // Create the file if it doesn't exist + if (!file.exists()) { + file.createNewFile(); + } + } catch (IOException e) { + e.printStackTrace(); + } } } + + diff --git a/examples/android/client/app/src/main/java/flwr/android_client/MessageAdapter.java b/examples/android/client/app/src/main/java/flwr/android_client/MessageAdapter.java new file mode 100644 index 00000000000..75bcfdbdbe6 --- /dev/null +++ b/examples/android/client/app/src/main/java/flwr/android_client/MessageAdapter.java @@ -0,0 +1,58 @@ +package flwr.android_client; + +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; +import android.widget.TextView; +import androidx.annotation.NonNull; +import androidx.recyclerview.widget.RecyclerView; +import java.util.List; + +public class MessageAdapter extends RecyclerView.Adapter { + + private List messages; + + // Constructor to initialize the data source + public MessageAdapter(List messages) { + this.messages = messages; + } + + @NonNull + @Override + public MessageViewHolder onCreateViewHolder(@NonNull ViewGroup parent, int viewType) { + View view = LayoutInflater.from(parent.getContext()).inflate(R.layout.item_message, parent, false); + return new MessageViewHolder(view); + } + + @Override + public void onBindViewHolder(@NonNull MessageViewHolder holder, int position) { + String message = messages.get(position); + holder.bind(message); + } + + @Override + public int getItemCount() { + return messages != null ? messages.size() : 0; + } + + public void setData(List messages) { + this.messages = messages; + notifyDataSetChanged(); + } + + + // ViewHolder class + public static class MessageViewHolder extends RecyclerView.ViewHolder { + TextView messageTextView; + + public MessageViewHolder(@NonNull View itemView) { + super(itemView); + messageTextView = itemView.findViewById(R.id.messageTextView); + } + + // Bind data to the TextView + public void bind(String message) { + messageTextView.setText(message); + } + } +} diff --git a/examples/android/client/app/src/main/java/flwr/android_client/MyWorker.java b/examples/android/client/app/src/main/java/flwr/android_client/MyWorker.java new file mode 100644 index 00000000000..ede924756ba --- /dev/null +++ b/examples/android/client/app/src/main/java/flwr/android_client/MyWorker.java @@ -0,0 +1,480 @@ +package flwr.android_client; + + + +import static android.content.Context.NOTIFICATION_SERVICE; +import android.app.Notification; +import android.app.NotificationChannel; +import android.app.NotificationManager; +import android.app.PendingIntent; +import android.content.Context; +import android.icu.text.SimpleDateFormat; +import android.os.Build; +import androidx.annotation.NonNull; +import androidx.annotation.RequiresApi; +import android.util.Log; +import android.util.Pair; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import flwr.android_client.FlowerServiceGrpc.FlowerServiceStub; +import com.google.protobuf.ByteString; +import io.grpc.stub.StreamObserver; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.HashMap; +import java.util.Map; +import androidx.core.app.NotificationCompat; +import androidx.work.Data; +import androidx.work.ForegroundInfo; +import androidx.work.WorkManager; +import androidx.work.Worker; +import androidx.work.WorkerParameters; + +public class MyWorker extends Worker { + + private ManagedChannel channel; + public FlowerClient fc; + private StreamObserver UniversalRequestObserver; + private static final String TAG = "Flower"; + String serverIp = "00:00:00"; + String serverPort = "0000"; + String dataslice = "1"; + public static String start_time; + public static String end_time; + // following variables are just to send the worker routine to the + public static String workerStartTime = ""; + + public static String workerEndTime = ""; + + public static String workerEndReason = "worker ended"; + + public String getTime() { + // Extract hours, minutes, and seconds + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.O) { + java.text.SimpleDateFormat sdf = new java.text.SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault()); + String formattedDateTime = sdf.format(new Date()); + return formattedDateTime; + } + return ""; + } + private NotificationManager notificationManager; + + private static String PROGRESS = "PROGRESS"; + + public MyWorker(@NonNull Context context, @NonNull WorkerParameters workerParams) { + super(context, workerParams); + MyWorker worker = this; + notificationManager = (NotificationManager) + context.getSystemService(NOTIFICATION_SERVICE); + fc = new FlowerClient(context.getApplicationContext()); + } + + @NonNull + @Override + public Result doWork() { + + Data checkData = getInputData(); + serverIp = checkData.getString("server"); + serverPort = checkData.getString("port"); + dataslice = checkData.getString("dataslice"); + + // Creating Foreground Notification Service about the Background Worker FL tasks + setForegroundAsync(createForegroundInfo("Progress")); + try { + workerStartTime = getTime(); + // Ensuring whether the connection is establish or not with the given gRPC IP & port + boolean resultConnect = connect(); + if(resultConnect) + { + loadData(); + CompletableFuture grpcFuture = runGrpc(); + grpcFuture.get(); + return Result.success(); + } + else + { + workerEndReason = "GRPC Connection failed"; + return Result.failure(); + } + + } catch (Exception e) { + // To handle any exceptions and return a failure result + // Failure if there is any OOM or midway connection error + workerEndReason = "Unknown Error occured in main try catch"; + Log.e(TAG, "Error executing flower code: " + e.getMessage(), e); + return Result.failure(); + } + } + + @Override + public void onStopped() { + super.onStopped(); + // Worker is canceled, stopping the global requestObserver if it's not null + Throwable cancellationCause = new Throwable("Worker canceled"); + if (UniversalRequestObserver != null) { + UniversalRequestObserver.onError(cancellationCause); // Signal to the server that communication is done + } + } + + public boolean connect() { + int port = Integer.parseInt(serverPort); + try { + channel = ManagedChannelBuilder.forAddress(serverIp, port) + .maxInboundMessageSize(10 * 1024 * 1024) + .usePlaintext() + .build(); + fc.writeStringToFile(getApplicationContext(), "FlowerResults.txt" , "Connection : Successful with " + serverIp + " : " + serverPort + " : " + dataslice); + return true; // connection is successful + } catch (Exception e) { + Log.e(TAG, "Failed to connect to the server: " + e.getMessage(), e); + fc.writeStringToFile(getApplicationContext(), "FlowerResults.txt" , "Connection : Failed with " + serverIp + " : " + serverPort + " : " + dataslice); + return false; // connection failed + } + } + + public void loadData() { + try { + fc.loadData(Integer.parseInt(dataslice)); + Log.d("LOAD", "Loading is complete"); + fc.writeStringToFile(getApplicationContext(), "FlowerResults.txt", "Loading Bit Images : Success" ); + } catch (Exception e) { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + e.printStackTrace(pw); + pw.flush(); + Log.d("LOAD_ERROR", "Error occured in Loading"); + fc.writeStringToFile(getApplicationContext(), "FlowerResults.txt", "Loading Bit Images : Failed" ); + } + } + + public CompletableFuture runGrpc() { + + CompletableFuture future = new CompletableFuture<>(); + MyWorker worker = this; + ExecutorService executor = Executors.newSingleThreadExecutor(); + + ProgressUpdater progressUpdater = new ProgressUpdater(); + + executor.execute(new Runnable() { + @Override + public void run() { + try { + CountDownLatch latch = new CountDownLatch(1); + + (new FlowerServiceRunnable()).run(FlowerServiceGrpc.newStub(channel), worker, latch , progressUpdater , getApplicationContext()); + + latch.await(); // Wait for the latch to count down + future.complete(null); // Complete the future when the latch is counted down + + Log.d("GRPC", "inside GRPC"); + } catch (Exception e) { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + e.printStackTrace(pw); + pw.flush(); + Log.e("GRPC", "Failed to connect to the FL server \n" + sw); + future.completeExceptionally(e); // Complete the future with an exception + } + } + }); + + return future; + } + + + @NonNull + private ForegroundInfo createForegroundInfo(@NonNull String progress) { + // Building a notification using bytesRead and contentLength + Context context = getApplicationContext(); + String id = context.getString(R.string.notification_channel_id); + String title = context.getString(R.string.notification_title); + String cancel =context.getString(R.string.cancel_download); + // Creating a PendingIntent that can be used to cancel the worker + PendingIntent intent = WorkManager.getInstance(context) + .createCancelPendingIntent(getId()); + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + createChannel(); + } + Notification notification = new NotificationCompat.Builder(context, id) + .setContentTitle(title) + .setTicker(title) + .setSmallIcon(R.drawable.ic_logo) + .setOngoing(true) + // Add the cancel action to the notification which can + // be used to cancel the worker + .addAction(android.R.drawable.ic_delete, cancel, intent) + .build(); + int notificationId = 1002; + return new ForegroundInfo(notificationId, notification); + } + + @RequiresApi(Build.VERSION_CODES.O) + private void createChannel() { + Context context = getApplicationContext(); + String channelId = context.getString(R.string.notification_channel_id); + String channelName = context.getString(R.string.notification_title); + int importance = NotificationManager.IMPORTANCE_DEFAULT; + + NotificationChannel channel = new NotificationChannel(channelId, channelName, importance); + // Configure the channel + channel.setDescription("Channel description"); + // Set other properties of the channel as needed if needed ... + NotificationManager notificationManager = (NotificationManager) context.getSystemService(Context.NOTIFICATION_SERVICE); + notificationManager.createNotificationChannel(channel); + } + + + public class ProgressUpdater { + public void setProgress() { + // Aim of this class is to allow static FlowerServiceRunnable Object to notifiy Main Activity about the changes in real time to be displayed to User + Log.d("DATA-BACKGROUND","Sending it to the main activity"); + setProgressAsync(new Data.Builder().putInt("progress", 0).build()); + + } + } + + private static class FlowerServiceRunnable{ + protected Throwable failed; + public void run(FlowerServiceStub asyncStub, MyWorker worker , CountDownLatch latch , ProgressUpdater progressUpdater , Context context) { + join(asyncStub, worker , latch , progressUpdater , context); + } + + public void writeStringToFile( Context context , String fileName, String content) { + try { + // Getting the app-specific external storage directory + File directory = context.getExternalFilesDir(null); + + if (directory != null) { + File file = new File(directory, fileName); + + // Checking if the file exists + boolean fileExists = file.exists(); + + // Open a FileWriter in append mode + FileWriter writer = new FileWriter(file, true); + + // If the file exists and is not empty, add a new line + if (fileExists && file.length() > 0) { + writer.append("\n"); + } + + // Write the string to the file + writer.append(content); + + // Close the FileWriter + writer.close(); + } + } catch (IOException e) { + e.printStackTrace(); // Handle the exception as needed + } + } + + private void join(FlowerServiceStub asyncStub, MyWorker worker, CountDownLatch latch , ProgressUpdater progressUpdater , Context context) + throws RuntimeException { + final CountDownLatch finishLatch = new CountDownLatch(1); + + worker.UniversalRequestObserver = asyncStub.join(new StreamObserver() { + @Override + public void onNext(ServerMessage msg) { + handleMessage(msg, worker , progressUpdater , context); + } + + @Override + public void onError(Throwable t) { + t.printStackTrace(); + failed = t; + finishLatch.countDown(); + latch.countDown(); + // Error handling for timeout & other GRPC communication related Errors + workerEndReason = t.getMessage(); + writeStringToFile( context ,"FlowerResults.txt", workerEndReason); + Log.e(TAG, t.getMessage()); + } + + @Override + public void onCompleted() { + finishLatch.countDown(); + latch.countDown(); + Log.e(TAG, "Done"); + } + }); + + + try { + finishLatch.await(); + } catch (InterruptedException e) { + Log.e(TAG, "Interrupted while waiting for gRPC communication to finish: " + e.getMessage(), e); + Thread.currentThread().interrupt(); + } + } + + private void handleMessage(ServerMessage message, MyWorker worker , ProgressUpdater progressUpdater , Context context) { + + try { + ByteBuffer[] weights; + ClientMessage c = null; + + if (message.hasGetParametersIns()) { + Log.e(TAG, "Handling GetParameters"); + + weights = worker.fc.getWeights(); + c = weightsAsProto(weights); + } else if (message.hasFitIns()) { + + SimpleDateFormat sdf = null; + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) { + sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault()); + } + + // Get the current date and time + Date currentDate = new Date(); + + // Format the date and time using the SimpleDateFormat object + // String formattedDate = null; + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) { + start_time = sdf.format(currentDate); + } + Log.e(TAG, "Handling FitIns"); + + List layers = message.getFitIns().getParameters().getTensorsList(); + + Scalar epoch_config = null; + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) { + epoch_config = message.getFitIns().getConfigMap().getOrDefault("local_epochs", Scalar.newBuilder().setSint64(1).build()); + } + + assert epoch_config != null; + int local_epochs = (int) epoch_config.getSint64(); + + // Our model has 10 layers + ByteBuffer[] newWeights = new ByteBuffer[10] ; + for (int i = 0; i < 10; i++) { + newWeights[i] = ByteBuffer.wrap(layers.get(i).toByteArray()); + } + + Pair outputs = worker.fc.fit(newWeights, local_epochs); + currentDate = new Date(); + // Format the date and time using the SimpleDateFormat object + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) { + end_time = sdf.format(currentDate); + } + Log.d("FIT-RESPONSE", "ABOUT TO SEND FIT RESPONSE"); + c = fitResAsProto(outputs.first, outputs.second); + } else if (message.hasEvaluateIns()) { + Log.e(TAG, "Handling EvaluateIns"); + + SimpleDateFormat sdf = null; + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) { + sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.getDefault()); + } + Date currentDate = new Date(); + // Format the date and time using the SimpleDateFormat object + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) { + start_time = sdf.format(currentDate); + } + List layers = message.getEvaluateIns().getParameters().getTensorsList(); + // Our model has 10 layers + ByteBuffer[] newWeights = new ByteBuffer[10] ; + for (int i = 0; i < 10; i++) { + newWeights[i] = ByteBuffer.wrap(layers.get(i).toByteArray()); + } + Pair, Integer> inference = worker.fc.evaluate(newWeights); + float loss = inference.first.first; + float accuracy = inference.first.second; + int test_size = inference.second; + currentDate = new Date(); + // Format the date and time using the SimpleDateFormat object + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.N) { + end_time = sdf.format(currentDate); + } + Log.d("EVALUATE-RESPONSE", "ABOUT TO SEND EVALUATE RESPONSE"); + String newMessage = "Time : " + end_time + " , " + " Round Accuracy : " + String.valueOf(accuracy); + writeStringToFile( context ,"FlowerResults.txt", newMessage); + progressUpdater.setProgress(); + c = evaluateResAsProto(loss , accuracy , test_size); + } + worker.UniversalRequestObserver.onNext(c); + } + catch (Exception e){ + Log.e("Exception","Exception occured in GRPC Connection"); + Log.e(TAG, e.getMessage()); + } + } + } + + private static ClientMessage weightsAsProto(ByteBuffer[] weights){ + List layers = new ArrayList<>(); + for (ByteBuffer weight : weights) { + layers.add(ByteString.copyFrom(weight)); + } + Parameters p = Parameters.newBuilder().addAllTensors(layers).setTensorType("ND").build(); + ClientMessage.GetParametersRes res = ClientMessage.GetParametersRes.newBuilder().setParameters(p).build(); + return ClientMessage.newBuilder().setGetParametersRes(res).build(); + } + + private static ClientMessage fitResAsProto(ByteBuffer[] weights, int training_size){ + List layers = new ArrayList<>(); + for (ByteBuffer weight : weights) { + layers.add(ByteString.copyFrom(weight)); + } + + Log.d("ENDTIME", end_time); + Log.d("STARTTIME", start_time); + + // An example portraying how to upload data to the server via FLower Server side GRPC + Map metrics = new HashMap<>(); + + metrics.put("start_time", Scalar.newBuilder().setString(start_time).build()); + metrics.put("end_time", Scalar.newBuilder().setString(end_time).build()); + Parameters p = Parameters.newBuilder().addAllTensors(layers).setTensorType("ND").build(); + ClientMessage.FitRes res = ClientMessage.FitRes.newBuilder().setParameters(p).setNumExamples(training_size).putAllMetrics(metrics).build(); + return ClientMessage.newBuilder().setFitRes(res).build(); + } + + + + private static ClientMessage evaluateResAsProto(float loss, float accuracy ,int testing_size){ + + // attempting to send accuracy to the server : + Map metrics = new HashMap<>(); + + + Log.d("ENDTIME", end_time); + Log.d("STARTTIME", start_time); + + Log.d("Accuracy", String.valueOf(accuracy)); + Log.d("Loss", String.valueOf(loss)); + + + // An example portraying how to upload data to the server via FLower Server side GRPC + metrics.put("Accuracy", Scalar.newBuilder().setString(String.valueOf(accuracy)).build()); + metrics.put("Loss" , Scalar.newBuilder().setString(String.valueOf(loss)).build()); + metrics.put("start_time", Scalar.newBuilder().setString(start_time).build()); + metrics.put("end_time", Scalar.newBuilder().setString(end_time).build()); + + + ClientMessage.EvaluateRes res = ClientMessage.EvaluateRes.newBuilder().setLoss(loss).setNumExamples(testing_size).putAllMetrics(metrics).build(); + return ClientMessage.newBuilder().setEvaluateRes(res).build(); + } + + +} + + + + + + + diff --git a/examples/android/client/app/src/main/res/layout/activity_main.xml b/examples/android/client/app/src/main/res/layout/activity_main.xml index 543f1eb1cd6..7d98e65823b 100644 --- a/examples/android/client/app/src/main/res/layout/activity_main.xml +++ b/examples/android/client/app/src/main/res/layout/activity_main.xml @@ -1,131 +1,200 @@ - + - + - - - - - + android:layout_margin="8dp" + android:hint="Client Partition ID (1-10)" + android:inputType="numberDecimal" + android:textAppearance="@style/TextAppearance.AppCompat.Medium" + android:textColor="#4a5663" + app:layout_constraintStart_toStartOf="parent" + app:layout_constraintEnd_toEndOf="parent" + app:layout_constraintTop_toTopOf="parent" /> + - + android:textColor="#4a5663" + app:layout_constraintStart_toStartOf="parent" + app:layout_constraintEnd_toEndOf="parent" + app:layout_constraintTop_toBottomOf="@+id/device_id_edit_text" /> + + android:textColor="#4a5663" + app:layout_constraintStart_toStartOf="parent" + app:layout_constraintEnd_toEndOf="parent" + app:layout_constraintTop_toBottomOf="@+id/serverIP" /> + + + +