Skip to content

Added CreateModel functionality and tests #788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5235,9 +5235,7 @@ declare namespace admin.machineLearning {
displayName?: string;
tags?: string[];

tfLiteModel?: {gcsTFLiteUri: string;};

toJSON(forUpload?: boolean): object;
tfliteModel?: {gcsTfliteUri: string;};
}

/**
Expand All @@ -5247,8 +5245,8 @@ declare namespace admin.machineLearning {
readonly modelId: string;
readonly displayName: string;
readonly tags?: string[];
readonly createTime: number;
readonly updateTime: number;
readonly createTime: string;
readonly updateTime: string;
readonly validationError?: string;
readonly published: boolean;
readonly etag: string;
Expand Down
25 changes: 25 additions & 0 deletions src/machine-learning/machine-learning-api-client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ export interface ModelResponse extends ModelContent {
readonly modelHash?: string;
}

export interface OperationResponse {
readonly name?: string;
readonly done: boolean;
readonly error?: StatusErrorResponse;
readonly response?: ModelResponse;
}


/**
* Class that facilitates sending requests to the Firebase ML backend API.
Expand All @@ -73,6 +80,24 @@ export class MachineLearningApiClient {
this.httpClient = new AuthorizedHttpClient(app);
}

public createModel(model: ModelContent): Promise<OperationResponse> {
if (!validator.isNonNullObject(model) ||
!validator.isNonEmptyString(model.displayName)) {
const err = new FirebaseMachineLearningError('invalid-argument', 'Invalid model content.');
return Promise.reject(err);
}
return this.getUrl()
.then((url) => {
const request: HttpRequestConfig = {
method: 'POST',
url: `${url}/models`,
data: model,
};
return this.sendRequest<OperationResponse>(request);
});
}


public getModel(modelId: string): Promise<ModelResponse> {
return Promise.resolve()
.then(() => {
Expand Down
32 changes: 31 additions & 1 deletion src/machine-learning/machine-learning-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,39 @@ export type MachineLearningErrorCode =
| 'not-found'
| 'resource-exhausted'
| 'service-unavailable'
| 'unknown-error';
| 'unknown-error'
| 'cancelled'
| 'deadline-exceeded'
| 'permission-denied'
| 'failed-precondition'
| 'aborted'
| 'out-of-range'
| 'data-loss'
| 'unauthenticated';

export class FirebaseMachineLearningError extends PrefixedFirebaseError {
public static fromOperationError(code: number, message: string): FirebaseMachineLearningError {
switch (code) {
case 1: return new FirebaseMachineLearningError('cancelled', message);
case 2: return new FirebaseMachineLearningError('unknown-error', message);
case 3: return new FirebaseMachineLearningError('invalid-argument', message);
case 4: return new FirebaseMachineLearningError('deadline-exceeded', message);
case 5: return new FirebaseMachineLearningError('not-found', message);
case 6: return new FirebaseMachineLearningError('already-exists', message);
case 7: return new FirebaseMachineLearningError('permission-denied', message);
case 8: return new FirebaseMachineLearningError('resource-exhausted', message);
case 9: return new FirebaseMachineLearningError('failed-precondition', message);
case 10: return new FirebaseMachineLearningError('aborted', message);
case 11: return new FirebaseMachineLearningError('out-of-range', message);
case 13: return new FirebaseMachineLearningError('internal-error', message);
case 14: return new FirebaseMachineLearningError('service-unavailable', message);
case 15: return new FirebaseMachineLearningError('data-loss', message);
case 16: return new FirebaseMachineLearningError('unauthenticated', message);
default:
return new FirebaseMachineLearningError('unknown-error', message);
}
}

constructor(code: MachineLearningErrorCode, message: string) {
super('machine-learning', code, message);
}
Expand Down
123 changes: 113 additions & 10 deletions src/machine-learning/machine-learning.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@
* limitations under the License.
*/


import {Storage as StorageClient} from '@google-cloud/storage';
import {FirebaseApp} from '../firebase-app';
import {FirebaseServiceInterface, FirebaseServiceInternalsInterface} from '../firebase-service';
import {MachineLearningApiClient, ModelResponse} from './machine-learning-api-client';
import {ServiceAccountCredential, isApplicationDefault} from '../auth/credential';
import {MachineLearningApiClient,
ModelResponse, OperationResponse, StatusErrorResponse} from './machine-learning-api-client';
import {FirebaseError} from '../utils/error';

import * as utils from '../utils/index';
import * as validator from '../utils/validator';
import {FirebaseMachineLearningError} from './machine-learning-utils';

// const ML_HOST = 'mlkit.googleapis.com';
import { deepCopy } from '../utils/deep-copy';

/**
* Internals of an ML instance.
Expand Down Expand Up @@ -62,6 +64,7 @@ export class MachineLearning implements FirebaseServiceInterface {

private readonly client: MachineLearningApiClient;
private readonly appInternal: FirebaseApp;
private readonly storageClient: StorageClient;

/**
* @param {FirebaseApp} app The app for this ML service.
Expand All @@ -76,6 +79,44 @@ export class MachineLearning implements FirebaseServiceInterface {
});
}

let storage: typeof StorageClient;
try {
storage = require('@google-cloud/storage').Storage;
} catch (err) {
throw new FirebaseError({
code: 'ml/missing-dependencies',
message:
'Failed to import the Cloud Storage client library for Node.js. ' +
'Make sure to install the "@google-cloud/storage" npm package. ' +
`Original error: ${err}`,
});
}

const projectId: string | null = utils.getExplicitProjectId(app);
const credential = app.options.credential;
if (credential instanceof ServiceAccountCredential) {
this.storageClient = new storage({
// When the SDK is initialized with ServiceAccountCredentials an
// explicit projectId is guaranteed to be available.
projectId: projectId!,
credentials: {
private_key: credential.privateKey,
client_email: credential.clientEmail,
},
});
} else if (isApplicationDefault(app.options.credential)) {
// Try to use the Google application default credentials.
this.storageClient = new storage();
} else {
throw new FirebaseError({
code: 'ml/invalid-credential',
message:
'Failed to initialize ML client with the available credential. ' +
'Must initialize the SDK with a certificate credential or ' +
'application default credentials to use Firebase ML API.',
});
}

this.appInternal = app;
this.client = new MachineLearningApiClient(app);
}
Expand All @@ -97,7 +138,9 @@ export class MachineLearning implements FirebaseServiceInterface {
* @return {Promise<Model>} A Promise fulfilled with the created model.
*/
public createModel(model: ModelOptions): Promise<Model> {
throw new Error('NotImplemented');
return convertOptionstoContent(model, true, this.storageClient)
.then((modelContent) => this.client.createModel(modelContent))
.then((operation) => handleOperation(operation));
}

/**
Expand Down Expand Up @@ -173,7 +216,7 @@ export class MachineLearning implements FirebaseServiceInterface {
}

/**
* A Firebase ML Model output object
* A Firebase ML Model output object.
*/
export class Model {
public readonly modelId: string;
Expand All @@ -196,7 +239,7 @@ export class Model {
!validator.isNonEmptyString(model.displayName) ||
!validator.isNonEmptyString(model.etag)) {
throw new FirebaseMachineLearningError(
'invalid-argument',
'invalid-server-response',
`Invalid Model response: ${JSON.stringify(model)}`);
}

Expand Down Expand Up @@ -252,13 +295,73 @@ export class ModelOptions {
public displayName?: string;
public tags?: string[];

public tfliteModel?: { gcsTFLiteUri: string; };
public tfliteModel?: { gcsTfliteUri: string; };
}

protected toJSON(forUpload?: boolean): object {
throw new Error('NotImplemented');
async function convertOptionstoContent(
options: ModelOptions, forUpload?: boolean,
storageClient?: StorageClient): Promise<object> {
const modelContent = deepCopy(options);
if (forUpload && modelContent.tfliteModel?.gcsTfliteUri) {
if (!storageClient) {
throw new FirebaseMachineLearningError(
'invalid-argument',
'Must specify storage client if forUpload and gcs Uri are specified.',
);
}
modelContent.tfliteModel.gcsTfliteUri = await signUrl(modelContent.tfliteModel.gcsTfliteUri, storageClient!);
}
return modelContent;
}

async function signUrl(unsignedUrl: string, storageClient: StorageClient): Promise<string> {
const MINUTES = 60 * 1000; // A minute in milliseconds.
const URL_VALID_DURATION = 10 * MINUTES;

const gcsRegex = /^gs:\/\/([a-z0-9_.-]{3,63})\/(.+)$/;
const matches = gcsRegex.exec(unsignedUrl);
if (!matches) {
throw new FirebaseMachineLearningError(
'invalid-argument',
`Invalid unsigned url: ${unsignedUrl}`);
}
const bucketName = matches[1];
const blobName = matches[2];
const bucket = storageClient.bucket(bucketName);
const blob = bucket.file(blobName);

try {
const signedUrl = blob.getSignedUrl({
action: 'read',
expires: Date.now() + URL_VALID_DURATION,
}).then((x) => x[0]);
return signedUrl;
} catch (err) {
throw new FirebaseMachineLearningError(
'internal-error',
`Error during signing upload url: ${err.message}`,
);
}
}

function extractModelId(resourceName: string): string {
return resourceName.split('/').pop()!;
}


function handleOperation(op: OperationResponse): Model {
if (op.done) {
if (op.response) {
return new Model(op.response);
} else if (op.error) {
handleOperationError(op.error);
}
}
throw new FirebaseMachineLearningError(
'invalid-server-response',
`Invalid Operation response: ${JSON.stringify(op)}`);
}

function handleOperationError(err: StatusErrorResponse) {
throw FirebaseMachineLearningError.fromOperationError(err.code, err.message);
}
Loading