31
31
ConverseStreamWrapper ,
32
32
InvokeModelWithResponseStreamWrapper ,
33
33
_Choice ,
34
+ estimate_token_count ,
34
35
genai_capture_message_content ,
35
36
message_to_event ,
36
37
)
@@ -223,6 +224,23 @@ def extract_attributes(self, attributes: _AttributeMapT):
223
224
self ._extract_claude_attributes (
224
225
attributes , request_body
225
226
)
227
+ elif "cohere.command-r" in model_id :
228
+ self ._extract_command_r_attributes (
229
+ attributes , request_body
230
+ )
231
+ elif "cohere.command" in model_id :
232
+ self ._extract_command_attributes (
233
+ attributes , request_body
234
+ )
235
+ elif "meta.llama" in model_id :
236
+ self ._extract_llama_attributes (
237
+ attributes , request_body
238
+ )
239
+ elif "mistral" in model_id :
240
+ self ._extract_mistral_attributes (
241
+ attributes , request_body
242
+ )
243
+
226
244
except json .JSONDecodeError :
227
245
_logger .debug ("Error: Unable to parse the body as JSON" )
228
246
@@ -280,14 +298,102 @@ def _extract_claude_attributes(self, attributes, request_body):
280
298
request_body .get ("stop_sequences" ),
281
299
)
282
300
301
+ def _extract_command_r_attributes (self , attributes , request_body ):
302
+ prompt = request_body .get ("message" )
303
+ self ._set_if_not_none (
304
+ attributes , GEN_AI_USAGE_INPUT_TOKENS , estimate_token_count (prompt )
305
+ )
306
+ self ._set_if_not_none (
307
+ attributes ,
308
+ GEN_AI_REQUEST_MAX_TOKENS ,
309
+ request_body .get ("max_tokens" ),
310
+ )
311
+ self ._set_if_not_none (
312
+ attributes ,
313
+ GEN_AI_REQUEST_TEMPERATURE ,
314
+ request_body .get ("temperature" ),
315
+ )
316
+ self ._set_if_not_none (
317
+ attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("p" )
318
+ )
319
+ self ._set_if_not_none (
320
+ attributes ,
321
+ GEN_AI_REQUEST_STOP_SEQUENCES ,
322
+ request_body .get ("stop_sequences" ),
323
+ )
324
+
325
+ def _extract_command_attributes (self , attributes , request_body ):
326
+ prompt = request_body .get ("prompt" )
327
+ self ._set_if_not_none (
328
+ attributes , GEN_AI_USAGE_INPUT_TOKENS , estimate_token_count (prompt )
329
+ )
330
+ self ._set_if_not_none (
331
+ attributes ,
332
+ GEN_AI_REQUEST_MAX_TOKENS ,
333
+ request_body .get ("max_tokens" ),
334
+ )
335
+ self ._set_if_not_none (
336
+ attributes ,
337
+ GEN_AI_REQUEST_TEMPERATURE ,
338
+ request_body .get ("temperature" ),
339
+ )
340
+ self ._set_if_not_none (
341
+ attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("p" )
342
+ )
343
+ self ._set_if_not_none (
344
+ attributes ,
345
+ GEN_AI_REQUEST_STOP_SEQUENCES ,
346
+ request_body .get ("stop_sequences" ),
347
+ )
348
+
349
+ def _extract_llama_attributes (self , attributes , request_body ):
350
+ self ._set_if_not_none (
351
+ attributes ,
352
+ GEN_AI_REQUEST_MAX_TOKENS ,
353
+ request_body .get ("max_gen_len" ),
354
+ )
355
+ self ._set_if_not_none (
356
+ attributes ,
357
+ GEN_AI_REQUEST_TEMPERATURE ,
358
+ request_body .get ("temperature" ),
359
+ )
360
+ self ._set_if_not_none (
361
+ attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("top_p" )
362
+ )
363
+ # request for meta llama models does not contain stop_sequences field
364
+
365
+ def _extract_mistral_attributes (self , attributes , request_body ):
366
+ prompt = request_body .get ("prompt" )
367
+ if prompt :
368
+ self ._set_if_not_none (
369
+ attributes ,
370
+ GEN_AI_USAGE_INPUT_TOKENS ,
371
+ estimate_token_count (prompt ),
372
+ )
373
+ self ._set_if_not_none (
374
+ attributes ,
375
+ GEN_AI_REQUEST_MAX_TOKENS ,
376
+ request_body .get ("max_tokens" ),
377
+ )
378
+ self ._set_if_not_none (
379
+ attributes ,
380
+ GEN_AI_REQUEST_TEMPERATURE ,
381
+ request_body .get ("temperature" ),
382
+ )
383
+ self ._set_if_not_none (
384
+ attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("top_p" )
385
+ )
386
+ self ._set_if_not_none (
387
+ attributes , GEN_AI_REQUEST_STOP_SEQUENCES , request_body .get ("stop" )
388
+ )
389
+
283
390
@staticmethod
284
391
def _set_if_not_none (attributes , key , value ):
285
392
if value is not None :
286
393
attributes [key ] = value
287
394
288
395
def _get_request_messages (self ):
289
396
"""Extracts and normalize system and user / assistant messages"""
290
- input_text = None
291
397
if system := self ._call_context .params .get ("system" , []):
292
398
system_messages = [{"role" : "system" , "content" : system }]
293
399
else :
@@ -304,15 +410,37 @@ def _get_request_messages(self):
304
410
system_messages = [{"role" : "system" , "content" : content }]
305
411
306
412
messages = decoded_body .get ("messages" , [])
413
+ # if no messages interface, convert to messages format from generic API
307
414
if not messages :
308
- # transform old school amazon titan invokeModel api to messages
309
- if input_text := decoded_body .get ("inputText" ):
310
- messages = [
311
- {"role" : "user" , "content" : [{"text" : input_text }]}
312
- ]
415
+ model_id = self ._call_context .params .get (_MODEL_ID_KEY )
416
+ if "amazon.titan" in model_id :
417
+ messages = self ._get_messages_from_input_text (
418
+ decoded_body , "inputText"
419
+ )
420
+ elif "cohere.command-r" in model_id :
421
+ # chat_history can be converted to messages; for now, just use message
422
+ messages = self ._get_messages_from_input_text (
423
+ decoded_body , "message"
424
+ )
425
+ elif (
426
+ "cohere.command" in model_id
427
+ or "meta.llama" in model_id
428
+ or "mistral.mistral" in model_id
429
+ ):
430
+ messages = self ._get_messages_from_input_text (
431
+ decoded_body , "prompt"
432
+ )
313
433
314
434
return system_messages + messages
315
435
436
+ # pylint: disable=no-self-use
437
+ def _get_messages_from_input_text (
438
+ self , decoded_body : dict [str , Any ], input_name : str
439
+ ):
440
+ if input_text := decoded_body .get (input_name ):
441
+ return [{"role" : "user" , "content" : [{"text" : input_text }]}]
442
+ return []
443
+
316
444
def before_service_call (
317
445
self , span : Span , instrumentor_context : _BotocoreInstrumentorContext
318
446
):
@@ -439,6 +567,22 @@ def _invoke_model_on_success(
439
567
self ._handle_anthropic_claude_response (
440
568
span , response_body , instrumentor_context , capture_content
441
569
)
570
+ elif "cohere.command-r" in model_id :
571
+ self ._handle_cohere_command_r_response (
572
+ span , response_body , instrumentor_context , capture_content
573
+ )
574
+ elif "cohere.command" in model_id :
575
+ self ._handle_cohere_command_response (
576
+ span , response_body , instrumentor_context , capture_content
577
+ )
578
+ elif "meta.llama" in model_id :
579
+ self ._handle_meta_llama_response (
580
+ span , response_body , instrumentor_context , capture_content
581
+ )
582
+ elif "mistral" in model_id :
583
+ self ._handle_mistral_ai_response (
584
+ span , response_body , instrumentor_context , capture_content
585
+ )
442
586
except json .JSONDecodeError :
443
587
_logger .debug ("Error: Unable to parse the response body as JSON" )
444
588
except Exception as exc : # pylint: disable=broad-exception-caught
@@ -725,6 +869,106 @@ def _handle_anthropic_claude_response(
725
869
output_tokens , output_attributes
726
870
)
727
871
872
+ def _handle_cohere_command_r_response (
873
+ self ,
874
+ span : Span ,
875
+ response_body : dict [str , Any ],
876
+ instrumentor_context : _BotocoreInstrumentorContext ,
877
+ capture_content : bool ,
878
+ ):
879
+ if "text" in response_body :
880
+ span .set_attribute (
881
+ GEN_AI_USAGE_OUTPUT_TOKENS ,
882
+ estimate_token_count (response_body ["text" ]),
883
+ )
884
+ if "finish_reason" in response_body :
885
+ span .set_attribute (
886
+ GEN_AI_RESPONSE_FINISH_REASONS ,
887
+ [response_body ["finish_reason" ]],
888
+ )
889
+
890
+ event_logger = instrumentor_context .event_logger
891
+ choice = _Choice .from_invoke_cohere_command_r (
892
+ response_body , capture_content
893
+ )
894
+ event_logger .emit (choice .to_choice_event ())
895
+
896
+ def _handle_cohere_command_response (
897
+ self ,
898
+ span : Span ,
899
+ response_body : dict [str , Any ],
900
+ instrumentor_context : _BotocoreInstrumentorContext ,
901
+ capture_content : bool ,
902
+ ):
903
+ if "generations" in response_body and response_body ["generations" ]:
904
+ generations = response_body ["generations" ][0 ]
905
+ if "text" in generations :
906
+ span .set_attribute (
907
+ GEN_AI_USAGE_OUTPUT_TOKENS ,
908
+ estimate_token_count (generations ["text" ]),
909
+ )
910
+ if "finish_reason" in generations :
911
+ span .set_attribute (
912
+ GEN_AI_RESPONSE_FINISH_REASONS ,
913
+ [generations ["finish_reason" ]],
914
+ )
915
+
916
+ event_logger = instrumentor_context .event_logger
917
+ choice = _Choice .from_invoke_cohere_command (
918
+ response_body , capture_content
919
+ )
920
+ event_logger .emit (choice .to_choice_event ())
921
+
922
+ def _handle_meta_llama_response (
923
+ self ,
924
+ span : Span ,
925
+ response_body : dict [str , Any ],
926
+ instrumentor_context : _BotocoreInstrumentorContext ,
927
+ capture_content : bool ,
928
+ ):
929
+ if "prompt_token_count" in response_body :
930
+ span .set_attribute (
931
+ GEN_AI_USAGE_INPUT_TOKENS , response_body ["prompt_token_count" ]
932
+ )
933
+ if "generation_token_count" in response_body :
934
+ span .set_attribute (
935
+ GEN_AI_USAGE_OUTPUT_TOKENS ,
936
+ response_body ["generation_token_count" ],
937
+ )
938
+ if "stop_reason" in response_body :
939
+ span .set_attribute (
940
+ GEN_AI_RESPONSE_FINISH_REASONS , [response_body ["stop_reason" ]]
941
+ )
942
+
943
+ event_logger = instrumentor_context .event_logger
944
+ choice = _Choice .from_invoke_meta_llama (response_body , capture_content )
945
+ event_logger .emit (choice .to_choice_event ())
946
+
947
+ def _handle_mistral_ai_response (
948
+ self ,
949
+ span : Span ,
950
+ response_body : dict [str , Any ],
951
+ instrumentor_context : _BotocoreInstrumentorContext ,
952
+ capture_content : bool ,
953
+ ):
954
+ if "outputs" in response_body :
955
+ outputs = response_body ["outputs" ][0 ]
956
+ if "text" in outputs :
957
+ span .set_attribute (
958
+ GEN_AI_USAGE_OUTPUT_TOKENS ,
959
+ estimate_token_count (outputs ["text" ]),
960
+ )
961
+ if "stop_reason" in outputs :
962
+ span .set_attribute (
963
+ GEN_AI_RESPONSE_FINISH_REASONS , [outputs ["stop_reason" ]]
964
+ )
965
+
966
+ event_logger = instrumentor_context .event_logger
967
+ choice = _Choice .from_invoke_mistral_mistral (
968
+ response_body , capture_content
969
+ )
970
+ event_logger .emit (choice .to_choice_event ())
971
+
728
972
def on_error (
729
973
self ,
730
974
span : Span ,
0 commit comments