Skip to content

Add Group Query Attention support with OV base OPs #28163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Mar 19, 2025

Conversation

sgbihu
Copy link
Contributor

@sgbihu sgbihu commented Dec 20, 2024

Details:

  • Try to enable LLM based on onnxruntime. (Phi3, Llama3 is working on CPU, Phi3 can work with iGPU)

Test scripts

import onnxruntime as rt
import os
import numpy as np
import time

import onnxruntime.tools.add_openvino_win_libs as utils
utils.add_openvino_libs_to_path()
from transformers import PreTrainedTokenizerFast


test_lama3 = False
test_phi3 = True
if test_phi3:
    modelPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'model.onnx')
    tokenizerPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'tokenizer.json')

if test_lama3:
    modelPath = os.path.join('D:\\', 'models', 'llm', 'llama3.1-8B-instruct-onnx', 'model.onnx')

so = rt.SessionOptions()
# so.log_severity_level = 3

# sess = rt.InferenceSession(modelPath, so, providers=['CPUExecutionProvider'])
sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU", 'cache_dir': "cache"}])
# sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU"}])
# sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "NPU"}])
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizerPath)

# print(sess.get_device())
# for name in sess.get_inputs():
#     print(f"Name: {name.name}, Shape: {name.shape}, Type: {name.type}")
outputs = sess.get_outputs()
output_names = list(map(lambda output: output.name, outputs))


# Assuming the model has 32 layers and each layer has a key and value state
# Phi3
def get_phi3_param():
    num_layers = 32
    batch_size = 1
    num_heads = 32
    sequence_length = 2048
    hidden_size = 96
    return num_layers, batch_size, num_heads, sequence_length, hidden_size

# lama
def get_llama3_param():
    num_layers = 32
    batch_size = 1
    num_heads = 8
    sequence_length = 2048
    hidden_size = 128
    return num_layers, batch_size, num_heads, sequence_length, hidden_size

if test_phi3:
    num_layers, batch_size, num_heads, sequence_length, hidden_size = get_phi3_param()

if test_lama3:
    num_layers, batch_size, num_heads, sequence_length, hidden_size = get_llama3_param()

# Initialize past_key_values with zeros
cpu_array = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)

# print("Output names: ", outputs[0].type.data)

def create_present_state_binding(binding, outputs):
    outputMap={}
    for output in outputs:
        shapes = []
        for item in output.shape:
            if isinstance(item, str):
                if 'batch_size' in item:
                    shapes.append(batch_size)
                elif 'sequence_length' in item:
                    if output.name == 'logits':
                        shapes.append(len(inputToken))
                    else:
                        shapes.append(sequence_length)
                elif 'hidden_size' in item:
                    shapes.append(hidden_size)
                elif 'num_heads' in item:
                    shapes.append(num_heads)
                else:
                    raise ValueError(f"Unknown dimension: {item}")
            else:
                shapes.append(item)
            
        present_state = rt.OrtValue.ortvalue_from_shape_and_type(shapes, np.float32)
        binding.bind_ortvalue_output(output.name, present_state)
        outputMap[output.name] = present_state
    return outputMap

def rebind_inputs(lastOutput, binding):
    for index in range(num_layers):
        binding.bind_ortvalue_input(f'past_key_values.{index}.key', lastOutput[f'present.{index}.key'])
        binding.bind_ortvalue_input(f'past_key_values.{index}.value', lastOutput[f'present.{index}.value'])
    return binding

def init_input_with_binding(binding):
    for index in range(num_layers):
        key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        binding.bind_ortvalue_input(f'past_key_values.{index}.key', key_state)
        binding.bind_ortvalue_input(f'past_key_values.{index}.value', value_state)
    return binding

def reinit_input_bindings(bindings, lastOutput):
    newOutput = create_present_state_binding(bindings, lastOutput)
    binding = rebind_inputs(lastOutput, bindings)
    return binding, newOutput

def create_numpy_inputs(inputToken):
    tokenLen = len(inputToken)
    npinput_ids = np.array([inputToken], dtype=np.int64)
    npattention_mask = np.array([[1] * (tokenLen)], dtype=np.int64)
    return npinput_ids, npattention_mask


def init_ortinput(inputToken):
    flattened_past_key_values = {}
    for index in range(num_layers):
        key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        flattened_past_key_values[f'past_key_values.{index}.key'] = key_state
        flattened_past_key_values[f'past_key_values.{index}.value'] = value_state
    ids, mask = create_numpy_inputs(inputToken)
    flattened_past_key_values['input_ids'] = rt.OrtValue.ortvalue_from_numpy(ids)
    flattened_past_key_values['attention_mask'] = rt.OrtValue.ortvalue_from_numpy(mask)
    return flattened_past_key_values

def init_npinput(inputToken):
    flattened_past_key_values = {}
    for index in range(num_layers):
        key_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)
        value_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)
        flattened_past_key_values[f'past_key_values.{index}.key'] = key_state
        flattened_past_key_values[f'past_key_values.{index}.value'] = value_state
    flattened_past_key_values['input_ids'], flattened_past_key_values['attention_mask'] = create_numpy_inputs(inputToken)
    return flattened_past_key_values

def init_bindinginput(inputToken):
    binding = sess.io_binding()
    binding = init_input_with_binding(binding)
    
    ids, mask = create_numpy_inputs(inputToken)
    binding.bind_ortvalue_input(f'attention_mask', rt.OrtValue.ortvalue_from_numpy(mask))
    binding.bind_ortvalue_input(f'input_ids',  rt.OrtValue.ortvalue_from_numpy(ids))
    return binding


# Question
# The Sun is yellow because

# Phi3
if test_phi3:
    # 450 8991 5692
    # inputToken = [32010, 29871, 13]
    inputToken = [32010, 29871, 13, 1576, 8991, 338, 13328, 1363, 29871, 32007, 13, 32001]
    # inputToken = [32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010]
# lama3
if test_lama3:
    # 315 1202 7479
    inputToken = [128000, 27, 91, 882, 91, 397, 791, 8219, 374, 14071, 1606, 83739, 408, 91, 397, 27, 91, 78191, 91, 29]
    # inputToken = [315]
history_tokens = inputToken

flattened_past_key_values = init_npinput(inputToken)

# flattened_past_key_values = init_ortinput(inputToken)

# binding = init_bindinginput(inputToken)
# lastoutput = create_present_state_binding(binding, outputs)

lastTokenLen = len(inputToken)


# roption = rt.RunOptions()
# roption.add_run_config_entry("gpu_graph_id", "-1")

before = time.time()
results = sess.run(output_names, flattened_past_key_values)
# results = sess.run_with_iobinding(binding)
# results = sess.run_with_ort_values(output_names, flattened_past_key_values)
after = time.time()
print("Time cost in ms: ", (after - before) * 1000)

# print(np.argmax(results[0].numpy(), axis=-1)[-1])
print(np.argmax(results[0], axis=-1)[-1])

# print(results[0])
# print(output_names[1])
# print(results[1][0][0][0])
# print(results[1][0][0][1])
# print(results[1][0][0][2])
# # print(results[1][0][0][14])
# # print(results[1])
# print(output_names[2])
# # print(results[2])
# print(results[2][0][0][0])
# print(results[2][0][0][1])
# print(results[2][0][0][2])
# print(results[2][0][0][14])
# inputToken.append(450)

# rebind_inputs(lastOutput, binding)

def update_kvcache(inputsMap, results):
    for index in range(len(output_names)):
        if not output_names[index].startswith('present'):
            continue
        # print(f'{output_names[index]}: {results[index].shape}')
        outputname = output_names[index]
        inputname = outputname.replace('present', 'past_key_values')
        inputsMap[inputname] = results[index]
    return inputsMap
# lastOutput = create_present_state_binding(binding, sess.get_outputs())

# flattened_past_key_values = update_kvcache(flattened_past_key_values, results)

for index in range(len(output_names)):
    if not output_names[index].startswith('present'):
        continue
    # print(f'{output_names[index]}: {results[index].shape}')
    outputname = output_names[index]
    inputname = outputname.replace('present', 'past_key_values')
    flattened_past_key_values[inputname] = results[index]
if test_phi3:
    inputToken = [450]

if test_lama3:
    inputToken = [315]
history_tokens += inputToken

npinput_ids = np.array([inputToken], dtype=np.int64)
npattention_mask = np.array([[1] * (lastTokenLen+1)], dtype=np.int64)
print(f"lastTokenLen:{lastTokenLen}")

# attention_mask = rt.OrtValue.ortvalue_from_numpy(npattention_mask)
# input_ids = rt.OrtValue.ortvalue_from_numpy(npinput_ids)
# binding.bind_ortvalue_input(f'attention_mask', attention_mask)
# binding.bind_ortvalue_input(f'input_ids', input_ids)
# flattened_past_key_values[f'attention_mask'].update_inplace(npattention_mask)
# flattened_past_key_values[f'input_ids'].update_inplace(npinput_ids)
# flattened_past_key_values[f'attention_mask'] = attention_mask
# flattened_past_key_values[f'input_ids'] = input_ids
flattened_past_key_values[f'attention_mask'] = npattention_mask
flattened_past_key_values[f'input_ids'] = npinput_ids
# print(flattened_past_key_values)

before = time.time()
results = sess.run(output_names, flattened_past_key_values)
# results = sess.run_with_iobinding(binding)
# results = sess.run_with_ort_values(output_names, flattened_past_key_values)
after = time.time()
print("Time cost in ms: ", (after - before) * 1000)

# Results:  [np.int32(450), np.int32(8991), np.int32(5692), np.int32(13328), np.int32(304), np.int32(502), np.int32(19434), np.int32(2861), np.int32(304), np.int32(9596), np.int32(280), np.int32(1141), np.int32(14801), np.int32(292), np.int32(29889), np.int32(1932), np.int32(6575), np.int32(4366), np.int32(14517), np.int32(1549), np.int32(278), np.int32(11563), np.int32(29915), np.int32(29879), np.int32(25005), np.int32(29892), np.int32(278), np.int32(20511), np.int32(7254), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(526), np.int32(29574), np.int32(297), np.int32(599), np.int32(18112), np.int32(491), np.int32(278), np.int32(330), np.int32(2129), np.int32(322), np.int32(17105), np.int32(297), np.int32(278), np.int32(4799), np.int32(29889), np.int32(910), np.int32(14801), np.int32(292), np.int32(9946), np.int32(278), np.int32(14744), np.int32(304), np.int32(1106), np.int32(7254), np.int32(29889), np.int32(2398), np.int32(29892), np.int32(278), np.int32(5520), np.int32(2654), np.int32(322), np.int32(13328), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(1209), np.int32(1549), np.int32(278), np.int32(25005), np.int32(901), np.int32(5948), np.int32(322), np.int32(526), np.int32(3109), np.int32(29574), np.int32(29889), np.int32(1932), np.int32(591), np.int32(1106), np.int32(472), np.int32(278), np.int32(8991), np.int32(29892), np.int32(591), np.int32(1074), np.int32(372), np.int32(408), np.int32(263), np.int32(13328), np.int32(470), np.int32(24841), np.int32(8086), np.int32(1363), np.int32(278), np.int32(7254), np.int32(3578), np.int32(338), np.int32(29574), np.int32(714), np.int32(310), np.int32(1749), np.int32(1196), np.int32(310), np.int32(11126), np.int32(29892), np.int32(322), np.int32(278), np.int32(9886), np.int32(3578), np.int32(393), np.int32(22170), np.int32(1749), np.int32(5076), np.int32(338), np.int32(758), np.int32(24130), np.int32(10835), np.int32(13328), np.int32(322), np.int32(2654), np.int32(29889), np.int32(32000)]
# index = 0
# for result in results:
#     print(f'{output_names[index]}: {result.shape}, {result.dtype}')
#     index += 1
print(np.argmax(results[0], axis=-1)[-1])
# print(np.argmax(results[0].numpy(), axis=-1)[-1])


# golden results
# Time cost in ms:  1255.2332878112793
# [30751    13    13  1494  1731   263 29889   372    13 24380    13   450]
# lastTokenLen:12
# Time cost in ms:  1006.781816482544
# [8991]

last_generated_token = np.argmax(results[0], axis=-1)[-1][-1]
history_tokens.append(last_generated_token)
NUM_INFERENCE = 15
for i in range(NUM_INFERENCE):
    # update kvcahe
    for index in range(len(output_names)):
        if not output_names[index].startswith('present'):
            continue
        # print(f'{output_names[index]}: {results[index].shape}')
        outputname = output_names[index]
        inputname = outputname.replace('present', 'past_key_values')
        flattened_past_key_values[inputname] = results[index]

    # update input token
    flattened_past_key_values[f'input_ids'] = np.array([[last_generated_token]], dtype=np.int64)
    flattened_past_key_values[f'attention_mask'] = np.array([[1] * len(history_tokens)], dtype=np.int64)

    before = time.time()
    results = sess.run(output_names, flattened_past_key_values)
    after = time.time()
    print("Time cost in ms: ", (after - before) * 1000)

    last_generated_token = np.argmax(results[0], axis=-1)[-1][-1]
    history_tokens.append(last_generated_token)

print(tokenizer.decode(history_tokens))

Tickets:

  • related to 155287, 157123

@github-actions github-actions bot added the category: ONNX FE OpenVINO ONNX FrontEnd label Dec 20, 2024
@sys-openvino-ci sys-openvino-ci added the ExternalIntelPR External contributor from Intel label Dec 20, 2024
@slyalin
Copy link
Contributor

slyalin commented Dec 20, 2024

How is it related to #27648?

@github-actions github-actions bot added category: Core OpenVINO Core (aka ngraph) category: GPU OpenVINO GPU plugin category: CPU OpenVINO CPU plugin category: transformations OpenVINO Runtime library - Transformations labels Jan 16, 2025
@github-actions github-actions bot added category: CPP API OpenVINO CPP API bindings and removed category: GPU OpenVINO GPU plugin category: CPU OpenVINO CPU plugin labels Jan 21, 2025
@wine99 wine99 force-pushed the gqa_enabling branch 2 times, most recently from f4770e0 to 911691b Compare January 26, 2025 02:32
@wine99
Copy link
Contributor

wine99 commented Feb 5, 2025

@slyalin we have relocated the transformation code from the ONNX frontend to the plugin transformation passes as detailed in #27648. Could you please review and provide feedback? Currently, the GQA node is defined in opset15, which likely needs to be updated.

@sgbihu sgbihu marked this pull request as ready for review February 6, 2025 13:03
@sgbihu sgbihu requested review from a team as code owners February 6, 2025 13:03
@sgbihu sgbihu requested review from itikhono and removed request for a team February 6, 2025 13:03
@mlukasze
Copy link
Contributor

mlukasze commented Feb 6, 2025

hey @sgbihu
please, resolve conflicts before CI will be triggered

@wine99 wine99 requested a review from a team as a code owner February 7, 2025 03:00
@wine99 wine99 requested review from ilya-lavrenov and removed request for a team February 7, 2025 03:00
const std::shared_ptr<ov::op::v3::ShapeOf>& shape,
const std::vector<int>& dims) {
using namespace ov::op;
const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor:
We use i64 in some places and i32 in others. There isn't a restriction to do so but it's better to align element types, e.g. use i64

GroupQueryAttentionDecomposition();

private:
ov::OutputVector decompose(std::shared_ptr<ov::op::GroupQueryAttention> node);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor:
as I can see, these functions don't use any members of GroupQueryAttentionDecomposition and can be used separately,
so it's better to move it to unnamed namespace inside group_query_attention_decomposition.cpp file or to utils if we plan to re-use it

std::shared_ptr<ov::Node> minus_inf = nullptr;
if (T == ov::element::f32)
minus_inf = register_new_node(v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits<float>::infinity()}));
else if (T == ov::element::f16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks unsafe, minus_inf might be nullptr after this if in case of another type, probably now it's impossible but it's not guaranteed that it won't be extended in the future
should we throw exception or return false or add some default else branch if T is not f32, f16?
usually we prefer to return false not to break model inference, if no replacements or model modifications were done at this moment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will address this in a following PR

@itikhono
Copy link
Contributor

Approved. We agreed to resolve the remaining comments in the next PR

@wine99
Copy link
Contributor

wine99 commented Mar 11, 2025

@t-jankowski @gkrivor Could you review again?

Copy link
Contributor

@t-jankowski t-jankowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for core part

@rkazants
Copy link
Member

build_jenkins

1 similar comment
@XinWangIntel
Copy link
Contributor

build_jenkins

@t-jankowski
Copy link
Contributor

build_jenkins

@wine99
Copy link
Contributor

wine99 commented Mar 14, 2025

Failed jobs in build_jenkins are same as in other PRs, e.g. commit fd793e4 in PR #29446
image

@wine99
Copy link
Contributor

wine99 commented Mar 14, 2025

CI status: 1 failing: ie_tests_cldnn_unit_dg2_ubuntu22_release

Details: 1 unit test failed:

[2025-03-14T01:54:32.242Z] cldnn_unit_tests_dg2-0 INFO: [2261/22863] convolution_int8_fw_gpu.quantized_convolution_u8s8f32_asymmetric_activations_per_channel_dynamic (939 ms)
[2025-03-14T01:54:32.242Z] cldnn_unit_tests_dg2-0 INFO: ERROR: ld.so: object 'libSegFault.so' from LD_PRELOAD cannot be preloaded (cannot open shared object file): ignored.
[2025-03-14T01:54:32.242Z] cldnn_unit_tests_dg2-0 INFO: Running main() from src/plugins/intel_gpu/tests/unit/gtest_main_gpu.cpp
[2025-03-14T01:54:32.242Z] cldnn_unit_tests_dg2-0 INFO: Env variable cl_cache_dir: /nfs/ov-share-02/data/volatile/cache/cl_cache_linux
[2025-03-14T01:54:32.242Z] cldnn_unit_tests_dg2-0 INFO: �[0;33mNote: Google Test filter = convolution_int8_fw_gpu.quantized_convolution_u8s8f32_asymmetric_activations_per_channel_dynamic
[2025-03-14T01:54:32.242Z] cldnn_unit_tests_dg2-0 INFO: �[m�[0;32m[==========] �[mRunning 1 test from 1 test suite.
[2025-03-14T01:54:32.242Z] cldnn_unit_tests_dg2-0 INFO: �[0;32m[----------] �[mGlobal test environment set-up.
[2025-03-14T01:54:32.242Z] cldnn_unit_tests_dg2-0 INFO: �[0;32m[ PG INFO  ] �[mPostgreSQL Reporting is disabled due to missing environment settings
[2025-03-14T01:54:32.242Z] cldnn_unit_tests_dg2-0 INFO: �[0;32m[----------] �[m1 test from convolution_int8_fw_gpu
[2025-03-14T01:54:32.242Z] cldnn_unit_tests_dg2-0 INFO: �[0;32m[ RUN      ] �[mconvolution_int8_fw_gpu.quantized_convolution_u8s8f32_asymmetric_activations_per_channel_dynamic
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO: src/plugins/intel_gpu/tests/unit/test_cases/convolution_gpu_test.cpp:4746: Failure
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO: The difference between output_vec[f][y][x] and ((float)output_ptr[f * y_size * x_size + y * x_size + x]) is 80, which exceeds 1e-5f, where
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO: output_vec[f][y][x] evaluates to -36,
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO: ((float)output_ptr[f * y_size * x_size + y * x_size + x]) evaluates to 44, and
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO: 1e-5f evaluates to 9.9999997473787516e-06.
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO:  x=0 y=0 f=0
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO: �[0;31m[  FAILED  ] �[mconvolution_int8_fw_gpu.quantized_convolution_u8s8f32_asymmetric_activations_per_channel_dynamic (417 ms)
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO: �[0;32m[----------] �[m1 test from convolution_int8_fw_gpu (417 ms total)
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO: 
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO: �[0;32m[----------] �[mGlobal test environment tear-down
[2025-03-14T01:54:32.243Z] cldnn_unit_tests_dg2-0 INFO: �[0;32m[==========] �[m1 test from 1 test suite ran. (419 ms total)
[2025-03-14T01:54:32.244Z] cldnn_unit_tests_dg2-0 INFO: �[0;32m[  PASSED  ] �[m0 tests.
[2025-03-14T01:54:32.244Z] cldnn_unit_tests_dg2-0 INFO: �[0;31m[  FAILED  ] �[m1 test, listed below:
[2025-03-14T01:54:32.244Z] cldnn_unit_tests_dg2-0 INFO: �[0;31m[  FAILED  ] �[mconvolution_int8_fw_gpu.quantized_convolution_u8s8f32_asymmetric_activations_per_channel_dynamic
[2025-03-14T01:54:32.244Z] cldnn_unit_tests_dg2-0 INFO: 
[2025-03-14T01:54:32.244Z] cldnn_unit_tests_dg2-0 INFO:  1 FAILED TEST
[2025-03-14T01:54:32.244Z] cldnn_unit_tests_dg2-0 INFO: [2261/22863] convolution_int8_fw_gpu.quantized_convolution_u8s8f32_asymmetric_activations_per_channel_dynamic returned/aborted with exit code 1 (939 ms)

The failure should not be related to this PR

@XinWangIntel
Copy link
Contributor

build_jenkins

@XinWangIntel
Copy link
Contributor

build_jenkins

1 similar comment
@itikhono
Copy link
Contributor

build_jenkins

@wine99
Copy link
Contributor

wine99 commented Mar 18, 2025

@praasz Suggestions applied

@itikhono
Copy link
Contributor

build_jenkins

Copy link
Contributor

@praasz praasz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for core part

@praasz praasz added this pull request to the merge queue Mar 19, 2025
Merged via the queue into openvinotoolkit:master with commit 307db82 Mar 19, 2025
189 checks passed
timxu826 pushed a commit to timxu826/openvino that referenced this pull request Apr 7, 2025
…8163)

### Details:
- Try to enable LLM based on onnxruntime. (Phi3, Llama3 is working on
CPU, Phi3 can work with iGPU)
 
### Test scripts
```
import onnxruntime as rt
import os
import numpy as np
import time

import onnxruntime.tools.add_openvino_win_libs as utils
utils.add_openvino_libs_to_path()
from transformers import PreTrainedTokenizerFast


test_lama3 = False
test_phi3 = True
if test_phi3:
    modelPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'model.onnx')
    tokenizerPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'tokenizer.json')

if test_lama3:
    modelPath = os.path.join('D:\\', 'models', 'llm', 'llama3.1-8B-instruct-onnx', 'model.onnx')

so = rt.SessionOptions()
# so.log_severity_level = 3

# sess = rt.InferenceSession(modelPath, so, providers=['CPUExecutionProvider'])
sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU", 'cache_dir': "cache"}])
# sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU"}])
# sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "NPU"}])
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizerPath)

# print(sess.get_device())
# for name in sess.get_inputs():
#     print(f"Name: {name.name}, Shape: {name.shape}, Type: {name.type}")
outputs = sess.get_outputs()
output_names = list(map(lambda output: output.name, outputs))


# Assuming the model has 32 layers and each layer has a key and value state
# Phi3
def get_phi3_param():
    num_layers = 32
    batch_size = 1
    num_heads = 32
    sequence_length = 2048
    hidden_size = 96
    return num_layers, batch_size, num_heads, sequence_length, hidden_size

# lama
def get_llama3_param():
    num_layers = 32
    batch_size = 1
    num_heads = 8
    sequence_length = 2048
    hidden_size = 128
    return num_layers, batch_size, num_heads, sequence_length, hidden_size

if test_phi3:
    num_layers, batch_size, num_heads, sequence_length, hidden_size = get_phi3_param()

if test_lama3:
    num_layers, batch_size, num_heads, sequence_length, hidden_size = get_llama3_param()

# Initialize past_key_values with zeros
cpu_array = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)

# print("Output names: ", outputs[0].type.data)

def create_present_state_binding(binding, outputs):
    outputMap={}
    for output in outputs:
        shapes = []
        for item in output.shape:
            if isinstance(item, str):
                if 'batch_size' in item:
                    shapes.append(batch_size)
                elif 'sequence_length' in item:
                    if output.name == 'logits':
                        shapes.append(len(inputToken))
                    else:
                        shapes.append(sequence_length)
                elif 'hidden_size' in item:
                    shapes.append(hidden_size)
                elif 'num_heads' in item:
                    shapes.append(num_heads)
                else:
                    raise ValueError(f"Unknown dimension: {item}")
            else:
                shapes.append(item)
            
        present_state = rt.OrtValue.ortvalue_from_shape_and_type(shapes, np.float32)
        binding.bind_ortvalue_output(output.name, present_state)
        outputMap[output.name] = present_state
    return outputMap

def rebind_inputs(lastOutput, binding):
    for index in range(num_layers):
        binding.bind_ortvalue_input(f'past_key_values.{index}.key', lastOutput[f'present.{index}.key'])
        binding.bind_ortvalue_input(f'past_key_values.{index}.value', lastOutput[f'present.{index}.value'])
    return binding

def init_input_with_binding(binding):
    for index in range(num_layers):
        key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        binding.bind_ortvalue_input(f'past_key_values.{index}.key', key_state)
        binding.bind_ortvalue_input(f'past_key_values.{index}.value', value_state)
    return binding

def reinit_input_bindings(bindings, lastOutput):
    newOutput = create_present_state_binding(bindings, lastOutput)
    binding = rebind_inputs(lastOutput, bindings)
    return binding, newOutput

def create_numpy_inputs(inputToken):
    tokenLen = len(inputToken)
    npinput_ids = np.array([inputToken], dtype=np.int64)
    npattention_mask = np.array([[1] * (tokenLen)], dtype=np.int64)
    return npinput_ids, npattention_mask


def init_ortinput(inputToken):
    flattened_past_key_values = {}
    for index in range(num_layers):
        key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array)
        flattened_past_key_values[f'past_key_values.{index}.key'] = key_state
        flattened_past_key_values[f'past_key_values.{index}.value'] = value_state
    ids, mask = create_numpy_inputs(inputToken)
    flattened_past_key_values['input_ids'] = rt.OrtValue.ortvalue_from_numpy(ids)
    flattened_past_key_values['attention_mask'] = rt.OrtValue.ortvalue_from_numpy(mask)
    return flattened_past_key_values

def init_npinput(inputToken):
    flattened_past_key_values = {}
    for index in range(num_layers):
        key_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)
        value_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32)
        flattened_past_key_values[f'past_key_values.{index}.key'] = key_state
        flattened_past_key_values[f'past_key_values.{index}.value'] = value_state
    flattened_past_key_values['input_ids'], flattened_past_key_values['attention_mask'] = create_numpy_inputs(inputToken)
    return flattened_past_key_values

def init_bindinginput(inputToken):
    binding = sess.io_binding()
    binding = init_input_with_binding(binding)
    
    ids, mask = create_numpy_inputs(inputToken)
    binding.bind_ortvalue_input(f'attention_mask', rt.OrtValue.ortvalue_from_numpy(mask))
    binding.bind_ortvalue_input(f'input_ids',  rt.OrtValue.ortvalue_from_numpy(ids))
    return binding


# Question
# The Sun is yellow because

# Phi3
if test_phi3:
    # 450 8991 5692
    # inputToken = [32010, 29871, 13]
    inputToken = [32010, 29871, 13, 1576, 8991, 338, 13328, 1363, 29871, 32007, 13, 32001]
    # inputToken = [32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010]
# lama3
if test_lama3:
    # 315 1202 7479
    inputToken = [128000, 27, 91, 882, 91, 397, 791, 8219, 374, 14071, 1606, 83739, 408, 91, 397, 27, 91, 78191, 91, 29]
    # inputToken = [315]
history_tokens = inputToken

flattened_past_key_values = init_npinput(inputToken)

# flattened_past_key_values = init_ortinput(inputToken)

# binding = init_bindinginput(inputToken)
# lastoutput = create_present_state_binding(binding, outputs)

lastTokenLen = len(inputToken)


# roption = rt.RunOptions()
# roption.add_run_config_entry("gpu_graph_id", "-1")

before = time.time()
results = sess.run(output_names, flattened_past_key_values)
# results = sess.run_with_iobinding(binding)
# results = sess.run_with_ort_values(output_names, flattened_past_key_values)
after = time.time()
print("Time cost in ms: ", (after - before) * 1000)

# print(np.argmax(results[0].numpy(), axis=-1)[-1])
print(np.argmax(results[0], axis=-1)[-1])

# print(results[0])
# print(output_names[1])
# print(results[1][0][0][0])
# print(results[1][0][0][1])
# print(results[1][0][0][2])
# # print(results[1][0][0][14])
# # print(results[1])
# print(output_names[2])
# # print(results[2])
# print(results[2][0][0][0])
# print(results[2][0][0][1])
# print(results[2][0][0][2])
# print(results[2][0][0][14])
# inputToken.append(450)

# rebind_inputs(lastOutput, binding)

def update_kvcache(inputsMap, results):
    for index in range(len(output_names)):
        if not output_names[index].startswith('present'):
            continue
        # print(f'{output_names[index]}: {results[index].shape}')
        outputname = output_names[index]
        inputname = outputname.replace('present', 'past_key_values')
        inputsMap[inputname] = results[index]
    return inputsMap
# lastOutput = create_present_state_binding(binding, sess.get_outputs())

# flattened_past_key_values = update_kvcache(flattened_past_key_values, results)

for index in range(len(output_names)):
    if not output_names[index].startswith('present'):
        continue
    # print(f'{output_names[index]}: {results[index].shape}')
    outputname = output_names[index]
    inputname = outputname.replace('present', 'past_key_values')
    flattened_past_key_values[inputname] = results[index]
if test_phi3:
    inputToken = [450]

if test_lama3:
    inputToken = [315]
history_tokens += inputToken

npinput_ids = np.array([inputToken], dtype=np.int64)
npattention_mask = np.array([[1] * (lastTokenLen+1)], dtype=np.int64)
print(f"lastTokenLen:{lastTokenLen}")

# attention_mask = rt.OrtValue.ortvalue_from_numpy(npattention_mask)
# input_ids = rt.OrtValue.ortvalue_from_numpy(npinput_ids)
# binding.bind_ortvalue_input(f'attention_mask', attention_mask)
# binding.bind_ortvalue_input(f'input_ids', input_ids)
# flattened_past_key_values[f'attention_mask'].update_inplace(npattention_mask)
# flattened_past_key_values[f'input_ids'].update_inplace(npinput_ids)
# flattened_past_key_values[f'attention_mask'] = attention_mask
# flattened_past_key_values[f'input_ids'] = input_ids
flattened_past_key_values[f'attention_mask'] = npattention_mask
flattened_past_key_values[f'input_ids'] = npinput_ids
# print(flattened_past_key_values)

before = time.time()
results = sess.run(output_names, flattened_past_key_values)
# results = sess.run_with_iobinding(binding)
# results = sess.run_with_ort_values(output_names, flattened_past_key_values)
after = time.time()
print("Time cost in ms: ", (after - before) * 1000)

# Results:  [np.int32(450), np.int32(8991), np.int32(5692), np.int32(13328), np.int32(304), np.int32(502), np.int32(19434), np.int32(2861), np.int32(304), np.int32(9596), np.int32(280), np.int32(1141), np.int32(14801), np.int32(292), np.int32(29889), np.int32(1932), np.int32(6575), np.int32(4366), np.int32(14517), np.int32(1549), np.int32(278), np.int32(11563), np.int32(29915), np.int32(29879), np.int32(25005), np.int32(29892), np.int32(278), np.int32(20511), np.int32(7254), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(526), np.int32(29574), np.int32(297), np.int32(599), np.int32(18112), np.int32(491), np.int32(278), np.int32(330), np.int32(2129), np.int32(322), np.int32(17105), np.int32(297), np.int32(278), np.int32(4799), np.int32(29889), np.int32(910), np.int32(14801), np.int32(292), np.int32(9946), np.int32(278), np.int32(14744), np.int32(304), np.int32(1106), np.int32(7254), np.int32(29889), np.int32(2398), np.int32(29892), np.int32(278), np.int32(5520), np.int32(2654), np.int32(322), np.int32(13328), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(1209), np.int32(1549), np.int32(278), np.int32(25005), np.int32(901), np.int32(5948), np.int32(322), np.int32(526), np.int32(3109), np.int32(29574), np.int32(29889), np.int32(1932), np.int32(591), np.int32(1106), np.int32(472), np.int32(278), np.int32(8991), np.int32(29892), np.int32(591), np.int32(1074), np.int32(372), np.int32(408), np.int32(263), np.int32(13328), np.int32(470), np.int32(24841), np.int32(8086), np.int32(1363), np.int32(278), np.int32(7254), np.int32(3578), np.int32(338), np.int32(29574), np.int32(714), np.int32(310), np.int32(1749), np.int32(1196), np.int32(310), np.int32(11126), np.int32(29892), np.int32(322), np.int32(278), np.int32(9886), np.int32(3578), np.int32(393), np.int32(22170), np.int32(1749), np.int32(5076), np.int32(338), np.int32(758), np.int32(24130), np.int32(10835), np.int32(13328), np.int32(322), np.int32(2654), np.int32(29889), np.int32(32000)]
# index = 0
# for result in results:
#     print(f'{output_names[index]}: {result.shape}, {result.dtype}')
#     index += 1
print(np.argmax(results[0], axis=-1)[-1])
# print(np.argmax(results[0].numpy(), axis=-1)[-1])


# golden results
# Time cost in ms:  1255.2332878112793
# [30751    13    13  1494  1731   263 29889   372    13 24380    13   450]
# lastTokenLen:12
# Time cost in ms:  1006.781816482544
# [8991]

last_generated_token = np.argmax(results[0], axis=-1)[-1][-1]
history_tokens.append(last_generated_token)
NUM_INFERENCE = 15
for i in range(NUM_INFERENCE):
    # update kvcahe
    for index in range(len(output_names)):
        if not output_names[index].startswith('present'):
            continue
        # print(f'{output_names[index]}: {results[index].shape}')
        outputname = output_names[index]
        inputname = outputname.replace('present', 'past_key_values')
        flattened_past_key_values[inputname] = results[index]

    # update input token
    flattened_past_key_values[f'input_ids'] = np.array([[last_generated_token]], dtype=np.int64)
    flattened_past_key_values[f'attention_mask'] = np.array([[1] * len(history_tokens)], dtype=np.int64)

    before = time.time()
    results = sess.run(output_names, flattened_past_key_values)
    after = time.time()
    print("Time cost in ms: ", (after - before) * 1000)

    last_generated_token = np.argmax(results[0], axis=-1)[-1][-1]
    history_tokens.append(last_generated_token)

print(tokenizer.decode(history_tokens))
```

### Tickets:
 - related to 155287, 157123

---------

Co-authored-by: Yu, Zijun <[email protected]>
Co-authored-by: Tomasz Jankowski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: Core OpenVINO Core (aka ngraph) category: ONNX FE OpenVINO ONNX FrontEnd category: transformations OpenVINO Runtime library - Transformations ExternalIntelPR External contributor from Intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.