Skip to content

Commit 4bf02c2

Browse files
Merge 2906438 into 9ec6f05
2 parents 9ec6f05 + 2906438 commit 4bf02c2

File tree

8 files changed

+174
-31
lines changed

8 files changed

+174
-31
lines changed

ydb/core/external_sources/object_storage.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ struct TObjectStorageExternalSource : public IExternalSource {
333333
}
334334
for (const auto& entry : entries.Objects) {
335335
if (entry.Size > 0) {
336-
return entry.Path;
336+
return entry;
337337
}
338338
}
339339
throw yexception() << "couldn't find any files for type inference, please check that the right path is provided";
@@ -350,9 +350,9 @@ struct TObjectStorageExternalSource : public IExternalSource {
350350

351351
auto fileFormat = NObjectStorage::NInference::ConvertFileFormat(*format);
352352
auto arrowFetcherId = ActorSystem->Register(NObjectStorage::NInference::CreateArrowFetchingActor(s3FetcherId, fileFormat, meta->Attributes));
353-
auto arrowInferencinatorId = ActorSystem->Register(NObjectStorage::NInference::CreateArrowInferencinator(arrowFetcherId, fileFormat, meta->Attributes));
353+
auto arrowInferencinatorId = ActorSystem->Register(NObjectStorage::NInference::CreateArrowInferencinator(arrowFetcherId, s3FetcherId, fileFormat, meta->Attributes));
354354

355-
return afterListing.Apply([arrowInferencinatorId, meta, actorSystem = ActorSystem](const NThreading::TFuture<TString>& pathFut) {
355+
return afterListing.Apply([arrowInferencinatorId, meta, actorSystem = ActorSystem](const NThreading::TFuture<NYql::NS3Lister::TObjectListEntry>& entryFut) {
356356
auto promise = NThreading::NewPromise<TMetadataResult>();
357357
auto schemaToMetadata = [meta](NThreading::TPromise<TMetadataResult> metaPromise, NObjectStorage::TEvInferredFileSchema&& response) {
358358
if (!response.Status.IsSuccess()) {
@@ -370,9 +370,10 @@ struct TObjectStorageExternalSource : public IExternalSource {
370370
result.Metadata = meta;
371371
metaPromise.SetValue(std::move(result));
372372
};
373+
auto [path, size, _] = entryFut.GetValue();
373374
actorSystem->Register(new NKqp::TActorRequestHandler<NObjectStorage::TEvInferFileSchema, NObjectStorage::TEvInferredFileSchema, TMetadataResult>(
374375
arrowInferencinatorId,
375-
new NObjectStorage::TEvInferFileSchema(TString{pathFut.GetValue()}),
376+
new NObjectStorage::TEvInferFileSchema(TString{path}, size),
376377
promise,
377378
std::move(schemaToMetadata)
378379
));

ydb/core/external_sources/object_storage/events.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ enum EEventTypes : ui32 {
2929
EvInferredFileSchema,
3030

3131
EvArrowFile,
32+
EvArrowSchema,
3233

3334
EvEnd,
3435
};
@@ -118,12 +119,24 @@ struct TEvArrowFile : public NActors::TEventLocal<TEvArrowFile, EvArrowFile> {
118119
TString Path;
119120
};
120121

122+
struct TEvArrowSchema : public NActors::TEventLocal<TEvArrowSchema, EvArrowSchema> {
123+
TEvArrowSchema(std::shared_ptr<arrow::Schema> schema, TString path)
124+
: Schema{std::move(schema)}
125+
, Path{std::move(path)}
126+
{}
127+
128+
std::shared_ptr<arrow::Schema> Schema;
129+
TString Path;
130+
};
131+
121132
struct TEvInferFileSchema : public NActors::TEventLocal<TEvInferFileSchema, EvInferFileSchema> {
122-
explicit TEvInferFileSchema(TString&& path)
133+
explicit TEvInferFileSchema(TString&& path, ui64 size)
123134
: Path{std::move(path)}
135+
, Size{size}
124136
{}
125137

126138
TString Path;
139+
ui64 Size = 0;
127140
};
128141

129142
struct TEvInferredFileSchema : public NActors::TEventLocal<TEvInferredFileSchema, EvInferredFileSchema> {

ydb/core/external_sources/object_storage/inference/arrow_inferencinator.cpp

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,15 @@ std::unique_ptr<FormatConfig> MakeFormatConfig(EFileFormat format, const THashMa
240240

241241
class TArrowInferencinator : public NActors::TActorBootstrapped<TArrowInferencinator> {
242242
public:
243-
TArrowInferencinator(NActors::TActorId arrowFetcher, EFileFormat format, const THashMap<TString, TString>& params)
243+
TArrowInferencinator(
244+
NActors::TActorId arrowFetcher,
245+
NActors::TActorId s3Fetcher,
246+
EFileFormat format,
247+
const THashMap<TString, TString>& params)
244248
: Format_{format}
245249
, Config_{MakeFormatConfig(Format_, params)}
246250
, ArrowFetcherId_{arrowFetcher}
251+
, S3FetcherId_{s3Fetcher}
247252
{
248253
Y_ABORT_UNLESS(IsArrowInferredFormat(Format_));
249254
}
@@ -256,11 +261,31 @@ class TArrowInferencinator : public NActors::TActorBootstrapped<TArrowInferencin
256261
HFunc(TEvInferFileSchema, HandleInferRequest);
257262
HFunc(TEvFileError, HandleFileError);
258263
HFunc(TEvArrowFile, HandleFileInference);
264+
HFunc(TEvArrowSchema, HandleFileSchema);
259265
)
260266

261267
void HandleInferRequest(TEvInferFileSchema::TPtr& ev, const NActors::TActorContext& ctx) {
262268
RequesterId_ = ev->Sender;
263-
ctx.Send(ArrowFetcherId_, ev->Release());
269+
auto& event = *ev->Get();
270+
271+
switch (Format_) {
272+
case EFileFormat::CsvWithNames:
273+
case EFileFormat::TsvWithNames: {
274+
ctx.Send(ArrowFetcherId_, ev->Release());
275+
return;
276+
}
277+
case EFileFormat::Parquet: {
278+
ctx.Send(S3FetcherId_, ev->Release());
279+
return;
280+
}
281+
default: {
282+
ctx.Send(RequesterId_, MakeError(event.Path, NFq::TIssuesIds::UNSUPPORTED, TStringBuilder{} << "unsupported format for inference: " << ConvertFileFormat(Format_)));
283+
return;
284+
}
285+
case EFileFormat::Undefined:
286+
Y_ABORT("Invalid format should be unreachable");
287+
}
288+
264289
}
265290

266291
void HandleFileInference(TEvArrowFile::TPtr& ev, const NActors::TActorContext& ctx) {
@@ -270,34 +295,47 @@ class TArrowInferencinator : public NActors::TActorBootstrapped<TArrowInferencin
270295
ctx.Send(RequesterId_, MakeErrorSchema(file.Path, NFq::TIssuesIds::INTERNAL_ERROR, std::get<TString>(mbArrowFields)));
271296
return;
272297
}
298+
ConvertArrowSchema(std::get<ArrowFields>(mbArrowFields), file.Path, ctx);
299+
}
300+
301+
void HandleFileSchema(TEvArrowSchema::TPtr& ev, const NActors::TActorContext& ctx) {
302+
auto& schema = *ev->Get();
303+
ConvertArrowSchema(schema.Schema->fields(), schema.Path, ctx);
304+
}
305+
306+
void HandleFileError(TEvFileError::TPtr& ev, const NActors::TActorContext& ctx) {
307+
Cout << "TArrowInferencinator::HandleFileError" << Endl;
308+
ctx.Send(RequesterId_, new TEvInferredFileSchema(ev->Get()->Path, std::move(ev->Get()->Issues)));
309+
}
273310

274-
auto& arrowFields = std::get<ArrowFields>(mbArrowFields);
311+
private:
312+
void ConvertArrowSchema(const ArrowFields& fields, const TString& path, const NActors::TActorContext& ctx) const {
275313
std::vector<Ydb::Column> ydbFields;
276-
for (const auto& field : arrowFields) {
314+
for (const auto& field : fields) {
277315
ydbFields.emplace_back();
278316
auto& ydbField = ydbFields.back();
279317
if (!ArrowToYdbType(*ydbField.mutable_type(), *field->type())) {
280-
ctx.Send(RequesterId_, MakeErrorSchema(file.Path, NFq::TIssuesIds::UNSUPPORTED, TStringBuilder{} << "couldn't convert arrow type to ydb: " << field->ToString()));
318+
ctx.Send(RequesterId_, MakeErrorSchema(path, NFq::TIssuesIds::UNSUPPORTED, TStringBuilder{} << "couldn't convert arrow type to ydb: " << field->ToString()));
281319
return;
282320
}
283321
ydbField.mutable_name()->assign(field->name());
284322
}
285-
ctx.Send(RequesterId_, new TEvInferredFileSchema(file.Path, std::move(ydbFields)));
286-
}
287-
288-
void HandleFileError(TEvFileError::TPtr& ev, const NActors::TActorContext& ctx) {
289-
Cout << "TArrowInferencinator::HandleFileError" << Endl;
290-
ctx.Send(RequesterId_, new TEvInferredFileSchema(ev->Get()->Path, std::move(ev->Get()->Issues)));
323+
ctx.Send(RequesterId_, new TEvInferredFileSchema(path, std::move(ydbFields)));
291324
}
292325

293-
private:
294326
EFileFormat Format_;
295327
std::unique_ptr<FormatConfig> Config_;
296328
NActors::TActorId ArrowFetcherId_;
329+
NActors::TActorId S3FetcherId_;
297330
NActors::TActorId RequesterId_;
298331
};
299332

300-
NActors::IActor* CreateArrowInferencinator(NActors::TActorId arrowFetcher, EFileFormat format, const THashMap<TString, TString>& params) {
301-
return new TArrowInferencinator{arrowFetcher, format, params};
333+
NActors::IActor* CreateArrowInferencinator(
334+
NActors::TActorId arrowFetcher,
335+
NActors::TActorId s3Fetcher,
336+
EFileFormat format,
337+
const THashMap<TString, TString>& params) {
338+
339+
return new TArrowInferencinator{arrowFetcher, s3Fetcher, format, params};
302340
}
303341
} // namespace NKikimr::NExternalSource::NObjectStorage::NInference

ydb/core/external_sources/object_storage/inference/arrow_inferencinator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,5 @@ constexpr bool IsArrowInferredFormat(TStringBuf format) {
5353
return IsArrowInferredFormat(ConvertFileFormat(format));
5454
}
5555

56-
NActors::IActor* CreateArrowInferencinator(NActors::TActorId arrowFetcher, EFileFormat format, const THashMap<TString, TString>& params);
56+
NActors::IActor* CreateArrowInferencinator(NActors::TActorId arrowFetcher, NActors::TActorId s3Fetcher, EFileFormat format, const THashMap<TString, TString>& params);
5757
} // namespace NKikimr::NExternalSource::NObjectStorage::NInference

ydb/core/external_sources/object_storage/inference/ut/arrow_inference_ut.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class ArrowInferenceTest : public testing::Test {
5151
NActors::TActorId RegisterInferencinator(TStringBuf formatStr) {
5252
auto format = NInference::ConvertFileFormat(formatStr);
5353
auto arrowFetcher = ActorSystem.Register(NInference::CreateArrowFetchingActor(S3ActorId, format, {}), 1);
54-
return ActorSystem.Register(NInference::CreateArrowInferencinator(arrowFetcher, format, {}), 1);
54+
return ActorSystem.Register(NInference::CreateArrowInferencinator(arrowFetcher, S3ActorId, format, {}), 1);
5555
}
5656

5757
void TearDown() override {
@@ -85,7 +85,7 @@ TEST_F(ArrowInferenceTest, csv_simple) {
8585

8686
auto inferencinatorId = RegisterInferencinator("csv_with_names");
8787
ActorSystem.WrapInActorContext(EdgeActorId, [this, inferencinatorId] {
88-
NActors::TActivationContext::AsActorContext().Send(inferencinatorId, new TEvInferFileSchema(TString{Path}));
88+
NActors::TActivationContext::AsActorContext().Send(inferencinatorId, new TEvInferFileSchema(TString{Path}, 0));
8989
});
9090

9191
std::unique_ptr<NActors::IEventHandle> event = ActorSystem.WaitForEdgeActorEvent({EdgeActorId});
@@ -121,7 +121,7 @@ TEST_F(ArrowInferenceTest, tsv_simple) {
121121

122122
auto inferencinatorId = RegisterInferencinator("tsv_with_names");
123123
ActorSystem.WrapInActorContext(EdgeActorId, [this, inferencinatorId] {
124-
NActors::TActivationContext::AsActorContext().Send(inferencinatorId, new TEvInferFileSchema(TString{Path}));
124+
NActors::TActivationContext::AsActorContext().Send(inferencinatorId, new TEvInferFileSchema(TString{Path}, 0));
125125
});
126126

127127
std::unique_ptr<NActors::IEventHandle> event = ActorSystem.WaitForEdgeActorEvent({EdgeActorId});

ydb/core/external_sources/object_storage/inference/ya.make

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ PEERDIR(
1717
ydb/core/external_sources/object_storage
1818

1919
ydb/library/yql/providers/s3/compressors
20+
ydb/library/yql/providers/common/arrow
2021
)
2122

2223
END()

ydb/core/external_sources/object_storage/s3_fetcher.cpp

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "s3_fetcher.h"
22

33
#include <ydb/library/actors/core/hfunc.h>
4+
#include <ydb/library/yql/providers/common/arrow/interface/arrow_reader.h>
45

56
namespace NKikimr::NExternalSource::NObjectStorage {
67

@@ -23,6 +24,7 @@ class S3Fetcher : public NActors::TActorBootstrapped<S3Fetcher> {
2324

2425
STRICT_STFUNC(WorkingState,
2526
HFunc(TEvRequestS3Range, HandleRequest);
27+
HFunc(TEvInferFileSchema, HandleRequest);
2628

2729
HFunc(TEvS3DownloadResponse, HandleDownloadReponse);
2830
)
@@ -31,6 +33,10 @@ class S3Fetcher : public NActors::TActorBootstrapped<S3Fetcher> {
3133
StartDownload(std::shared_ptr<TEvRequestS3Range>(ev->Release().Release()), ctx.ActorSystem());
3234
}
3335

36+
void HandleRequest(TEvInferFileSchema::TPtr& ev, const NActors::TActorContext& ctx) {
37+
StartDownload(std::shared_ptr<TEvInferFileSchema>(ev->Release().Release()), ctx.ActorSystem(), ev->Sender);
38+
}
39+
3440
void HandleDownloadReponse(TEvS3DownloadResponse::TPtr& ev, const NActors::TActorContext& ctx) {
3541
auto& response = *ev->Get();
3642
auto& result = response.Result;
@@ -60,23 +66,60 @@ class S3Fetcher : public NActors::TActorBootstrapped<S3Fetcher> {
6066

6167
void StartDownload(std::shared_ptr<TEvRequestS3Range>&& request, NActors::TActorSystem* actorSystem) {
6268
auto length = request->End - request->Start;
69+
auto headers = MakeHeaders(request->RequestId.AsGuidString());
70+
71+
Gateway_->Download(
72+
Url_ + request->Path, std::move(headers), request->Start, length,
73+
[actorSystem, selfId = SelfId(), request = std::move(request)](NYql::IHTTPGateway::TResult&& result) mutable {
74+
actorSystem->Send(selfId, new TEvS3DownloadResponse(std::move(request), std::move(result)));
75+
}, {}, RetryPolicy_);
76+
}
77+
78+
void StartDownload(std::shared_ptr<TEvInferFileSchema>&& request, NActors::TActorSystem* actorSystem, NActors::TActorId sender) {
79+
NYql::TArrowFileDesc desc(
80+
Url_ + request->Path,
81+
Gateway_,
82+
MakeHeaders(CreateGuidAsString()),
83+
RetryPolicy_,
84+
request->Size,
85+
"parquet"
86+
);
87+
88+
auto schemaReader = NYql::MakeArrowReader(NYql::TArrowReaderSettings());
89+
auto futureSchema = schemaReader->GetSchema(desc);
90+
futureSchema.Apply([actorSystem, sender, request](NThreading::TFuture<NYql::IArrowReader::TSchemaResponse> response) {
91+
if (response.HasException()) {
92+
try {
93+
response.TryRethrow();
94+
} catch (const yexception& exception) {
95+
auto error = MakeError(
96+
request->Path,
97+
NFq::TIssuesIds::INTERNAL_ERROR,
98+
TStringBuilder() << "couldn't read file schema, check format params: " << exception.what()
99+
);
100+
actorSystem->Send(sender, error);
101+
return;
102+
}
103+
}
104+
105+
actorSystem->Send(sender, new TEvArrowSchema(response.GetValue().Schema, request->Path));
106+
});
107+
}
108+
109+
private:
110+
NYql::IHTTPGateway::THeaders MakeHeaders(const TString& guid) const {
63111
const auto& authInfo = Credentials_.GetAuthInfo();
64112
auto headers = NYql::IHTTPGateway::MakeYcHeaders(
65-
request->RequestId.AsGuidString(),
113+
guid,
66114
authInfo.GetToken(),
67115
{},
68116
authInfo.GetAwsUserPwd(),
69117
authInfo.GetAwsSigV4()
70118
);
71119

72-
Gateway_->Download(
73-
Url_ + request->Path, std::move(headers), request->Start, length,
74-
[actorSystem, selfId = SelfId(), request = std::move(request)](NYql::IHTTPGateway::TResult&& result) mutable {
75-
actorSystem->Send(selfId, new TEvS3DownloadResponse(std::move(request), std::move(result)));
76-
}, {}, RetryPolicy_);
120+
return std::move(headers);
77121
}
78122

79-
private:
80123
TString Url_;
81124
NYql::IHTTPGateway::TPtr Gateway_;
82125
NYql::IHTTPGateway::TRetryPolicy::TPtr RetryPolicy_;

ydb/tests/fq/s3/test_formats.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import ydb.public.api.protos.draft.fq_pb2 as fq
1212

1313
import ydb.tests.fq.s3.s3_helpers as s3_helpers
14-
from ydb.tests.tools.fq_runner.kikimr_utils import yq_all, YQ_STATS_FULL
14+
from ydb.tests.tools.fq_runner.kikimr_utils import yq_all, yq_v2, YQ_STATS_FULL
1515

1616

1717
class TestS3Formats:
@@ -38,6 +38,26 @@ def validate_result(self, result_set):
3838
assert result_set.rows[2].items[0].bytes_value == b"Pear"
3939
assert result_set.rows[2].items[1].int32_value == 15
4040
assert result_set.rows[2].items[2].int32_value == 33
41+
42+
def validate_result_inference(self, result_set):
43+
logging.debug(str(result_set))
44+
assert len(result_set.columns) == 3
45+
assert result_set.columns[0].name == "Fruit"
46+
assert result_set.columns[0].type.type_id == ydb.Type.UTF8
47+
assert result_set.columns[1].name == "Price"
48+
assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.INT64
49+
assert result_set.columns[2].name == "Weight"
50+
assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64
51+
assert len(result_set.rows) == 3
52+
assert result_set.rows[0].items[0].text_value == "Banana"
53+
assert result_set.rows[0].items[1].int64_value == 3
54+
assert result_set.rows[0].items[2].int64_value == 100
55+
assert result_set.rows[1].items[0].text_value == "Apple"
56+
assert result_set.rows[1].items[1].int64_value == 2
57+
assert result_set.rows[1].items[2].int64_value == 22
58+
assert result_set.rows[2].items[0].text_value == "Pear"
59+
assert result_set.rows[2].items[1].int64_value == 15
60+
assert result_set.rows[2].items[2].int64_value == 33
4161

4262
def validate_pg_result(self, result_set):
4363
logging.debug(str(result_set))
@@ -104,6 +124,33 @@ def test_format(self, kikimr, s3, client, filename, type_format, yq_version, uni
104124
if type_format != "json_list":
105125
assert stat["ResultSet"]["IngressRows"]["sum"] == 3
106126

127+
@yq_v2
128+
@pytest.mark.parametrize(
129+
"filename, type_format",
130+
[
131+
("test.csv", "csv_with_names"),
132+
("test.tsv", "tsv_with_names"),
133+
("test.parquet", "parquet"),
134+
],
135+
)
136+
def test_format_inference(self, kikimr, s3, client, filename, type_format, unique_prefix):
137+
self.create_bucket_and_upload_file(filename, s3, kikimr)
138+
storage_connection_name = unique_prefix + "fruitbucket"
139+
client.create_storage_connection(storage_connection_name, "fbucket")
140+
141+
sql = f'''
142+
SELECT *
143+
FROM `{storage_connection_name}`.`{filename}`
144+
WITH (format=`{type_format}`, with_infer='true');
145+
'''
146+
147+
query_id = client.create_query("simple", sql, type=fq.QueryContent.QueryType.ANALYTICS).result.query_id
148+
client.wait_query_status(query_id, fq.QueryMeta.COMPLETED)
149+
150+
data = client.get_result_data(query_id)
151+
result_set = data.result.result_set
152+
self.validate_result_inference(result_set)
153+
107154
@yq_all
108155
def test_btc(self, kikimr, s3, client, unique_prefix):
109156
self.create_bucket_and_upload_file("btct.parquet", s3, kikimr)

0 commit comments

Comments
 (0)