Skip to content
This repository was archived by the owner on Dec 31, 2023. It is now read-only.

Commit 11c627b

Browse files
nnegreyleahecole
authored andcommitted
Automl ga base samples [(#2613)](GoogleCloudPlatform/python-docs-samples#2613)
* automl: add base samples * automl: add base set of samples * Clean up tests * License year 2020, drop python2 print statement unicode * use centralized automl testing project * Fix GCS path typo * Use fake dataset for batch predict * lint: line length * fix fixture naming and use * Fix fixture changes * Catch resource exhausted error * use fake data for import test * update how to access an operation id Co-authored-by: Leah E. Cole <[email protected]>
1 parent 49594c6 commit 11c627b

12 files changed

+277
-58
lines changed

samples/snippets/batch_predict.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def batch_predict(project_id, model_id, input_uri, output_uri):
17+
"""Batch predict"""
18+
# [START automl_batch_predict]
19+
from google.cloud import automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# model_id = "YOUR_MODEL_ID"
24+
# input_uri = "gs://YOUR_BUCKET_ID/path/to/your/input/csv_or_jsonl"
25+
# output_uri = "gs://YOUR_BUCKET_ID/path/to/save/results/"
26+
27+
prediction_client = automl.PredictionServiceClient()
28+
29+
# Get the full path of the model.
30+
model_full_id = prediction_client.model_path(
31+
project_id, "us-central1", model_id
32+
)
33+
34+
gcs_source = automl.types.GcsSource(input_uris=[input_uri])
35+
36+
input_config = automl.types.BatchPredictInputConfig(gcs_source=gcs_source)
37+
gcs_destination = automl.types.GcsDestination(output_uri_prefix=output_uri)
38+
output_config = automl.types.BatchPredictOutputConfig(
39+
gcs_destination=gcs_destination
40+
)
41+
42+
response = prediction_client.batch_predict(
43+
model_full_id, input_config, output_config
44+
)
45+
46+
print("Waiting for operation to complete...")
47+
print(
48+
"Batch Prediction results saved to Cloud Storage bucket. {}".format(
49+
response.result()
50+
)
51+
)
52+
# [END automl_batch_predict]
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific ladnguage governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
import os
17+
18+
import batch_predict
19+
20+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
21+
BUCKET_ID = "{}-lcm".format(PROJECT_ID)
22+
MODEL_ID = "TEN0000000000000000000"
23+
PREFIX = "TEST_EXPORT_OUTPUT_" + datetime.datetime.now().strftime(
24+
"%Y%m%d%H%M%S"
25+
)
26+
27+
28+
def test_batch_predict(capsys):
29+
# As batch prediction can take a long time. Try to batch predict on a model
30+
# and confirm that the model was not found, but other elements of the
31+
# request were valid.
32+
try:
33+
input_uri = "gs://{}/entity-extraction/input.jsonl".format(BUCKET_ID)
34+
output_uri = "gs://{}/{}/".format(BUCKET_ID, PREFIX)
35+
batch_predict.batch_predict(
36+
PROJECT_ID, MODEL_ID, input_uri, output_uri
37+
)
38+
out, _ = capsys.readouterr()
39+
assert (
40+
"The model is either not found or not supported for prediction yet"
41+
in out
42+
)
43+
except Exception as e:
44+
assert (
45+
"The model is either not found or not supported for prediction yet"
46+
in e.message
47+
)

samples/snippets/delete_dataset_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
@pytest.fixture(scope="function")
28-
def create_dataset():
28+
def dataset_id():
2929
client = automl.AutoMlClient()
3030
project_location = client.location_path(PROJECT_ID, "us-central1")
3131
display_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
@@ -39,8 +39,8 @@ def create_dataset():
3939
yield dataset_id
4040

4141

42-
def test_delete_dataset(capsys, create_dataset):
42+
def test_delete_dataset(capsys, dataset_id):
4343
# delete dataset
44-
delete_dataset.delete_dataset(PROJECT_ID, create_dataset)
44+
delete_dataset.delete_dataset(PROJECT_ID, dataset_id)
4545
out, _ = capsys.readouterr()
4646
assert "Dataset deleted." in out

samples/snippets/get_model_evaluation_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
@pytest.fixture(scope="function")
27-
def get_evaluation_id():
27+
def model_evaluation_id():
2828
client = automl.AutoMlClient()
2929
model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID)
3030
evaluation = None
@@ -37,9 +37,9 @@ def get_evaluation_id():
3737
yield model_evaluation_id
3838

3939

40-
def test_get_model_evaluation(capsys, get_evaluation_id):
40+
def test_get_model_evaluation(capsys, model_evaluation_id):
4141
get_model_evaluation.get_model_evaluation(
42-
PROJECT_ID, MODEL_ID, get_evaluation_id
42+
PROJECT_ID, MODEL_ID, model_evaluation_id
4343
)
4444
out, _ = capsys.readouterr()
4545
assert "Model evaluation name: " in out
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def get_operation_status(operation_full_id):
17+
"""Get operation status."""
18+
# [START automl_get_operation_status]
19+
from google.cloud import automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# operation_full_id = \
23+
# "projects/[projectId]/locations/us-central1/operations/[operationId]"
24+
25+
client = automl.AutoMlClient()
26+
# Get the latest state of a long-running operation.
27+
response = client.transport._operations_client.get_operation(
28+
operation_full_id
29+
)
30+
31+
print("Name: {}".format(response.name))
32+
print("Operation details:")
33+
print(response)
34+
# [END automl_get_operation_status]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud import automl
18+
import pytest
19+
20+
import get_operation_status
21+
22+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
23+
24+
25+
@pytest.fixture(scope="function")
26+
def operation_id():
27+
client = automl.AutoMlClient()
28+
project_location = client.location_path(PROJECT_ID, "us-central1")
29+
generator = client.transport._operations_client.list_operations(
30+
project_location, filter_=""
31+
).pages
32+
page = next(generator)
33+
operation = page.next()
34+
yield operation.name
35+
36+
37+
def test_get_operation_status(capsys, operation_id):
38+
get_operation_status.get_operation_status(operation_id)
39+
out, _ = capsys.readouterr()
40+
assert "Operation details" in out

samples/snippets/import_dataset_test.py

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,49 +12,30 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import datetime
1615
import os
1716

18-
from google.cloud import automl
19-
import pytest
20-
2117
import import_dataset
2218

2319
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
2420
BUCKET_ID = "{}-lcm".format(PROJECT_ID)
25-
26-
27-
@pytest.fixture(scope="function")
28-
def create_dataset():
29-
client = automl.AutoMlClient()
30-
project_location = client.location_path(PROJECT_ID, "us-central1")
31-
display_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
32-
metadata = automl.types.TextSentimentDatasetMetadata(
33-
sentiment_max=4
34-
)
35-
dataset = automl.types.Dataset(
36-
display_name=display_name, text_sentiment_dataset_metadata=metadata
37-
)
38-
response = client.create_dataset(project_location, dataset)
39-
dataset_id = response.result().name.split("/")[-1]
40-
41-
yield dataset_id
42-
43-
44-
@pytest.mark.slow
45-
def test_import_dataset(capsys, create_dataset):
46-
data = (
47-
"gs://{}/sentiment-analysis/dataset.csv".format(BUCKET_ID)
48-
)
49-
dataset_id = create_dataset
50-
import_dataset.import_dataset(PROJECT_ID, dataset_id, data)
51-
out, _ = capsys.readouterr()
52-
assert "Data imported." in out
53-
54-
# delete created dataset
55-
client = automl.AutoMlClient()
56-
dataset_full_id = client.dataset_path(
57-
PROJECT_ID, "us-central1", dataset_id
58-
)
59-
response = client.delete_dataset(dataset_full_id)
60-
response.result()
21+
DATASET_ID = "TEN0000000000000000000"
22+
23+
24+
def test_import_dataset(capsys):
25+
# As importing a dataset can take a long time and only four operations can
26+
# be run on a dataset at once. Try to import into a nonexistent dataset and
27+
# confirm that the dataset was not found, but other elements of the request
28+
# were valid.
29+
try:
30+
data = "gs://{}/sentiment-analysis/dataset.csv".format(BUCKET_ID)
31+
import_dataset.import_dataset(PROJECT_ID, DATASET_ID, data)
32+
out, _ = capsys.readouterr()
33+
assert (
34+
"The Dataset doesn't exist or is inaccessible for use with AutoMl."
35+
in out
36+
)
37+
except Exception as e:
38+
assert (
39+
"The Dataset doesn't exist or is inaccessible for use with AutoMl."
40+
in e.message
41+
)

samples/snippets/language_sentiment_analysis_predict_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
MODEL_ID = os.environ["SENTIMENT_ANALYSIS_MODEL_ID"]
2424

2525

26-
@pytest.fixture(scope="function")
27-
def verify_model_state():
26+
@pytest.fixture(scope="function", autouse=True)
27+
def setup():
28+
# Verify the model is deployed before trying to predict
2829
client = automl.AutoMlClient()
2930
model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID)
3031

@@ -35,8 +36,7 @@ def verify_model_state():
3536
response.result()
3637

3738

38-
def test_sentiment_analysis_predict(capsys, verify_model_state):
39-
verify_model_state
39+
def test_sentiment_analysis_predict(capsys):
4040
text = "Hopefully this Claritin kicks in soon"
4141
language_sentiment_analysis_predict.predict(PROJECT_ID, MODEL_ID, text)
4242
out, _ = capsys.readouterr()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def list_operation_status(project_id):
17+
"""List operation status."""
18+
# [START automl_list_operation_status]
19+
from google.cloud import automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
24+
client = automl.AutoMlClient()
25+
# A resource that represents Google Cloud Platform location.
26+
project_location = client.location_path(project_id, "us-central1")
27+
# List all the operations names available in the region.
28+
response = client.transport._operations_client.list_operations(
29+
project_location, ""
30+
)
31+
32+
print("List of operations:")
33+
for operation in response:
34+
print("Name: {}".format(operation.name))
35+
print("Operation details:")
36+
print(operation)
37+
# [END automl_list_operation_status]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import pytest
18+
19+
import list_operation_status
20+
21+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
22+
23+
24+
@pytest.mark.slow
25+
def test_list_operation_status(capsys):
26+
list_operation_status.list_operation_status(PROJECT_ID)
27+
out, _ = capsys.readouterr()
28+
assert "Operation details" in out

samples/snippets/translate_predict_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@
2323
MODEL_ID = os.environ["TRANSLATION_MODEL_ID"]
2424

2525

26-
@pytest.fixture(scope="function")
27-
def verify_model_state():
26+
@pytest.fixture(scope="function", autouse=True)
27+
def setup():
28+
# Verify the model is deployed before trying to predict
2829
client = automl.AutoMlClient()
2930
model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID)
3031

@@ -35,8 +36,7 @@ def verify_model_state():
3536
response.result()
3637

3738

38-
def test_translate_predict(capsys, verify_model_state):
39-
verify_model_state
39+
def test_translate_predict(capsys):
4040
translate_predict.predict(PROJECT_ID, MODEL_ID, "resources/input.txt")
4141
out, _ = capsys.readouterr()
4242
assert "Translated content: " in out

0 commit comments

Comments
 (0)