Skip to content

Commit 1a73c0c

Browse files
fix: forward start model parameters (#1825)
Co-authored-by: vansangpfiev <[email protected]>
1 parent fb72167 commit 1a73c0c

File tree

5 files changed

+116
-73
lines changed

5 files changed

+116
-73
lines changed

engine/controllers/models.cc

Lines changed: 15 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -488,65 +488,40 @@ void Models::StartModel(
488488
if (!http_util::HasFieldInReq(req, callback, "model"))
489489
return;
490490
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
491-
StartParameterOverride params_override;
492-
if (auto& o = (*(req->getJsonObject()))["prompt_template"]; !o.isNull()) {
493-
params_override.custom_prompt_template = o.asString();
494-
}
495-
496-
if (auto& o = (*(req->getJsonObject()))["cache_enabled"]; !o.isNull()) {
497-
params_override.cache_enabled = o.asBool();
498-
}
499-
500-
if (auto& o = (*(req->getJsonObject()))["ngl"]; !o.isNull()) {
501-
params_override.ngl = o.asInt();
502-
}
503-
504-
if (auto& o = (*(req->getJsonObject()))["n_parallel"]; !o.isNull()) {
505-
params_override.n_parallel = o.asInt();
506-
}
507-
508-
if (auto& o = (*(req->getJsonObject()))["ctx_len"]; !o.isNull()) {
509-
params_override.ctx_len = o.asInt();
510-
}
511-
512-
if (auto& o = (*(req->getJsonObject()))["cache_type"]; !o.isNull()) {
513-
params_override.cache_type = o.asString();
514-
}
515491

492+
std::optional<std::string> mmproj;
516493
if (auto& o = (*(req->getJsonObject()))["mmproj"]; !o.isNull()) {
517-
params_override.mmproj = o.asString();
494+
mmproj = o.asString();
518495
}
519496

497+
auto bypass_llama_model_path = false;
520498
// Support both llama_model_path and model_path for backward compatible
521499
// model_path has higher priority
522500
if (auto& o = (*(req->getJsonObject()))["llama_model_path"]; !o.isNull()) {
523-
params_override.model_path = o.asString();
501+
auto model_path = o.asString();
524502
if (auto& mp = (*(req->getJsonObject()))["model_path"]; mp.isNull()) {
525503
// Bypass if model does not exist in DB and llama_model_path exists
526-
if (std::filesystem::exists(params_override.model_path.value()) &&
504+
if (std::filesystem::exists(model_path) &&
527505
!model_service_->HasModel(model_handle)) {
528506
CTL_INF("llama_model_path exists, bypass check model id");
529-
params_override.bypass_llama_model_path = true;
507+
bypass_llama_model_path = true;
530508
}
531509
}
532510
}
533511

534-
if (auto& o = (*(req->getJsonObject()))["model_path"]; !o.isNull()) {
535-
params_override.model_path = o.asString();
536-
}
512+
auto bypass_model_check = (mmproj.has_value() || bypass_llama_model_path);
537513

538514
auto model_entry = model_service_->GetDownloadedModel(model_handle);
539-
if (!model_entry.has_value() && !params_override.bypass_model_check()) {
515+
if (!model_entry.has_value() && !bypass_model_check) {
540516
Json::Value ret;
541517
ret["message"] = "Cannot find model: " + model_handle;
542518
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
543519
resp->setStatusCode(drogon::k400BadRequest);
544520
callback(resp);
545521
return;
546522
}
547-
std::string engine_name = params_override.bypass_model_check()
548-
? kLlamaEngine
549-
: model_entry.value().engine;
523+
std::string engine_name =
524+
bypass_model_check ? kLlamaEngine : model_entry.value().engine;
550525
auto engine_validate = engine_service_->IsEngineReady(engine_name);
551526
if (engine_validate.has_error()) {
552527
Json::Value ret;
@@ -565,7 +540,9 @@ void Models::StartModel(
565540
return;
566541
}
567542

568-
auto result = model_service_->StartModel(model_handle, params_override);
543+
auto result = model_service_->StartModel(
544+
model_handle, *(req->getJsonObject()) /*params_override*/,
545+
bypass_model_check);
569546
if (result.has_error()) {
570547
Json::Value ret;
571548
ret["message"] = result.error();
@@ -668,7 +645,7 @@ void Models::AddRemoteModel(
668645

669646
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
670647
auto engine_name = (*(req->getJsonObject())).get("engine", "").asString();
671-
648+
672649
auto engine_validate = engine_service_->IsEngineReady(engine_name);
673650
if (engine_validate.has_error()) {
674651
Json::Value ret;
@@ -687,7 +664,7 @@ void Models::AddRemoteModel(
687664
callback(resp);
688665
return;
689666
}
690-
667+
691668
config::RemoteModelConfig model_config;
692669
model_config.LoadFromJson(*(req->getJsonObject()));
693670
cortex::db::Models modellist_utils_obj;

engine/services/model_service.cc

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -749,19 +749,28 @@ cpp::result<void, std::string> ModelService::DeleteModel(
749749
}
750750

751751
cpp::result<StartModelResult, std::string> ModelService::StartModel(
752-
const std::string& model_handle,
753-
const StartParameterOverride& params_override) {
752+
const std::string& model_handle, const Json::Value& params_override,
753+
bool bypass_model_check) {
754754
namespace fs = std::filesystem;
755755
namespace fmu = file_manager_utils;
756756
cortex::db::Models modellist_handler;
757757
config::YamlHandler yaml_handler;
758+
std::optional<std::string> custom_prompt_template;
759+
std::optional<int> ctx_len;
760+
if (auto& o = params_override["prompt_template"]; !o.isNull()) {
761+
custom_prompt_template = o.asString();
762+
}
763+
764+
if (auto& o = params_override["ctx_len"]; !o.isNull()) {
765+
ctx_len = o.asInt();
766+
}
758767

759768
try {
760769
constexpr const int kDefautlContextLength = 8192;
761770
int max_model_context_length = kDefautlContextLength;
762771
Json::Value json_data;
763772
// Currently we don't support download vision models, so we need to bypass check
764-
if (!params_override.bypass_model_check()) {
773+
if (!bypass_model_check) {
765774
auto model_entry = modellist_handler.GetModelInfo(model_handle);
766775
if (model_entry.has_error()) {
767776
CTL_WRN("Error: " + model_entry.error());
@@ -839,29 +848,19 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
839848
}
840849

841850
json_data["model"] = model_handle;
842-
if (auto& cpt = params_override.custom_prompt_template;
843-
!cpt.value_or("").empty()) {
851+
if (auto& cpt = custom_prompt_template; !cpt.value_or("").empty()) {
844852
auto parse_prompt_result = string_utils::ParsePrompt(cpt.value());
845853
json_data["system_prompt"] = parse_prompt_result.system_prompt;
846854
json_data["user_prompt"] = parse_prompt_result.user_prompt;
847855
json_data["ai_prompt"] = parse_prompt_result.ai_prompt;
848856
}
849857

850-
#define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \
851-
if (param_override.param_name) { \
852-
json_obj[#param_name] = param_override.param_name.value(); \
853-
}
858+
json_helper::MergeJson(json_data, params_override);
854859

855-
ASSIGN_IF_PRESENT(json_data, params_override, cache_enabled);
856-
ASSIGN_IF_PRESENT(json_data, params_override, ngl);
857-
ASSIGN_IF_PRESENT(json_data, params_override, n_parallel);
858-
ASSIGN_IF_PRESENT(json_data, params_override, cache_type);
859-
ASSIGN_IF_PRESENT(json_data, params_override, mmproj);
860-
ASSIGN_IF_PRESENT(json_data, params_override, model_path);
861-
#undef ASSIGN_IF_PRESENT
862-
if (params_override.ctx_len) {
860+
// Set the latest ctx_len
861+
if (ctx_len) {
863862
json_data["ctx_len"] =
864-
std::min(params_override.ctx_len.value(), max_model_context_length);
863+
std::min(ctx_len.value(), max_model_context_length);
865864
}
866865
CTL_INF(json_data.toStyledString());
867866
auto may_fallback_res = MayFallbackToCpu(json_data["model_path"].asString(),

engine/services/model_service.h

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,6 @@ struct ModelPullInfo {
2222
std::string download_url;
2323
};
2424

25-
struct StartParameterOverride {
26-
std::optional<bool> cache_enabled;
27-
std::optional<int> ngl;
28-
std::optional<int> n_parallel;
29-
std::optional<int> ctx_len;
30-
std::optional<std::string> custom_prompt_template;
31-
std::optional<std::string> cache_type;
32-
std::optional<std::string> mmproj;
33-
std::optional<std::string> model_path;
34-
bool bypass_llama_model_path = false;
35-
bool bypass_model_check() const {
36-
return mmproj.has_value() || bypass_llama_model_path;
37-
}
38-
};
39-
4025
struct StartModelResult {
4126
bool success;
4227
std::optional<std::string> warning;
@@ -82,8 +67,8 @@ class ModelService {
8267
cpp::result<void, std::string> DeleteModel(const std::string& model_handle);
8368

8469
cpp::result<StartModelResult, std::string> StartModel(
85-
const std::string& model_handle,
86-
const StartParameterOverride& params_override);
70+
const std::string& model_handle, const Json::Value& params_override,
71+
bool bypass_model_check);
8772

8873
cpp::result<bool, std::string> StopModel(const std::string& model_handle);
8974

engine/test/components/test_json_helper.cc

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,61 @@ TEST(ParseJsonStringTest, EmptyString) {
3333

3434
EXPECT_TRUE(result.isNull());
3535
}
36+
37+
TEST(MergeJsonTest, MergeSimpleObjects) {
38+
Json::Value json1, json2;
39+
json1["name"] = "John";
40+
json1["age"] = 30;
41+
42+
json2["age"] = 31;
43+
json2["email"] = "[email protected]";
44+
45+
json_helper::MergeJson(json1, json2);
46+
47+
Json::Value expected;
48+
expected["name"] = "John";
49+
expected["age"] = 31;
50+
expected["email"] = "[email protected]";
51+
52+
EXPECT_EQ(json1, expected);
53+
}
54+
55+
TEST(MergeJsonTest, MergeNestedObjects) {
56+
Json::Value json1, json2;
57+
json1["person"]["name"] = "John";
58+
json1["person"]["age"] = 30;
59+
60+
json2["person"]["age"] = 31;
61+
json2["person"]["email"] = "[email protected]";
62+
63+
json_helper::MergeJson(json1, json2);
64+
65+
Json::Value expected;
66+
expected["person"]["name"] = "John";
67+
expected["person"]["age"] = 31;
68+
expected["person"]["email"] = "[email protected]";
69+
70+
EXPECT_EQ(json1, expected);
71+
}
72+
73+
TEST(MergeJsonTest, MergeArrays) {
74+
Json::Value json1, json2;
75+
json1["hobbies"] = Json::Value(Json::arrayValue);
76+
json1["hobbies"].append("reading");
77+
json1["hobbies"].append("painting");
78+
79+
json2["hobbies"] = Json::Value(Json::arrayValue);
80+
json2["hobbies"].append("hiking");
81+
json2["hobbies"].append("painting");
82+
83+
json_helper::MergeJson(json1, json2);
84+
85+
Json::Value expected;
86+
expected["hobbies"] = Json::Value(Json::arrayValue);
87+
expected["hobbies"].append("reading");
88+
expected["hobbies"].append("painting");
89+
expected["hobbies"].append("hiking");
90+
expected["hobbies"].append("painting");
91+
92+
EXPECT_EQ(json1, expected);
93+
}

engine/utils/json_helper.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,28 @@ inline std::string DumpJsonString(const Json::Value& json) {
1616
builder["indentation"] = "";
1717
return Json::writeString(builder, json);
1818
}
19+
20+
inline void MergeJson(Json::Value& target, const Json::Value& source) {
21+
for (const auto& member : source.getMemberNames()) {
22+
if (target.isMember(member)) {
23+
// If the member exists in both objects, recursively merge the values
24+
if (target[member].type() == Json::objectValue &&
25+
source[member].type() == Json::objectValue) {
26+
MergeJson(target[member], source[member]);
27+
} else if (target[member].type() == Json::arrayValue &&
28+
source[member].type() == Json::arrayValue) {
29+
// If the member is an array in both objects, merge the arrays
30+
for (const auto& value : source[member]) {
31+
target[member].append(value);
32+
}
33+
} else {
34+
// Otherwise, overwrite the value in the target with the value from the source
35+
target[member] = source[member];
36+
}
37+
} else {
38+
// If the member doesn't exist in the target, add it
39+
target[member] = source[member];
40+
}
41+
}
42+
}
1943
} // namespace json_helper

0 commit comments

Comments
 (0)