From 4eae3f5eba1eb7515d082f779430b81cd29d76ac Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 23 Sep 2024 15:08:29 -0700 Subject: [PATCH 1/6] Update phi-3 Android app layout and add logs --- .../onnxruntime/genai/demo/MainActivity.java | 20 +++++++++++++++---- .../app/src/main/res/layout/activity_main.xml | 12 +++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java index 6c2809551..ac44b5b72 100644 --- a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java +++ b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java @@ -55,6 +55,12 @@ protected void onCreate(Bundle savedInstanceState) { binding = ActivityMainBinding.inflate(getLayoutInflater()); setContentView(binding.getRoot()); + sendMsgIB = findViewById(R.id.idIBSend); + userMsgEdt = findViewById(R.id.idEdtMessage); + generatedTV = findViewById(R.id.sample_text); + promptTV = findViewById(R.id.user_text); + progressText = findViewById(R.id.progress_text); + // Trigger the download operation when the application is created try { downloadModels( @@ -63,10 +69,6 @@ protected void onCreate(Bundle savedInstanceState) { throw new RuntimeException(e); } - sendMsgIB = findViewById(R.id.idIBSend); - userMsgEdt = findViewById(R.id.idEdtMessage); - generatedTV = findViewById(R.id.sample_text); - promptTV = findViewById(R.id.user_text); Consumer tokenListener = this; @@ -99,6 +101,7 @@ public void onClick(View v) { // Disable send button while responding to prompt. sendMsgIB.setEnabled(false); + sendMsgIB.setAlpha(0.5f); promptTV.setText(promptQuestion); // Clear Edit Text or prompt question. @@ -125,6 +128,9 @@ public void run() { generator = new Generator(model, generatorParams); + // try to measure average time taken to generate each token. + long startTime = System.currentTimeMillis(); + int numTokens = 0; while (!generator.isDone()) { generator.computeLogits(); generator.generateNextToken(); @@ -132,7 +138,12 @@ public void run() { int token = generator.getLastTokenInSequence(0); tokenListener.accept(stream.decode(token)); + Log.i(TAG, "Generated token: " + token + ": " + stream.decode(token)); + numTokens++; } + long totalTime = System.currentTimeMillis() - startTime; + Log.i(TAG, "Total time taken to generate + " + numTokens + "tokens: " + totalTime); + Log.i(TAG, "Average time taken to generate each token: " + totalTime / numTokens); } catch (GenAIException e) { Log.e(TAG, "Exception occurred during model query: " + e.getMessage()); @@ -146,6 +157,7 @@ public void run() { runOnUiThread(() -> { sendMsgIB.setEnabled(true); + sendMsgIB.setAlpha(1.0f); }); } }).start(); diff --git a/mobile/examples/phi-3/android/app/src/main/res/layout/activity_main.xml b/mobile/examples/phi-3/android/app/src/main/res/layout/activity_main.xml index 07bfe6530..db507d29a 100644 --- a/mobile/examples/phi-3/android/app/src/main/res/layout/activity_main.xml +++ b/mobile/examples/phi-3/android/app/src/main/res/layout/activity_main.xml @@ -83,4 +83,16 @@ app:layout_constraintEnd_toEndOf="parent" tools:ignore="UseAppTint" /> + + From 4df9ab2dcb586012c0fc350353d5c3e1ba511703 Mon Sep 17 00:00:00 2001 From: vraspar Date: Tue, 24 Sep 2024 13:17:13 -0700 Subject: [PATCH 2/6] Update Android app layout and add settings functionality --- .../onnxruntime/genai/demo/BottomSheet.java | 47 ++++++++++++++++ .../onnxruntime/genai/demo/MainActivity.java | 28 +++++++++- .../src/main/res/drawable/rounded_corner2.xml | 2 +- .../app/src/main/res/layout/activity_main.xml | 40 ++++++++++---- .../app/src/main/res/layout/bottom_sheet.xml | 55 +++++++++++++++++++ 5 files changed, 157 insertions(+), 15 deletions(-) create mode 100644 mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/BottomSheet.java create mode 100644 mobile/examples/phi-3/android/app/src/main/res/layout/bottom_sheet.xml diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/BottomSheet.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/BottomSheet.java new file mode 100644 index 000000000..562f013ee --- /dev/null +++ b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/BottomSheet.java @@ -0,0 +1,47 @@ +package ai.onnxruntime.genai.demo; + +import android.os.Bundle; +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; +import android.widget.Button; +import android.widget.EditText; +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import com.google.android.material.bottomsheet.BottomSheetDialogFragment; + +public class BottomSheet extends BottomSheetDialogFragment { + private EditText maxLengthEditText; + private EditText lengthPenaltyEditText; + private SettingsListener settingsListener; + + public interface SettingsListener { + void onSettingsApplied(int maxLength, int lengthPenalty); + } + + public void setSettingsListener(SettingsListener listener) { + this.settingsListener = listener; + } + + @Nullable + @Override + public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) { + View view = inflater.inflate(R.layout.bottom_sheet, container, false); + + maxLengthEditText = view.findViewById(R.id.idEdtMaxLength); + lengthPenaltyEditText = view.findViewById(R.id.idEdtLengthPenalty); + + Button applyButton = view.findViewById(R.id.applySettingsButton); + + applyButton.setOnClickListener(v -> { + if (settingsListener != null) { + int maxLength = Integer.parseInt(maxLengthEditText.getText().toString()); + int lengthPenalty = Integer.parseInt(lengthPenaltyEditText.getText().toString()); + settingsListener.onSettingsApplied(maxLength, lengthPenalty); + dismiss(); + } + }); + + return view; + } +} diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java index ac44b5b72..a647b1db4 100644 --- a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java +++ b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java @@ -41,7 +41,10 @@ public class MainActivity extends AppCompatActivity implements Consumer private TextView generatedTV; private TextView promptTV; private TextView progressText; + private ImageButton settingsButton; private static final String TAG = "genai.demo.MainActivity"; + private int maxLength = 100; + private int lengthPenalty = 1000; private static boolean fileExists(Context context, String fileName) { File file = new File(context.getFilesDir(), fileName); @@ -60,6 +63,7 @@ protected void onCreate(Bundle savedInstanceState) { generatedTV = findViewById(R.id.sample_text); promptTV = findViewById(R.id.user_text); progressText = findViewById(R.id.progress_text); + settingsButton = findViewById(R.id.idIBSettings); // Trigger the download operation when the application is created try { @@ -69,6 +73,20 @@ protected void onCreate(Bundle savedInstanceState) { throw new RuntimeException(e); } + settingsButton.setOnClickListener(v -> { + BottomSheet bottomSheet = new BottomSheet(); + bottomSheet.setSettingsListener(new BottomSheet.SettingsListener() { + @Override + public void onSettingsApplied(int maxLength, int lengthPenalty) { + MainActivity.this.maxLength = maxLength; + MainActivity.this.lengthPenalty = lengthPenalty; + Log.i(TAG, "Max Response length: " + maxLength); + Log.i(TAG, "Length penalty: " + lengthPenalty); + } + }); + bottomSheet.show(getSupportFragmentManager(), "BottomSheet"); + }); + Consumer tokenListener = this; @@ -120,8 +138,12 @@ public void run() { generatorParams = model.createGeneratorParams(); //examples for optional parameters to format AI response - //generatorParams.setSearchOption("length_penalty", 1000); - //generatorParams.setSearchOption("max_length", 500); + generatorParams.setSearchOption("length_penalty", lengthPenalty); + generatorParams.setSearchOption("max_length", maxLength); + + Log.i(TAG, "Length penalty: " + lengthPenalty); + Log.i(TAG, "Max Response length: " + maxLength); + encodedPrompt = tokenizer.encode(promptQuestion_formatted); generatorParams.setInput(encodedPrompt); @@ -143,7 +165,7 @@ public void run() { } long totalTime = System.currentTimeMillis() - startTime; Log.i(TAG, "Total time taken to generate + " + numTokens + "tokens: " + totalTime); - Log.i(TAG, "Average time taken to generate each token: " + totalTime / numTokens); + Log.i(TAG, "Tokens generated per second: " + 1000 * (numTokens / totalTime)); } catch (GenAIException e) { Log.e(TAG, "Exception occurred during model query: " + e.getMessage()); diff --git a/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml b/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml index d863a41d6..2d8ea1b6c 100644 --- a/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml +++ b/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml @@ -1,7 +1,7 @@ - + + + app:layout_constraintVertical_bias="0.0" + tools:visibility="visible" /> + app:layout_constraintVertical_bias="0.0" + tools:visibility="visible" /> + + + + + + + + + + + + + + + +