Skip to content

Commit 61e58fe

Browse files
committed
Add android model tester app
1 parent 0c70b9f commit 61e58fe

File tree

11 files changed

+851
-161
lines changed

11 files changed

+851
-161
lines changed

mobile/examples/model_tester/android/app/src/main/cpp/native-lib.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <string>
77
#include <stdexcept>
88
#include <vector>
9+
#include <optional>
910

1011
#include "model_runner.h"
1112

@@ -67,20 +68,29 @@ auto MakeUniqueJbyteArrayElementsPtr(JNIEnv& env, jbyteArray array) {
6768

6869
extern "C" JNIEXPORT jstring JNICALL
6970
Java_com_onnxruntime_example_modeltester_MainActivity_run(JNIEnv* env, jobject thiz,
70-
jbyteArray java_model_bytes,
71+
jobject model_path_or_bytes,
7172
jint num_iterations,
73+
jboolean run_warmup_iteration,
7274
jstring java_execution_provider_type,
7375
jobjectArray java_execution_provider_option_names,
74-
jobjectArray java_execution_provider_option_values) {
76+
jobjectArray java_execution_provider_option_values,
77+
jint log_level) {
7578
try {
76-
auto model_bytes = util::MakeUniqueJbyteArrayElementsPtr(*env, java_model_bytes);
77-
const size_t model_bytes_length = env->GetArrayLength(java_model_bytes);
78-
auto model_bytes_span = std::span{reinterpret_cast<const std::byte*>(model_bytes.get()),
79-
model_bytes_length};
80-
8179
auto config = model_runner::RunConfig{};
82-
config.model_path_or_bytes = model_bytes_span;
8380
config.num_iterations = num_iterations;
81+
config.run_warmup_iteration = run_warmup_iteration;
82+
83+
// Handle model_path_or_bytes
84+
jclass byte_array_class = env->FindClass("[B");
85+
if (env->IsInstanceOf(model_path_or_bytes, byte_array_class)) {
86+
jbyteArray java_model_bytes = static_cast<jbyteArray>(model_path_or_bytes);
87+
auto model_bytes = util::MakeUniqueJbyteArrayElementsPtr(*env, java_model_bytes);
88+
const size_t model_bytes_length = env->GetArrayLength(java_model_bytes);
89+
config.model_path_or_bytes = std::span{reinterpret_cast<const std::byte*>(model_bytes.get()), model_bytes_length};
90+
} else {
91+
jstring java_model_path = static_cast<jstring>(model_path_or_bytes);
92+
config.model_path_or_bytes = util::JstringToStdString(*env, java_model_path);
93+
}
8494

8595
if (java_execution_provider_type != nullptr) {
8696
config.ep.emplace();
@@ -99,6 +109,13 @@ Java_com_onnxruntime_example_modeltester_MainActivity_run(JNIEnv* env, jobject t
99109
}
100110
}
101111

112+
// If log_level is -1 (sentinel from Java), config.log_level will remain std::nullopt,
113+
// and ONNX Runtime will use its default log level.
114+
// Otherwise, set the log level specified from Java.
115+
if (log_level != -1) {
116+
config.log_level = log_level;
117+
}
118+
102119
auto result = model_runner::Run(config);
103120

104121
auto summary = model_runner::GetRunSummary(config, result);

mobile/examples/model_tester/android/app/src/main/java/com/onnxruntime/example/modeltester/MainActivity.kt

Lines changed: 213 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,243 @@
11
package com.onnxruntime.example.modeltester
22

3+
import android.net.Uri
34
import androidx.appcompat.app.AppCompatActivity
45
import android.os.Bundle
6+
import android.view.View
7+
import android.widget.AdapterView
8+
import android.widget.ArrayAdapter
59
import com.onnxruntime.example.modeltester.databinding.ActivityMainBinding
10+
import androidx.activity.result.contract.ActivityResultContracts
11+
import java.io.File
12+
import java.io.FileOutputStream
13+
import java.io.IOException
14+
import java.util.Locale
615

716
class MainActivity : AppCompatActivity() {
817

918
private lateinit var binding: ActivityMainBinding
19+
private var currentModel: Any = "" // Can be String (path) or ByteArray (default model)
20+
21+
// ActivityResultLauncher for picking a model file
22+
private val pickFileLauncher = registerForActivityResult(ActivityResultContracts.GetContent()) { uri ->
23+
uri?.let { handleSelectedModelFile(it) }
24+
}
1025

1126
override fun onCreate(savedInstanceState: Bundle?) {
1227
super.onCreate(savedInstanceState)
13-
1428
binding = ActivityMainBinding.inflate(layoutInflater)
1529
setContentView(binding.root)
1630

17-
val modelResourceId = R.raw.model
18-
val modelBytes = resources.openRawResource(modelResourceId).readBytes()
31+
loadDefaultModel()
32+
setupUI()
33+
runInitialModel()
34+
}
35+
36+
private fun loadDefaultModel() {
37+
try {
38+
currentModel = resources.openRawResource(R.raw.yolo11n).readBytes()
39+
} catch (e: IOException) {
40+
// Handle error loading default model, e.g., show a toast or log
41+
displayError("Failed to load default model: ${e.message}")
42+
}
43+
}
44+
45+
private fun setupUI() {
46+
binding.browseModelButton.setOnClickListener { pickFileLauncher.launch("*/*") }
47+
binding.runButton.setOnClickListener { runModelFromUI() }
48+
49+
// Setup Spinners with ArrayAdapter if not already done via XML entries
50+
// ArrayAdapter.createFromResource(
51+
// this,
52+
// R.array.ep_array,
53+
// android.R.layout.simple_spinner_item
54+
// ).also { adapter ->
55+
// adapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item)
56+
// binding.epSpinner.adapter = adapter
57+
// }
58+
59+
// ArrayAdapter.createFromResource(
60+
// this,
61+
// R.array.log_level_array,
62+
// android.R.layout.simple_spinner_item
63+
// ).also { adapter ->
64+
// adapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item)
65+
// binding.logLevelSpinner.adapter = adapter
66+
// }
67+
}
68+
69+
private fun handleSelectedModelFile(uri: Uri) {
70+
binding.modelPathEditText.setText(uri.toString())
71+
// Attempt to get a real path or copy to a cache file if it's a content URI
72+
val filePath = getPathFromUri(uri)
73+
if (filePath != null) {
74+
currentModel = filePath
75+
} else {
76+
// Fallback or error handling if path resolution fails
77+
binding.modelPathEditText.error = "Could not resolve file path"
78+
// Optionally, revert to default model or prevent run
79+
}
80+
}
81+
82+
// Helper to attempt to get a real file path from a URI
83+
// This can be complex. For content URIs, copying to a cache file is often most reliable.
84+
private fun getPathFromUri(uri: Uri): String? {
85+
if ("content".equals(uri.scheme, ignoreCase = true)) {
86+
return try {
87+
val inputStream = contentResolver.openInputStream(uri) ?: return null
88+
val tempFile = File(cacheDir, "temp_model_file")
89+
FileOutputStream(tempFile).use { outputStream ->
90+
inputStream.copyTo(outputStream)
91+
}
92+
inputStream.close()
93+
tempFile.absolutePath
94+
} catch (e: IOException) {
95+
e.printStackTrace()
96+
null
97+
}
98+
}
99+
return uri.path // For file URIs or if direct path access is possible
100+
}
19101

20-
val summary = run(modelBytes, 10, null, null, null)
21102

22-
binding.sampleText.text = summary
103+
private fun runModelFromUI() {
104+
val numIterations = binding.iterationsEditText.text.toString().toIntOrNull() ?: 10
105+
val runWarmup = binding.warmupSwitchMaterial.isChecked
106+
107+
val selectedEpString = binding.epSpinner.selectedItem.toString()
108+
val epName = if (selectedEpString.equals("CPU", ignoreCase = true)) {
109+
null
110+
} else {
111+
selectedEpString
112+
}
113+
114+
val logLevel = mapSpinnerPositionToLogLevel(binding.logLevelSpinner.selectedItemPosition)
115+
116+
val modelPathString = binding.modelPathEditText.text.toString()
117+
if (modelPathString.isNotEmpty() && modelPathString != currentModel.toString()) {
118+
// User typed a path directly or it wasn't a content URI initially handled by pickFileLauncher
119+
if (modelPathString.startsWith("content://")) {
120+
val newUri = Uri.parse(modelPathString)
121+
val resolvedPath = getPathFromUri(newUri)
122+
if (resolvedPath != null) {
123+
currentModel = resolvedPath
124+
} else {
125+
binding.modelPathEditText.error = "Invalid model path/URI"
126+
return
127+
}
128+
} else {
129+
currentModel = modelPathString // Assume direct path
130+
}
131+
} else if (modelPathString.isEmpty()) {
132+
loadDefaultModel() // Revert to default if path is cleared
133+
}
134+
// If currentModel is already a ByteArray (default model), it's used directly.
135+
// If it's a String (path), it's used directly.
136+
137+
executeNativeRun(currentModel, numIterations, runWarmup, epName, logLevel)
138+
}
139+
140+
private fun runInitialModel() {
141+
// Use default values for the initial run
142+
val defaultNumIterations = 10
143+
val defaultRunWarmup = true
144+
// For initial run, always use CPU, which means passing null for epName
145+
val defaultEpName: String? = null
146+
val defaultLogLevel = -1 // ORT default
147+
148+
// Ensure default model (ByteArray) is used for initial run
149+
if (currentModel !is ByteArray) {
150+
loadDefaultModel()
151+
}
152+
executeNativeRun(currentModel, defaultNumIterations, defaultRunWarmup, defaultEpName, defaultLogLevel)
153+
}
154+
155+
private fun mapSpinnerPositionToLogLevel(position: Int): Int {
156+
return when (position) {
157+
0 -> -1 // Default (ORT default)
158+
1 -> 0 // Verbose
159+
2 -> 1 // Info
160+
3 -> 2 // Warning
161+
4 -> 3 // Error
162+
5 -> 4 // Fatal
163+
else -> -1 // Default to ORT default
164+
}
165+
}
166+
167+
private fun executeNativeRun(model: Any, numIterations: Int, runWarmup: Boolean, epName: String?, logLevel: Int) {
168+
try {
169+
val summary = run(
170+
model, // This is currentModel (String path or ByteArray)
171+
numIterations,
172+
runWarmup,
173+
epName,
174+
null, // executionProviderOptionNames - not used in this example
175+
null, // executionProviderOptionValues - not used in this example
176+
logLevel
177+
)
178+
parseAndDisplaySummary(summary)
179+
} catch (e: Exception) {
180+
displayError("Native run failed: ${e.message}", e.stackTraceToString())
181+
}
182+
}
183+
184+
private fun parseAndDisplaySummary(summary: String) {
185+
val na = getString(R.string.na)
186+
val loadTimeRegex = "Load time: (\\S+)".toRegex()
187+
val numRunsRegex = "N \\(number of runs\\): (\\d+)".toRegex()
188+
val avgLatencyRegex = "avg: (\\S+)".toRegex()
189+
val p50LatencyRegex = "p50: (\\S+)".toRegex()
190+
val p90LatencyRegex = "p90: (\\S+)".toRegex()
191+
val p99LatencyRegex = "p99: (\\S+)".toRegex()
192+
val minLatencyRegex = "min: (\\S+)".toRegex()
193+
val maxLatencyRegex = "max: (\\S+)".toRegex()
194+
195+
binding.loadTimeTextView.text = getString(R.string.load_time_label, loadTimeRegex.find(summary)?.groupValues?.get(1) ?: na)
196+
binding.numRunsTextView.text = getString(R.string.num_runs_label, numRunsRegex.find(summary)?.groupValues?.get(1) ?: na)
197+
binding.latencyTitleTextView.text = getString(R.string.latency_title_label)
198+
binding.avgLatencyTextView.text = getString(R.string.avg_latency_label, avgLatencyRegex.find(summary)?.groupValues?.get(1) ?: na)
199+
binding.p50LatencyTextView.text = getString(R.string.p50_latency_label, p50LatencyRegex.find(summary)?.groupValues?.get(1) ?: na)
200+
binding.p90LatencyTextView.text = getString(R.string.p90_latency_label, p90LatencyRegex.find(summary)?.groupValues?.get(1) ?: na)
201+
binding.p99LatencyTextView.text = getString(R.string.p99_latency_label, p99LatencyRegex.find(summary)?.groupValues?.get(1) ?: na)
202+
binding.minLatencyTextView.text = getString(R.string.min_latency_label, minLatencyRegex.find(summary)?.groupValues?.get(1) ?: na)
203+
binding.maxLatencyTextView.text = getString(R.string.max_latency_label, maxLatencyRegex.find(summary)?.groupValues?.get(1) ?: na)
204+
205+
binding.rawSummaryText.text = summary
206+
binding.rawSummaryText.visibility = View.GONE // Hide by default, show on error or if explicitly toggled
207+
}
208+
209+
private fun displayError(errorMessage: String, stackTrace: String? = null) {
210+
binding.loadTimeTextView.text = getString(R.string.error_prefix, errorMessage)
211+
binding.numRunsTextView.text = ""
212+
binding.avgLatencyTextView.text = ""
213+
binding.p50LatencyTextView.text = ""
214+
binding.p90LatencyTextView.text = ""
215+
binding.p99LatencyTextView.text = ""
216+
binding.minLatencyTextView.text = ""
217+
binding.maxLatencyTextView.text = ""
218+
binding.latencyTitleTextView.text = getString(R.string.latency_title_label) // Keep title
219+
220+
if (stackTrace != null) {
221+
binding.rawSummaryText.text = getString(R.string.exception_prefix, stackTrace)
222+
binding.rawSummaryText.visibility = View.VISIBLE
223+
} else {
224+
binding.rawSummaryText.text = ""
225+
binding.rawSummaryText.visibility = View.GONE
226+
}
23227
}
24228

25229
/**
26230
* A native method that is implemented by the 'modeltester' native library,
27231
* which is packaged with this application.
28232
*/
29-
external fun run(modelBytes: ByteArray,
233+
external fun run(modelPathOrBytes: Any, // Can be String (path) or ByteArray (bytes)
30234
numIterations: Int,
235+
runWarmupIteration: Boolean,
31236
executionProviderType: String?,
32237
executionProviderOptionNames: Array<String>?,
33238
executionProviderOptionValues: Array<String>?,
34-
): String
239+
logLevel: Int
240+
): String
35241

36242
companion object {
37243
// Used to load the 'modeltester' library on application startup.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
<shape xmlns:android="http://schemas.android.com/apk/res/android"
2+
android:shape="rectangle">
3+
<solid android:color="@color/muted_card_background"/> <!-- Or a slightly different background if preferred -->
4+
<corners android:radius="12dp"/>
5+
<stroke android:width="1dp" android:color="@color/muted_input_border"/>
6+
</shape>

0 commit comments

Comments
 (0)