Skip to content

Added request_options parameter to various classes and functions for canceling requests #1190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/configs.js
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ export class PretrainedConfig {
cache_dir = null,
local_files_only = false,
revision = 'main',
request_options = {},
} = {}) {
if (config && !(config instanceof PretrainedConfig)) {
config = new PretrainedConfig(config);
Expand All @@ -378,6 +379,7 @@ export class PretrainedConfig {
cache_dir,
local_files_only,
revision,
request_options
})
return new this(data);
}
Expand Down
5 changes: 5 additions & 0 deletions src/env.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ const localModelPath = RUNNING_LOCALLY
* @property {string} cacheDir The directory to use for caching files with the file system. By default, it is `./.cache`.
* @property {boolean} useCustomCache Whether to use a custom cache system (defined by `customCache`), defaults to `false`.
* @property {Object} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which
* @property {(input: RequestInfo | URL, init?: RequestInit) => Promise<Response>} customFetch A custom fetch function to use. Defaults to `null`. Note: this must be a function which
* implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache
*/

Expand Down Expand Up @@ -150,6 +151,10 @@ export const env = {
useCustomCache: false,
customCache: null,
//////////////////////////////////////////////////////

/////////////////// custom settings ///////////////////
customFetch: null,
//////////////////////////////////////////////////////
}


Expand Down
4 changes: 4 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,7 @@ export class PreTrainedModel extends Callable {
dtype = null,
use_external_data_format = null,
session_options = {},
request_options = {}
} = {}) {

let options = {
Expand All @@ -999,6 +1000,7 @@ export class PreTrainedModel extends Callable {
dtype,
use_external_data_format,
session_options,
request_options
}

const modelName = MODEL_CLASS_TO_NAME_MAPPING.get(this);
Expand Down Expand Up @@ -6999,6 +7001,7 @@ export class PretrainedMixin {
dtype = null,
use_external_data_format = null,
session_options = {},
request_options = {}
} = {}) {

const options = {
Expand All @@ -7013,6 +7016,7 @@ export class PretrainedMixin {
dtype,
use_external_data_format,
session_options,
request_options,
}
options.config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options);

Expand Down
2 changes: 2 additions & 0 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -3301,6 +3301,7 @@ export async function pipeline(
dtype = null,
model_file_name = null,
session_options = {},
request_options = {}
} = {}
) {
// Helper method to construct pipeline
Expand Down Expand Up @@ -3331,6 +3332,7 @@ export async function pipeline(
dtype,
model_file_name,
session_options,
request_options,
}

const classes = new Map([
Expand Down
4 changes: 4 additions & 0 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2682,6 +2682,7 @@ export class PreTrainedTokenizer extends Callable {
local_files_only = false,
revision = 'main',
legacy = null,
request_options = {},
} = {}) {

const info = await loadTokenizer(pretrained_model_name_or_path, {
Expand All @@ -2691,6 +2692,7 @@ export class PreTrainedTokenizer extends Callable {
local_files_only,
revision,
legacy,
request_options,
})

// @ts-ignore
Expand Down Expand Up @@ -4351,6 +4353,7 @@ export class AutoTokenizer {
local_files_only = false,
revision = 'main',
legacy = null,
request_options = {}
} = {}) {

const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, {
Expand All @@ -4360,6 +4363,7 @@ export class AutoTokenizer {
local_files_only,
revision,
legacy,
request_options
})

// Some tokenizers are saved with the "Fast" suffix, so we remove that if present.
Expand Down
30 changes: 23 additions & 7 deletions src/utils/hub.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { dispatchCallback } from './core.js';
* @property {string} [cache_dir=null] Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
* @property {boolean} [local_files_only=false] Whether or not to only look at local files (e.g., not try downloading the model).
* @property {string} [revision='main'] The specific model version to use. It can be a branch name, a tag name, or a commit id,
* @property {RequestInit} [request_options] The options to use when making the request.
* since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
* NOTE: This setting is ignored for local requests.
*/
Expand Down Expand Up @@ -185,18 +186,23 @@ function isValidUrl(string, protocols = null, validHosts = null) {
* Helper function to get a file, using either the Fetch API or FileSystem API.
*
* @param {URL|string} urlOrPath The URL/path of the file to get.
* @param {RequestInit} [request_options] The options to use when making the request.
* @returns {Promise<FileResponse|Response>} A promise that resolves to a FileResponse object (if the file is retrieved using the FileSystem API), or a Response object (if the file is retrieved using the Fetch API).
*/
export async function getFile(urlOrPath) {
export async function getFile(urlOrPath, request_options) {

/**
* @type {Headers} The headers to use when making the request.
*/
let headers

if (env.useFS && !isValidUrl(urlOrPath, ['http:', 'https:', 'blob:'])) {
return new FileResponse(urlOrPath);

} else if (typeof process !== 'undefined' && process?.release?.name === 'node') {
const IS_CI = !!process.env?.TESTING_REMOTELY;
const version = env.version;

const headers = new Headers();
headers = new Headers();
headers.set('User-Agent', `transformers.js/${version}; is_ci/${IS_CI};`);

// Check whether we are making a request to the Hugging Face Hub.
Expand All @@ -210,13 +216,23 @@ export async function getFile(urlOrPath) {
headers.set('Authorization', `Bearer ${token}`);
}
}
return fetch(urlOrPath, { headers });
} else {
// Running in a browser-environment, so we use default headers
// NOTE: We do not allow passing authorization headers in the browser,
// since this would require exposing the token to the client.
return fetch(urlOrPath);
}

/**
* @type {(input: RequestInfo | URL, init?: RequestInit) => Promise<Response>} A custom fetch function to use. Defaults to `null`. Note: this must be a function which
*/
let resolvedFetch;
if (env.customFetch) {
resolvedFetch = env.customFetch;
} else {
resolvedFetch = fetch
}

return resolvedFetch(urlOrPath, {headers, ...request_options});
}

const ERROR_MAPPING = {
Expand Down Expand Up @@ -447,7 +463,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
const isURL = isValidUrl(requestURL, ['http:', 'https:']);
if (!isURL) {
try {
response = await getFile(localPath);
response = await getFile(localPath, options.request_options);
cacheKey = localPath; // Update the cache key to be the local path
} catch (e) {
// Something went wrong while trying to get the file locally.
Expand Down Expand Up @@ -479,7 +495,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
}

// File not found locally, so we try to download it from the remote server
response = await getFile(remoteURL);
response = await getFile(remoteURL, options.request_options);

if (response.status !== 200) {
return handleError(response.status, remoteURL, fatal);
Expand Down
30 changes: 30 additions & 0 deletions tests/utils/hub.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,35 @@ describe("Hub", () => {
},
MAX_TEST_EXECUTION_TIME,
);

it("should cancel model loading", async () => {
const controller = new AbortController();
const signal = controller.signal;
setTimeout(() => controller.abort(), 10);
try {
await AutoModel.from_pretrained("hf-internal-testing/this-model-does-not-exist", { ...DEFAULT_MODEL_OPTIONS, request_options: { signal } })
} catch (error) {
expect(error.name).toBe("AbortError");
}
}, MAX_TEST_EXECUTION_TIME + 1000);

it("should cancel multiple model loading", async () => {
const controller = new AbortController();
const signal = controller.signal;
setTimeout(() => controller.abort(), 10);

try {
await AutoModel.from_pretrained("hf-internal-testing/this-model-does-not-exist", { ...DEFAULT_MODEL_OPTIONS, request_options: { signal } })
} catch (error) {
expect(error.name).toBe("AbortError");
}

try {
await AutoModel.from_pretrained("hf-internal-testing/this-model-does-not-exist", { ...DEFAULT_MODEL_OPTIONS, request_options: { signal } })
} catch (error) {
expect(error.name).toBe("AbortError");
}

}, MAX_TEST_EXECUTION_TIME + 1000);
});
});