|
1 | 1 | package com.onnxruntime.example.modeltester
|
2 | 2 |
|
| 3 | +import android.net.Uri |
3 | 4 | import androidx.appcompat.app.AppCompatActivity
|
4 | 5 | import android.os.Bundle
|
| 6 | +import android.view.View |
| 7 | +import android.widget.AdapterView |
| 8 | +import android.widget.ArrayAdapter |
5 | 9 | import com.onnxruntime.example.modeltester.databinding.ActivityMainBinding
|
| 10 | +import androidx.activity.result.contract.ActivityResultContracts |
| 11 | +import java.io.File |
| 12 | +import java.io.FileOutputStream |
| 13 | +import java.io.IOException |
| 14 | +import java.util.Locale |
6 | 15 |
|
7 | 16 | class MainActivity : AppCompatActivity() {
|
8 | 17 |
|
9 | 18 | private lateinit var binding: ActivityMainBinding
|
| 19 | + private var currentModel: Any = "" // Can be String (path) or ByteArray (default model) |
| 20 | + |
| 21 | + // ActivityResultLauncher for picking a model file |
| 22 | + private val pickFileLauncher = registerForActivityResult(ActivityResultContracts.GetContent()) { uri -> |
| 23 | + uri?.let { handleSelectedModelFile(it) } |
| 24 | + } |
10 | 25 |
|
11 | 26 | override fun onCreate(savedInstanceState: Bundle?) {
|
12 | 27 | super.onCreate(savedInstanceState)
|
13 |
| - |
14 | 28 | binding = ActivityMainBinding.inflate(layoutInflater)
|
15 | 29 | setContentView(binding.root)
|
16 | 30 |
|
17 |
| - val modelResourceId = R.raw.model |
18 |
| - val modelBytes = resources.openRawResource(modelResourceId).readBytes() |
| 31 | + loadDefaultModel() |
| 32 | + setupUI() |
| 33 | + runInitialModel() |
| 34 | + } |
| 35 | + |
| 36 | + private fun loadDefaultModel() { |
| 37 | + try { |
| 38 | + currentModel = resources.openRawResource(R.raw.yolo11n).readBytes() |
| 39 | + } catch (e: IOException) { |
| 40 | + // Handle error loading default model, e.g., show a toast or log |
| 41 | + displayError("Failed to load default model: ${e.message}") |
| 42 | + } |
| 43 | + } |
| 44 | + |
| 45 | + private fun setupUI() { |
| 46 | + binding.browseModelButton.setOnClickListener { pickFileLauncher.launch("*/*") } |
| 47 | + binding.runButton.setOnClickListener { runModelFromUI() } |
| 48 | + |
| 49 | + // Setup Spinners with ArrayAdapter if not already done via XML entries |
| 50 | + // ArrayAdapter.createFromResource( |
| 51 | + // this, |
| 52 | + // R.array.ep_array, |
| 53 | + // android.R.layout.simple_spinner_item |
| 54 | + // ).also { adapter -> |
| 55 | + // adapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item) |
| 56 | + // binding.epSpinner.adapter = adapter |
| 57 | + // } |
| 58 | + |
| 59 | + // ArrayAdapter.createFromResource( |
| 60 | + // this, |
| 61 | + // R.array.log_level_array, |
| 62 | + // android.R.layout.simple_spinner_item |
| 63 | + // ).also { adapter -> |
| 64 | + // adapter.setDropDownViewResource(android.R.layout.simple_spinner_dropdown_item) |
| 65 | + // binding.logLevelSpinner.adapter = adapter |
| 66 | + // } |
| 67 | + } |
| 68 | + |
| 69 | + private fun handleSelectedModelFile(uri: Uri) { |
| 70 | + binding.modelPathEditText.setText(uri.toString()) |
| 71 | + // Attempt to get a real path or copy to a cache file if it's a content URI |
| 72 | + val filePath = getPathFromUri(uri) |
| 73 | + if (filePath != null) { |
| 74 | + currentModel = filePath |
| 75 | + } else { |
| 76 | + // Fallback or error handling if path resolution fails |
| 77 | + binding.modelPathEditText.error = "Could not resolve file path" |
| 78 | + // Optionally, revert to default model or prevent run |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + // Helper to attempt to get a real file path from a URI |
| 83 | + // This can be complex. For content URIs, copying to a cache file is often most reliable. |
| 84 | + private fun getPathFromUri(uri: Uri): String? { |
| 85 | + if ("content".equals(uri.scheme, ignoreCase = true)) { |
| 86 | + return try { |
| 87 | + val inputStream = contentResolver.openInputStream(uri) ?: return null |
| 88 | + val tempFile = File(cacheDir, "temp_model_file") |
| 89 | + FileOutputStream(tempFile).use { outputStream -> |
| 90 | + inputStream.copyTo(outputStream) |
| 91 | + } |
| 92 | + inputStream.close() |
| 93 | + tempFile.absolutePath |
| 94 | + } catch (e: IOException) { |
| 95 | + e.printStackTrace() |
| 96 | + null |
| 97 | + } |
| 98 | + } |
| 99 | + return uri.path // For file URIs or if direct path access is possible |
| 100 | + } |
19 | 101 |
|
20 |
| - val summary = run(modelBytes, 10, null, null, null) |
21 | 102 |
|
22 |
| - binding.sampleText.text = summary |
| 103 | + private fun runModelFromUI() { |
| 104 | + val numIterations = binding.iterationsEditText.text.toString().toIntOrNull() ?: 10 |
| 105 | + val runWarmup = binding.warmupSwitchMaterial.isChecked |
| 106 | + |
| 107 | + val selectedEpString = binding.epSpinner.selectedItem.toString() |
| 108 | + val epName = if (selectedEpString.equals("CPU", ignoreCase = true)) { |
| 109 | + null |
| 110 | + } else { |
| 111 | + selectedEpString |
| 112 | + } |
| 113 | + |
| 114 | + val logLevel = mapSpinnerPositionToLogLevel(binding.logLevelSpinner.selectedItemPosition) |
| 115 | + |
| 116 | + val modelPathString = binding.modelPathEditText.text.toString() |
| 117 | + if (modelPathString.isNotEmpty() && modelPathString != currentModel.toString()) { |
| 118 | + // User typed a path directly or it wasn't a content URI initially handled by pickFileLauncher |
| 119 | + if (modelPathString.startsWith("content://")) { |
| 120 | + val newUri = Uri.parse(modelPathString) |
| 121 | + val resolvedPath = getPathFromUri(newUri) |
| 122 | + if (resolvedPath != null) { |
| 123 | + currentModel = resolvedPath |
| 124 | + } else { |
| 125 | + binding.modelPathEditText.error = "Invalid model path/URI" |
| 126 | + return |
| 127 | + } |
| 128 | + } else { |
| 129 | + currentModel = modelPathString // Assume direct path |
| 130 | + } |
| 131 | + } else if (modelPathString.isEmpty()) { |
| 132 | + loadDefaultModel() // Revert to default if path is cleared |
| 133 | + } |
| 134 | + // If currentModel is already a ByteArray (default model), it's used directly. |
| 135 | + // If it's a String (path), it's used directly. |
| 136 | + |
| 137 | + executeNativeRun(currentModel, numIterations, runWarmup, epName, logLevel) |
| 138 | + } |
| 139 | + |
| 140 | + private fun runInitialModel() { |
| 141 | + // Use default values for the initial run |
| 142 | + val defaultNumIterations = 10 |
| 143 | + val defaultRunWarmup = true |
| 144 | + // For initial run, always use CPU, which means passing null for epName |
| 145 | + val defaultEpName: String? = null |
| 146 | + val defaultLogLevel = -1 // ORT default |
| 147 | + |
| 148 | + // Ensure default model (ByteArray) is used for initial run |
| 149 | + if (currentModel !is ByteArray) { |
| 150 | + loadDefaultModel() |
| 151 | + } |
| 152 | + executeNativeRun(currentModel, defaultNumIterations, defaultRunWarmup, defaultEpName, defaultLogLevel) |
| 153 | + } |
| 154 | + |
| 155 | + private fun mapSpinnerPositionToLogLevel(position: Int): Int { |
| 156 | + return when (position) { |
| 157 | + 0 -> -1 // Default (ORT default) |
| 158 | + 1 -> 0 // Verbose |
| 159 | + 2 -> 1 // Info |
| 160 | + 3 -> 2 // Warning |
| 161 | + 4 -> 3 // Error |
| 162 | + 5 -> 4 // Fatal |
| 163 | + else -> -1 // Default to ORT default |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + private fun executeNativeRun(model: Any, numIterations: Int, runWarmup: Boolean, epName: String?, logLevel: Int) { |
| 168 | + try { |
| 169 | + val summary = run( |
| 170 | + model, // This is currentModel (String path or ByteArray) |
| 171 | + numIterations, |
| 172 | + runWarmup, |
| 173 | + epName, |
| 174 | + null, // executionProviderOptionNames - not used in this example |
| 175 | + null, // executionProviderOptionValues - not used in this example |
| 176 | + logLevel |
| 177 | + ) |
| 178 | + parseAndDisplaySummary(summary) |
| 179 | + } catch (e: Exception) { |
| 180 | + displayError("Native run failed: ${e.message}", e.stackTraceToString()) |
| 181 | + } |
| 182 | + } |
| 183 | + |
| 184 | + private fun parseAndDisplaySummary(summary: String) { |
| 185 | + val na = getString(R.string.na) |
| 186 | + val loadTimeRegex = "Load time: (\\S+)".toRegex() |
| 187 | + val numRunsRegex = "N \\(number of runs\\): (\\d+)".toRegex() |
| 188 | + val avgLatencyRegex = "avg: (\\S+)".toRegex() |
| 189 | + val p50LatencyRegex = "p50: (\\S+)".toRegex() |
| 190 | + val p90LatencyRegex = "p90: (\\S+)".toRegex() |
| 191 | + val p99LatencyRegex = "p99: (\\S+)".toRegex() |
| 192 | + val minLatencyRegex = "min: (\\S+)".toRegex() |
| 193 | + val maxLatencyRegex = "max: (\\S+)".toRegex() |
| 194 | + |
| 195 | + binding.loadTimeTextView.text = getString(R.string.load_time_label, loadTimeRegex.find(summary)?.groupValues?.get(1) ?: na) |
| 196 | + binding.numRunsTextView.text = getString(R.string.num_runs_label, numRunsRegex.find(summary)?.groupValues?.get(1) ?: na) |
| 197 | + binding.latencyTitleTextView.text = getString(R.string.latency_title_label) |
| 198 | + binding.avgLatencyTextView.text = getString(R.string.avg_latency_label, avgLatencyRegex.find(summary)?.groupValues?.get(1) ?: na) |
| 199 | + binding.p50LatencyTextView.text = getString(R.string.p50_latency_label, p50LatencyRegex.find(summary)?.groupValues?.get(1) ?: na) |
| 200 | + binding.p90LatencyTextView.text = getString(R.string.p90_latency_label, p90LatencyRegex.find(summary)?.groupValues?.get(1) ?: na) |
| 201 | + binding.p99LatencyTextView.text = getString(R.string.p99_latency_label, p99LatencyRegex.find(summary)?.groupValues?.get(1) ?: na) |
| 202 | + binding.minLatencyTextView.text = getString(R.string.min_latency_label, minLatencyRegex.find(summary)?.groupValues?.get(1) ?: na) |
| 203 | + binding.maxLatencyTextView.text = getString(R.string.max_latency_label, maxLatencyRegex.find(summary)?.groupValues?.get(1) ?: na) |
| 204 | + |
| 205 | + binding.rawSummaryText.text = summary |
| 206 | + binding.rawSummaryText.visibility = View.GONE // Hide by default, show on error or if explicitly toggled |
| 207 | + } |
| 208 | + |
| 209 | + private fun displayError(errorMessage: String, stackTrace: String? = null) { |
| 210 | + binding.loadTimeTextView.text = getString(R.string.error_prefix, errorMessage) |
| 211 | + binding.numRunsTextView.text = "" |
| 212 | + binding.avgLatencyTextView.text = "" |
| 213 | + binding.p50LatencyTextView.text = "" |
| 214 | + binding.p90LatencyTextView.text = "" |
| 215 | + binding.p99LatencyTextView.text = "" |
| 216 | + binding.minLatencyTextView.text = "" |
| 217 | + binding.maxLatencyTextView.text = "" |
| 218 | + binding.latencyTitleTextView.text = getString(R.string.latency_title_label) // Keep title |
| 219 | + |
| 220 | + if (stackTrace != null) { |
| 221 | + binding.rawSummaryText.text = getString(R.string.exception_prefix, stackTrace) |
| 222 | + binding.rawSummaryText.visibility = View.VISIBLE |
| 223 | + } else { |
| 224 | + binding.rawSummaryText.text = "" |
| 225 | + binding.rawSummaryText.visibility = View.GONE |
| 226 | + } |
23 | 227 | }
|
24 | 228 |
|
25 | 229 | /**
|
26 | 230 | * A native method that is implemented by the 'modeltester' native library,
|
27 | 231 | * which is packaged with this application.
|
28 | 232 | */
|
29 |
| - external fun run(modelBytes: ByteArray, |
| 233 | + external fun run(modelPathOrBytes: Any, // Can be String (path) or ByteArray (bytes) |
30 | 234 | numIterations: Int,
|
| 235 | + runWarmupIteration: Boolean, |
31 | 236 | executionProviderType: String?,
|
32 | 237 | executionProviderOptionNames: Array<String>?,
|
33 | 238 | executionProviderOptionValues: Array<String>?,
|
34 |
| - ): String |
| 239 | + logLevel: Int |
| 240 | + ): String |
35 | 241 |
|
36 | 242 | companion object {
|
37 | 243 | // Used to load the 'modeltester' library on application startup.
|
|
0 commit comments