|
| 1 | +import numpy as np |
| 2 | +import onnx |
| 3 | +import onnxruntime as ort |
| 4 | +import tvm |
| 5 | +from tvm import relay |
| 6 | +from skl2onnx import convert_sklearn |
| 7 | +from skl2onnx.common.data_types import FloatTensorType |
| 8 | +from sklearn.ensemble import RandomForestClassifier |
| 9 | +from time import time |
| 10 | + |
| 11 | +# Generate sample data |
| 12 | +X_train = np.random.rand(100, 10).astype(np.float32) |
| 13 | +y_train = np.random.randint(0, 2, size=(100,)) |
| 14 | + |
| 15 | +# Train a simple RandomForest model |
| 16 | +model = RandomForestClassifier(n_estimators=10) |
| 17 | +model.fit(X_train, y_train) |
| 18 | + |
| 19 | +# Convert model to ONNX |
| 20 | +initial_type = [("input", FloatTensorType([None, 10]))] |
| 21 | +onnx_model = convert_sklearn(model, initial_types=initial_type) |
| 22 | +onnx.save_model(onnx_model, "model.onnx") |
| 23 | + |
| 24 | +# Load ONNX model for inference test |
| 25 | +ort_session = ort.InferenceSession("model.onnx") |
| 26 | +input_data = {ort_session.get_inputs()[0].name: X_train[:5]} |
| 27 | +start = time() |
| 28 | +ort_outs = ort_session.run(None, input_data) |
| 29 | +print(f"ONNX Inference Time: {time() - start:.4f}s") |
| 30 | + |
| 31 | +# Optimize ONNX model with TVM |
| 32 | +onnx_model = onnx.load("model.onnx") |
| 33 | +mod, params = relay.frontend.from_onnx(onnx_model, shape={"input": (1, 10)}) |
| 34 | + |
| 35 | +# Compile with TVM |
| 36 | +target = "llvm" |
| 37 | +with tvm.transform.PassContext(opt_level=3): |
| 38 | + lib = relay.build(mod, target=target, params=params) |
| 39 | + |
| 40 | +# Run inference with TVM |
| 41 | +dev = tvm.cpu() |
| 42 | +dtype = "float32" |
| 43 | +tvm_model = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) |
| 44 | +tvm_model.set_input("input", tvm.nd.array(X_train[:5].astype(dtype))) |
| 45 | + |
| 46 | +start = time() |
| 47 | +tvm_model.run() |
| 48 | +tvm_out = tvm_model.get_output(0).numpy() |
| 49 | +print(f"TVM Optimized Inference Time: {time() - start:.4f}s") |
| 50 | + |
| 51 | +print("Optimization complete! Compare ONNX vs. TVM inference times.") |
0 commit comments