Skip to content

Commit d9779ea

Browse files
committed
feat: add support for hugging face model handle
1 parent effadfb commit d9779ea

File tree

6 files changed

+362
-12
lines changed

6 files changed

+362
-12
lines changed

engine/services/model_service.cc

+106-11
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,107 @@
11
#include "model_service.h"
22
#include <filesystem>
33
#include <iostream>
4+
#include <ostream>
45
#include "commands/cmd_info.h"
56
#include "utils/cortexso_parser.h"
67
#include "utils/file_manager_utils.h"
8+
#include "utils/huggingface_utils.h"
79
#include "utils/logging_utils.h"
810
#include "utils/model_callback_utils.h"
9-
#include "utils/url_parser.h"
11+
12+
void PrintMenu(const std::vector<std::string>& options) {
13+
auto index{1};
14+
for (const auto& option : options) {
15+
std::cout << index << ". " << option << "\n";
16+
index++;
17+
}
18+
std::endl(std::cout);
19+
}
20+
21+
std::optional<std::string> PrintSelection(
22+
const std::vector<std::string>& options) {
23+
std::string selection{""};
24+
PrintMenu(options);
25+
std::cin >> selection;
26+
27+
if (selection.empty()) {
28+
return std::nullopt;
29+
}
30+
31+
// std::cout << "Selection: " << selection << "\n";
32+
// std::cout << "Int representaion: " << std::stoi(selection) << "\n";
33+
if (std::stoi(selection) > options.size() || std::stoi(selection) < 1) {
34+
return std::nullopt;
35+
}
36+
37+
return options[std::stoi(selection) - 1];
38+
}
1039

1140
void ModelService::DownloadModel(const std::string& input) {
1241
if (input.empty()) {
1342
throw std::runtime_error(
1443
"Input must be Cortex Model Hub handle or HuggingFace url!");
1544
}
1645

17-
// case input is a direct url
18-
auto url_obj = url_parser::FromUrlString(input);
19-
// TODO: handle case user paste url from cortexso
20-
if (url_obj.protocol == "https") {
21-
if (url_obj.host != kHuggingFaceHost) {
22-
CLI_LOG("Only huggingface.co is supported for now");
46+
if (input.starts_with("https://")) {
47+
return DownloadModelByDirectUrl(input);
48+
}
49+
50+
// if input contains / then handle it differently
51+
if (input.find("/") != std::string::npos) {
52+
// TODO: what if we have more than one /?
53+
// TODO: what if the left size of / is cortexso?
54+
55+
// split by /. TODO: Move this function to somewhere else
56+
std::string model_input = input;
57+
std::string delimiter{"/"};
58+
std::string token{""};
59+
std::vector<std::string> parsed{};
60+
std::string author{""};
61+
std::string model_name{""};
62+
while (token != model_input) {
63+
token = model_input.substr(0, model_input.find_first_of("/"));
64+
model_input = model_input.substr(model_input.find_first_of("/") + 1);
65+
std::string new_str{token};
66+
parsed.push_back(new_str);
67+
}
68+
69+
author = parsed[0];
70+
model_name = parsed[1];
71+
auto repo_info =
72+
huggingface_utils::GetHuggingFaceModelRepoInfo(author, model_name);
73+
if (!repo_info.has_value()) {
74+
// throw is better?
75+
CTL_ERR("Model not found");
2376
return;
2477
}
25-
return DownloadModelByDirectUrl(input);
26-
} else {
27-
commands::CmdInfo ci(input);
28-
return DownloadModelFromCortexso(ci.model_name, ci.branch);
78+
79+
if (!repo_info->gguf.has_value()) {
80+
throw std::runtime_error(
81+
"Not a GGUF model. Currently, only GGUF single file is supported.");
82+
}
83+
84+
std::vector<std::string> options{};
85+
for (const auto& sibling : repo_info->siblings) {
86+
if (sibling.rfilename.ends_with(".gguf")) {
87+
options.push_back(sibling.rfilename);
88+
}
89+
}
90+
auto selection = PrintSelection(options);
91+
std::cout << "Selected: " << selection.value() << std::endl;
92+
93+
auto download_url = huggingface_utils::GetDownloadableUrl(
94+
author, model_name, selection.value());
95+
96+
std::cout << "Download url: " << download_url << std::endl;
97+
// TODO: split to this function
98+
// DownloadHuggingFaceGgufModel(author, model_name, nullptr);
99+
return;
29100
}
101+
102+
// user just input a text, seems like a model name only, maybe comes with a branch, using : as delimeter
103+
// handle cortexso here
104+
// separate into another function and the above can route to it if we regconize a cortexso url
30105
}
31106

32107
std::optional<config::ModelConfig> ModelService::GetDownloadedModel(
@@ -114,3 +189,23 @@ void ModelService::DownloadModelFromCortexso(const std::string& name,
114189
CTL_ERR("Model not found");
115190
}
116191
}
192+
193+
void ModelService::DownloadHuggingFaceGgufModel(
194+
const std::string& author, const std::string& modelName,
195+
std::optional<std::string> fileName) {
196+
std::cout << author << std::endl;
197+
std::cout << modelName << std::endl;
198+
// if we don't have file name, we must display a list for user to pick
199+
// auto repo_info =
200+
// huggingface_utils::GetHuggingFaceModelRepoInfo(author, modelName);
201+
//
202+
// if (!repo_info.has_value()) {
203+
// // throw is better?
204+
// CTL_ERR("Model not found");
205+
// return;
206+
// }
207+
//
208+
// for (const auto& sibling : repo_info->siblings) {
209+
// std::cout << sibling.rfilename << "\n";
210+
// }
211+
}

engine/services/model_service.h

+7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ class ModelService {
1919
void DownloadModelFromCortexso(const std::string& name,
2020
const std::string& branch);
2121

22+
/**
23+
* Handle downloading model which have following pattern: author/model_name
24+
*/
25+
void DownloadHuggingFaceGgufModel(const std::string& author,
26+
const std::string& modelName,
27+
std::optional<std::string> fileName);
28+
2229
DownloadService download_service_;
2330

2431
constexpr auto static kHuggingFaceHost = "huggingface.co";

engine/test/components/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ find_package(Drogon CONFIG REQUIRED)
99
find_package(GTest CONFIG REQUIRED)
1010
find_package(yaml-cpp CONFIG REQUIRED)
1111
find_package(jinja2cpp CONFIG REQUIRED)
12+
find_package(httplib CONFIG REQUIRED)
1213

1314
target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp jinja2cpp
1415
${CMAKE_THREAD_LIBS_INIT})
16+
17+
target_link_libraries(${PROJECT_NAME} PRIVATE httplib::httplib)
1518
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
1619

1720
add_test(NAME ${PROJECT_NAME}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#include "gtest/gtest.h"
2+
#include "utils/huggingface_utils.h"
3+
4+
class HuggingFaceUtilTestSuite : public ::testing::Test {};
5+
6+
TEST_F(HuggingFaceUtilTestSuite, TestGetModelRepositoryBranches) {
7+
auto branches =
8+
huggingface_utils::GetModelRepositoryBranches("cortexso", "tinyllama");
9+
10+
EXPECT_EQ(branches.size(), 3);
11+
EXPECT_EQ(branches[0].name, "gguf");
12+
EXPECT_EQ(branches[0].ref, "refs/heads/gguf");
13+
EXPECT_EQ(branches[1].name, "1b-gguf");
14+
EXPECT_EQ(branches[1].ref, "refs/heads/1b-gguf");
15+
EXPECT_EQ(branches[2].name, "main");
16+
EXPECT_EQ(branches[2].ref, "refs/heads/main");
17+
}
18+
19+
TEST_F(HuggingFaceUtilTestSuite, TestGetHuggingFaceModelRepoInfoSuccessfully) {
20+
auto model_info =
21+
huggingface_utils::GetHuggingFaceModelRepoInfo("cortexso", "tinyllama");
22+
auto not_null = model_info.has_value();
23+
24+
EXPECT_TRUE(not_null);
25+
EXPECT_EQ(model_info->id, "cortexso/tinyllama");
26+
EXPECT_EQ(model_info->modelId, "cortexso/tinyllama");
27+
EXPECT_EQ(model_info->author, "cortexso");
28+
EXPECT_EQ(model_info->disabled, false);
29+
EXPECT_EQ(model_info->gated, false);
30+
31+
auto tag_contains_gguf =
32+
std::find(model_info->tags.begin(), model_info->tags.end(), "gguf") !=
33+
model_info->tags.end();
34+
EXPECT_TRUE(tag_contains_gguf);
35+
36+
auto contain_gguf_info = model_info->gguf.has_value();
37+
EXPECT_TRUE(contain_gguf_info);
38+
39+
auto sibling_not_empty = !model_info->siblings.empty();
40+
EXPECT_TRUE(sibling_not_empty);
41+
}
42+
43+
TEST_F(HuggingFaceUtilTestSuite,
44+
TestGetHuggingFaceModelRepoInfoReturnNullGgufInfoWhenNotAGgufModel) {
45+
auto model_info = huggingface_utils::GetHuggingFaceModelRepoInfo(
46+
"BAAI", "bge-reranker-v2-m3");
47+
auto not_null = model_info.has_value();
48+
49+
EXPECT_TRUE(not_null);
50+
EXPECT_EQ(model_info->disabled, false);
51+
EXPECT_EQ(model_info->gated, false);
52+
53+
auto tag_not_contain_gguf =
54+
std::find(model_info->tags.begin(), model_info->tags.end(), "gguf") ==
55+
model_info->tags.end();
56+
EXPECT_TRUE(tag_not_contain_gguf);
57+
58+
auto contain_gguf_info = model_info->gguf.has_value();
59+
EXPECT_TRUE(!contain_gguf_info);
60+
61+
auto sibling_not_empty = !model_info->siblings.empty();
62+
EXPECT_TRUE(sibling_not_empty);
63+
}
64+
65+
TEST_F(HuggingFaceUtilTestSuite,
66+
TestGetHuggingFaceDownloadUrlWithoutBranchName) {
67+
auto downloadable_url = huggingface_utils::GetDownloadableUrl(
68+
"pervll", "bge-reranker-v2-gemma-Q4_K_M-GGUF",
69+
"bge-reranker-v2-gemma-q4_k_m.gguf");
70+
71+
auto expected_url{
72+
"https://huggingface.co/pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF/resolve/"
73+
"main/bge-reranker-v2-gemma-q4_k_m.gguf"};
74+
75+
EXPECT_EQ(downloadable_url, expected_url);
76+
}
77+
78+
TEST_F(HuggingFaceUtilTestSuite, TestGetHuggingFaceDownloadUrlWithBranchName) {
79+
auto downloadable_url = huggingface_utils::GetDownloadableUrl(
80+
"pervll", "bge-reranker-v2-gemma-Q4_K_M-GGUF",
81+
"bge-reranker-v2-gemma-q4_k_m.gguf", "1b-gguf");
82+
83+
auto expected_url{
84+
"https://huggingface.co/pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF/resolve/"
85+
"1b-gguf/bge-reranker-v2-gemma-q4_k_m.gguf"};
86+
87+
EXPECT_EQ(downloadable_url, expected_url);
88+
}

engine/test/components/test_url_parser.cc

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include <iostream>
21
#include "gtest/gtest.h"
32
#include "utils/url_parser.h"
43

0 commit comments

Comments
 (0)