Skip to content

Commit 26a0653

Browse files
samples: Automl table batch test [(#4267)](GoogleCloudPlatform/python-docs-samples#4267)
* added rtest req.txt * samples: added automl batch predict test * added missing package * Update tables/automl/batch_predict_test.py Co-authored-by: Bu Sun Kim <[email protected]> Co-authored-by: Bu Sun Kim <[email protected]>
1 parent bb71346 commit 26a0653

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

samples/tables/batch_predict_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2020 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import os
18+
19+
from google.cloud.automl_v1beta1.gapic import enums
20+
21+
import pytest
22+
23+
import automl_tables_model
24+
import automl_tables_predict
25+
import model_test
26+
27+
28+
PROJECT = os.environ["GOOGLE_CLOUD_PROJECT"]
29+
REGION = "us-central1"
30+
STATIC_MODEL = model_test.STATIC_MODEL
31+
GCS_INPUT = "gs://{}-automl-tables-test/bank-marketing.csv".format(PROJECT)
32+
GCS_OUTPUT = "gs://{}-automl-tables-test/TABLE_TEST_OUTPUT/".format(PROJECT)
33+
34+
35+
@pytest.mark.slow
36+
def test_batch_predict(capsys):
37+
ensure_model_online()
38+
automl_tables_predict.batch_predict(
39+
PROJECT, REGION, STATIC_MODEL, GCS_INPUT, GCS_OUTPUT
40+
)
41+
out, _ = capsys.readouterr()
42+
assert "Batch prediction complete" in out
43+
44+
45+
def ensure_model_online():
46+
model = model_test.ensure_model_ready()
47+
if model.deployment_state != enums.Model.DeploymentState.DEPLOYED:
48+
automl_tables_model.deploy_model(PROJECT, REGION, model.display_name)
49+
50+
return automl_tables_model.get_model(PROJECT, REGION, model.display_name)

0 commit comments

Comments
 (0)