Skip to content

Commit d1510d6

Browse files
nnegreyleahecole
authored andcommitted
automl: add translate ga samples [(#2679)](#2679)
* automl: add translate ga samples * While still testing python2 on kokoro, use unicode print for non-ascii strings Co-authored-by: Leah E. Cole <[email protected]>
1 parent 9b5f526 commit d1510d6

7 files changed

+251
-0
lines changed

automl/snippets/resources/input.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Tell me how this ends
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_dataset(project_id, display_name):
17+
"""Create a dataset."""
18+
# [START automl_translate_create_dataset]
19+
from google.cloud import automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# display_name = "YOUR_DATASET_NAME"
24+
25+
client = automl.AutoMlClient()
26+
27+
# A resource that represents Google Cloud Platform location.
28+
project_location = client.location_path(project_id, "us-central1")
29+
dataset_metadata = automl.types.TranslationDatasetMetadata(
30+
source_language_code="en", target_language_code="ja"
31+
)
32+
dataset = automl.types.Dataset(
33+
display_name=display_name,
34+
translation_dataset_metadata=dataset_metadata,
35+
)
36+
37+
# Create a dataset with the dataset metadata in the region.
38+
response = client.create_dataset(project_location, dataset)
39+
40+
created_dataset = response.result()
41+
42+
# Display the dataset information
43+
print("Dataset name: {}".format(created_dataset.name))
44+
print("Dataset id: {}".format(created_dataset.name.split("/")[-1]))
45+
# [END automl_translate_create_dataset]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
import os
17+
18+
from google.cloud import automl
19+
20+
import translate_create_dataset
21+
22+
23+
PROJECT_ID = os.environ["GCLOUD_PROJECT"]
24+
25+
26+
def test_translate_create_dataset(capsys):
27+
# create dataset
28+
dataset_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
29+
translate_create_dataset.create_dataset(PROJECT_ID, dataset_name)
30+
out, _ = capsys.readouterr()
31+
assert "Dataset id: " in out
32+
33+
# Delete the created dataset
34+
dataset_id = out.splitlines()[1].split()[2]
35+
client = automl.AutoMlClient()
36+
dataset_full_id = client.dataset_path(
37+
PROJECT_ID, "us-central1", dataset_id
38+
)
39+
response = client.delete_dataset(dataset_full_id)
40+
response.result()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_model(project_id, dataset_id, display_name):
17+
"""Create a model."""
18+
# [START automl_translate_create_model]
19+
from google.cloud import automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# dataset_id = "YOUR_DATASET_ID"
24+
# display_name = "YOUR_MODEL_NAME"
25+
26+
client = automl.AutoMlClient()
27+
28+
# A resource that represents Google Cloud Platform location.
29+
project_location = client.location_path(project_id, "us-central1")
30+
# Leave model unset to use the default base model provided by Google
31+
translation_model_metadata = automl.types.TranslationModelMetadata()
32+
model = automl.types.Model(
33+
display_name=display_name,
34+
dataset_id=dataset_id,
35+
translation_model_metadata=translation_model_metadata,
36+
)
37+
38+
# Create a model with the model metadata in the region.
39+
response = client.create_model(project_location, model)
40+
41+
print("Training operation name: {}".format(response.operation.name))
42+
print("Training started...")
43+
# [END automl_translate_create_model]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud import automl
18+
19+
import translate_create_model
20+
21+
PROJECT_ID = os.environ["GCLOUD_PROJECT"]
22+
DATASET_ID = "TRL3876092572857648864"
23+
24+
25+
def test_translate_create_model(capsys):
26+
translate_create_model.create_model(
27+
PROJECT_ID, DATASET_ID, "translate_test_create_model"
28+
)
29+
out, _ = capsys.readouterr()
30+
assert "Training started" in out
31+
32+
# Cancel the operation
33+
operation_id = out.split("Training operation name: ")[1].split("\n")[0]
34+
client = automl.AutoMlClient()
35+
client.transport._operations_client.cancel_operation(operation_id)

automl/snippets/translate_predict.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def predict(project_id, model_id, file_path):
17+
"""Predict."""
18+
# [START automl_translate_predict]
19+
from google.cloud import automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# model_id = "YOUR_MODEL_ID"
24+
# file_path = "path_to_local_file.txt"
25+
26+
prediction_client = automl.PredictionServiceClient()
27+
28+
# Get the full path of the model.
29+
model_full_id = prediction_client.model_path(
30+
project_id, "us-central1", model_id
31+
)
32+
33+
# Read the file content for translation.
34+
with open(file_path, "rb") as content_file:
35+
content = content_file.read()
36+
content.decode("utf-8")
37+
38+
text_snippet = automl.types.TextSnippet(content=content)
39+
payload = automl.types.ExamplePayload(text_snippet=text_snippet)
40+
41+
response = prediction_client.predict(model_full_id, payload)
42+
translated_content = response.payload[0].translation.translated_content
43+
44+
print(u"Translated content: {}".format(translated_content.content))
45+
# [END automl_translate_predict]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud import automl
18+
import pytest
19+
20+
import translate_predict
21+
22+
PROJECT_ID = os.environ["GCLOUD_PROJECT"]
23+
MODEL_ID = "TRL3128559826197068699"
24+
25+
26+
@pytest.fixture(scope="function")
27+
def verify_model_state():
28+
client = automl.AutoMlClient()
29+
model_full_id = client.model_path(PROJECT_ID, "us-central1", MODEL_ID)
30+
31+
model = client.get_model(model_full_id)
32+
if model.deployment_state == automl.enums.Model.DeploymentState.UNDEPLOYED:
33+
# Deploy model if it is not deployed
34+
response = client.deploy_model(model_full_id)
35+
response.result()
36+
37+
38+
def test_predict(capsys, verify_model_state):
39+
verify_model_state
40+
translate_predict.predict(PROJECT_ID, MODEL_ID, "resources/input.txt")
41+
out, _ = capsys.readouterr()
42+
assert "Translated content: " in out

0 commit comments

Comments
 (0)