Skip to content

Commit 7884901

Browse files
Merge pull request #1 from JaimeAdanCuevas/bench
Create compilation_frameworks.py
2 parents d30f084 + a73d25d commit 7884901

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

compilation_frameworks.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
import onnx
3+
import onnxruntime as ort
4+
import tvm
5+
from tvm import relay
6+
from skl2onnx import convert_sklearn
7+
from skl2onnx.common.data_types import FloatTensorType
8+
from sklearn.ensemble import RandomForestClassifier
9+
from time import time
10+
11+
# Generate sample data
12+
X_train = np.random.rand(100, 10).astype(np.float32)
13+
y_train = np.random.randint(0, 2, size=(100,))
14+
15+
# Train a simple RandomForest model
16+
model = RandomForestClassifier(n_estimators=10)
17+
model.fit(X_train, y_train)
18+
19+
# Convert model to ONNX
20+
initial_type = [("input", FloatTensorType([None, 10]))]
21+
onnx_model = convert_sklearn(model, initial_types=initial_type)
22+
onnx.save_model(onnx_model, "model.onnx")
23+
24+
# Load ONNX model for inference test
25+
ort_session = ort.InferenceSession("model.onnx")
26+
input_data = {ort_session.get_inputs()[0].name: X_train[:5]}
27+
start = time()
28+
ort_outs = ort_session.run(None, input_data)
29+
print(f"ONNX Inference Time: {time() - start:.4f}s")
30+
31+
# Optimize ONNX model with TVM
32+
onnx_model = onnx.load("model.onnx")
33+
mod, params = relay.frontend.from_onnx(onnx_model, shape={"input": (1, 10)})
34+
35+
# Compile with TVM
36+
target = "llvm"
37+
with tvm.transform.PassContext(opt_level=3):
38+
lib = relay.build(mod, target=target, params=params)
39+
40+
# Run inference with TVM
41+
dev = tvm.cpu()
42+
dtype = "float32"
43+
tvm_model = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
44+
tvm_model.set_input("input", tvm.nd.array(X_train[:5].astype(dtype)))
45+
46+
start = time()
47+
tvm_model.run()
48+
tvm_out = tvm_model.get_output(0).numpy()
49+
print(f"TVM Optimized Inference Time: {time() - start:.4f}s")
50+
51+
print("Optimization complete! Compare ONNX vs. TVM inference times.")

0 commit comments

Comments
 (0)