2
2
3
3
import androidx .appcompat .app .AppCompatActivity ;
4
4
5
+ import android .app .Dialog ;
5
6
import android .content .Context ;
6
7
import android .os .Bundle ;
7
8
import android .text .method .ScrollingMovementMethod ;
8
9
import android .util .Log ;
9
10
import android .util .Pair ;
10
11
import android .view .View ;
11
12
import android .view .WindowManager ;
13
+ import android .widget .Button ;
12
14
import android .widget .EditText ;
13
15
import android .widget .ImageButton ;
14
16
import android .widget .TextView ;
@@ -41,7 +43,10 @@ public class MainActivity extends AppCompatActivity implements Consumer<String>
41
43
private TextView generatedTV ;
42
44
private TextView promptTV ;
43
45
private TextView progressText ;
46
+ private ImageButton settingsButton ;
44
47
private static final String TAG = "genai.demo.MainActivity" ;
48
+ private int maxLength = 100 ;
49
+ private float lengthPenalty = 1.0f ;
45
50
46
51
private static boolean fileExists (Context context , String fileName ) {
47
52
File file = new File (context .getFilesDir (), fileName );
@@ -55,6 +60,13 @@ protected void onCreate(Bundle savedInstanceState) {
55
60
binding = ActivityMainBinding .inflate (getLayoutInflater ());
56
61
setContentView (binding .getRoot ());
57
62
63
+ sendMsgIB = findViewById (R .id .idIBSend );
64
+ userMsgEdt = findViewById (R .id .idEdtMessage );
65
+ generatedTV = findViewById (R .id .sample_text );
66
+ promptTV = findViewById (R .id .user_text );
67
+ progressText = findViewById (R .id .progress_text );
68
+ settingsButton = findViewById (R .id .idIBSettings );
69
+
58
70
// Trigger the download operation when the application is created
59
71
try {
60
72
downloadModels (
@@ -63,10 +75,20 @@ protected void onCreate(Bundle savedInstanceState) {
63
75
throw new RuntimeException (e );
64
76
}
65
77
66
- sendMsgIB = findViewById (R .id .idIBSend );
67
- userMsgEdt = findViewById (R .id .idEdtMessage );
68
- generatedTV = findViewById (R .id .sample_text );
69
- promptTV = findViewById (R .id .user_text );
78
+ settingsButton .setOnClickListener (v -> {
79
+ BottomSheet bottomSheet = new BottomSheet ();
80
+ bottomSheet .setSettingsListener (new BottomSheet .SettingsListener () {
81
+ @ Override
82
+ public void onSettingsApplied (int maxLength , float lengthPenalty ) {
83
+ MainActivity .this .maxLength = maxLength ;
84
+ MainActivity .this .lengthPenalty = lengthPenalty ;
85
+ Log .i (TAG , "Setting max response length to: " + maxLength );
86
+ Log .i (TAG , "Setting length penalty to: " + lengthPenalty );
87
+ }
88
+ });
89
+ bottomSheet .show (getSupportFragmentManager (), "BottomSheet" );
90
+ });
91
+
70
92
71
93
Consumer <String > tokenListener = this ;
72
94
@@ -99,6 +121,7 @@ public void onClick(View v) {
99
121
100
122
// Disable send button while responding to prompt.
101
123
sendMsgIB .setEnabled (false );
124
+ sendMsgIB .setAlpha (0.5f );
102
125
103
126
promptTV .setText (promptQuestion );
104
127
// Clear Edit Text or prompt question.
@@ -117,22 +140,53 @@ public void run() {
117
140
118
141
generatorParams = model .createGeneratorParams ();
119
142
//examples for optional parameters to format AI response
120
- //generatorParams.setSearchOption("length_penalty", 1000);
121
- //generatorParams.setSearchOption("max_length", 500);
143
+ // https://onnxruntime.ai/docs/genai/reference/config.html
144
+ generatorParams .setSearchOption ("length_penalty" , lengthPenalty );
145
+ generatorParams .setSearchOption ("max_length" , maxLength );
122
146
123
147
encodedPrompt = tokenizer .encode (promptQuestion_formatted );
124
148
generatorParams .setInput (encodedPrompt );
125
149
126
150
generator = new Generator (model , generatorParams );
127
151
152
+ // try to measure average time taken to generate each token.
153
+ long startTime = System .currentTimeMillis ();
154
+ long firstTokenTime = startTime ;
155
+ long currentTime = startTime ;
156
+ int numTokens = 0 ;
128
157
while (!generator .isDone ()) {
129
158
generator .computeLogits ();
130
159
generator .generateNextToken ();
131
160
132
161
int token = generator .getLastTokenInSequence (0 );
133
-
162
+
163
+ if (numTokens == 0 ) { //first token
164
+ firstTokenTime = System .currentTimeMillis ();
165
+ }
166
+
134
167
tokenListener .accept (stream .decode (token ));
168
+
169
+
170
+ Log .i (TAG , "Generated token: " + token + ": " + stream .decode (token ));
171
+ Log .i (TAG , "Time taken to generate token: " + (System .currentTimeMillis () - currentTime )/ 1000.0 + " seconds" );
172
+ currentTime = System .currentTimeMillis ();
173
+ numTokens ++;
135
174
}
175
+ long totalTime = System .currentTimeMillis () - firstTokenTime ;
176
+
177
+ float promptProcessingTime = (firstTokenTime - startTime )/ 1000.0f ;
178
+ float tokensPerSecond = (1000 * (numTokens -1 )) / totalTime ;
179
+
180
+ runOnUiThread (() -> {
181
+ sendMsgIB .setEnabled (true );
182
+ sendMsgIB .setAlpha (1.0f );
183
+
184
+ // Display the token generation rate in a dialog popup
185
+ showTokenPopup (promptProcessingTime , tokensPerSecond );
186
+ });
187
+
188
+ Log .i (TAG , "Prompt processing time (first token): " + promptProcessingTime + " seconds" );
189
+ Log .i (TAG , "Tokens generated per second (excluding prompt processing): " + tokensPerSecond );
136
190
}
137
191
catch (GenAIException e ) {
138
192
Log .e (TAG , "Exception occurred during model query: " + e .getMessage ());
@@ -146,6 +200,7 @@ public void run() {
146
200
147
201
runOnUiThread (() -> {
148
202
sendMsgIB .setEnabled (true );
203
+ sendMsgIB .setAlpha (1.0f );
149
204
});
150
205
}
151
206
}).start ();
@@ -256,4 +311,23 @@ public void setVisibility() {
256
311
TextView botView = (TextView ) findViewById (R .id .sample_text );
257
312
botView .setVisibility (View .VISIBLE );
258
313
}
314
+
315
+ private void showTokenPopup (float promptProcessingTime , float tokenRate ) {
316
+
317
+ final Dialog dialog = new Dialog (MainActivity .this );
318
+ dialog .setContentView (R .layout .info_popup );
319
+
320
+ TextView promptProcessingTimeTv = dialog .findViewById (R .id .prompt_processing_time_tv );
321
+ TextView tokensPerSecondTv = dialog .findViewById (R .id .tokens_per_second_tv );
322
+ Button closeBtn = dialog .findViewById (R .id .close_btn );
323
+
324
+ promptProcessingTimeTv .setText (String .format ("Prompt processing time: %.2f seconds" , promptProcessingTime ));
325
+ tokensPerSecondTv .setText (String .format ("Tokens per second: %.2f" , tokenRate ));
326
+
327
+ closeBtn .setOnClickListener (v -> dialog .dismiss ());
328
+
329
+ dialog .show ();
330
+ }
331
+
332
+
259
333
}
0 commit comments