Skip to content

Commit 2f288bd

Browse files
authored
PgWire auth with ApiKey (#8283)
1 parent 8e7a11a commit 2f288bd

File tree

5 files changed

+251
-101
lines changed

5 files changed

+251
-101
lines changed

ydb/core/local_pgwire/local_pgwire.cpp

Lines changed: 24 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,8 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
1818
using TBase = TActor<TPgYdbProxy>;
1919

2020
struct TSecurityState {
21-
TString Ticket;
22-
Ydb::Auth::LoginResult LoginResult;
23-
TEvTicketParser::TError Error;
24-
TIntrusiveConstPtr<NACLib::TUserToken> Token;
2521
TString SerializedToken;
26-
};
27-
28-
struct TTokenState {
29-
std::unordered_set<TActorId> Senders;
30-
};
31-
32-
struct TEvPrivate {
33-
enum EEv {
34-
EvTokenReady = EventSpaceBegin(NActors::TEvents::ES_PRIVATE),
35-
EvEnd
36-
};
37-
38-
static_assert(EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE), "expect EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE)");
39-
40-
struct TEvTokenReady : TEventLocal<TEvTokenReady, EvTokenReady> {
41-
Ydb::Auth::LoginResult LoginResult;
42-
TActorId Sender;
43-
TString Database;
44-
TString PeerName;
45-
46-
TEvTokenReady() = default;
47-
};
22+
TString Ticket;
4823
};
4924

5025
struct TConnectionState {
@@ -54,7 +29,6 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
5429

5530
std::unordered_map<TActorId, TConnectionState> ConnectionState;
5631
std::unordered_map<TActorId, TSecurityState> SecurityState;
57-
std::unordered_map<TString, TTokenState> TokenState;
5832
uint32_t ConnectionNum = 0;
5933

6034
public:
@@ -63,85 +37,24 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
6337
{
6438
}
6539

66-
void Handle(TEvTicketParser::TEvAuthorizeTicketResult::TPtr& ev) {
67-
auto token = ev->Get()->Ticket;
68-
auto itTokenState = TokenState.find(token);
69-
if (itTokenState == TokenState.end()) {
70-
BLOG_W("Couldn't find token in reply from TicketParser");
71-
return;
72-
}
73-
for (auto sender : itTokenState->second.Senders) {
74-
auto& securityState(SecurityState[sender]);
75-
securityState.Ticket = token;
76-
securityState.Error = ev->Get()->Error;
77-
securityState.Token = ev->Get()->Token;
78-
securityState.SerializedToken = ev->Get()->SerializedToken;
79-
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
80-
if (ev->Get()->Error) {
81-
authResponse->Error = ev->Get()->Error.Message;
82-
}
83-
Send(sender, authResponse.release());
84-
}
85-
TokenState.erase(itTokenState);
86-
}
87-
88-
void Handle(TEvPrivate::TEvTokenReady::TPtr& ev) {
89-
auto token = ev->Get()->LoginResult.token();
90-
auto itTokenState = TokenState.find(token);
91-
if (itTokenState == TokenState.end()) {
92-
itTokenState = TokenState.insert({token, {}}).first;
93-
}
94-
bool needSend = itTokenState->second.Senders.empty();
95-
itTokenState->second.Senders.insert(ev->Get()->Sender);
96-
if (needSend) {
97-
Send(MakeTicketParserID(), new TEvTicketParser::TEvAuthorizeTicket({
98-
.Database = ev->Get()->Database,
99-
.Ticket = token,
100-
.PeerName = ev->Get()->PeerName,
101-
}));
102-
}
103-
SecurityState[ev->Get()->Sender].LoginResult = std::move(ev->Get()->LoginResult);
104-
}
105-
10640
void Handle(NPG::TEvPGEvents::TEvAuth::TPtr& ev) {
107-
std::unordered_map<TString, TString> clientParams = ev->Get()->InitialMessage->GetClientParams();
10841
BLOG_D("TEvAuth " << ev->Get()->InitialMessage->Dump() << " cookie " << ev->Cookie);
109-
Ydb::Auth::LoginRequest request;
110-
request.set_user(clientParams["user"]);
42+
std::unordered_map<TString, TString> clientParams = ev->Get()->InitialMessage->GetClientParams();
43+
TPgWireAuthData pgWireAuthData;
44+
pgWireAuthData.UserName = clientParams["user"];
11145
if (ev->Get()->PasswordMessage) {
112-
request.set_password(TString(ev->Get()->PasswordMessage->GetPassword()));
46+
pgWireAuthData.Password = TString(ev->Get()->PasswordMessage->GetPassword());
11347
}
114-
TActorSystem* actorSystem = TActivationContext::ActorSystem();
115-
TActorId sender = ev->Sender;
116-
TString database = clientParams["database"];
117-
if (database == "/postgres") {
48+
pgWireAuthData.Sender = ev->Sender;
49+
pgWireAuthData.DatabasePath = clientParams["database"];
50+
if (pgWireAuthData.DatabasePath == "/postgres") {
11851
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
11952
authResponse->Error = Ydb::StatusIds_StatusCode_Name(Ydb::StatusIds_StatusCode::StatusIds_StatusCode_BAD_REQUEST);
120-
actorSystem->Send(sender, authResponse.release());
53+
Send(pgWireAuthData.Sender, authResponse.release());
12154
}
122-
TString peerName = TStringBuilder() << ev->Get()->Address;
55+
pgWireAuthData.PeerName = TStringBuilder() << ev->Get()->Address;
12356

124-
using TRpcEv = NGRpcService::TGRpcRequestWrapperNoAuth<NGRpcService::TRpcServices::EvLogin, Ydb::Auth::LoginRequest, Ydb::Auth::LoginResponse>;
125-
auto rpcFuture = NRpcService::DoLocalRpc<TRpcEv>(std::move(request), database, {}, actorSystem);
126-
rpcFuture.Subscribe([actorSystem, sender, database, peerName, selfId = SelfId()](const NThreading::TFuture<Ydb::Auth::LoginResponse>& future) {
127-
auto& response = future.GetValueSync();
128-
if (response.operation().status() == Ydb::StatusIds::SUCCESS) {
129-
auto tokenReady = std::make_unique<TEvPrivate::TEvTokenReady>();
130-
response.operation().result().UnpackTo(&(tokenReady->LoginResult));
131-
tokenReady->Sender = sender;
132-
tokenReady->Database = database;
133-
tokenReady->PeerName = peerName;
134-
actorSystem->Send(selfId, tokenReady.release());
135-
} else {
136-
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
137-
if (response.operation().issues_size() > 0) {
138-
authResponse->Error = response.operation().issues(0).message();
139-
} else {
140-
authResponse->Error = Ydb::StatusIds_StatusCode_Name(response.operation().status());
141-
}
142-
actorSystem->Send(sender, authResponse.release());
143-
}
144-
});
57+
Register(CreateLocalPgWireAuthActor(pgWireAuthData, SelfId()));
14558
}
14659

14760
void Handle(NPG::TEvPGEvents::TEvConnectionOpened::TPtr& ev) {
@@ -173,7 +86,6 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
17386
}
17487
SecurityState.erase(ev->Sender);
17588
ConnectionState.erase(itConnection);
176-
// TODO: cleanup TokenState too
17789
}
17890

17991
void Handle(NPG::TEvPGEvents::TEvQuery::TPtr& ev) {
@@ -236,6 +148,18 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
236148
}
237149
}
238150

151+
void Handle(TEvEvents::TEvAuthResponse::TPtr& ev) {
152+
auto& securityState = SecurityState[ev->Get()->Sender];
153+
auto authResponse = std::make_unique<NPG::TEvPGEvents::TEvAuthResponse>();
154+
if (!ev->Get()->ErrorMessage.empty()) {
155+
authResponse->Error = ev->Get()->ErrorMessage;
156+
} else {
157+
securityState.SerializedToken = ev->Get()->SerializedToken;
158+
securityState.Ticket = ev->Get()->Ticket;
159+
}
160+
Send(ev->Get()->Sender, authResponse.release());
161+
}
162+
239163
STATEFN(StateWork) {
240164
switch (ev->GetTypeRewrite()) {
241165
hFunc(NPG::TEvPGEvents::TEvAuth, Handle);
@@ -248,8 +172,7 @@ class TPgYdbProxy : public TActor<TPgYdbProxy> {
248172
hFunc(NPG::TEvPGEvents::TEvExecute, Handle);
249173
hFunc(NPG::TEvPGEvents::TEvClose, Handle);
250174
hFunc(NPG::TEvPGEvents::TEvCancelRequest, Handle);
251-
hFunc(TEvPrivate::TEvTokenReady, Handle);
252-
hFunc(TEvTicketParser::TEvAuthorizeTicketResult, Handle);
175+
hFunc(TEvEvents::TEvAuthResponse, Handle);
253176
}
254177
}
255178
};

ydb/core/local_pgwire/local_pgwire.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
#pragma once
2+
3+
#include "local_pgwire_util.h"
14
#include <ydb/library/actors/core/actor.h>
25

36
namespace NLocalPgWire {
47

58
inline NActors::TActorId CreateLocalPgWireProxyId(uint32_t nodeId = 0) { return NActors::TActorId(nodeId, "localpgwire"); }
69
NActors::IActor* CreateLocalPgWireProxy();
710

11+
NActors::IActor* CreateLocalPgWireAuthActor(const TPgWireAuthData& pgWireAuthData, const NActors::TActorId& pgYdbProxy);
12+
813
}
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
#include "log_impl.h"
2+
#include "local_pgwire.h"
3+
#include "local_pgwire_util.h"
4+
5+
#include <ydb/core/base/path.h>
6+
#include <ydb/core/base/ticket_parser.h>
7+
#include <ydb/core/grpc_services/local_rpc/local_rpc.h>
8+
#include <ydb/core/tx/scheme_cache/scheme_cache.h>
9+
10+
#include <ydb/library/actors/core/actor.h>
11+
#include <ydb/library/actors/core/actor_bootstrapped.h>
12+
13+
#include <ydb/public/api/grpc/ydb_auth_v1.grpc.pb.h>
14+
15+
#include <ydb/services/persqueue_v1/actors/persqueue_utils.h>
16+
17+
namespace NLocalPgWire {
18+
19+
using namespace NActors;
20+
using namespace NKikimr;
21+
22+
class TPgYdbAuthActor : public NActors::TActorBootstrapped<TPgYdbAuthActor> {
23+
using TBase = TActor<TPgYdbAuthActor>;
24+
25+
struct TEvPrivate {
26+
enum EEv {
27+
EvTokenReady = EventSpaceBegin(NActors::TEvents::ES_PRIVATE),
28+
EvAuthFailed,
29+
EvEnd
30+
};
31+
32+
static_assert(EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE), "expect EvEnd < EventSpaceEnd(NActors::TEvents::ES_PRIVATE)");
33+
34+
struct TEvTokenReady : TEventLocal<TEvTokenReady, EvTokenReady> {
35+
Ydb::Auth::LoginResult LoginResult;
36+
37+
TEvTokenReady() = default;
38+
};
39+
40+
struct TEvAuthFailed : NActors::TEventLocal<TEvAuthFailed, EvAuthFailed> {
41+
TString ErrorMessage;
42+
};
43+
};
44+
45+
TPgWireAuthData PgWireAuthData;
46+
TActorId PgYdbProxy;
47+
48+
TString DatabaseId;
49+
TString FolderId;
50+
TString SerializedToken;
51+
TString Ticket;
52+
53+
public:
54+
TPgYdbAuthActor(const TPgWireAuthData& pgWireAuthData, const TActorId& pgYdbProxy)
55+
: PgWireAuthData(pgWireAuthData)
56+
, PgYdbProxy(pgYdbProxy) {
57+
}
58+
59+
void Bootstrap() {
60+
if (PgWireAuthData.UserName == "__ydb_apikey") {
61+
if (PgWireAuthData.Password.empty()) {
62+
SendResponseAndDie("Invalid password");
63+
}
64+
SendDescribeRequest();
65+
} else {
66+
SendLoginRequest();
67+
}
68+
69+
Become(&TPgYdbAuthActor::StateWork);
70+
}
71+
72+
void Handle(TEvTicketParser::TEvAuthorizeTicketResult::TPtr& ev) {
73+
if (ev->Get()->Error) {
74+
SendResponseAndDie(ev->Get()->Error.Message);
75+
return;
76+
}
77+
78+
SerializedToken = ev->Get()->SerializedToken;
79+
Ticket = ev->Get()->Ticket;
80+
81+
SendResponseAndDie();
82+
}
83+
84+
void Handle(TEvPrivate::TEvTokenReady::TPtr& ev) {
85+
Send(MakeTicketParserID(), new TEvTicketParser::TEvAuthorizeTicket({
86+
.Database = PgWireAuthData.DatabasePath,
87+
.Ticket = ev->Get()->LoginResult.token(),
88+
.PeerName = PgWireAuthData.PeerName,
89+
}));
90+
}
91+
92+
void Handle(TEvPrivate::TEvAuthFailed::TPtr& ev) {
93+
SendResponseAndDie(ev->Get()->ErrorMessage);
94+
}
95+
96+
void Handle(NKikimr::TEvTxProxySchemeCache::TEvNavigateKeySetResult::TPtr& ev) {
97+
const NKikimr::NSchemeCache::TSchemeCacheNavigate* navigate = ev->Get()->Request.Get();
98+
if (navigate->ErrorCount) {
99+
SendResponseAndDie(TStringBuilder() << "Database with path '" << PgWireAuthData.DatabasePath << "' doesn't exists");
100+
return;
101+
}
102+
Y_ABORT_UNLESS(navigate->ResultSet.size() == 1);
103+
104+
const auto& entry = navigate->ResultSet.front();
105+
106+
for (const auto& attr : entry.Attributes) {
107+
if (attr.first == "folderId") FolderId = attr.second;
108+
else if (attr.first == "database_id") DatabaseId = attr.second;
109+
}
110+
111+
SendApiKeyRequest();
112+
}
113+
114+
STATEFN(StateWork) {
115+
switch (ev->GetTypeRewrite()) {
116+
hFunc(TEvPrivate::TEvTokenReady, Handle);
117+
hFunc(TEvTicketParser::TEvAuthorizeTicketResult, Handle);
118+
hFunc(TEvTxProxySchemeCache::TEvNavigateKeySetResult, Handle);
119+
hFunc(TEvPrivate::TEvAuthFailed, Handle);
120+
}
121+
}
122+
private:
123+
void SendLoginRequest() {
124+
Ydb::Auth::LoginRequest request;
125+
request.set_user(PgWireAuthData.UserName);
126+
if (!PgWireAuthData.Password.empty()) {
127+
request.set_password(PgWireAuthData.Password);
128+
}
129+
130+
auto* actorSystem = TActivationContext::ActorSystem();;
131+
132+
using TRpcEv = NGRpcService::TGRpcRequestWrapperNoAuth<NGRpcService::TRpcServices::EvLogin, Ydb::Auth::LoginRequest, Ydb::Auth::LoginResponse>;
133+
auto rpcFuture = NRpcService::DoLocalRpc<TRpcEv>(std::move(request), PgWireAuthData.DatabasePath, {}, actorSystem);
134+
rpcFuture.Subscribe([actorSystem, selfId = SelfId()](const NThreading::TFuture<Ydb::Auth::LoginResponse>& future) {
135+
auto& response = future.GetValueSync();
136+
if (response.operation().status() == Ydb::StatusIds::SUCCESS) {
137+
auto tokenReady = std::make_unique<TEvPrivate::TEvTokenReady>();
138+
response.operation().result().UnpackTo(&(tokenReady->LoginResult));
139+
actorSystem->Send(selfId, tokenReady.release());
140+
} else {
141+
auto authFailedEvent = std::make_unique<TEvPrivate::TEvAuthFailed>();
142+
if (response.operation().issues_size() > 0) {
143+
authFailedEvent->ErrorMessage = response.operation().issues(0).message();
144+
} else {
145+
authFailedEvent->ErrorMessage = Ydb::StatusIds_StatusCode_Name(response.operation().status());
146+
}
147+
actorSystem->Send(selfId, authFailedEvent.release());
148+
}
149+
});
150+
}
151+
152+
void SendApiKeyRequest() {
153+
auto entries = NKikimr::NGRpcProxy::V1::GetTicketParserEntries(DatabaseId, FolderId);
154+
155+
Send(NKikimr::MakeTicketParserID(), new NKikimr::TEvTicketParser::TEvAuthorizeTicket({
156+
.Database = PgWireAuthData.DatabasePath,
157+
.Ticket = "ApiKey " + PgWireAuthData.Password,
158+
.PeerName = PgWireAuthData.PeerName,
159+
.Entries = entries
160+
}));
161+
}
162+
163+
void SendDescribeRequest() {
164+
auto schemeCacheRequest = std::make_unique<NKikimr::NSchemeCache::TSchemeCacheNavigate>();
165+
NKikimr::NSchemeCache::TSchemeCacheNavigate::TEntry entry;
166+
entry.Path = NKikimr::SplitPath(PgWireAuthData.DatabasePath);
167+
entry.Operation = NKikimr::NSchemeCache::TSchemeCacheNavigate::OpPath;
168+
entry.SyncVersion = false;
169+
schemeCacheRequest->ResultSet.emplace_back(entry);
170+
Send(NKikimr::MakeSchemeCacheID(), MakeHolder<NKikimr::TEvTxProxySchemeCache::TEvNavigateKeySet>(schemeCacheRequest.release()));
171+
}
172+
173+
void SendResponseAndDie(const TString& errorMessage = "") {
174+
std::unique_ptr<TEvEvents::TEvAuthResponse> authResponse;
175+
if (!errorMessage.empty()) {
176+
authResponse = std::make_unique<TEvEvents::TEvAuthResponse>(errorMessage, PgWireAuthData.Sender);
177+
} else {
178+
authResponse = std::make_unique<TEvEvents::TEvAuthResponse>(SerializedToken, Ticket, PgWireAuthData.Sender);
179+
}
180+
181+
Send(PgYdbProxy, authResponse.release());
182+
183+
PassAway();
184+
}
185+
};
186+
187+
188+
NActors::IActor* CreateLocalPgWireAuthActor(const TPgWireAuthData& pgWireAuthData, const TActorId& pgYdbProxy) {
189+
return new TPgYdbAuthActor(pgWireAuthData, pgYdbProxy);
190+
}
191+
192+
}

0 commit comments

Comments
 (0)