Skip to content

Commit 67bd83e

Browse files
author
Rebecca Taylor
committed
automl_vision_batch_predict.py
1 parent b76e222 commit 67bd83e

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright 2019 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# https://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# DO NOT EDIT! This is a generated sample ("LongRunningPromise", "automl_vision_batch_predict")
18+
19+
# To install the latest published package dependency, execute the following:
20+
# pip install google-cloud-automl
21+
22+
# sample-metadata
23+
# title: AutoML Batch Predict (AutoML Vision)
24+
# description: AutoML Batch Predict using AutoML Vision
25+
# usage: python3 samples/v1beta1/automl_vision_batch_predict.py [--input_uri "gs://[BUCKET-NAME]/path/to/file-with-image-urls.csv"] [--output_uri "gs://[BUCKET-NAME]/directory-for-output-files/"] [--project "[Google Cloud Project ID]"] [--model_id "[Model ID]"]
26+
27+
# [START automl_vision_batch_predict]
28+
from google.cloud import automl_v1beta1
29+
30+
31+
def sample_batch_predict(input_uri, output_uri, project, model_id):
32+
"""
33+
AutoML Batch Predict using AutoML Vision
34+
35+
Args:
36+
input_uri Google Cloud Storage URI to CSV file in your bucket that contains the
37+
paths to the images to annotate, e.g. gs://[BUCKET-NAME]/path/to/images.csv
38+
Each line specifies a separate path to an image in Google Cloud Storage.
39+
output_uri Identifies where to store the output of your prediction request
40+
in your Google Cloud Storage bucket.
41+
You must have write permissions to the Google Cloud Storage bucket.
42+
project Required. Your Google Cloud Project ID.
43+
model_id Model ID, e.g. VOT1234567890123456789
44+
"""
45+
46+
client = automl_v1beta1.PredictionServiceClient()
47+
48+
# input_uri = 'gs://[BUCKET-NAME]/path/to/file-with-image-urls.csv'
49+
# output_uri = 'gs://[BUCKET-NAME]/directory-for-output-files/'
50+
# project = '[Google Cloud Project ID]'
51+
# model_id = '[Model ID]'
52+
name = client.model_path(project, "us-central1", model_id)
53+
input_uris = [input_uri]
54+
gcs_source = {"input_uris": input_uris}
55+
input_config = {"gcs_source": gcs_source}
56+
gcs_destination = {"output_uri_prefix": output_uri}
57+
output_config = {"gcs_destination": gcs_destination}
58+
59+
# A value from 0.0 to 1.0. When the model detects objects on video frames,
60+
# it will only produce bounding boxes that have at least this confidence score.
61+
# The default is 0.5.
62+
params_item = "0.0"
63+
params = {"score_threshold": params_item}
64+
65+
operation = client.batch_predict(name, input_config, output_config, params=params)
66+
67+
print(u"Waiting for operation to complete...")
68+
response = operation.result()
69+
70+
print(u"Batch Prediction results saved to specified Cloud Storage bucket.")
71+
72+
73+
# [END automl_vision_batch_predict]
74+
75+
76+
def main():
77+
import argparse
78+
79+
parser = argparse.ArgumentParser()
80+
parser.add_argument(
81+
"--input_uri",
82+
type=str,
83+
default="gs://[BUCKET-NAME]/path/to/file-with-image-urls.csv",
84+
)
85+
parser.add_argument(
86+
"--output_uri",
87+
type=str,
88+
default="gs://[BUCKET-NAME]/directory-for-output-files/",
89+
)
90+
parser.add_argument("--project", type=str, default="[Google Cloud Project ID]")
91+
parser.add_argument("--model_id", type=str, default="[Model ID]")
92+
args = parser.parse_args()
93+
94+
sample_batch_predict(args.input_uri, args.output_uri, args.project, args.model_id)
95+
96+
97+
if __name__ == "__main__":
98+
main()

0 commit comments

Comments
 (0)