Skip to content

Commit ccece16

Browse files
chore: unit test for cohere and handle stop curl (#1856)
* chore: unit test for cohere and handle stop curl * fix: parse failed nlohmann::json * fix: tojson string * fix: return * fix: escape json --------- Co-authored-by: vansangpfiev <[email protected]>
1 parent b3df25d commit ccece16

File tree

5 files changed

+302
-11
lines changed

5 files changed

+302
-11
lines changed

engine/extensions/remote-engine/remote_engine.cc

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
3535
status["has_error"] = true;
3636
status["is_stream"] = true;
3737
status["status_code"] = k400BadRequest;
38+
context->need_stop = false;
3839
(*context->callback)(std::move(status), std::move(check_error));
3940
return size * nmemb;
4041
}
@@ -58,7 +59,8 @@ size_t StreamWriteCallback(char* ptr, size_t size, size_t nmemb,
5859
status["is_done"] = true;
5960
status["has_error"] = false;
6061
status["is_stream"] = true;
61-
status["status_code"] = 200;
62+
status["status_code"] = k200OK;
63+
context->need_stop = false;
6264
(*context->callback)(std::move(status), Json::Value());
6365
break;
6466
}
@@ -169,6 +171,15 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest(
169171

170172
curl_slist_free_all(headers);
171173
curl_easy_cleanup(curl);
174+
if (context.need_stop) {
175+
CTL_DBG("No stop message received, need to stop");
176+
Json::Value status;
177+
status["is_done"] = true;
178+
status["has_error"] = false;
179+
status["is_stream"] = true;
180+
status["status_code"] = k200OK;
181+
(*context.callback)(std::move(status), Json::Value());
182+
}
172183
return response;
173184
}
174185

@@ -602,6 +613,7 @@ void RemoteEngine::HandleChatCompletion(
602613
status["status_code"] = k500InternalServerError;
603614
Json::Value error;
604615
error["error"] = "Failed to parse response";
616+
LOG_WARN << "Failed to parse response: " << response.body;
605617
callback(std::move(status), std::move(error));
606618
return;
607619
}
@@ -626,15 +638,19 @@ void RemoteEngine::HandleChatCompletion(
626638

627639
try {
628640
response_json["stream"] = false;
641+
if (!response_json.isMember("model")) {
642+
response_json["model"] = model;
643+
}
629644
response_str = renderer_.Render(template_str, response_json);
630645
} catch (const std::exception& e) {
631646
throw std::runtime_error("Template rendering error: " +
632647
std::string(e.what()));
633648
}
634649
} catch (const std::exception& e) {
635650
// Log error and potentially rethrow or handle accordingly
636-
LOG_WARN << "Error in TransformRequest: " << e.what();
637-
LOG_WARN << "Using original request body";
651+
LOG_WARN << "Error: " << e.what();
652+
LOG_WARN << "Response: " << response.body;
653+
LOG_WARN << "Using original body";
638654
response_str = response_json.toStyledString();
639655
}
640656

@@ -649,6 +665,7 @@ void RemoteEngine::HandleChatCompletion(
649665
Json::Value error;
650666
error["error"] = "Failed to parse response";
651667
callback(std::move(status), std::move(error));
668+
LOG_WARN << "Failed to parse response: " << response_str;
652669
return;
653670
}
654671

engine/extensions/remote-engine/remote_engine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct StreamContext {
2424
std::string model;
2525
extensions::TemplateRenderer& renderer;
2626
std::string stream_template;
27+
bool need_stop = true;
2728
};
2829
struct CurlResponse {
2930
std::string body;

engine/extensions/template_renderer.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
#include <regex>
88
#include <stdexcept>
99
#include "utils/logging_utils.h"
10+
#include "utils/string_utils.h"
1011
namespace extensions {
12+
1113
TemplateRenderer::TemplateRenderer() {
1214
// Configure Inja environment
1315
env_.set_trim_blocks(true);
@@ -21,7 +23,8 @@ TemplateRenderer::TemplateRenderer() {
2123
const auto& value = *args[0];
2224

2325
if (value.is_string()) {
24-
return nlohmann::json(std::string("\"") + value.get<std::string>() +
26+
return nlohmann::json(std::string("\"") +
27+
string_utils::EscapeJson(value.get<std::string>()) +
2528
"\"");
2629
}
2730
return value;
@@ -46,16 +49,14 @@ std::string TemplateRenderer::Render(const std::string& tmpl,
4649
std::string result = env_.render(tmpl, template_data);
4750

4851
// Clean up any potential double quotes in JSON strings
49-
result = std::regex_replace(result, std::regex("\\\"\\\""), "\"");
52+
// result = std::regex_replace(result, std::regex("\\\"\\\""), "\"");
5053

5154
LOG_DEBUG << "Result: " << result;
5255

53-
// Validate JSON
54-
auto parsed = nlohmann::json::parse(result);
55-
5656
return result;
5757
} catch (const std::exception& e) {
5858
LOG_ERROR << "Template rendering failed: " << e.what();
59+
LOG_ERROR << "Data: " << data.toStyledString();
5960
LOG_ERROR << "Template: " << tmpl;
6061
throw std::runtime_error(std::string("Template rendering failed: ") +
6162
e.what());
@@ -133,4 +134,4 @@ std::string TemplateRenderer::RenderFile(const std::string& template_path,
133134
e.what());
134135
}
135136
}
136-
} // namespace remote_engine
137+
} // namespace extensions

engine/test/components/test_remote_engine.cc

Lines changed: 236 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ TEST_F(RemoteEngineTest, OpenAiToAnthropicRequest) {
1818
"messages": [
1919
{% for message in input_request.messages %}
2020
{% if not loop.is_first %}
21-
{"role": "{{ message.role }}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %}
21+
{"role": "{{ message.role }}", "content": {{ tojson(message.content) }} } {% if not loop.is_last %},{% endif %}
2222
{% endif %}
2323
{% endfor %}
2424
]
2525
{% else %}
2626
"messages": [
2727
{% for message in input_request.messages %}
28-
{"role": " {{ message.role}}", "content": "{{ message.content }}" } {% if not loop.is_last %},{% endif %}
28+
{"role": " {{ message.role}}", "content": {{ tojson(message.content) }} } {% if not loop.is_last %},{% endif %}
2929
{% endfor %}
3030
]
3131
{% endif %}
@@ -181,6 +181,240 @@ TEST_F(RemoteEngineTest, AnthropicResponse) {
181181
EXPECT_TRUE(res_json["choices"][0]["message"]["content"].isNull());
182182
}
183183

184+
TEST_F(RemoteEngineTest, CohereRequest) {
185+
std::string tpl =
186+
R"({
187+
{% for key, value in input_request %}
188+
{% if key == "messages" %}
189+
{% if input_request.messages.0.role == "system" %}
190+
"preamble": {{ tojson(input_request.messages.0.content) }},
191+
{% if length(input_request.messages) > 2 %}
192+
"chatHistory": [
193+
{% for message in input_request.messages %}
194+
{% if not loop.is_first and not loop.is_last %}
195+
{"role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": {{ tojson(message.content) }} } {% if loop.index < length(input_request.messages) - 2 %},{% endif %}
196+
{% endif %}
197+
{% endfor %}
198+
],
199+
{% endif %}
200+
"message": {{ tojson(last(input_request.messages).content) }}
201+
{% else %}
202+
{% if length(input_request.messages) > 2 %}
203+
"chatHistory": [
204+
{% for message in input_request.messages %}
205+
{% if not loop.is_last %}
206+
{ "role": {% if message.role == "user" %} "USER" {% else %} "CHATBOT" {% endif %}, "content": {{ tojson(message.content) }} } {% if loop.index < length(input_request.messages) - 2 %},{% endif %}
207+
{% endif %}
208+
{% endfor %}
209+
],
210+
{% endif %}
211+
"message": {{ tojson(last(input_request.messages).content) }}
212+
{% endif %}
213+
{% if not loop.is_last %},{% endif %}
214+
{% else if key == "system" or key == "model" or key == "temperature" or key == "store" or key == "max_tokens" or key == "stream" or key == "presence_penalty" or key == "metadata" or key == "frequency_penalty" or key == "tools" or key == "tool_choice" or key == "logprobs" or key == "top_logprobs" or key == "logit_bias" or key == "n" or key == "modalities" or key == "prediction" or key == "response_format" or key == "service_tier" or key == "seed" or key == "stop" or key == "stream_options" or key == "top_p" or key == "parallel_tool_calls" or key == "user" %}
215+
"{{ key }}": {{ tojson(value) }}
216+
{% if not loop.is_last %},{% endif %}
217+
{% endif %}
218+
{% endfor %} })";
219+
{
220+
std::string message_with_system = R"({
221+
"engine" : "cohere",
222+
"max_tokens" : 1024,
223+
"messages": [
224+
{"role": "system", "content": "You are a seasoned data scientist at a Fortune 500 company."},
225+
{"role": "user", "content": "Hello, world"},
226+
{"role": "assistant", "content": "The man who is widely credited with discovering gravity is Sir Isaac Newton"},
227+
{"role": "user", "content": "How are you today?"}
228+
],
229+
"model": "command-r-plus-08-2024",
230+
"stream" : true
231+
})";
232+
233+
auto data = json_helper::ParseJsonString(message_with_system);
234+
235+
extensions::TemplateRenderer rdr;
236+
auto res = rdr.Render(tpl, data);
237+
238+
auto res_json = json_helper::ParseJsonString(res);
239+
EXPECT_EQ(data["model"].asString(), res_json["model"].asString());
240+
EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt());
241+
for (auto const& msg : data["messages"]) {
242+
if (msg["role"].asString() == "system") {
243+
EXPECT_EQ(msg["content"].asString(), res_json["preamble"].asString());
244+
}
245+
}
246+
EXPECT_EQ(res_json["message"].asString(), "How are you today?");
247+
}
248+
249+
{
250+
std::string message_without_system = R"({
251+
"messages": [
252+
{"role": "user", "content": "Hello, \"the\" \n\nworld"}
253+
],
254+
"model": "command-r-plus-08-2024",
255+
"max_tokens": 1024,
256+
})";
257+
258+
auto data = json_helper::ParseJsonString(message_without_system);
259+
260+
extensions::TemplateRenderer rdr;
261+
auto res = rdr.Render(tpl, data);
262+
263+
auto res_json = json_helper::ParseJsonString(res);
264+
EXPECT_EQ(data["model"].asString(), res_json["model"].asString());
265+
EXPECT_EQ(data["max_tokens"].asInt(), res_json["max_tokens"].asInt());
266+
EXPECT_EQ(data["messages"][0]["content"].asString(),
267+
res_json["message"].asString());
268+
}
269+
}
270+
271+
TEST_F(RemoteEngineTest, CohereResponse) {
272+
std::string tpl = R"(
273+
{% if input_request.stream %}
274+
{"object": "chat.completion.chunk",
275+
"model": "{{ input_request.model }}",
276+
"choices": [{"index": 0, "delta": { {% if input_request.event_type == "text-generation" %} "role": "assistant", "content": {{ tojson(input_request.text) }} {% else %} "role": "assistant", "content": null {% endif %} },
277+
{% if input_request.event_type == "stream-end" %} "finish_reason": "{{ input_request.finish_reason }}" {% else %} "finish_reason": null {% endif %} }]
278+
}
279+
{% else %}
280+
{"id": "{{ input_request.generation_id }}",
281+
"created": null,
282+
"object": "chat.completion",
283+
"model": "{{ input_request.model }}",
284+
"choices": [{ "index": 0, "message": { "role": "assistant", "content": {% if not input_request.text %} null {% else %} {{ tojson(input_request.text) }} {% endif %}, "refusal": null }, "logprobs": null, "finish_reason": "{{ input_request.finish_reason }}" } ], "usage": { "prompt_tokens": {{ input_request.meta.tokens.input_tokens }}, "completion_tokens": {{ input_request.meta.tokens.output_tokens }}, "total_tokens": {{ input_request.meta.tokens.input_tokens + input_request.meta.tokens.output_tokens }}, "prompt_tokens_details": { "cached_tokens": 0 }, "completion_tokens_details": { "reasoning_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0 } }, "system_fingerprint": "fp_6b68a8204b"} {% endif %})";
285+
std::string message = R"({
286+
"event_type": "text-generation",
287+
"text": " help"
288+
})";
289+
auto data = json_helper::ParseJsonString(message);
290+
data["stream"] = true;
291+
data["model"] = "cohere";
292+
extensions::TemplateRenderer rdr;
293+
auto res = rdr.Render(tpl, data);
294+
auto res_json = json_helper::ParseJsonString(res);
295+
EXPECT_EQ(res_json["choices"][0]["delta"]["content"].asString(), " help");
296+
297+
message = R"(
298+
{
299+
"event_type": "stream-end",
300+
"response": {
301+
"text": "Hello! How can I help you today?",
302+
"generation_id": "29f14a5a-11de-4cae-9800-25e4747408ea",
303+
"chat_history": [
304+
{
305+
"role": "USER",
306+
"message": "hello world!"
307+
},
308+
{
309+
"role": "CHATBOT",
310+
"message": "Hello! How can I help you today?"
311+
}
312+
],
313+
"finish_reason": "COMPLETE",
314+
"meta": {
315+
"api_version": {
316+
"version": "1"
317+
},
318+
"billed_units": {
319+
"input_tokens": 3,
320+
"output_tokens": 9
321+
},
322+
"tokens": {
323+
"input_tokens": 69,
324+
"output_tokens": 9
325+
}
326+
}
327+
},
328+
"finish_reason": "COMPLETE"
329+
})";
330+
data = json_helper::ParseJsonString(message);
331+
data["stream"] = true;
332+
data["model"] = "cohere";
333+
res = rdr.Render(tpl, data);
334+
res_json = json_helper::ParseJsonString(res);
335+
EXPECT_TRUE(res_json["choices"][0]["delta"]["content"].isNull());
336+
337+
// non-stream
338+
message = R"(
339+
{
340+
"text": "Isaac \t\tNewton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 January 1643 (New Style).",
341+
"generation_id": "0385c7cf-4247-43a3-a450-b25b547a31e1",
342+
"citations": [
343+
{
344+
"start": 25,
345+
"end": 41,
346+
"text": "25 December 1642",
347+
"document_ids": [
348+
"web-search_0"
349+
]
350+
}
351+
],
352+
"search_queries": [
353+
{
354+
"text": "Isaac Newton birth year",
355+
"generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d"
356+
}
357+
],
358+
"search_results": [
359+
{
360+
"connector": {
361+
"id": "web-search"
362+
},
363+
"document_ids": [
364+
"web-search_0"
365+
],
366+
"search_query": {
367+
"text": "Isaac Newton birth year",
368+
"generation_id": "9a497980-c3e2-4460-b81c-ef44d293f95d"
369+
}
370+
}
371+
],
372+
"finish_reason": "COMPLETE",
373+
"chat_history": [
374+
{
375+
"role": "USER",
376+
"message": "Who discovered gravity?"
377+
},
378+
{
379+
"role": "CHATBOT",
380+
"message": "The man who is widely credited with discovering gravity is Sir Isaac Newton"
381+
},
382+
{
383+
"role": "USER",
384+
"message": "What year was he born?"
385+
},
386+
{
387+
"role": "CHATBOT",
388+
"message": "Isaac Newton was born on 25 December 1642 (Old Style) or 4 January 1643 (New Style)."
389+
}
390+
],
391+
"meta": {
392+
"api_version": {
393+
"version": "1"
394+
},
395+
"billed_units": {
396+
"input_tokens": 31738,
397+
"output_tokens": 35
398+
},
399+
"tokens": {
400+
"input_tokens": 32465,
401+
"output_tokens": 205
402+
}
403+
}
404+
}
405+
)";
406+
407+
data = json_helper::ParseJsonString(message);
408+
data["stream"] = false;
409+
data["model"] = "cohere";
410+
res = rdr.Render(tpl, data);
411+
res_json = json_helper::ParseJsonString(res);
412+
EXPECT_EQ(
413+
res_json["choices"][0]["message"]["content"].asString(),
414+
"Isaac \t\tNewton was 'born' on 25 \"December\" 1642 (Old Style) \n\nor 4 "
415+
"January 1643 (New Style).");
416+
}
417+
184418
TEST_F(RemoteEngineTest, HeaderTemplate) {
185419
{
186420
std::string header_template =

0 commit comments

Comments
 (0)