diff --git a/mobile/examples/phi-3/android/README.md b/mobile/examples/phi-3/android/README.md
index 129a6d961..a0931b775 100644
--- a/mobile/examples/phi-3/android/README.md
+++ b/mobile/examples/phi-3/android/README.md
@@ -58,3 +58,4 @@ Here are some sample example screenshots of the app.
+
diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/BottomSheet.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/BottomSheet.java
new file mode 100644
index 000000000..be12fce3c
--- /dev/null
+++ b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/BottomSheet.java
@@ -0,0 +1,47 @@
+package ai.onnxruntime.genai.demo;
+
+import android.os.Bundle;
+import android.view.LayoutInflater;
+import android.view.View;
+import android.view.ViewGroup;
+import android.widget.Button;
+import android.widget.EditText;
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+import com.google.android.material.bottomsheet.BottomSheetDialogFragment;
+
+public class BottomSheet extends BottomSheetDialogFragment {
+ private EditText maxLengthEditText;
+ private EditText lengthPenaltyEditText;
+ private SettingsListener settingsListener;
+
+ public interface SettingsListener {
+ void onSettingsApplied(int maxLength, float lengthPenalty);
+ }
+
+ public void setSettingsListener(SettingsListener listener) {
+ this.settingsListener = listener;
+ }
+
+ @Nullable
+ @Override
+ public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) {
+ View view = inflater.inflate(R.layout.bottom_sheet, container, false);
+
+ maxLengthEditText = view.findViewById(R.id.idEdtMaxLength);
+ lengthPenaltyEditText = view.findViewById(R.id.idEdtLengthPenalty);
+
+ Button applyButton = view.findViewById(R.id.applySettingsButton);
+
+ applyButton.setOnClickListener(v -> {
+ if (settingsListener != null) {
+ int maxLength = Integer.parseInt(maxLengthEditText.getText().toString());
+ float lengthPenalty = Float.parseFloat(lengthPenaltyEditText.getText().toString());
+ settingsListener.onSettingsApplied(maxLength, lengthPenalty);
+ dismiss();
+ }
+ });
+
+ return view;
+ }
+}
diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java
index 6c2809551..a6efcb3ed 100644
--- a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java
+++ b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java
@@ -2,6 +2,7 @@
import androidx.appcompat.app.AppCompatActivity;
+import android.app.Dialog;
import android.content.Context;
import android.os.Bundle;
import android.text.method.ScrollingMovementMethod;
@@ -9,6 +10,7 @@
import android.util.Pair;
import android.view.View;
import android.view.WindowManager;
+import android.widget.Button;
import android.widget.EditText;
import android.widget.ImageButton;
import android.widget.TextView;
@@ -41,7 +43,10 @@ public class MainActivity extends AppCompatActivity implements Consumer
private TextView generatedTV;
private TextView promptTV;
private TextView progressText;
+ private ImageButton settingsButton;
private static final String TAG = "genai.demo.MainActivity";
+ private int maxLength = 100;
+ private float lengthPenalty = 1.0f;
private static boolean fileExists(Context context, String fileName) {
File file = new File(context.getFilesDir(), fileName);
@@ -55,6 +60,13 @@ protected void onCreate(Bundle savedInstanceState) {
binding = ActivityMainBinding.inflate(getLayoutInflater());
setContentView(binding.getRoot());
+ sendMsgIB = findViewById(R.id.idIBSend);
+ userMsgEdt = findViewById(R.id.idEdtMessage);
+ generatedTV = findViewById(R.id.sample_text);
+ promptTV = findViewById(R.id.user_text);
+ progressText = findViewById(R.id.progress_text);
+ settingsButton = findViewById(R.id.idIBSettings);
+
// Trigger the download operation when the application is created
try {
downloadModels(
@@ -63,10 +75,20 @@ protected void onCreate(Bundle savedInstanceState) {
throw new RuntimeException(e);
}
- sendMsgIB = findViewById(R.id.idIBSend);
- userMsgEdt = findViewById(R.id.idEdtMessage);
- generatedTV = findViewById(R.id.sample_text);
- promptTV = findViewById(R.id.user_text);
+ settingsButton.setOnClickListener(v -> {
+ BottomSheet bottomSheet = new BottomSheet();
+ bottomSheet.setSettingsListener(new BottomSheet.SettingsListener() {
+ @Override
+ public void onSettingsApplied(int maxLength, float lengthPenalty) {
+ MainActivity.this.maxLength = maxLength;
+ MainActivity.this.lengthPenalty = lengthPenalty;
+ Log.i(TAG, "Setting max response length to: " + maxLength);
+ Log.i(TAG, "Setting length penalty to: " + lengthPenalty);
+ }
+ });
+ bottomSheet.show(getSupportFragmentManager(), "BottomSheet");
+ });
+
Consumer tokenListener = this;
@@ -99,6 +121,7 @@ public void onClick(View v) {
// Disable send button while responding to prompt.
sendMsgIB.setEnabled(false);
+ sendMsgIB.setAlpha(0.5f);
promptTV.setText(promptQuestion);
// Clear Edit Text or prompt question.
@@ -117,22 +140,53 @@ public void run() {
generatorParams = model.createGeneratorParams();
//examples for optional parameters to format AI response
- //generatorParams.setSearchOption("length_penalty", 1000);
- //generatorParams.setSearchOption("max_length", 500);
+ // https://onnxruntime.ai/docs/genai/reference/config.html
+ generatorParams.setSearchOption("length_penalty", lengthPenalty);
+ generatorParams.setSearchOption("max_length", maxLength);
encodedPrompt = tokenizer.encode(promptQuestion_formatted);
generatorParams.setInput(encodedPrompt);
generator = new Generator(model, generatorParams);
+ // try to measure average time taken to generate each token.
+ long startTime = System.currentTimeMillis();
+ long firstTokenTime = startTime;
+ long currentTime = startTime;
+ int numTokens = 0;
while (!generator.isDone()) {
generator.computeLogits();
generator.generateNextToken();
int token = generator.getLastTokenInSequence(0);
-
+
+ if (numTokens == 0) { //first token
+ firstTokenTime = System.currentTimeMillis();
+ }
+
tokenListener.accept(stream.decode(token));
+
+
+ Log.i(TAG, "Generated token: " + token + ": " + stream.decode(token));
+ Log.i(TAG, "Time taken to generate token: " + (System.currentTimeMillis() - currentTime)/ 1000.0 + " seconds");
+ currentTime = System.currentTimeMillis();
+ numTokens++;
}
+ long totalTime = System.currentTimeMillis() - firstTokenTime;
+
+ float promptProcessingTime = (firstTokenTime - startTime)/ 1000.0f;
+ float tokensPerSecond = (1000 * (numTokens -1)) / totalTime;
+
+ runOnUiThread(() -> {
+ sendMsgIB.setEnabled(true);
+ sendMsgIB.setAlpha(1.0f);
+
+ // Display the token generation rate in a dialog popup
+ showTokenPopup(promptProcessingTime, tokensPerSecond);
+ });
+
+ Log.i(TAG, "Prompt processing time (first token): " + promptProcessingTime + " seconds");
+ Log.i(TAG, "Tokens generated per second (excluding prompt processing): " + tokensPerSecond);
}
catch (GenAIException e) {
Log.e(TAG, "Exception occurred during model query: " + e.getMessage());
@@ -146,6 +200,7 @@ public void run() {
runOnUiThread(() -> {
sendMsgIB.setEnabled(true);
+ sendMsgIB.setAlpha(1.0f);
});
}
}).start();
@@ -256,4 +311,23 @@ public void setVisibility() {
TextView botView = (TextView) findViewById(R.id.sample_text);
botView.setVisibility(View.VISIBLE);
}
+
+ private void showTokenPopup(float promptProcessingTime, float tokenRate) {
+
+ final Dialog dialog = new Dialog(MainActivity.this);
+ dialog.setContentView(R.layout.info_popup);
+
+ TextView promptProcessingTimeTv = dialog.findViewById(R.id.prompt_processing_time_tv);
+ TextView tokensPerSecondTv = dialog.findViewById(R.id.tokens_per_second_tv);
+ Button closeBtn = dialog.findViewById(R.id.close_btn);
+
+ promptProcessingTimeTv.setText(String.format("Prompt processing time: %.2f seconds", promptProcessingTime));
+ tokensPerSecondTv.setText(String.format("Tokens per second: %.2f", tokenRate));
+
+ closeBtn.setOnClickListener(v -> dialog.dismiss());
+
+ dialog.show();
+ }
+
+
}
diff --git a/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml b/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml
index d863a41d6..2d8ea1b6c 100644
--- a/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml
+++ b/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml
@@ -1,7 +1,7 @@
-
+
+
+ app:layout_constraintVertical_bias="0.0"
+ tools:visibility="visible" />
+ app:layout_constraintVertical_bias="0.0"
+ tools:visibility="visible" />
+
+
+
+
diff --git a/mobile/examples/phi-3/android/app/src/main/res/layout/bottom_sheet.xml b/mobile/examples/phi-3/android/app/src/main/res/layout/bottom_sheet.xml
new file mode 100644
index 000000000..c076ef87a
--- /dev/null
+++ b/mobile/examples/phi-3/android/app/src/main/res/layout/bottom_sheet.xml
@@ -0,0 +1,61 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/mobile/examples/phi-3/android/app/src/main/res/layout/info_popup.xml b/mobile/examples/phi-3/android/app/src/main/res/layout/info_popup.xml
new file mode 100644
index 000000000..220745787
--- /dev/null
+++ b/mobile/examples/phi-3/android/app/src/main/res/layout/info_popup.xml
@@ -0,0 +1,28 @@
+
+
+
+
+
+
+
+
+
diff --git a/mobile/examples/phi-3/android/images/Local_LLM_4.png b/mobile/examples/phi-3/android/images/Local_LLM_4.png
new file mode 100644
index 000000000..7259b59e7
Binary files /dev/null and b/mobile/examples/phi-3/android/images/Local_LLM_4.png differ