diff --git a/mobile/examples/phi-3/android/README.md b/mobile/examples/phi-3/android/README.md index 129a6d961..a0931b775 100644 --- a/mobile/examples/phi-3/android/README.md +++ b/mobile/examples/phi-3/android/README.md @@ -58,3 +58,4 @@ Here are some sample example screenshots of the app. App Screenshot 3 +App Screenshot 3 diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/BottomSheet.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/BottomSheet.java new file mode 100644 index 000000000..be12fce3c --- /dev/null +++ b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/BottomSheet.java @@ -0,0 +1,47 @@ +package ai.onnxruntime.genai.demo; + +import android.os.Bundle; +import android.view.LayoutInflater; +import android.view.View; +import android.view.ViewGroup; +import android.widget.Button; +import android.widget.EditText; +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import com.google.android.material.bottomsheet.BottomSheetDialogFragment; + +public class BottomSheet extends BottomSheetDialogFragment { + private EditText maxLengthEditText; + private EditText lengthPenaltyEditText; + private SettingsListener settingsListener; + + public interface SettingsListener { + void onSettingsApplied(int maxLength, float lengthPenalty); + } + + public void setSettingsListener(SettingsListener listener) { + this.settingsListener = listener; + } + + @Nullable + @Override + public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) { + View view = inflater.inflate(R.layout.bottom_sheet, container, false); + + maxLengthEditText = view.findViewById(R.id.idEdtMaxLength); + lengthPenaltyEditText = view.findViewById(R.id.idEdtLengthPenalty); + + Button applyButton = view.findViewById(R.id.applySettingsButton); + + applyButton.setOnClickListener(v -> { + if (settingsListener != null) { + int maxLength = Integer.parseInt(maxLengthEditText.getText().toString()); + float lengthPenalty = Float.parseFloat(lengthPenaltyEditText.getText().toString()); + settingsListener.onSettingsApplied(maxLength, lengthPenalty); + dismiss(); + } + }); + + return view; + } +} diff --git a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java index 6c2809551..a6efcb3ed 100644 --- a/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java +++ b/mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java @@ -2,6 +2,7 @@ import androidx.appcompat.app.AppCompatActivity; +import android.app.Dialog; import android.content.Context; import android.os.Bundle; import android.text.method.ScrollingMovementMethod; @@ -9,6 +10,7 @@ import android.util.Pair; import android.view.View; import android.view.WindowManager; +import android.widget.Button; import android.widget.EditText; import android.widget.ImageButton; import android.widget.TextView; @@ -41,7 +43,10 @@ public class MainActivity extends AppCompatActivity implements Consumer private TextView generatedTV; private TextView promptTV; private TextView progressText; + private ImageButton settingsButton; private static final String TAG = "genai.demo.MainActivity"; + private int maxLength = 100; + private float lengthPenalty = 1.0f; private static boolean fileExists(Context context, String fileName) { File file = new File(context.getFilesDir(), fileName); @@ -55,6 +60,13 @@ protected void onCreate(Bundle savedInstanceState) { binding = ActivityMainBinding.inflate(getLayoutInflater()); setContentView(binding.getRoot()); + sendMsgIB = findViewById(R.id.idIBSend); + userMsgEdt = findViewById(R.id.idEdtMessage); + generatedTV = findViewById(R.id.sample_text); + promptTV = findViewById(R.id.user_text); + progressText = findViewById(R.id.progress_text); + settingsButton = findViewById(R.id.idIBSettings); + // Trigger the download operation when the application is created try { downloadModels( @@ -63,10 +75,20 @@ protected void onCreate(Bundle savedInstanceState) { throw new RuntimeException(e); } - sendMsgIB = findViewById(R.id.idIBSend); - userMsgEdt = findViewById(R.id.idEdtMessage); - generatedTV = findViewById(R.id.sample_text); - promptTV = findViewById(R.id.user_text); + settingsButton.setOnClickListener(v -> { + BottomSheet bottomSheet = new BottomSheet(); + bottomSheet.setSettingsListener(new BottomSheet.SettingsListener() { + @Override + public void onSettingsApplied(int maxLength, float lengthPenalty) { + MainActivity.this.maxLength = maxLength; + MainActivity.this.lengthPenalty = lengthPenalty; + Log.i(TAG, "Setting max response length to: " + maxLength); + Log.i(TAG, "Setting length penalty to: " + lengthPenalty); + } + }); + bottomSheet.show(getSupportFragmentManager(), "BottomSheet"); + }); + Consumer tokenListener = this; @@ -99,6 +121,7 @@ public void onClick(View v) { // Disable send button while responding to prompt. sendMsgIB.setEnabled(false); + sendMsgIB.setAlpha(0.5f); promptTV.setText(promptQuestion); // Clear Edit Text or prompt question. @@ -117,22 +140,53 @@ public void run() { generatorParams = model.createGeneratorParams(); //examples for optional parameters to format AI response - //generatorParams.setSearchOption("length_penalty", 1000); - //generatorParams.setSearchOption("max_length", 500); + // https://onnxruntime.ai/docs/genai/reference/config.html + generatorParams.setSearchOption("length_penalty", lengthPenalty); + generatorParams.setSearchOption("max_length", maxLength); encodedPrompt = tokenizer.encode(promptQuestion_formatted); generatorParams.setInput(encodedPrompt); generator = new Generator(model, generatorParams); + // try to measure average time taken to generate each token. + long startTime = System.currentTimeMillis(); + long firstTokenTime = startTime; + long currentTime = startTime; + int numTokens = 0; while (!generator.isDone()) { generator.computeLogits(); generator.generateNextToken(); int token = generator.getLastTokenInSequence(0); - + + if (numTokens == 0) { //first token + firstTokenTime = System.currentTimeMillis(); + } + tokenListener.accept(stream.decode(token)); + + + Log.i(TAG, "Generated token: " + token + ": " + stream.decode(token)); + Log.i(TAG, "Time taken to generate token: " + (System.currentTimeMillis() - currentTime)/ 1000.0 + " seconds"); + currentTime = System.currentTimeMillis(); + numTokens++; } + long totalTime = System.currentTimeMillis() - firstTokenTime; + + float promptProcessingTime = (firstTokenTime - startTime)/ 1000.0f; + float tokensPerSecond = (1000 * (numTokens -1)) / totalTime; + + runOnUiThread(() -> { + sendMsgIB.setEnabled(true); + sendMsgIB.setAlpha(1.0f); + + // Display the token generation rate in a dialog popup + showTokenPopup(promptProcessingTime, tokensPerSecond); + }); + + Log.i(TAG, "Prompt processing time (first token): " + promptProcessingTime + " seconds"); + Log.i(TAG, "Tokens generated per second (excluding prompt processing): " + tokensPerSecond); } catch (GenAIException e) { Log.e(TAG, "Exception occurred during model query: " + e.getMessage()); @@ -146,6 +200,7 @@ public void run() { runOnUiThread(() -> { sendMsgIB.setEnabled(true); + sendMsgIB.setAlpha(1.0f); }); } }).start(); @@ -256,4 +311,23 @@ public void setVisibility() { TextView botView = (TextView) findViewById(R.id.sample_text); botView.setVisibility(View.VISIBLE); } + + private void showTokenPopup(float promptProcessingTime, float tokenRate) { + + final Dialog dialog = new Dialog(MainActivity.this); + dialog.setContentView(R.layout.info_popup); + + TextView promptProcessingTimeTv = dialog.findViewById(R.id.prompt_processing_time_tv); + TextView tokensPerSecondTv = dialog.findViewById(R.id.tokens_per_second_tv); + Button closeBtn = dialog.findViewById(R.id.close_btn); + + promptProcessingTimeTv.setText(String.format("Prompt processing time: %.2f seconds", promptProcessingTime)); + tokensPerSecondTv.setText(String.format("Tokens per second: %.2f", tokenRate)); + + closeBtn.setOnClickListener(v -> dialog.dismiss()); + + dialog.show(); + } + + } diff --git a/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml b/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml index d863a41d6..2d8ea1b6c 100644 --- a/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml +++ b/mobile/examples/phi-3/android/app/src/main/res/drawable/rounded_corner2.xml @@ -1,7 +1,7 @@ - + + + app:layout_constraintVertical_bias="0.0" + tools:visibility="visible" /> + app:layout_constraintVertical_bias="0.0" + tools:visibility="visible" /> + + + + diff --git a/mobile/examples/phi-3/android/app/src/main/res/layout/bottom_sheet.xml b/mobile/examples/phi-3/android/app/src/main/res/layout/bottom_sheet.xml new file mode 100644 index 000000000..c076ef87a --- /dev/null +++ b/mobile/examples/phi-3/android/app/src/main/res/layout/bottom_sheet.xml @@ -0,0 +1,61 @@ + + + + + + + + + + + + + + +