diff --git a/engine/main.cc b/engine/main.cc index a7b5bb81f..d01ba1b65 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -64,7 +64,7 @@ void RunServer(std::optional host, std::optional port, bool ignore_cout) { #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) auto signal_handler = +[](int sig) -> void { - std::cout << "\rCaught interrupt signal:" << sig << ", shutting down\n";; + std::cout << "\rCaught interrupt signal:" << sig << ", shutting down\n"; shutdown_signal = true; }; signal(SIGINT, signal_handler); @@ -288,54 +288,105 @@ void RunServer(std::optional host, std::optional port, return false; }; + auto handle_cors = [config_service](const drogon::HttpRequestPtr& req, + const drogon::HttpResponsePtr& resp) { + const std::string& origin = req->getHeader("Origin"); + CTL_INF("Origin: " << origin); + + auto allowed_origins = + config_service->GetApiServerConfiguration()->allowed_origins; + + auto is_contains_asterisk = + std::find(allowed_origins.begin(), allowed_origins.end(), "*"); + if (is_contains_asterisk != allowed_origins.end()) { + resp->addHeader("Access-Control-Allow-Origin", "*"); + resp->addHeader("Access-Control-Allow-Methods", "*"); + return; + } + + // Check if the origin is in our allowed list + auto it = std::find(allowed_origins.begin(), allowed_origins.end(), origin); + if (it != allowed_origins.end()) { + resp->addHeader("Access-Control-Allow-Origin", origin); + } else if (allowed_origins.empty()) { + resp->addHeader("Access-Control-Allow-Origin", "*"); + } + resp->addHeader("Access-Control-Allow-Methods", "*"); + }; + drogon::app().registerPreRoutingAdvice( - [&validate_api_key]( + [&validate_api_key, &handle_cors]( const drogon::HttpRequestPtr& req, - std::function&& cb, - drogon::AdviceChainCallback&& ccb) { + std::function&& stop, + drogon::AdviceChainCallback&& pass) { + // Handle OPTIONS preflight requests + if (req->method() == drogon::HttpMethod::Options) { + auto resp = HttpResponse::newHttpResponse(); + auto handlers = drogon::app().getHandlersInfo(); + bool has_ep = [req, &handlers]() { + for (auto const& h : handlers) { + if (string_utils::AreUrlPathsEqual(req->path(), std::get<0>(h))) + return true; + } + return false; + }(); + if (!has_ep) { + resp->setStatusCode(drogon::HttpStatusCode::k404NotFound); + stop(resp); + return; + } + + handle_cors(req, resp); + std::string supported_methods = [req, &handlers]() { + std::string methods; + for (auto const& h : handlers) { + if (string_utils::AreUrlPathsEqual(req->path(), std::get<0>(h))) { + auto m = drogon::to_string_view(std::get<1>(h)); + if (methods.find(m) == std::string::npos) { + methods += drogon::to_string_view(std::get<1>(h)); + methods += ", "; + } + } + } + if (methods.size() < 2) + return std::string(); + return methods.substr(0, methods.size() - 2); + }(); + + // Add more info to header + resp->addHeader("Access-Control-Allow-Methods", supported_methods); + { + const auto& val = req->getHeader("Access-Control-Request-Headers"); + if (!val.empty()) + resp->addHeader("Access-Control-Allow-Headers", val); + } + // Set Access-Control-Max-Age + resp->addHeader("Access-Control-Max-Age", + "600"); // Cache for 10 minutes + stop(resp); + return; + } + if (!validate_api_key(req)) { Json::Value ret; ret["message"] = "Invalid API Key"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(drogon::k401Unauthorized); - cb(resp); + stop(resp); return; } - ccb(); + pass(); }); // CORS drogon::app().registerPostHandlingAdvice( - [config_service](const drogon::HttpRequestPtr& req, - const drogon::HttpResponsePtr& resp) { + [config_service, &handle_cors](const drogon::HttpRequestPtr& req, + const drogon::HttpResponsePtr& resp) { if (!config_service->GetApiServerConfiguration()->cors) { CTL_INF("CORS is disabled!"); return; } - - const std::string& origin = req->getHeader("Origin"); - CTL_INF("Origin: " << origin); - - auto allowed_origins = - config_service->GetApiServerConfiguration()->allowed_origins; - - auto is_contains_asterisk = - std::find(allowed_origins.begin(), allowed_origins.end(), "*"); - if (is_contains_asterisk != allowed_origins.end()) { - resp->addHeader("Access-Control-Allow-Origin", "*"); - resp->addHeader("Access-Control-Allow-Methods", "*"); - return; - } - - // Check if the origin is in our allowed list - auto it = - std::find(allowed_origins.begin(), allowed_origins.end(), origin); - if (it != allowed_origins.end()) { - resp->addHeader("Access-Control-Allow-Origin", origin); - } else if (allowed_origins.empty()) { - resp->addHeader("Access-Control-Allow-Origin", "*"); - } - resp->addHeader("Access-Control-Allow-Methods", "*"); + handle_cors(req, resp); }); // ssl diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc index e396f0ed1..42211b668 100644 --- a/engine/test/components/test_string_utils.cc +++ b/engine/test/components/test_string_utils.cc @@ -289,3 +289,50 @@ TEST_F(StringUtilsTestSuite, LargeInputPerformance) { } +TEST_F(StringUtilsTestSuite, UrlPaths_SimilarStrings) { + std::string str1 = "/v1/threads/{1}/messages/{2}"; + std::string str2 = "/v1/threads/xxx/messages/yyy"; + EXPECT_TRUE( AreUrlPathsEqual(str1, str2)); +} + +TEST_F(StringUtilsTestSuite, UrlPaths_DifferentPaths) { + std::string str1 = "/v1/threads/{1}/messages/{2}"; + std::string str2 = "/v1/threads/xxx/messages/yyy/extra"; + EXPECT_FALSE(AreUrlPathsEqual(str1, str2)); +} + +TEST_F(StringUtilsTestSuite, UrlPaths_DifferentPlaceholderCounts) { + std::string str1 = "/v1/threads/{1}/messages/{2}"; + std::string str2 = "/v1/threads/{1}/messages/{2}/{3}"; + EXPECT_FALSE(AreUrlPathsEqual(str1, str2)); +} + +TEST_F(StringUtilsTestSuite, UrlPaths_NoPlaceholders) { + std::string str1 = "/v1/threads/1/messages/2"; + std::string str2 = "/v1/threads/xxx/messages/yyy"; + EXPECT_FALSE(AreUrlPathsEqual(str1, str2)); +} + +TEST_F(StringUtilsTestSuite, UrlPaths_EmptyStrings) { + std::string str1 = ""; + std::string str2 = ""; + EXPECT_TRUE(AreUrlPathsEqual(str1, str2)); +} + +TEST_F(StringUtilsTestSuite, UrlPaths_SinglePlaceholder) { + std::string str1 = "/v1/threads/{1}"; + std::string str2 = "/v1/threads/xxx"; + EXPECT_TRUE(AreUrlPathsEqual(str1, str2)); +} + +TEST_F(StringUtilsTestSuite, UrlPaths_MultiplePlaceholdersSameFormat) { + std::string str1 = "/v1/threads/{1}/messages/{2}/comments/{3}"; + std::string str2 = "/v1/threads/xxx/messages/yyy/comments/zzz"; + EXPECT_TRUE(AreUrlPathsEqual(str1, str2)); +} + +TEST_F(StringUtilsTestSuite, UrlPaths_NonPlaceholderDifferences) { + std::string str1 = "/v1/threads/{1}/messages/{2}"; + std::string str2 = "/v2/threads/xxx/messages/yyy"; + EXPECT_FALSE(AreUrlPathsEqual(str1, str2)); +} diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h index a962109e8..d7db8b29d 100644 --- a/engine/utils/string_utils.h +++ b/engine/utils/string_utils.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -200,4 +201,31 @@ inline std::string EscapeJson(const std::string& s) { } return o.str(); } + +// Add a method to compares two url paths +inline bool AreUrlPathsEqual(const std::string& path1, + const std::string& path2) { + auto has_placeholder = [](const std::string& s) { + if (s.empty()) + return false; + return s.find_first_of('{') < s.find_last_of('}'); + }; + std::vector parts1 = SplitBy(path1, "/"); + std::vector parts2 = SplitBy(path2, "/"); + + // Check if both strings have the same number of parts + if (parts1.size() != parts2.size()) { + return false; + } + + for (size_t i = 0; i < parts1.size(); ++i) { + if (has_placeholder(parts1[i]) || has_placeholder(parts2[i])) + continue; + if (parts1[i] != parts2[i]) { + return false; + } + } + + return true; +} } // namespace string_utils