Skip to content

fix: handle preflight requests #2175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 83 additions & 32 deletions engine/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
bool ignore_cout) {
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
auto signal_handler = +[](int sig) -> void {
std::cout << "\rCaught interrupt signal:" << sig << ", shutting down\n";;
std::cout << "\rCaught interrupt signal:" << sig << ", shutting down\n";
shutdown_signal = true;
};
signal(SIGINT, signal_handler);
Expand Down Expand Up @@ -288,54 +288,105 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
return false;
};

auto handle_cors = [config_service](const drogon::HttpRequestPtr& req,
const drogon::HttpResponsePtr& resp) {
const std::string& origin = req->getHeader("Origin");
CTL_INF("Origin: " << origin);

auto allowed_origins =
config_service->GetApiServerConfiguration()->allowed_origins;

auto is_contains_asterisk =
std::find(allowed_origins.begin(), allowed_origins.end(), "*");
if (is_contains_asterisk != allowed_origins.end()) {
resp->addHeader("Access-Control-Allow-Origin", "*");
resp->addHeader("Access-Control-Allow-Methods", "*");
return;
}

// Check if the origin is in our allowed list
auto it = std::find(allowed_origins.begin(), allowed_origins.end(), origin);
if (it != allowed_origins.end()) {
resp->addHeader("Access-Control-Allow-Origin", origin);
} else if (allowed_origins.empty()) {
resp->addHeader("Access-Control-Allow-Origin", "*");
}
resp->addHeader("Access-Control-Allow-Methods", "*");
};

drogon::app().registerPreRoutingAdvice(
[&validate_api_key](
[&validate_api_key, &handle_cors](
const drogon::HttpRequestPtr& req,
std::function<void(const drogon::HttpResponsePtr&)>&& cb,
drogon::AdviceChainCallback&& ccb) {
std::function<void(const drogon::HttpResponsePtr&)>&& stop,
drogon::AdviceChainCallback&& pass) {
// Handle OPTIONS preflight requests
if (req->method() == drogon::HttpMethod::Options) {
auto resp = HttpResponse::newHttpResponse();
auto handlers = drogon::app().getHandlersInfo();
bool has_ep = [req, &handlers]() {
for (auto const& h : handlers) {
if (string_utils::AreUrlPathsEqual(req->path(), std::get<0>(h)))
return true;
}
return false;
}();
if (!has_ep) {
resp->setStatusCode(drogon::HttpStatusCode::k404NotFound);
stop(resp);
return;
}

handle_cors(req, resp);
std::string supported_methods = [req, &handlers]() {
std::string methods;
for (auto const& h : handlers) {
if (string_utils::AreUrlPathsEqual(req->path(), std::get<0>(h))) {
auto m = drogon::to_string_view(std::get<1>(h));
if (methods.find(m) == std::string::npos) {
methods += drogon::to_string_view(std::get<1>(h));
methods += ", ";
}
}
}
if (methods.size() < 2)
return std::string();
return methods.substr(0, methods.size() - 2);
}();

// Add more info to header
resp->addHeader("Access-Control-Allow-Methods", supported_methods);
{
const auto& val = req->getHeader("Access-Control-Request-Headers");
if (!val.empty())
resp->addHeader("Access-Control-Allow-Headers", val);
}
// Set Access-Control-Max-Age
resp->addHeader("Access-Control-Max-Age",
"600"); // Cache for 10 minutes
stop(resp);
return;
}

if (!validate_api_key(req)) {
Json::Value ret;
ret["message"] = "Invalid API Key";
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k401Unauthorized);
cb(resp);
stop(resp);
return;
}
ccb();
pass();
});

// CORS
drogon::app().registerPostHandlingAdvice(
[config_service](const drogon::HttpRequestPtr& req,
const drogon::HttpResponsePtr& resp) {
[config_service, &handle_cors](const drogon::HttpRequestPtr& req,
const drogon::HttpResponsePtr& resp) {
if (!config_service->GetApiServerConfiguration()->cors) {
CTL_INF("CORS is disabled!");
return;
}

const std::string& origin = req->getHeader("Origin");
CTL_INF("Origin: " << origin);

auto allowed_origins =
config_service->GetApiServerConfiguration()->allowed_origins;

auto is_contains_asterisk =
std::find(allowed_origins.begin(), allowed_origins.end(), "*");
if (is_contains_asterisk != allowed_origins.end()) {
resp->addHeader("Access-Control-Allow-Origin", "*");
resp->addHeader("Access-Control-Allow-Methods", "*");
return;
}

// Check if the origin is in our allowed list
auto it =
std::find(allowed_origins.begin(), allowed_origins.end(), origin);
if (it != allowed_origins.end()) {
resp->addHeader("Access-Control-Allow-Origin", origin);
} else if (allowed_origins.empty()) {
resp->addHeader("Access-Control-Allow-Origin", "*");
}
resp->addHeader("Access-Control-Allow-Methods", "*");
handle_cors(req, resp);
});

// ssl
Expand Down
47 changes: 47 additions & 0 deletions engine/test/components/test_string_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,50 @@ TEST_F(StringUtilsTestSuite, LargeInputPerformance) {
}


TEST_F(StringUtilsTestSuite, UrlPaths_SimilarStrings) {
std::string str1 = "/v1/threads/{1}/messages/{2}";
std::string str2 = "/v1/threads/xxx/messages/yyy";
EXPECT_TRUE( AreUrlPathsEqual(str1, str2));
}

TEST_F(StringUtilsTestSuite, UrlPaths_DifferentPaths) {
std::string str1 = "/v1/threads/{1}/messages/{2}";
std::string str2 = "/v1/threads/xxx/messages/yyy/extra";
EXPECT_FALSE(AreUrlPathsEqual(str1, str2));
}

TEST_F(StringUtilsTestSuite, UrlPaths_DifferentPlaceholderCounts) {
std::string str1 = "/v1/threads/{1}/messages/{2}";
std::string str2 = "/v1/threads/{1}/messages/{2}/{3}";
EXPECT_FALSE(AreUrlPathsEqual(str1, str2));
}

TEST_F(StringUtilsTestSuite, UrlPaths_NoPlaceholders) {
std::string str1 = "/v1/threads/1/messages/2";
std::string str2 = "/v1/threads/xxx/messages/yyy";
EXPECT_FALSE(AreUrlPathsEqual(str1, str2));
}

TEST_F(StringUtilsTestSuite, UrlPaths_EmptyStrings) {
std::string str1 = "";
std::string str2 = "";
EXPECT_TRUE(AreUrlPathsEqual(str1, str2));
}

TEST_F(StringUtilsTestSuite, UrlPaths_SinglePlaceholder) {
std::string str1 = "/v1/threads/{1}";
std::string str2 = "/v1/threads/xxx";
EXPECT_TRUE(AreUrlPathsEqual(str1, str2));
}

TEST_F(StringUtilsTestSuite, UrlPaths_MultiplePlaceholdersSameFormat) {
std::string str1 = "/v1/threads/{1}/messages/{2}/comments/{3}";
std::string str2 = "/v1/threads/xxx/messages/yyy/comments/zzz";
EXPECT_TRUE(AreUrlPathsEqual(str1, str2));
}

TEST_F(StringUtilsTestSuite, UrlPaths_NonPlaceholderDifferences) {
std::string str1 = "/v1/threads/{1}/messages/{2}";
std::string str2 = "/v2/threads/xxx/messages/yyy";
EXPECT_FALSE(AreUrlPathsEqual(str1, str2));
}
28 changes: 28 additions & 0 deletions engine/utils/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cctype>
#include <chrono>
#include <iomanip>
#include <regex>
#include <sstream>
#include <string>
#include <vector>
Expand Down Expand Up @@ -200,4 +201,31 @@ inline std::string EscapeJson(const std::string& s) {
}
return o.str();
}

// Add a method to compares two url paths
inline bool AreUrlPathsEqual(const std::string& path1,
const std::string& path2) {
auto has_placeholder = [](const std::string& s) {
if (s.empty())
return false;
return s.find_first_of('{') < s.find_last_of('}');
};
std::vector<std::string> parts1 = SplitBy(path1, "/");
std::vector<std::string> parts2 = SplitBy(path2, "/");

// Check if both strings have the same number of parts
if (parts1.size() != parts2.size()) {
return false;
}

for (size_t i = 0; i < parts1.size(); ++i) {
if (has_placeholder(parts1[i]) || has_placeholder(parts2[i]))
continue;
if (parts1[i] != parts2[i]) {
return false;
}
}

return true;
}
} // namespace string_utils
Loading