-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Add Group Query Attention support with OV base OPs #28163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
How is it related to #27648? |
f4770e0
to
911691b
Compare
hey @sgbihu |
...ansformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp
Outdated
Show resolved
Hide resolved
...n/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp
Outdated
Show resolved
Hide resolved
const std::shared_ptr<ov::op::v3::ShapeOf>& shape, | ||
const std::vector<int>& dims) { | ||
using namespace ov::op; | ||
const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor:
We use i64 in some places and i32 in others. There isn't a restriction to do so but it's better to align element types, e.g. use i64
GroupQueryAttentionDecomposition(); | ||
|
||
private: | ||
ov::OutputVector decompose(std::shared_ptr<ov::op::GroupQueryAttention> node); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor:
as I can see, these functions don't use any members of GroupQueryAttentionDecomposition and can be used separately,
so it's better to move it to unnamed namespace inside group_query_attention_decomposition.cpp file or to utils if we plan to re-use it
std::shared_ptr<ov::Node> minus_inf = nullptr; | ||
if (T == ov::element::f32) | ||
minus_inf = register_new_node(v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits<float>::infinity()})); | ||
else if (T == ov::element::f16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks unsafe, minus_inf
might be nullptr
after this if
in case of another type, probably now it's impossible but it's not guaranteed that it won't be extended in the future
should we throw exception or return false
or add some default else branch if T is not f32, f16?
usually we prefer to return false
not to break model inference, if no replacements or model modifications were done at this moment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will address this in a following PR
Approved. We agreed to resolve the remaining comments in the next PR |
@t-jankowski @gkrivor Could you review again? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for core part
build_jenkins |
1 similar comment
build_jenkins |
build_jenkins |
CI status: 1 failing: ie_tests_cldnn_unit_dg2_ubuntu22_release Details: 1 unit test failed:
The failure should not be related to this PR |
build_jenkins |
build_jenkins |
1 similar comment
build_jenkins |
src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp
Outdated
Show resolved
Hide resolved
@praasz Suggestions applied |
build_jenkins |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for core part
…8163) ### Details: - Try to enable LLM based on onnxruntime. (Phi3, Llama3 is working on CPU, Phi3 can work with iGPU) ### Test scripts ``` import onnxruntime as rt import os import numpy as np import time import onnxruntime.tools.add_openvino_win_libs as utils utils.add_openvino_libs_to_path() from transformers import PreTrainedTokenizerFast test_lama3 = False test_phi3 = True if test_phi3: modelPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'model.onnx') tokenizerPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'tokenizer.json') if test_lama3: modelPath = os.path.join('D:\\', 'models', 'llm', 'llama3.1-8B-instruct-onnx', 'model.onnx') so = rt.SessionOptions() # so.log_severity_level = 3 # sess = rt.InferenceSession(modelPath, so, providers=['CPUExecutionProvider']) sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU", 'cache_dir': "cache"}]) # sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU"}]) # sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "NPU"}]) tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizerPath) # print(sess.get_device()) # for name in sess.get_inputs(): # print(f"Name: {name.name}, Shape: {name.shape}, Type: {name.type}") outputs = sess.get_outputs() output_names = list(map(lambda output: output.name, outputs)) # Assuming the model has 32 layers and each layer has a key and value state # Phi3 def get_phi3_param(): num_layers = 32 batch_size = 1 num_heads = 32 sequence_length = 2048 hidden_size = 96 return num_layers, batch_size, num_heads, sequence_length, hidden_size # lama def get_llama3_param(): num_layers = 32 batch_size = 1 num_heads = 8 sequence_length = 2048 hidden_size = 128 return num_layers, batch_size, num_heads, sequence_length, hidden_size if test_phi3: num_layers, batch_size, num_heads, sequence_length, hidden_size = get_phi3_param() if test_lama3: num_layers, batch_size, num_heads, sequence_length, hidden_size = get_llama3_param() # Initialize past_key_values with zeros cpu_array = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32) # print("Output names: ", outputs[0].type.data) def create_present_state_binding(binding, outputs): outputMap={} for output in outputs: shapes = [] for item in output.shape: if isinstance(item, str): if 'batch_size' in item: shapes.append(batch_size) elif 'sequence_length' in item: if output.name == 'logits': shapes.append(len(inputToken)) else: shapes.append(sequence_length) elif 'hidden_size' in item: shapes.append(hidden_size) elif 'num_heads' in item: shapes.append(num_heads) else: raise ValueError(f"Unknown dimension: {item}") else: shapes.append(item) present_state = rt.OrtValue.ortvalue_from_shape_and_type(shapes, np.float32) binding.bind_ortvalue_output(output.name, present_state) outputMap[output.name] = present_state return outputMap def rebind_inputs(lastOutput, binding): for index in range(num_layers): binding.bind_ortvalue_input(f'past_key_values.{index}.key', lastOutput[f'present.{index}.key']) binding.bind_ortvalue_input(f'past_key_values.{index}.value', lastOutput[f'present.{index}.value']) return binding def init_input_with_binding(binding): for index in range(num_layers): key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array) value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array) binding.bind_ortvalue_input(f'past_key_values.{index}.key', key_state) binding.bind_ortvalue_input(f'past_key_values.{index}.value', value_state) return binding def reinit_input_bindings(bindings, lastOutput): newOutput = create_present_state_binding(bindings, lastOutput) binding = rebind_inputs(lastOutput, bindings) return binding, newOutput def create_numpy_inputs(inputToken): tokenLen = len(inputToken) npinput_ids = np.array([inputToken], dtype=np.int64) npattention_mask = np.array([[1] * (tokenLen)], dtype=np.int64) return npinput_ids, npattention_mask def init_ortinput(inputToken): flattened_past_key_values = {} for index in range(num_layers): key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array) value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array) flattened_past_key_values[f'past_key_values.{index}.key'] = key_state flattened_past_key_values[f'past_key_values.{index}.value'] = value_state ids, mask = create_numpy_inputs(inputToken) flattened_past_key_values['input_ids'] = rt.OrtValue.ortvalue_from_numpy(ids) flattened_past_key_values['attention_mask'] = rt.OrtValue.ortvalue_from_numpy(mask) return flattened_past_key_values def init_npinput(inputToken): flattened_past_key_values = {} for index in range(num_layers): key_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32) value_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32) flattened_past_key_values[f'past_key_values.{index}.key'] = key_state flattened_past_key_values[f'past_key_values.{index}.value'] = value_state flattened_past_key_values['input_ids'], flattened_past_key_values['attention_mask'] = create_numpy_inputs(inputToken) return flattened_past_key_values def init_bindinginput(inputToken): binding = sess.io_binding() binding = init_input_with_binding(binding) ids, mask = create_numpy_inputs(inputToken) binding.bind_ortvalue_input(f'attention_mask', rt.OrtValue.ortvalue_from_numpy(mask)) binding.bind_ortvalue_input(f'input_ids', rt.OrtValue.ortvalue_from_numpy(ids)) return binding # Question # The Sun is yellow because # Phi3 if test_phi3: # 450 8991 5692 # inputToken = [32010, 29871, 13] inputToken = [32010, 29871, 13, 1576, 8991, 338, 13328, 1363, 29871, 32007, 13, 32001] # inputToken = [32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010] # lama3 if test_lama3: # 315 1202 7479 inputToken = [128000, 27, 91, 882, 91, 397, 791, 8219, 374, 14071, 1606, 83739, 408, 91, 397, 27, 91, 78191, 91, 29] # inputToken = [315] history_tokens = inputToken flattened_past_key_values = init_npinput(inputToken) # flattened_past_key_values = init_ortinput(inputToken) # binding = init_bindinginput(inputToken) # lastoutput = create_present_state_binding(binding, outputs) lastTokenLen = len(inputToken) # roption = rt.RunOptions() # roption.add_run_config_entry("gpu_graph_id", "-1") before = time.time() results = sess.run(output_names, flattened_past_key_values) # results = sess.run_with_iobinding(binding) # results = sess.run_with_ort_values(output_names, flattened_past_key_values) after = time.time() print("Time cost in ms: ", (after - before) * 1000) # print(np.argmax(results[0].numpy(), axis=-1)[-1]) print(np.argmax(results[0], axis=-1)[-1]) # print(results[0]) # print(output_names[1]) # print(results[1][0][0][0]) # print(results[1][0][0][1]) # print(results[1][0][0][2]) # # print(results[1][0][0][14]) # # print(results[1]) # print(output_names[2]) # # print(results[2]) # print(results[2][0][0][0]) # print(results[2][0][0][1]) # print(results[2][0][0][2]) # print(results[2][0][0][14]) # inputToken.append(450) # rebind_inputs(lastOutput, binding) def update_kvcache(inputsMap, results): for index in range(len(output_names)): if not output_names[index].startswith('present'): continue # print(f'{output_names[index]}: {results[index].shape}') outputname = output_names[index] inputname = outputname.replace('present', 'past_key_values') inputsMap[inputname] = results[index] return inputsMap # lastOutput = create_present_state_binding(binding, sess.get_outputs()) # flattened_past_key_values = update_kvcache(flattened_past_key_values, results) for index in range(len(output_names)): if not output_names[index].startswith('present'): continue # print(f'{output_names[index]}: {results[index].shape}') outputname = output_names[index] inputname = outputname.replace('present', 'past_key_values') flattened_past_key_values[inputname] = results[index] if test_phi3: inputToken = [450] if test_lama3: inputToken = [315] history_tokens += inputToken npinput_ids = np.array([inputToken], dtype=np.int64) npattention_mask = np.array([[1] * (lastTokenLen+1)], dtype=np.int64) print(f"lastTokenLen:{lastTokenLen}") # attention_mask = rt.OrtValue.ortvalue_from_numpy(npattention_mask) # input_ids = rt.OrtValue.ortvalue_from_numpy(npinput_ids) # binding.bind_ortvalue_input(f'attention_mask', attention_mask) # binding.bind_ortvalue_input(f'input_ids', input_ids) # flattened_past_key_values[f'attention_mask'].update_inplace(npattention_mask) # flattened_past_key_values[f'input_ids'].update_inplace(npinput_ids) # flattened_past_key_values[f'attention_mask'] = attention_mask # flattened_past_key_values[f'input_ids'] = input_ids flattened_past_key_values[f'attention_mask'] = npattention_mask flattened_past_key_values[f'input_ids'] = npinput_ids # print(flattened_past_key_values) before = time.time() results = sess.run(output_names, flattened_past_key_values) # results = sess.run_with_iobinding(binding) # results = sess.run_with_ort_values(output_names, flattened_past_key_values) after = time.time() print("Time cost in ms: ", (after - before) * 1000) # Results: [np.int32(450), np.int32(8991), np.int32(5692), np.int32(13328), np.int32(304), np.int32(502), np.int32(19434), np.int32(2861), np.int32(304), np.int32(9596), np.int32(280), np.int32(1141), np.int32(14801), np.int32(292), np.int32(29889), np.int32(1932), np.int32(6575), np.int32(4366), np.int32(14517), np.int32(1549), np.int32(278), np.int32(11563), np.int32(29915), np.int32(29879), np.int32(25005), np.int32(29892), np.int32(278), np.int32(20511), np.int32(7254), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(526), np.int32(29574), np.int32(297), np.int32(599), np.int32(18112), np.int32(491), np.int32(278), np.int32(330), np.int32(2129), np.int32(322), np.int32(17105), np.int32(297), np.int32(278), np.int32(4799), np.int32(29889), np.int32(910), np.int32(14801), np.int32(292), np.int32(9946), np.int32(278), np.int32(14744), np.int32(304), np.int32(1106), np.int32(7254), np.int32(29889), np.int32(2398), np.int32(29892), np.int32(278), np.int32(5520), np.int32(2654), np.int32(322), np.int32(13328), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(1209), np.int32(1549), np.int32(278), np.int32(25005), np.int32(901), np.int32(5948), np.int32(322), np.int32(526), np.int32(3109), np.int32(29574), np.int32(29889), np.int32(1932), np.int32(591), np.int32(1106), np.int32(472), np.int32(278), np.int32(8991), np.int32(29892), np.int32(591), np.int32(1074), np.int32(372), np.int32(408), np.int32(263), np.int32(13328), np.int32(470), np.int32(24841), np.int32(8086), np.int32(1363), np.int32(278), np.int32(7254), np.int32(3578), np.int32(338), np.int32(29574), np.int32(714), np.int32(310), np.int32(1749), np.int32(1196), np.int32(310), np.int32(11126), np.int32(29892), np.int32(322), np.int32(278), np.int32(9886), np.int32(3578), np.int32(393), np.int32(22170), np.int32(1749), np.int32(5076), np.int32(338), np.int32(758), np.int32(24130), np.int32(10835), np.int32(13328), np.int32(322), np.int32(2654), np.int32(29889), np.int32(32000)] # index = 0 # for result in results: # print(f'{output_names[index]}: {result.shape}, {result.dtype}') # index += 1 print(np.argmax(results[0], axis=-1)[-1]) # print(np.argmax(results[0].numpy(), axis=-1)[-1]) # golden results # Time cost in ms: 1255.2332878112793 # [30751 13 13 1494 1731 263 29889 372 13 24380 13 450] # lastTokenLen:12 # Time cost in ms: 1006.781816482544 # [8991] last_generated_token = np.argmax(results[0], axis=-1)[-1][-1] history_tokens.append(last_generated_token) NUM_INFERENCE = 15 for i in range(NUM_INFERENCE): # update kvcahe for index in range(len(output_names)): if not output_names[index].startswith('present'): continue # print(f'{output_names[index]}: {results[index].shape}') outputname = output_names[index] inputname = outputname.replace('present', 'past_key_values') flattened_past_key_values[inputname] = results[index] # update input token flattened_past_key_values[f'input_ids'] = np.array([[last_generated_token]], dtype=np.int64) flattened_past_key_values[f'attention_mask'] = np.array([[1] * len(history_tokens)], dtype=np.int64) before = time.time() results = sess.run(output_names, flattened_past_key_values) after = time.time() print("Time cost in ms: ", (after - before) * 1000) last_generated_token = np.argmax(results[0], axis=-1)[-1][-1] history_tokens.append(last_generated_token) print(tokenizer.decode(history_tokens)) ``` ### Tickets: - related to 155287, 157123 --------- Co-authored-by: Yu, Zijun <[email protected]> Co-authored-by: Tomasz Jankowski <[email protected]>
Details:
Test scripts
Tickets: