diff --git a/package-lock.json b/package-lock.json index faacfd5dbf..4c5b575b3a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -240,9 +240,9 @@ } }, "@google-cloud/common": { - "version": "3.3.1", - "resolved": "https://registry.npmjs.org/@google-cloud/common/-/common-3.3.1.tgz", - "integrity": "sha512-bJamcNvZ2j5xS01uFBT1GqfHIKrtwpyUhIU/Xn3uwMZkK/t6JA3mlID0wuZlo7XjbjFSRT2iLBEmDWv9T2hP8g==", + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/@google-cloud/common/-/common-3.3.2.tgz", + "integrity": "sha512-W7JRLBEJWYtZQQuGQX06U6GBOSLrSrlvZxv6kGNwJtFrusu6AVgZltQ9Pajuz9Dh9aSXy9aTnBcyxn2/O0EGUw==", "optional": true, "requires": { "@google-cloud/projectify": "^2.0.0", @@ -309,15 +309,15 @@ "optional": true }, "@google-cloud/promisify": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/@google-cloud/promisify/-/promisify-2.0.1.tgz", - "integrity": "sha512-82EQzwrNauw1fkbUSr3f+50Bcq7g4h0XvLOk8C5e9ABkXYHei7ZPi9tiMMD7Vh3SfcdH97d1ibJ3KBWp2o1J+w==", + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/@google-cloud/promisify/-/promisify-2.0.2.tgz", + "integrity": "sha512-EvuabjzzZ9E2+OaYf+7P9OAiiwbTxKYL0oGLnREQd+Su2NTQBpomkdlkBowFvyWsaV0d1sSGxrKpSNcrhPqbxg==", "optional": true }, "@google-cloud/storage": { - "version": "5.1.1", - "resolved": "https://registry.npmjs.org/@google-cloud/storage/-/storage-5.1.1.tgz", - "integrity": "sha512-w/64V+eJl+vpYUXT15sBcO8pX0KTmb9Ni2ZNuQQ8HmyhAbEA3//G8JFaLPCXGBWO2/b0OQZytUT6q2wII9a9aQ==", + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/@google-cloud/storage/-/storage-5.1.2.tgz", + "integrity": "sha512-j2blsBVv6Tt5Z7ff6kOSIg5zVQPdlcTQh/4zMb9h7xMj4ekwndQA60le8c1KEa+Y6SR3EM6ER2AvKYK53P7vdQ==", "optional": true, "requires": { "@google-cloud/common": "^3.0.0", @@ -340,7 +340,7 @@ "readable-stream": "^3.4.0", "snakeize": "^0.1.0", "stream-events": "^1.0.1", - "through2": "^3.0.0", + "through2": "^4.0.0", "xdg-basedir": "^4.0.0" }, "dependencies": { @@ -354,6 +354,15 @@ "string_decoder": "^1.1.1", "util-deprecate": "^1.0.1" } + }, + "through2": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/through2/-/through2-4.0.2.tgz", + "integrity": "sha512-iOqSav00cVxEEICeD7TjLB1sueEL+81Wpzp2bY17uZjZN0pWZPuo4suZ/61VujxmqSGFfgOcNuTZ85QJwNZQpw==", + "optional": true, + "requires": { + "readable-stream": "3" + } } } }, @@ -538,7 +547,7 @@ }, "@types/firebase-token-generator": { "version": "2.0.28", - "resolved": "https://registry.npmjs.org/@types/firebase-token-generator/-/firebase-token-generator-2.0.28.tgz", + "resolved": "http://registry.npmjs.org/@types/firebase-token-generator/-/firebase-token-generator-2.0.28.tgz", "integrity": "sha1-Z1VIHZMk4mt6XItFXWgUg3aCw5Y=", "dev": true }, @@ -577,7 +586,7 @@ }, "@types/minimist": { "version": "1.2.0", - "resolved": "https://registry.npmjs.org/@types/minimist/-/minimist-1.2.0.tgz", + "resolved": "http://registry.npmjs.org/@types/minimist/-/minimist-1.2.0.tgz", "integrity": "sha1-aaI6OtKcrwCX8G7aWbNh7i8GOfY=", "dev": true }, @@ -1265,7 +1274,7 @@ }, "binaryextensions": { "version": "1.0.1", - "resolved": "https://registry.npmjs.org/binaryextensions/-/binaryextensions-1.0.1.tgz", + "resolved": "http://registry.npmjs.org/binaryextensions/-/binaryextensions-1.0.1.tgz", "integrity": "sha1-HmN0iLNbWL2l9HdL+WpSEqjJB1U=", "dev": true }, @@ -2971,7 +2980,7 @@ }, "firebase-token-generator": { "version": "2.0.0", - "resolved": "https://registry.npmjs.org/firebase-token-generator/-/firebase-token-generator-2.0.0.tgz", + "resolved": "http://registry.npmjs.org/firebase-token-generator/-/firebase-token-generator-2.0.0.tgz", "integrity": "sha1-l2fXWewTq9yZuhFf1eqZ2Lk9EgY=", "dev": true }, @@ -3180,9 +3189,9 @@ } }, "gcs-resumable-upload": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/gcs-resumable-upload/-/gcs-resumable-upload-3.1.0.tgz", - "integrity": "sha512-gB8xH6EjYCv9lfBEL4FK5+AMgTY0feYoNHAYOV5nCuOrDPhy5MOiyJE8WosgxhbKBPS361H7fkwv6CTufEh9bg==", + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/gcs-resumable-upload/-/gcs-resumable-upload-3.1.1.tgz", + "integrity": "sha512-RS1osvAicj9+MjCc6jAcVL1Pt3tg7NK2C2gXM5nqD1Gs0klF2kj5nnAFSBy97JrtslMIQzpb7iSuxaG8rFWd2A==", "optional": true, "requires": { "abort-controller": "^3.0.0", @@ -3359,7 +3368,7 @@ }, "globby": { "version": "5.0.0", - "resolved": "https://registry.npmjs.org/globby/-/globby-5.0.0.tgz", + "resolved": "http://registry.npmjs.org/globby/-/globby-5.0.0.tgz", "integrity": "sha1-69hGZ8oNuzMLmbz8aOrCvFQ3Dg0=", "dev": true, "requires": { @@ -4575,7 +4584,7 @@ }, "istextorbinary": { "version": "1.0.2", - "resolved": "https://registry.npmjs.org/istextorbinary/-/istextorbinary-1.0.2.tgz", + "resolved": "http://registry.npmjs.org/istextorbinary/-/istextorbinary-1.0.2.tgz", "integrity": "sha1-rOGTVNGpoBc+/rEITOD4ewrX3s8=", "dev": true, "requires": { @@ -6091,9 +6100,9 @@ } }, "p-limit": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.0.1.tgz", - "integrity": "sha512-mw/p92EyOzl2MhauKodw54Rx5ZK4624rNfgNaBguFZkHzyUG9WsDzFF5/yQVEJinbJDdP4jEfMN+uBquiGnaLg==", + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.0.2.tgz", + "integrity": "sha512-iwqZSOoWIW+Ew4kAGUlN16J4M7OB3ysMLSZtnhmqx7njIHFPlxWBX8xo3lVTyFVq6mI/lL9qt2IsN1sHwaxJkg==", "optional": true, "requires": { "p-try": "^2.0.0" @@ -6206,7 +6215,7 @@ }, "path-is-absolute": { "version": "1.0.1", - "resolved": "https://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", + "resolved": "http://registry.npmjs.org/path-is-absolute/-/path-is-absolute-1.0.1.tgz", "integrity": "sha1-F0uSaHNVNP+8es5r9TpanhtcX18=", "dev": true }, @@ -6362,7 +6371,7 @@ }, "pretty-hrtime": { "version": "1.0.3", - "resolved": "https://registry.npmjs.org/pretty-hrtime/-/pretty-hrtime-1.0.3.tgz", + "resolved": "http://registry.npmjs.org/pretty-hrtime/-/pretty-hrtime-1.0.3.tgz", "integrity": "sha1-t+PqQkNaTJsnWdmeDyAesZWALuE=", "dev": true }, @@ -6908,7 +6917,7 @@ }, "safe-regex": { "version": "1.1.0", - "resolved": "https://registry.npmjs.org/safe-regex/-/safe-regex-1.1.0.tgz", + "resolved": "http://registry.npmjs.org/safe-regex/-/safe-regex-1.1.0.tgz", "integrity": "sha1-QKNmnzsHfR6UPURinhV91IAjvy4=", "dev": true, "requires": { @@ -7687,7 +7696,7 @@ }, "textextensions": { "version": "1.0.2", - "resolved": "https://registry.npmjs.org/textextensions/-/textextensions-1.0.2.tgz", + "resolved": "http://registry.npmjs.org/textextensions/-/textextensions-1.0.2.tgz", "integrity": "sha1-ZUhjk+4fK7A5pgy7oFsLaL2VAdI=", "dev": true }, diff --git a/src/index.d.ts b/src/index.d.ts index 3df285bed9..42f25b8977 100644 --- a/src/index.d.ts +++ b/src/index.d.ts @@ -1068,8 +1068,8 @@ declare namespace admin.remoteConfig { * The `nextPageToken` value returned from a previous list versions request, if any. */ pageToken?: string; - - /** + + /** * Specifies the newest version number to include in the results. * If specified, must be greater than zero. Defaults to the newest version. */ @@ -1126,9 +1126,9 @@ declare namespace admin.remoteConfig { /** * Gets the requested version of the {@link admin.remoteConfig.RemoteConfigTemplate * `RemoteConfigTemplate`} of the project. - * + * * @param versionNumber Version number of the Remote Config template to look up. - * + * * @return A promise that fulfills with a `RemoteConfigTemplate`. */ getTemplateAtVersion(versionNumber: number | string): Promise; @@ -1161,7 +1161,7 @@ declare namespace admin.remoteConfig { * Rolls back a project's published Remote Config template to the specified version. * A rollback is equivalent to getting a previously published Remote Config * template and re-publishing it using a force update. - * + * * @param versionNumber The version number of the Remote Config template to roll back to. * The specified version number must be lower than the current version number, and not have * been deleted due to staleness. Only the last 300 versions are stored. @@ -1172,11 +1172,11 @@ declare namespace admin.remoteConfig { rollback(versionNumber: string | number): Promise; /** - * Gets a list of Remote Config template versions that have been published, sorted in reverse + * Gets a list of Remote Config template versions that have been published, sorted in reverse * chronological order. Only the last 300 versions are stored. - * All versions that correspond to non-active Remote Config templates (that is, all except the + * All versions that correspond to non-active Remote Config templates (that is, all except the * template that is being fetched by clients) are also deleted if they are more than 90 days old. - * + * * @param options Optional {@link admin.remoteConfig.ListVersionsOptions `ListVersionsOptions`} * object for getting a list of template versions. * @return A promise that fulfills with a `ListVersionsResult`. @@ -1317,12 +1317,18 @@ declare namespace admin.machineLearning { /** * Wait for the model to be unlocked. * - * @param {number} maxTimeSeconds The maximum time in seconds to wait. + * @param {number} maxTimeMillis The maximum time in milliseconds to wait. + * If not specified, a default maximum of 2 minutes is used. * * @return {Promise} A promise that resolves when the model is unlocked * or the maximum wait time has passed. */ - waitForUnlocked(maxTimeSeconds?: number): Promise; + waitForUnlocked(maxTimeMillis?: number): Promise; + + /** + * Return the model as a JSON object. + */ + toJSON(): {[key: string]: any}; /** Metadata about the model's TensorFlow Lite model file. */ readonly tfliteModel?: TFLiteModel; diff --git a/src/machine-learning/machine-learning-api-client.ts b/src/machine-learning/machine-learning-api-client.ts index 56b47118c7..7b48fdfcd1 100644 --- a/src/machine-learning/machine-learning-api-client.ts +++ b/src/machine-learning/machine-learning-api-client.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { HttpRequestConfig, HttpClient, HttpError, AuthorizedHttpClient } from '../utils/api-request'; +import { HttpRequestConfig, HttpClient, HttpError, AuthorizedHttpClient, ExponentialBackoffPoller } from '../utils/api-request'; import { PrefixedFirebaseError } from '../utils/error'; import { FirebaseMachineLearningError, MachineLearningErrorCode } from './machine-learning-utils'; import * as utils from '../utils/index'; @@ -26,6 +26,11 @@ const FIREBASE_VERSION_HEADER = { 'X-Firebase-Client': 'fire-admin-node/', }; +// Operation polling defaults +const POLL_DEFAULT_MAX_TIME_MILLISECONDS = 120000; // Maximum overall 2 minutes +const POLL_BASE_WAIT_TIME_MILLISECONDS = 3000; // Start with 3 second delay +const POLL_MAX_WAIT_TIME_MILLISECONDS = 30000; // Maximum 30 second delay + export interface StatusErrorResponse { readonly code: number; readonly message: string; @@ -71,6 +76,7 @@ export interface ModelResponse extends ModelContent { readonly updateTime: string; readonly etag: string; readonly modelHash?: string; + readonly activeOperations?: OperationResponse[]; } export interface ListModelsResponse { @@ -80,6 +86,7 @@ export interface ListModelsResponse { export interface OperationResponse { readonly name?: string; + readonly metadata?: {[key: string]: any}; readonly done: boolean; readonly error?: StatusErrorResponse; readonly response?: ModelResponse; @@ -112,7 +119,7 @@ export class MachineLearningApiClient { const err = new FirebaseMachineLearningError('invalid-argument', 'Invalid model content.'); return Promise.reject(err); } - return this.getUrl() + return this.getProjectUrl() .then((url) => { const request: HttpRequestConfig = { method: 'POST', @@ -130,7 +137,7 @@ export class MachineLearningApiClient { const err = new FirebaseMachineLearningError('invalid-argument', 'Invalid model or mask content.'); return Promise.reject(err); } - return this.getUrl() + return this.getProjectUrl() .then((url) => { const request: HttpRequestConfig = { method: 'PATCH', @@ -141,14 +148,20 @@ export class MachineLearningApiClient { }); } - public getModel(modelId: string): Promise { return Promise.resolve() .then(() => { return this.getModelName(modelId); }) .then((modelName) => { - return this.getResource(modelName); + return this.getResourceWithShortName(modelName); + }); + } + + public getOperation(operationName: string): Promise { + return Promise.resolve() + .then(() => { + return this.getResourceWithFullName(operationName); }); } @@ -177,7 +190,7 @@ export class MachineLearningApiClient { 'invalid-argument', 'Next page token must be a non-empty string.'); return Promise.reject(err); } - return this.getUrl() + return this.getProjectUrl() .then((url) => { const request: HttpRequestConfig = { method: 'GET', @@ -189,7 +202,7 @@ export class MachineLearningApiClient { } public deleteModel(modelId: string): Promise { - return this.getUrl() + return this.getProjectUrl() .then((url) => { const modelName = this.getModelName(modelId); const request: HttpRequestConfig = { @@ -200,15 +213,93 @@ export class MachineLearningApiClient { }); } + /** + * Handles a Long Running Operation coming back from the server. + * + * @param op The operation to handle + * @param options The options for polling + */ + public handleOperation( + op: OperationResponse, + options?: { + wait?: boolean; + maxTimeMillis?: number; + baseWaitMillis?: number; + maxWaitMillis?: number; + }): + Promise { + if (op.done) { + if (op.response) { + return Promise.resolve(op.response); + } else if (op.error) { + const err = FirebaseMachineLearningError.fromOperationError( + op.error.code, op.error.message); + return Promise.reject(err); + } + + // Done operations must have either a response or an error. + throw new FirebaseMachineLearningError('invalid-server-response', + 'Invalid operation response.'); + } + + // Operation is not done + if (options?.wait) { + return this.pollOperationWithExponentialBackoff(op.name!, options); + } + + const metadata = op.metadata || {}; + const metadataType: string = metadata['@type'] || ''; + if (!metadataType.includes('ModelOperationMetadata')) { + throw new FirebaseMachineLearningError('invalid-server-response', + `Unknown Metadata type: ${JSON.stringify(metadata)}`); + } + + return this.getModel(extractModelId(metadata.name)); + } + + // baseWaitMillis and maxWaitMillis should only ever be modified by unit tests to run faster. + private pollOperationWithExponentialBackoff( + opName: string, + options?: { + maxTimeMillis?: number; + baseWaitMillis?: number; + maxWaitMillis?: number; + }): Promise { + + const maxTimeMilliseconds = options?.maxTimeMillis ?? POLL_DEFAULT_MAX_TIME_MILLISECONDS; + const baseWaitMillis = options?.baseWaitMillis ?? POLL_BASE_WAIT_TIME_MILLISECONDS; + const maxWaitMillis = options?.maxWaitMillis ?? POLL_MAX_WAIT_TIME_MILLISECONDS; + + const poller = new ExponentialBackoffPoller( + baseWaitMillis, + maxWaitMillis, + maxTimeMilliseconds); + + return poller.poll(() => { + return this.getOperation(opName) + .then((responseData: {[key: string]: any}) => { + if (!responseData.done) { + return null; + } + if (responseData.error) { + const err = FirebaseMachineLearningError.fromOperationError( + responseData.error.code, responseData.error.message); + throw err; + } + return responseData.response; + }); + }); + } + /** * Gets the specified resource from the ML API. Resource names must be the short names without project * ID prefix (e.g. `models/123456789`). * - * @param {string} name Full qualified name of the resource to get. + * @param {string} name Short name of the resource to get. e.g. 'models/12345' * @returns {Promise} A promise that fulfills with the resource. */ - private getResource(name: string): Promise { - return this.getUrl() + private getResourceWithShortName(name: string): Promise { + return this.getProjectUrl() .then((url) => { const request: HttpRequestConfig = { method: 'GET', @@ -218,6 +309,20 @@ export class MachineLearningApiClient { }); } + /** + * Gets the specified resource from the ML API. Resource names must be the full names including project + * number prefix. + * @param fullName Full resource name of the resource to get. e.g. projects/123465/operations/987654 + * @returns {Promise} A promise that fulfulls with the resource. + */ + private getResourceWithFullName(fullName: string): Promise { + const request: HttpRequestConfig = { + method: 'GET', + url: `${ML_V1BETA2_API}/${fullName}` + }; + return this.sendRequest(request); + } + private sendRequest(request: HttpRequestConfig): Promise { request.headers = FIREBASE_VERSION_HEADER; return this.httpClient.send(request) @@ -250,7 +355,7 @@ export class MachineLearningApiClient { return new FirebaseMachineLearningError(code, message); } - private getUrl(): Promise { + private getProjectUrl(): Promise { return this.getProjectIdPrefix() .then((projectIdPrefix) => { return `${ML_V1BETA2_API}/${projectIdPrefix}`; @@ -309,3 +414,7 @@ const ERROR_CODE_MAPPING: {[key: string]: MachineLearningErrorCode} = { UNAUTHENTICATED: 'authentication-error', UNKNOWN: 'unknown-error', }; + +function extractModelId(resourceName: string): string { + return resourceName.split('/').pop()!; +} diff --git a/src/machine-learning/machine-learning.ts b/src/machine-learning/machine-learning.ts index 9d2330aa92..4b474e374f 100644 --- a/src/machine-learning/machine-learning.ts +++ b/src/machine-learning/machine-learning.ts @@ -16,8 +16,8 @@ import { FirebaseApp } from '../firebase-app'; import { FirebaseServiceInterface, FirebaseServiceInternalsInterface } from '../firebase-service'; -import { MachineLearningApiClient, ModelResponse, OperationResponse, - ModelOptions, ModelUpdateOptions, ListModelsOptions } from './machine-learning-api-client'; +import { MachineLearningApiClient, ModelResponse, ModelOptions, + ModelUpdateOptions, ListModelsOptions } from './machine-learning-api-client'; import { FirebaseError } from '../utils/error'; import * as validator from '../utils/validator'; @@ -92,7 +92,8 @@ export class MachineLearning implements FirebaseServiceInterface { public createModel(model: ModelOptions): Promise { return this.signUrlIfPresent(model) .then((modelContent) => this.client.createModel(modelContent)) - .then((operation) => handleOperation(operation)); + .then((operation) => this.client.handleOperation(operation)) + .then((modelResponse) => new Model(modelResponse, this.client)); } /** @@ -107,7 +108,8 @@ export class MachineLearning implements FirebaseServiceInterface { const updateMask = utils.generateUpdateMask(model); return this.signUrlIfPresent(model) .then((modelContent) => this.client.updateModel(modelId, modelContent, updateMask)) - .then((operation) => handleOperation(operation)); + .then((operation) => this.client.handleOperation(operation)) + .then((modelResponse) => new Model(modelResponse, this.client)); } /** @@ -141,7 +143,7 @@ export class MachineLearning implements FirebaseServiceInterface { */ public getModel(modelId: string): Promise { return this.client.getModel(modelId) - .then((modelResponse) => new Model(modelResponse)); + .then((modelResponse) => new Model(modelResponse, this.client)); } /** @@ -164,7 +166,7 @@ export class MachineLearning implements FirebaseServiceInterface { } let models: Model[] = []; if (resp.models) { - models = resp.models.map((rs) => new Model(rs)); + models = resp.models.map((rs) => new Model(rs, this.client)); } const result: ListModelsResult = { models }; if (resp.nextPageToken) { @@ -187,7 +189,8 @@ export class MachineLearning implements FirebaseServiceInterface { const updateMask = ['state.published']; const options: ModelUpdateOptions = { state: { published: publish } }; return this.client.updateModel(modelId, options, updateMask) - .then((operation) => handleOperation(operation)); + .then((operation) => this.client.handleOperation(operation)) + .then((modelResponse) => new Model(modelResponse, this.client)); } private signUrlIfPresent(options: ModelOptions): Promise { @@ -229,68 +232,138 @@ export class MachineLearning implements FirebaseServiceInterface { } } - /** * A Firebase ML Model output object. */ export class Model { - public readonly modelId: string; - public readonly displayName: string; - public readonly tags?: string[]; - public readonly createTime: string; - public readonly updateTime: string; - public readonly validationError?: string; - public readonly published: boolean; - public readonly etag: string; - public readonly modelHash?: string; - - public readonly tfliteModel?: TFLiteModel; - - constructor(model: ModelResponse) { - if (!validator.isNonNullObject(model) || - !validator.isNonEmptyString(model.name) || - !validator.isNonEmptyString(model.createTime) || - !validator.isNonEmptyString(model.updateTime) || - !validator.isNonEmptyString(model.displayName) || - !validator.isNonEmptyString(model.etag)) { - throw new FirebaseMachineLearningError( - 'invalid-server-response', - `Invalid Model response: ${JSON.stringify(model)}`); - } + private model: ModelResponse; + private readonly client?: MachineLearningApiClient; + + constructor(model: ModelResponse, client: MachineLearningApiClient) { + this.model = Model.validateAndClone(model); + this.client = client; + } + + get modelId(): string { + return extractModelId(this.model.name); + } + + get displayName(): string { + return this.model.displayName!; + } + + get tags(): string[] { + return this.model.tags || []; + } + + get createTime(): string { + return new Date(this.model.createTime).toUTCString(); + } + + get updateTime(): string { + return new Date(this.model.updateTime).toUTCString(); + } + + get validationError(): string | undefined { + return this.model.state?.validationError?.message; + } + + get published(): boolean { + return this.model.state?.published || false; + } - this.modelId = extractModelId(model.name); - this.displayName = model.displayName; - this.tags = model.tags || []; - this.createTime = new Date(model.createTime).toUTCString(); - this.updateTime = new Date(model.updateTime).toUTCString(); - if (model.state?.validationError?.message) { - this.validationError = model.state?.validationError?.message; + get etag(): string { + return this.model.etag; + } + + get modelHash(): string | undefined { + return this.model.modelHash; + } + + get tfliteModel(): TFLiteModel | undefined { + // Make a copy so people can't directly modify the private this.model object. + return deepCopy(this.model.tfliteModel); + } + + /** + * Locked indicates if there are active long running operations on the model. + * Models may not be modified when they are locked. + */ + public get locked(): boolean { + return (this.model.activeOperations?.length ?? 0) > 0; + } + + public toJSON(): {[key: string]: any} { + // We can't just return this.model because it has extra fields and + // different formats etc. So we build the expected model object. + const jsonModel: {[key: string]: any} = { + modelId: this.modelId, + displayName: this.displayName, + tags: this.tags, + createTime: this.createTime, + updateTime: this.updateTime, + published: this.published, + etag: this.etag, + locked: this.locked, + }; + + // Also add possibly undefined fields if they exist. + + if (this.validationError) { + jsonModel['validationError'] = this.validationError; } - this.published = model.state?.published || false; - this.etag = model.etag; - if (model.modelHash) { - this.modelHash = model.modelHash; + + if (this.modelHash) { + jsonModel['modelHash'] = this.modelHash; } - if (model.tfliteModel) { - this.tfliteModel = { - gcsTfliteUri: model.tfliteModel.gcsTfliteUri, - sizeBytes: model.tfliteModel.sizeBytes, - }; + + if (this.tfliteModel) { + jsonModel['tfliteModel'] = this.tfliteModel; } - } - public get locked(): boolean { - // Backend does not currently return locked models. - // This will likely change in future. - return false; + return jsonModel; } - /* eslint-disable-next-line @typescript-eslint/no-unused-vars */ - public waitForUnlocked(maxTimeSeconds?: number): Promise { + + /** + * Wait for the active operations on the model to complete. + * @param maxTimeMillis The number of milliseconds to wait for the model to be unlocked. If unspecified, + * a default will be used. + */ + public waitForUnlocked(maxTimeMillis?: number): Promise { // Backend does not currently return locked models. // This will likely change in future. + if ((this.model.activeOperations?.length ?? 0) > 0) { + // The client will always be defined on Models that have activeOperations + // because models with active operations came back from the server and + // were constructed with a non-empty client. + return this.client!.handleOperation(this.model.activeOperations![0], { wait: true, maxTimeMillis }) + .then((modelResponse) => { + this.model = Model.validateAndClone(modelResponse); + }); + } return Promise.resolve(); } + + private static validateAndClone(model: ModelResponse): ModelResponse { + if (!validator.isNonNullObject(model) || + !validator.isNonEmptyString(model.name) || + !validator.isNonEmptyString(model.createTime) || + !validator.isNonEmptyString(model.updateTime) || + !validator.isNonEmptyString(model.displayName) || + !validator.isNonEmptyString(model.etag)) { + throw new FirebaseMachineLearningError( + 'invalid-server-response', + `Invalid Model response: ${JSON.stringify(model)}`); + } + + const tmpModel = deepCopy(model); + // Remove '@type' field. We don't need it. + if ((tmpModel as any)["@type"]) { + delete (tmpModel as any)["@type"]; + } + return tmpModel; + } } /** @@ -305,19 +378,3 @@ export interface TFLiteModel { function extractModelId(resourceName: string): string { return resourceName.split('/').pop()!; } - -function handleOperation(op: OperationResponse): Model { - // Backend currently does not return operations that are not done. - if (op.done) { - // Done operations must have either a response or an error. - if (op.response) { - return new Model(op.response); - } else if (op.error) { - throw FirebaseMachineLearningError.fromOperationError( - op.error.code, op.error.message); - } - } - throw new FirebaseMachineLearningError( - 'invalid-server-response', - `Invalid Operation response: ${JSON.stringify(op)}`); -} diff --git a/src/project-management/project-management-api-request.ts b/src/project-management/project-management-api-request.ts index c7528f9157..c5b9cc1859 100644 --- a/src/project-management/project-management-api-request.ts +++ b/src/project-management/project-management-api-request.ts @@ -275,7 +275,7 @@ export class ProjectManagementRequestHandler { private pollRemoteOperationWithExponentialBackoff( operationResourceName: string): Promise { - const poller = new ExponentialBackoffPoller(); + const poller = new ExponentialBackoffPoller(); return poller.poll(() => { return this.invokeRequestHandler('GET', operationResourceName, /* requestData */ null) diff --git a/src/utils/api-request.ts b/src/utils/api-request.ts index 03f52d0240..3b83221d5f 100644 --- a/src/utils/api-request.ts +++ b/src/utils/api-request.ts @@ -909,15 +909,15 @@ export class ApiSettings { * }); * ``` */ -export class ExponentialBackoffPoller extends EventEmitter { +export class ExponentialBackoffPoller extends EventEmitter { private numTries = 0; private completed = false; private masterTimer: NodeJS.Timer; private repollTimer: NodeJS.Timer; - private pollCallback?: () => Promise; - private resolve: (result: object) => void; + private pollCallback?: () => Promise; + private resolve: (result: T) => void; private reject: (err: object) => void; constructor( @@ -930,13 +930,13 @@ export class ExponentialBackoffPoller extends EventEmitter { /** * Poll the provided callback with exponential backoff. * - * @param {() => Promise} callback The callback to be called for each poll. If the + * @param {() => Promise} callback The callback to be called for each poll. If the * callback resolves to a falsey value, polling will continue. Otherwise, the truthy * resolution will be used to resolve the promise returned by this method. - * @return {Promise} A Promise which resolves to the truthy value returned by the provided + * @return {Promise} A Promise which resolves to the truthy value returned by the provided * callback when polling is complete. */ - public poll(callback: () => Promise): Promise { + public poll(callback: () => Promise): Promise { if (this.pollCallback) { throw new Error('poll() can only be called once per instance of ExponentialBackoffPoller'); } @@ -953,7 +953,7 @@ export class ExponentialBackoffPoller extends EventEmitter { this.reject(new Error('ExponentialBackoffPoller deadline exceeded - Master timeout reached')); }, this.masterTimeoutMillis); - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { this.resolve = resolve; this.reject = reject; this.repoll(); diff --git a/test/integration/machine-learning.spec.ts b/test/integration/machine-learning.spec.ts index a679c31111..9917d89a4f 100644 --- a/test/integration/machine-learning.spec.ts +++ b/test/integration/machine-learning.spec.ts @@ -74,8 +74,9 @@ describe('admin.machineLearning', () => { describe('createModel()', () => { it('creates a new Model without ModelFormat', () => { const modelOptions: admin.machineLearning.ModelOptions = { - displayName: 'node-integration-test-create-1', - tags: ['tag123', 'tag345'] }; + displayName: 'node-integ-test-create-1', + tags: ['tag123', 'tag345'] + }; return admin.machineLearning().createModel(modelOptions) .then((model) => { scheduleForDelete(model); @@ -85,7 +86,7 @@ describe('admin.machineLearning', () => { it('creates a new Model with valid ModelFormat', () => { const modelOptions: admin.machineLearning.ModelOptions = { - displayName: 'node-integration-test-create-2', + displayName: 'node-integ-test-create-2', tags: ['tag234', 'tag456'], tfliteModel: { gcsTfliteUri: 'this will be replaced below' }, }; @@ -103,7 +104,7 @@ describe('admin.machineLearning', () => { it('creates a new Model with invalid ModelFormat', () => { // Upload a file to default gcs bucket const modelOptions: admin.machineLearning.ModelOptions = { - displayName: 'node-integration-test-create-3', + displayName: 'node-integ-test-create-3', tags: ['tag234', 'tag456'], tfliteModel: { gcsTfliteUri: 'this will be replaced below' }, }; @@ -150,15 +151,15 @@ describe('admin.machineLearning', () => { const modelOptions: admin.machineLearning.ModelOptions = { displayName: 'Invalid Name#*^!', }; - return createTemporaryModel({ displayName: 'node-integration-invalid-arg' }) + return createTemporaryModel({ displayName: 'node-integ-invalid-argument' }) .then((model) => admin.machineLearning().updateModel(model.modelId, modelOptions) .should.eventually.be.rejected.and.have.property( 'code', 'machine-learning/invalid-argument')); }); it('updates the displayName', () => { - const DISPLAY_NAME = 'node-integration-test-update-1b'; - return createTemporaryModel({ displayName: 'node-integration-test-update-1a' }) + const DISPLAY_NAME = 'node-integ-test-update-1b'; + return createTemporaryModel({ displayName: 'node-integ-test-update-1a' }) .then((model) => { const modelOptions: admin.machineLearning.ModelOptions = { displayName: DISPLAY_NAME, @@ -175,7 +176,7 @@ describe('admin.machineLearning', () => { const NEW_TAGS = ['tag-node-update-2', 'tag-node-update-3']; return createTemporaryModel({ - displayName: 'node-integration-test-update-2', + displayName: 'node-integ-test-update-2', tags: ORIGINAL_TAGS, }).then((expectedModel) => { const modelOptions: admin.machineLearning.ModelOptions = { @@ -205,9 +206,9 @@ describe('admin.machineLearning', () => { }); it('can update more than 1 field', () => { - const DISPLAY_NAME = 'node-integration-test-update-3b'; - const TAGS = ['node-integration-tag-1', 'node-integration-tag-2']; - return createTemporaryModel({ displayName: 'node-integration-test-update-3a' }) + const DISPLAY_NAME = 'node-integ-test-update-3b'; + const TAGS = ['node-integ-tag-1', 'node-integ-tag-2']; + return createTemporaryModel({ displayName: 'node-integ-test-update-3a' }) .then((model) => { const modelOptions: admin.machineLearning.ModelOptions = { displayName: DISPLAY_NAME, @@ -238,7 +239,7 @@ describe('admin.machineLearning', () => { it('publishes the model successfully', () => { const modelOptions: admin.machineLearning.ModelOptions = { - displayName: 'node-integration-test-publish-1', + displayName: 'node-integ-test-publish-1', tfliteModel: { gcsTfliteUri: 'this will be replaced below' }, }; return uploadModelToGcs('model1.tflite', 'valid_model.tflite') @@ -273,7 +274,7 @@ describe('admin.machineLearning', () => { it('unpublishes the model successfully', () => { const modelOptions: admin.machineLearning.ModelOptions = { - displayName: 'node-integration-test-unpublish1', + displayName: 'node-integ-test-unpublish-1', tfliteModel: { gcsTfliteUri: 'this will be replaced below' }, }; return uploadModelToGcs('model1.tflite', 'valid_model.tflite') @@ -330,16 +331,16 @@ describe('admin.machineLearning', () => { before(() => { return Promise.all([ admin.machineLearning().createModel({ - displayName: 'node-integration-list1', - tags: ['node-integration-tag-1'], + displayName: 'node-integ-list1', + tags: ['node-integ-tag-1'], }), admin.machineLearning().createModel({ - displayName: 'node-integration-list2', - tags: ['node-integration-tag-1'], + displayName: 'node-integ-list2', + tags: ['node-integ-tag-1'], }), admin.machineLearning().createModel({ - displayName: 'node-integration-list3', - tags: ['node-integration-tag-1'], + displayName: 'node-integ-list3', + tags: ['node-integ-tag-1'], })]) .then(([m1, m2, m3]: admin.machineLearning.Model[]) => { model1 = m1; @@ -370,12 +371,12 @@ describe('admin.machineLearning', () => { return admin.machineLearning().listModels({ pageSize: 2 }) .then((modelList) => { expect(modelList.models.length).to.equal(2); - expect(modelList.pageToken).not.to.be.undefined; + expect(modelList.pageToken).not.to.be.empty; }); }); it('filters by exact displayName', () => { - return admin.machineLearning().listModels({ filter: 'displayName=node-integration-list1' }) + return admin.machineLearning().listModels({ filter: 'displayName=node-integ-list1' }) .then((modelList) => { expect(modelList.models.length).to.equal(1); expect(modelList.models[0]).to.deep.equal(model1); @@ -384,7 +385,7 @@ describe('admin.machineLearning', () => { }); it('filters by displayName prefix', () => { - return admin.machineLearning().listModels({ filter: 'displayName:node-integration-list*', pageSize: 100 }) + return admin.machineLearning().listModels({ filter: 'displayName:node-integ-list*', pageSize: 100 }) .then((modelList) => { expect(modelList.models.length).to.be.at.least(3); expect(modelList.models).to.deep.include(model1); @@ -395,7 +396,7 @@ describe('admin.machineLearning', () => { }); it('filters by tag', () => { - return admin.machineLearning().listModels({ filter: 'tags:node-integration-tag-1', pageSize: 100 }) + return admin.machineLearning().listModels({ filter: 'tags:node-integ-tag-1', pageSize: 100 }) .then((modelList) => { expect(modelList.models.length).to.be.at.least(3); expect(modelList.models).to.deep.include(model1); @@ -406,14 +407,15 @@ describe('admin.machineLearning', () => { }); it('handles pageTokens properly', () => { - return admin.machineLearning().listModels({ filter: 'displayName:node-integration-list*', pageSize: 2 }) + return admin.machineLearning().listModels({ filter: 'displayName:node-integ-list*', pageSize: 2 }) .then((modelList) => { expect(modelList.models.length).to.equal(2); - expect(modelList.pageToken).not.to.be.empty; + expect(modelList.pageToken).not.to.be.undefined; return admin.machineLearning().listModels({ - filter: 'displayName:node-integration-list*', + filter: 'displayName:node-integ-list*', pageSize: 2, - pageToken: modelList.pageToken }) + pageToken: modelList.pageToken + }) .then((modelList2) => { expect(modelList2.models.length).to.be.at.least(1); expect(modelList2.pageToken).to.be.undefined; diff --git a/test/unit/machine-learning/machine-learning-api-client.spec.ts b/test/unit/machine-learning/machine-learning-api-client.spec.ts index 69f5f1d877..320142e872 100644 --- a/test/unit/machine-learning/machine-learning-api-client.spec.ts +++ b/test/unit/machine-learning/machine-learning-api-client.spec.ts @@ -64,9 +64,14 @@ describe('MachineLearningApiClient', () => { }, }; + const PROJECT_ID = 'test-project'; + const PROJECT_NUMBER = '1234567'; + const OPERATION_ID = '987654'; + const OPERATION_NAME = `projects/${PROJECT_NUMBER}/operations/${OPERATION_ID}`; + const STATUS_ERROR_MESSAGE = 'Invalid Argument message' const STATUS_ERROR_RESPONSE = { code: 3, - message: 'Invalid Argument message', + message: STATUS_ERROR_MESSAGE, }; const OPERATION_SUCCESS_RESPONSE = { done: true, @@ -76,6 +81,30 @@ describe('MachineLearningApiClient', () => { done: true, error: STATUS_ERROR_RESPONSE, }; + const OPERATION_NOT_DONE_RESPONSE = { + name: OPERATION_NAME, + metadata: { + '@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata', + name: `projects/${PROJECT_ID}/models/${MODEL_ID}`, + basicOperationStatus: 'BASIC_OPERATION_STATUS_UPLOADING' + }, + done: false, + }; + const LOCKED_MODEL_RESPONSE = { + name: 'projects/test-project/models/1234567', + createTime: '2020-02-07T23:45:23.288047Z', + updateTime: '2020-02-08T23:45:23.288047Z', + etag: 'etag123', + modelHash: 'modelHash123', + displayName: 'model_1', + tags: ['tag_1', 'tag_2'], + activeOperations: [OPERATION_NOT_DONE_RESPONSE], + state: { published: true }, + tfliteModel: { + gcsTfliteUri: 'gs://test-project-bucket/Firebase/ML/Models/model1.tflite', + sizeBytes: 16900988, + }, + }; const ERROR_RESPONSE = { error: { @@ -386,6 +415,153 @@ describe('MachineLearningApiClient', () => { }); }); + describe('getOperation', () => { + it('should resolve with the requested operation on success', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .resolves(utils.responseFrom(OPERATION_SUCCESS_RESPONSE)); + stubs.push(stub); + return apiClient.getOperation(OPERATION_NAME) + .then((resp) => { + expect(resp).to.deep.equal(OPERATION_SUCCESS_RESPONSE); + expect(stub).to.have.been.calledOnce.and.calledWith({ + method: 'GET', + url: `${BASE_URL}/projects/${PROJECT_NUMBER}/operations/${OPERATION_ID}`, + headers: EXPECTED_HEADERS, + }); + }); + }); + + it('should reject when a full platform error response is received', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .rejects(utils.errorFrom(ERROR_RESPONSE, 404)); + stubs.push(stub); + const expected = new FirebaseMachineLearningError('not-found', 'Requested entity not found'); + return apiClient.getOperation(OPERATION_NAME) + .should.eventually.be.rejected.and.deep.include(expected); + }); + + it('should reject with unknown-error when error code is not present', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .rejects(utils.errorFrom({}, 404)); + stubs.push(stub); + const expected = new FirebaseMachineLearningError('unknown-error', 'Unknown server error: {}'); + return apiClient.getOperation(OPERATION_NAME) + .should.eventually.be.rejected.and.deep.include(expected); + }); + + it('should reject with unknown-error for non-json response', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .rejects(utils.errorFrom('not json', 404)); + stubs.push(stub); + const expected = new FirebaseMachineLearningError( + 'unknown-error', 'Unexpected response with status: 404 and body: not json'); + return apiClient.getOperation(OPERATION_NAME) + .should.eventually.be.rejected.and.deep.include(expected); + }); + + it('should reject when failed with a FirebaseAppError', () => { + const expected = new FirebaseAppError('network-error', 'socket hang up'); + const stub = sinon + .stub(HttpClient.prototype, 'send') + .rejects(expected); + stubs.push(stub); + return apiClient.getOperation(OPERATION_NAME) + .should.eventually.be.rejected.and.deep.include(expected); + }); + }); + + describe('handleOperation', () => { + it('handles a done operation with result', () => { + return apiClient.handleOperation(OPERATION_SUCCESS_RESPONSE) + .then((resp) => { + expect(resp).deep.equals(MODEL_RESPONSE); + }); + }); + + it('handles a done operation with error', () => { + const expected = new FirebaseMachineLearningError('invalid-argument', STATUS_ERROR_MESSAGE); + return apiClient.handleOperation(OPERATION_ERROR_RESPONSE) + .should.eventually.be.rejected.and.deep.include(expected); + }); + + it('handles a running operation with no wait', () => { + const stub = sinon + .stub(HttpClient.prototype, 'send') + .resolves(utils.responseFrom(LOCKED_MODEL_RESPONSE)); + stubs.push(stub); + return apiClient.handleOperation(OPERATION_NOT_DONE_RESPONSE) + .then((resp) => { + expect(resp).to.deep.equal(LOCKED_MODEL_RESPONSE); + expect(stub).to.have.been.calledOnce.and.calledWith({ + method: 'GET', + url: `${BASE_URL}/projects/${PROJECT_ID}/models/${MODEL_ID}`, + headers: EXPECTED_HEADERS, + }); + }); + }); + + it('handles a running operation with wait', () => { + const stub = sinon.stub(HttpClient.prototype, 'send'); + stub.onCall(0).resolves(utils.responseFrom(OPERATION_NOT_DONE_RESPONSE)); + stub.onCall(1).resolves(utils.responseFrom(OPERATION_SUCCESS_RESPONSE)); + stubs.push(stub); + return apiClient.handleOperation(OPERATION_NOT_DONE_RESPONSE, { + wait: true, + maxTimeMillis: 1000, + baseWaitMillis: 2, + maxWaitMillis: 5 }) + .then((resp) => { + expect(resp).to.deep.equal(MODEL_RESPONSE); + expect(stub).to.have.been.calledTwice.and.calledWith({ + method: 'GET', + url: `${BASE_URL}/projects/${PROJECT_NUMBER}/operations/${OPERATION_ID}`, + headers: EXPECTED_HEADERS, + }); + }); + }); + + it('handles a running operation with wait ending in error', () => { + const stub = sinon.stub(HttpClient.prototype, 'send'); + stub.onCall(0).resolves(utils.responseFrom(OPERATION_NOT_DONE_RESPONSE)); + stub.onCall(1).resolves(utils.responseFrom(OPERATION_ERROR_RESPONSE)); + stubs.push(stub); + const expected = new FirebaseMachineLearningError('invalid-argument', STATUS_ERROR_MESSAGE); + return apiClient.handleOperation(OPERATION_NOT_DONE_RESPONSE, { + wait: true, + maxTimeMillis: 1000, + baseWaitMillis: 2, + maxWaitMillis: 5 }) + .should.eventually.be.rejected.and.deep.include(expected) + .then(() => { + expect(stub).to.have.been.calledTwice.and.calledWith({ + method: 'GET', + url: `${BASE_URL}/projects/${PROJECT_NUMBER}/operations/${OPERATION_ID}`, + headers: EXPECTED_HEADERS, + }); + }); + }); + + it('handles a running operation with wait ending in timeout', () => { + const stub = sinon.stub(HttpClient.prototype, 'send'); + stub.onCall(0).resolves(utils.responseFrom(OPERATION_NOT_DONE_RESPONSE)); + stub.onCall(1).resolves(utils.responseFrom(OPERATION_NOT_DONE_RESPONSE)); + stub.onCall(2).resolves(utils.responseFrom(OPERATION_NOT_DONE_RESPONSE)); + stubs.push(stub); + const expected = new Error('ExponentialBackoffPoller dealine exceeded - Master timeout reached'); + return apiClient.handleOperation(OPERATION_NOT_DONE_RESPONSE, { + wait: true, + maxTimeMillis: 1000, + baseWaitMillis: 500, + maxWaitMillis: 1000 }) + .should.eventually.be.rejected.and.deep.include(expected); + }); + + }); + describe('listModels', () => { const LIST_RESPONSE = { models: [MODEL_RESPONSE, MODEL_RESPONSE2], diff --git a/test/unit/machine-learning/machine-learning.spec.ts b/test/unit/machine-learning/machine-learning.spec.ts index ece3057ba6..c95afc9d5e 100644 --- a/test/unit/machine-learning/machine-learning.spec.ts +++ b/test/unit/machine-learning/machine-learning.spec.ts @@ -23,7 +23,7 @@ import { MachineLearning, Model } from '../../../src/machine-learning/machine-le import { FirebaseApp } from '../../../src/firebase-app'; import * as mocks from '../../resources/mocks'; import { MachineLearningApiClient, StatusErrorResponse, - ModelOptions, ModelResponse } from '../../../src/machine-learning/machine-learning-api-client'; + ModelOptions, ModelResponse, OperationResponse } from '../../../src/machine-learning/machine-learning-api-client'; import { FirebaseMachineLearningError } from '../../../src/machine-learning/machine-learning-utils'; import { deepCopy } from '../../../src/utils/deep-copy'; @@ -32,6 +32,10 @@ const expect = chai.expect; describe('MachineLearning', () => { const MODEL_ID = '1234567'; + const PROJECT_ID = 'test-project'; + const PROJECT_NUMBER = '987654'; + const OPERATION_ID = '456789'; + const OPERATION_NAME = `projects/${PROJECT_NUMBER}/operations/${OPERATION_ID}` const EXPECTED_ERROR = new FirebaseMachineLearningError('internal-error', 'message'); const CREATE_TIME_UTC = 'Fri, 07 Feb 2020 23:45:23 GMT'; const UPDATE_TIME_UTC = 'Sat, 08 Feb 2020 23:45:23 GMT'; @@ -68,7 +72,7 @@ describe('MachineLearning', () => { sizeBytes: 16900988, }, }; - const MODEL1 = new Model(MODEL_RESPONSE); + const MODEL_RESPONSE2: { name: string; @@ -103,7 +107,6 @@ describe('MachineLearning', () => { sizeBytes: 22200222, }, }; - const MODEL2 = new Model(MODEL_RESPONSE2); const STATUS_ERROR_RESPONSE: { code: number; @@ -115,6 +118,7 @@ describe('MachineLearning', () => { const OPERATION_RESPONSE: { name?: string; + metadata?: any; done: boolean; error?: StatusErrorResponse; response?: { @@ -144,6 +148,7 @@ describe('MachineLearning', () => { const OPERATION_RESPONSE_ERROR: { name?: string; + metadata?: any; done: boolean; error?: { code: number; @@ -155,18 +160,79 @@ describe('MachineLearning', () => { error: STATUS_ERROR_RESPONSE, }; + const OPERATION_RESPONSE_NOT_DONE: { + name?: string; + metadata?: any; + done: boolean; + error?: { + code: number; + message: string; + }; + response?: ModelResponse; + } = { + name: OPERATION_NAME, + metadata: { + '@type': 'type.googleapis.com/google.firebase.ml.v1beta2.ModelOperationMetadata', + name: `projects/${PROJECT_ID}/models/${MODEL_ID}`, + basicOperationStatus: 'BASIC_OPERATION_STATUS_UPLOADING' + }, + done: false, + }; + + const MODEL_RESPONSE_LOCKED: { + name: string; + createTime: string; + updateTime: string; + etag: string; + modelHash: string; + displayName?: string; + tags?: string[]; + activeOperations?: OperationResponse[]; + state?: { + validationError?: { + code: number; + message: string; + }; + published?: boolean; + }; + tfliteModel?: { + gcsTfliteUri: string; + sizeBytes: number; + }; + } = { + name: 'projects/test-project/models/1234567', + createTime: '2020-02-07T23:45:23.288047Z', + updateTime: '2020-02-08T23:45:23.288047Z', + etag: 'etag123', + modelHash: 'modelHash123', + displayName: 'model_1', + tags: ['tag_1', 'tag_2'], + activeOperations: [OPERATION_RESPONSE_NOT_DONE], + state: { published: true }, + tfliteModel: { + gcsTfliteUri: 'gs://test-project-bucket/Firebase/ML/Models/model1.tflite', + sizeBytes: 16900988, + }, + }; let machineLearning: MachineLearning; let mockApp: FirebaseApp; + let mockClient: MachineLearningApiClient; let mockCredentialApp: FirebaseApp; + let model1: Model; + let model2: Model; + const stubs: sinon.SinonStub[] = []; before(() => { mockApp = mocks.app(); + mockClient = new MachineLearningApiClient(mockApp); mockCredentialApp = mocks.mockCredentialApp(); machineLearning = new MachineLearning(mockApp); + model1 = new Model(MODEL_RESPONSE, mockClient); + model2 = new Model(MODEL_RESPONSE2, mockClient); }); after(() => { @@ -229,7 +295,7 @@ describe('MachineLearning', () => { describe('Model', () => { it('should successfully construct a model', () => { - const model = new Model(MODEL_RESPONSE); + const model = new Model(MODEL_RESPONSE, mockClient); expect(model.modelId).to.equal(MODEL_ID); expect(model.displayName).to.equal('model_1'); expect(model.tags).to.deep.equal(['tag_1', 'tag_2']); @@ -245,6 +311,53 @@ describe('MachineLearning', () => { 'gs://test-project-bucket/Firebase/ML/Models/model1.tflite'); expect(tflite.sizeBytes).to.be.equal(16900988); }); + + it('should successfully serialize a model to JSON', () => { + const model = new Model(MODEL_RESPONSE, mockClient); + const expectedModel = { + modelId: MODEL_ID, + displayName: 'model_1', + tags: ['tag_1', 'tag_2'], + createTime: CREATE_TIME_UTC, + updateTime: UPDATE_TIME_UTC, + published: true, + etag: 'etag123', + locked: false, + modelHash: 'modelHash123', + tfliteModel: { + gcsTfliteUri: 'gs://test-project-bucket/Firebase/ML/Models/model1.tflite', + sizeBytes: 16900988, + } + } + const jsonString = JSON.stringify(model); + expect(JSON.parse(jsonString)).to.deep.equal(expectedModel); + }) + + it('should return locked when active operations are present', () => { + const model = new Model(MODEL_RESPONSE_LOCKED, mockClient); + expect(model.locked).to.be.true; + }); + + it('should return locked as false when no active operations are present', () => { + const model = new Model(MODEL_RESPONSE, mockClient); + expect(model.locked).to.be.false; + }); + + it('should successfully update a model from a Response', () => { + const model = new Model(MODEL_RESPONSE_LOCKED, mockClient); + expect(model.locked).to.be.true; + + const stub = sinon + .stub(MachineLearningApiClient.prototype, 'handleOperation') + .resolves(MODEL_RESPONSE2); + stubs.push(stub); + + model.waitForUnlocked() + .then(() => { + expect(model.locked).to.be.false; + expect(model).to.deep.equal(model2); + }); + }); }); describe('getModel', () => { @@ -335,7 +448,7 @@ describe('MachineLearning', () => { return machineLearning.getModel(MODEL_ID) .then((model) => { - expect(model).to.deep.equal(MODEL1); + expect(model).to.deep.equal(model1); }); }); }); @@ -377,8 +490,8 @@ describe('MachineLearning', () => { return machineLearning.listModels() .then((result) => { expect(result.models.length).equals(2); - expect(result.models[0]).to.deep.equal(MODEL1); - expect(result.models[1]).to.deep.equal(MODEL2); + expect(result.models[0]).to.deep.equal(model1); + expect(result.models[1]).to.deep.equal(model2); expect(result.pageToken).to.equal(LIST_MODELS_RESPONSE.nextPageToken); }); }); @@ -505,7 +618,7 @@ describe('MachineLearning', () => { return machineLearning.createModel(MODEL_OPTIONS_WITH_GCS) .then((model) => { - expect(model).to.deep.equal(MODEL1); + expect(model).to.deep.equal(model1); }); }); @@ -622,7 +735,7 @@ describe('MachineLearning', () => { return machineLearning.updateModel(MODEL_ID, MODEL_OPTIONS_WITH_GCS) .then((model) => { - expect(model).to.deep.equal(MODEL1); + expect(model).to.deep.equal(model1); }); }); @@ -726,7 +839,7 @@ describe('MachineLearning', () => { return machineLearning.publishModel(MODEL_ID) .then((model) => { - expect(model).to.deep.equal(MODEL1); + expect(model).to.deep.equal(model1); }); }); @@ -830,7 +943,7 @@ describe('MachineLearning', () => { return machineLearning.unpublishModel(MODEL_ID) .then((model) => { - expect(model).to.deep.equal(MODEL1); + expect(model).to.deep.equal(model1); }); });