Skip to content

Commit 7182374

Browse files
nnegreyleahecole
andauthored
automl: video beta - move beta samples out of branch and into master (#2750)
* automl: video beta - move beta samples out of branch and into master * lint * update error message on batch predict Co-authored-by: Leah E. Cole <[email protected]>
1 parent 18dc311 commit 7182374

6 files changed

+283
-0
lines changed

automl/beta/batch_predict.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def batch_predict(project_id, model_id, input_uri, output_uri):
17+
"""Batch predict"""
18+
# [START automl_batch_predict_beta]
19+
from google.cloud import automl_v1beta1 as automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# model_id = "YOUR_MODEL_ID"
24+
# input_uri = "gs://YOUR_BUCKET_ID/path/to/your/input/csv_or_jsonl"
25+
# output_uri = "gs://YOUR_BUCKET_ID/path/to/save/results/"
26+
27+
prediction_client = automl.PredictionServiceClient()
28+
29+
# Get the full path of the model.
30+
model_full_id = prediction_client.model_path(
31+
project_id, "us-central1", model_id
32+
)
33+
34+
gcs_source = automl.types.GcsSource(input_uris=[input_uri])
35+
36+
input_config = automl.types.BatchPredictInputConfig(gcs_source=gcs_source)
37+
gcs_destination = automl.types.GcsDestination(output_uri_prefix=output_uri)
38+
output_config = automl.types.BatchPredictOutputConfig(
39+
gcs_destination=gcs_destination
40+
)
41+
42+
response = prediction_client.batch_predict(
43+
model_full_id, input_config, output_config
44+
)
45+
46+
print("Waiting for operation to complete...")
47+
print(
48+
"Batch Prediction results saved to Cloud Storage bucket. {}".format(
49+
response.result()
50+
)
51+
)
52+
# [END automl_batch_predict_beta]

automl/beta/batch_predict_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific ladnguage governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
import os
17+
18+
import batch_predict
19+
20+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
21+
BUCKET_ID = "{}-lcm".format(PROJECT_ID)
22+
MODEL_ID = "TEN0000000000000000000"
23+
PREFIX = "TEST_EXPORT_OUTPUT_" + datetime.datetime.now().strftime(
24+
"%Y%m%d%H%M%S"
25+
)
26+
27+
28+
def test_batch_predict(capsys):
29+
# As batch prediction can take a long time. Try to batch predict on a model
30+
# and confirm that the model was not found, but other elements of the
31+
# request were valid.
32+
try:
33+
input_uri = "gs://{}/entity-extraction/input.jsonl".format(BUCKET_ID)
34+
output_uri = "gs://{}/{}/".format(BUCKET_ID, PREFIX)
35+
batch_predict.batch_predict(
36+
PROJECT_ID, MODEL_ID, input_uri, output_uri
37+
)
38+
out, _ = capsys.readouterr()
39+
assert (
40+
"does not exist"
41+
in out
42+
)
43+
except Exception as e:
44+
assert (
45+
"does not exist"
46+
in e.message
47+
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_dataset(project_id, display_name):
17+
"""Create a dataset."""
18+
# [START automl_video_classification_create_dataset_beta]
19+
from google.cloud import automl_v1beta1 as automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# display_name = "your_datasets_display_name"
24+
25+
client = automl.AutoMlClient()
26+
27+
# A resource that represents Google Cloud Platform location.
28+
project_location = client.location_path(project_id, "us-central1")
29+
metadata = automl.types.VideoClassificationDatasetMetadata()
30+
dataset = automl.types.Dataset(
31+
display_name=display_name,
32+
video_classification_dataset_metadata=metadata,
33+
)
34+
35+
# Create a dataset with the dataset metadata in the region.
36+
created_dataset = client.create_dataset(project_location, dataset)
37+
38+
# Display the dataset information
39+
print("Dataset name: {}".format(created_dataset.name))
40+
# To get the dataset id, you have to parse it out of the `name` field.
41+
# As dataset Ids are required for other methods.
42+
# Name Form:
43+
# `projects/{project_id}/locations/{location_id}/datasets/{dataset_id}`
44+
print("Dataset id: {}".format(created_dataset.name.split("/")[-1]))
45+
# [END automl_video_classification_create_dataset_beta]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import datetime
16+
import os
17+
18+
from google.cloud import automl_v1beta1 as automl
19+
import pytest
20+
21+
import video_classification_create_dataset
22+
23+
24+
PROJECT_ID = os.environ["AUTOML_PROJECT_ID"]
25+
pytest.DATASET_ID = None
26+
27+
28+
@pytest.fixture(scope="function", autouse=True)
29+
def teardown():
30+
yield
31+
32+
# Delete the created dataset
33+
client = automl.AutoMlClient()
34+
dataset_full_id = client.dataset_path(
35+
PROJECT_ID, "us-central1", pytest.DATASET_ID
36+
)
37+
response = client.delete_dataset(dataset_full_id)
38+
response.result()
39+
40+
41+
def test_video_classification_create_dataset(capsys):
42+
# create dataset
43+
dataset_name = "test_" + datetime.datetime.now().strftime("%Y%m%d%H%M%S")
44+
video_classification_create_dataset.create_dataset(
45+
PROJECT_ID, dataset_name
46+
)
47+
out, _ = capsys.readouterr()
48+
assert "Dataset id: " in out
49+
50+
# Get the the created dataset id for deletion
51+
pytest.DATASET_ID = out.splitlines()[1].split()[2]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_model(project_id, dataset_id, display_name):
17+
"""Create a model."""
18+
# [START automl_video_classification_create_model_beta]
19+
from google.cloud import automl_v1beta1 as automl
20+
21+
# TODO(developer): Uncomment and set the following variables
22+
# project_id = "YOUR_PROJECT_ID"
23+
# dataset_id = "YOUR_DATASET_ID"
24+
# display_name = "your_models_display_name"
25+
26+
client = automl.AutoMlClient()
27+
28+
# A resource that represents Google Cloud Platform location.
29+
project_location = client.location_path(project_id, "us-central1")
30+
metadata = automl.types.VideoClassificationModelMetadata()
31+
model = automl.types.Model(
32+
display_name=display_name,
33+
dataset_id=dataset_id,
34+
video_classification_model_metadata=metadata,
35+
)
36+
37+
# Create a model with the model metadata in the region.
38+
response = client.create_model(project_location, model)
39+
40+
print("Training operation name: {}".format(response.operation.name))
41+
print("Training started...")
42+
# [END automl_video_classification_create_model_beta]
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
from google.cloud import automl_v1beta1 as automl
18+
import pytest
19+
20+
import video_classification_create_model
21+
22+
PROJECT_ID = os.environ["GCLOUD_PROJECT"]
23+
DATASET_ID = "VCN510437278078730240"
24+
pytest.OPERATION_ID = None
25+
26+
27+
@pytest.fixture(scope="function", autouse=True)
28+
def teardown():
29+
yield
30+
31+
# Cancel the operation
32+
client = automl.AutoMlClient()
33+
client.transport._operations_client.cancel_operation(pytest.OPERATION_ID)
34+
35+
36+
def test_video_classification_create_model(capsys):
37+
video_classification_create_model.create_model(
38+
PROJECT_ID, DATASET_ID, "classification_test_create_model"
39+
)
40+
out, _ = capsys.readouterr()
41+
assert "Training started" in out
42+
43+
# Get the the operation id for cancellation
44+
pytest.OPERATION_ID = out.split("Training operation name: ")[1].split(
45+
"\n"
46+
)[0]

0 commit comments

Comments
 (0)