diff --git a/CHANGELOG.md b/CHANGELOG.md index ef290ce19f..f5048a7d6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,17 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#2098](https://github.com/plotly/dash/pull/2098) Accept HTTP code 400 as well as 401 for JWT expiry - [#2097](https://github.com/plotly/dash/pull/2097) Fix bug [#2095](https://github.com/plotly/dash/issues/2095) with TypeScript compiler and `React.FC` empty valueDeclaration error & support empty props components. - [#2104](https://github.com/plotly/dash/pull/2104) Fix bug [#2099](https://github.com/plotly/dash/issues/2099) with Dropdown clearing search value when a value is selected. +- [#2039](https://github.com/plotly/dash/pull/2039) Fix bugs in long callbacks: + - Fix [#1769](https://github.com/plotly/dash/issues/1769) and [#1852](https://github.com/plotly/dash/issues/1852) short interval makes job run in loop. + - Fix [#1974](https://github.com/plotly/dash/issues/1974) returning `no_update` or raising `PreventUpdate` not supported with celery. + - Fix use of the callback context in celery long callbacks. + - Fix support of pattern matching for long callbacks. + +### Added + +- [#2039](https://github.com/plotly/dash/pull/2039) Long callback changes: + - Add `long=False` to `dash.callback` to use instead of `app.long_callback`. + - Add previous `app.long_callback` arguments to `dash.callback` prefixed with `long_` (`interval`, `running`, `cancel`, `progress`, `progress_default`, `cache_args_to_ignore`, `manager`) ## [2.5.1] - 2022-06-13 diff --git a/dash/_callback.py b/dash/_callback.py index 7653b29f7b..e7763b4702 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,12 +1,20 @@ import collections from functools import wraps +import flask + from .dependencies import ( handle_callback_args, handle_grouped_callback_args, Output, ) -from .exceptions import PreventUpdate +from .exceptions import ( + PreventUpdate, + WildcardInLongCallback, + DuplicateCallback, + MissingLongCallbackManagerError, + LongCallbackError, +) from ._grouping import ( flatten_grouping, @@ -17,14 +25,24 @@ create_callback_id, stringify_id, to_json, + coerce_to_list, + AttributeDict, ) from . import _validate +from .long_callback.managers import BaseLongCallbackManager +from ._callback_context import context_value class NoUpdate: - # pylint: disable=too-few-public-methods - pass + def to_plotly_json(self): # pylint: disable=no-self-use + return {"_dash_no_update": "_dash_no_update"} + + @staticmethod + def is_no_update(obj): + return isinstance(obj, NoUpdate) or obj == { + "_dash_no_update": "_dash_no_update" + } GLOBAL_CALLBACK_LIST = [] @@ -32,7 +50,19 @@ class NoUpdate: GLOBAL_INLINE_SCRIPTS = [] -def callback(*_args, **_kwargs): +# pylint: disable=too-many-locals +def callback( + *_args, + long=False, + long_interval=1000, + long_progress=None, + long_progress_default=None, + long_running=None, + long_cancel=None, + long_manager=None, + long_cache_args_to_ignore=None, + **_kwargs, +): """ Normally used as a decorator, `@dash.callback` provides a server-side callback relating the values of one or more `Output` items to one or @@ -49,16 +79,138 @@ def callback(*_args, **_kwargs): The last, optional argument `prevent_initial_call` causes the callback not to fire when its outputs are first added to the page. Defaults to `False` and unlike `app.callback` is not configurable at the app level. + + :Keyword Arguments: + :param long: + Mark the callback as a long callback to execute in a manager for + callbacks that take a long time without locking up the Dash app + or timing out. + :param long_manager: + A long callback manager instance. Currently an instance of one of + `DiskcacheLongCallbackManager` or `CeleryLongCallbackManager`. + Defaults to the `long_callback_manager` instance provided to the + `dash.Dash constructor`. + - A diskcache manager (`DiskcacheLongCallbackManager`) that runs callback + logic in a separate process and stores the results to disk using the + diskcache library. This is the easiest backend to use for local + development. + - A Celery manager (`CeleryLongCallbackManager`) that runs callback logic + in a celery worker and returns results to the Dash app through a Celery + broker like RabbitMQ or Redis. + :param long_running: + A list of 3-element tuples. The first element of each tuple should be + an `Output` dependency object referencing a property of a component in + the app layout. The second element is the value that the property + should be set to while the callback is running, and the third element + is the value the property should be set to when the callback completes. + :param long_cancel: + A list of `Input` dependency objects that reference a property of a + component in the app's layout. When the value of this property changes + while a callback is running, the callback is canceled. + Note that the value of the property is not significant, any change in + value will result in the cancellation of the running job (if any). + :param long_progress: + An `Output` dependency grouping that references properties of + components in the app's layout. When provided, the decorated function + will be called with an extra argument as the first argument to the + function. This argument, is a function handle that the decorated + function should call in order to provide updates to the app on its + current progress. This function accepts a single argument, which + correspond to the grouping of properties specified in the provided + `Output` dependency grouping + :param long_progress_default: + A grouping of values that should be assigned to the components + specified by the `progress` argument when the callback is not in + progress. If `progress_default` is not provided, all the dependency + properties specified in `progress` will be set to `None` when the + callback is not running. + :param long_cache_args_to_ignore: + Arguments to ignore when caching is enabled. If callback is configured + with keyword arguments (Input/State provided in a dict), + this should be a list of argument names as strings. Otherwise, + this should be a list of argument indices as integers. + :param long_interval: + Time to wait between the long callback update requests. """ + + long_spec = None + + config_prevent_initial_callbacks = _kwargs.pop( + "config_prevent_initial_callbacks", False + ) + callback_map = _kwargs.pop("callback_map", GLOBAL_CALLBACK_MAP) + callback_list = _kwargs.pop("callback_list", GLOBAL_CALLBACK_LIST) + + if long: + long_spec = { + "interval": long_interval, + } + + if long_manager: + long_spec["manager"] = long_manager + + if long_progress: + long_spec["progress"] = coerce_to_list(long_progress) + validate_long_inputs(long_spec["progress"]) + + if long_progress_default: + long_spec["progressDefault"] = coerce_to_list(long_progress_default) + + if not len(long_spec["progress"]) == len(long_spec["progressDefault"]): + raise Exception( + "Progress and progress default needs to be of same length" + ) + + if long_running: + long_spec["running"] = coerce_to_list(long_running) + validate_long_inputs(x[0] for x in long_spec["running"]) + + if long_cancel: + cancel_inputs = coerce_to_list(long_cancel) + validate_long_inputs(cancel_inputs) + + cancels_output = [Output(c.component_id, "id") for c in cancel_inputs] + + try: + + @callback(cancels_output, cancel_inputs, prevent_initial_call=True) + def cancel_call(*_): + job_ids = flask.request.args.getlist("cancelJob") + manager = long_manager or context_value.get().long_callback_manager + if job_ids: + for job_id in job_ids: + manager.terminate_job(job_id) + return NoUpdate() + + except DuplicateCallback: + pass # Already a callback to cancel, will get the proper jobs from the store. + + long_spec["cancel"] = [c.to_dict() for c in cancel_inputs] + + if long_cache_args_to_ignore: + long_spec["cache_args_to_ignore"] = long_cache_args_to_ignore + return register_callback( - GLOBAL_CALLBACK_LIST, - GLOBAL_CALLBACK_MAP, - False, + callback_list, + callback_map, + config_prevent_initial_callbacks, *_args, **_kwargs, + long=long_spec, ) +def validate_long_inputs(deps): + for dep in deps: + if dep.has_wildcard(): + raise WildcardInLongCallback( + f""" + long callbacks does not support dependencies with + pattern-matching ids + Received: {repr(dep)}\n""" + ) + + def clientside_callback(clientside_function, *args, **kwargs): return register_clientside_callback( GLOBAL_CALLBACK_LIST, @@ -81,6 +233,7 @@ def insert_callback( state, inputs_state_indices, prevent_initial_call, + long=None, ): if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -92,19 +245,28 @@ def insert_callback( "state": [c.to_dict() for c in state], "clientside_function": None, "prevent_initial_call": prevent_initial_call, + "long": long + and { + "interval": long["interval"], + }, } + callback_map[callback_id] = { "inputs": callback_spec["inputs"], "state": callback_spec["state"], "outputs_indices": outputs_indices, "inputs_state_indices": inputs_state_indices, + "long": long, + "output": output, + "raw_inputs": inputs, } callback_list.append(callback_spec) return callback_id -def register_callback( +# pylint: disable=R0912, R0915 +def register_callback( # pylint: disable=R0914 callback_list, callback_map, config_prevent_initial_callbacks, *_args, **_kwargs ): ( @@ -123,6 +285,8 @@ def register_callback( insert_output = flatten_grouping(output) multi = True + long = _kwargs.get("long") + output_indices = make_grouping_by_index(output, list(range(grouping_len(output)))) callback_id = insert_callback( callback_list, @@ -134,23 +298,145 @@ def register_callback( flat_state, inputs_state_indices, prevent_initial_call, + long=long, ) # pylint: disable=too-many-locals def wrap_func(func): + + if long is not None: + long_key = BaseLongCallbackManager.register_func( + func, long.get("progress") is not None + ) + @wraps(func) def add_context(*args, **kwargs): output_spec = kwargs.pop("outputs_list") + app_callback_manager = kwargs.pop("long_callback_manager", None) + callback_ctx = kwargs.pop("callback_context", {}) + callback_manager = long and long.get("manager", app_callback_manager) _validate.validate_output_spec(insert_output, output_spec, Output) + context_value.set(callback_ctx) + func_args, func_kwargs = _validate.validate_and_group_input_args( args, inputs_state_indices ) - # don't touch the comment on the next line - used by debugger - output_value = func(*func_args, **func_kwargs) # %% callback invoked %% + response = {"multi": True} + + if long is not None: + if not callback_manager: + raise MissingLongCallbackManagerError( + "Running `long` callbacks requires a manager to be installed.\n" + "Available managers:\n" + "- Diskcache (`pip install dash[diskcache]`) to run callbacks in a separate Process" + " and store results on the local filesystem.\n" + "- Celery (`pip install dash[celery]`) to run callbacks in a celery worker" + " and store results on redis.\n" + ) + + progress_outputs = long.get("progress") + cache_key = flask.request.args.get("cacheKey") + job_id = flask.request.args.get("job") + old_job = flask.request.args.getlist("oldJob") + + current_key = callback_manager.build_cache_key( + func, + # Inputs provided as dict is kwargs. + func_args if func_args else func_kwargs, + long.get("cache_args_to_ignore", []), + ) + + if old_job: + for job in old_job: + callback_manager.terminate_job(job) + + if not cache_key: + cache_key = current_key + + job_fn = callback_manager.func_registry.get(long_key) + + job = callback_manager.call_job_fn( + cache_key, + job_fn, + args, + AttributeDict( + args_grouping=callback_ctx.args_grouping, + using_args_grouping=callback_ctx.using_args_grouping, + outputs_grouping=callback_ctx.outputs_grouping, + using_outputs_grouping=callback_ctx.using_outputs_grouping, + inputs_list=callback_ctx.inputs_list, + states_list=callback_ctx.states_list, + outputs_list=callback_ctx.outputs_list, + input_values=callback_ctx.input_values, + state_values=callback_ctx.state_values, + triggered_inputs=callback_ctx.triggered_inputs, + ), + ) + + data = { + "cacheKey": cache_key, + "job": job, + } + + running = long.get("running") + + if running: + data["running"] = {str(r[0]): r[1] for r in running} + data["runningOff"] = {str(r[0]): r[2] for r in running} + cancel = long.get("cancel") + if cancel: + data["cancel"] = cancel + + progress_default = long.get("progressDefault") + if progress_default: + data["progressDefault"] = { + str(o): x + for o, x in zip(progress_outputs, progress_default) + } + return to_json(data) + if progress_outputs: + # Get the progress before the result as it would be erased after the results. + progress = callback_manager.get_progress(cache_key) + if progress: + response["progress"] = { + str(x): progress[i] for i, x in enumerate(progress_outputs) + } + + output_value = callback_manager.get_result(cache_key, job_id) + # Must get job_running after get_result since get_results terminates it. + job_running = callback_manager.job_running(job_id) + if not job_running and output_value is callback_manager.UNDEFINED: + # Job canceled -> no output to close the loop. + output_value = NoUpdate() + + elif ( + isinstance(output_value, dict) + and "long_callback_error" in output_value + ): + error = output_value.get("long_callback_error") + raise LongCallbackError( + f"An error occurred inside a long callback: {error['msg']}\n{error['tb']}" + ) + + if job_running and output_value is not callback_manager.UNDEFINED: + # cached results. + callback_manager.terminate_job(job_id) + + if multi and isinstance(output_value, (list, tuple)): + output_value = [ + NoUpdate() if NoUpdate.is_no_update(r) else r + for r in output_value + ] + + if output_value is callback_manager.UNDEFINED: + return to_json(response) + else: + # don't touch the comment on the next line - used by debugger + output_value = func(*func_args, **func_kwargs) # %% callback invoked %% - if isinstance(output_value, NoUpdate): + if NoUpdate.is_no_update(output_value): raise PreventUpdate if not multi: @@ -185,7 +471,7 @@ def add_context(*args, **kwargs): if not has_update: raise PreventUpdate - response = {"response": component_ids, "multi": True} + response["response"] = component_ids try: jsonResponse = to_json(response) diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 0d3ed3d2e8..7638d0a860 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -1,16 +1,22 @@ import functools import warnings import json +import contextvars + import flask from . import exceptions from ._utils import AttributeDict +context_value = contextvars.ContextVar("callback_context") +context_value.set({}) + + def has_context(func): @functools.wraps(func) def assert_context(*args, **kwargs): - if not flask.has_request_context(): + if not context_value.get(): raise exceptions.MissingCallbackContextException( f"dash.callback_context.{getattr(func, '__name__')} is only available from a callback!" ) @@ -19,6 +25,10 @@ def assert_context(*args, **kwargs): return assert_context +def _get_context_value(): + return context_value.get() + + class FalsyList(list): def __bool__(self): # for Python 3 @@ -37,12 +47,12 @@ class CallbackContext: @property @has_context def inputs(self): - return getattr(flask.g, "input_values", {}) + return getattr(_get_context_value(), "input_values", {}) @property @has_context def states(self): - return getattr(flask.g, "state_values", {}) + return getattr(_get_context_value(), "state_values", {}) @property @has_context @@ -64,7 +74,7 @@ def triggered(self): # value - to avoid breaking existing apps, add a dummy item but # make the list still look falsy. So `if ctx.triggered` will make it # look empty, but you can still do `triggered[0]["prop_id"].split(".")` - return getattr(flask.g, "triggered_inputs", []) or falsy_triggered + return getattr(_get_context_value(), "triggered_inputs", []) or falsy_triggered @property @has_context @@ -90,7 +100,7 @@ def triggered_prop_ids(self): `if "btn-1.n_clicks" in ctx.triggered_prop_ids: do_something()` """ - triggered = getattr(flask.g, "triggered_inputs", []) + triggered = getattr(_get_context_value(), "triggered_inputs", []) ids = AttributeDict({}) for item in triggered: component_id, _, _ = item["prop_id"].rpartition(".") @@ -146,12 +156,12 @@ def display(btn1, btn2): return "No clicks yet" """ - return getattr(flask.g, "args_grouping", []) + return getattr(_get_context_value(), "args_grouping", []) @property @has_context def outputs_grouping(self): - return getattr(flask.g, "outputs_grouping", []) + return getattr(_get_context_value(), "outputs_grouping", []) @property @has_context @@ -162,7 +172,7 @@ def outputs_list(self): DeprecationWarning, ) - return getattr(flask.g, "outputs_list", []) + return getattr(_get_context_value(), "outputs_list", []) @property @has_context @@ -173,7 +183,7 @@ def inputs_list(self): DeprecationWarning, ) - return getattr(flask.g, "inputs_list", []) + return getattr(_get_context_value(), "inputs_list", []) @property @has_context @@ -183,12 +193,12 @@ def states_list(self): "states_list is deprecated, use args_grouping instead", DeprecationWarning, ) - return getattr(flask.g, "states_list", []) + return getattr(_get_context_value(), "states_list", []) @property @has_context def response(self): - return getattr(flask.g, "dash_response") + return getattr(_get_context_value(), "dash_response") @staticmethod @has_context @@ -221,7 +231,7 @@ def using_args_grouping(self): Return True if this callback is using dictionary or nested groupings for Input/State dependencies, or if Input and State dependencies are interleaved """ - return getattr(flask.g, "using_args_grouping", []) + return getattr(_get_context_value(), "using_args_grouping", []) @property @has_context @@ -230,7 +240,12 @@ def using_outputs_grouping(self): Return True if this callback is using dictionary or nested groupings for Output dependencies. """ - return getattr(flask.g, "using_outputs_grouping", []) + return getattr(_get_context_value(), "using_outputs_grouping", []) + + @property + @has_context + def timing_information(self): + return getattr(flask.g, "timing_information", {}) callback_context = CallbackContext() diff --git a/dash/_utils.py b/dash/_utils.py index aa0470f43d..4ce9697d0c 100644 --- a/dash/_utils.py +++ b/dash/_utils.py @@ -217,3 +217,9 @@ def gen_salt(chars): return "".join( secrets.choice(string.ascii_letters + string.digits) for _ in range(chars) ) + + +def coerce_to_list(obj): + if not isinstance(obj, (list, tuple)): + return [obj] + return obj diff --git a/dash/_validate.py b/dash/_validate.py index 0e5a097ec1..a4eb61cf1e 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -7,7 +7,7 @@ from ._grouping import grouping_len, map_grouping from .development.base_component import Component from . import exceptions -from ._utils import patch_collections_abc, stringify_id, to_json +from ._utils import patch_collections_abc, stringify_id, to_json, coerce_to_list def validate_callback(outputs, inputs, state, extra_args, types): @@ -479,3 +479,35 @@ def validate_module_name(module): "The first attribute of dash.register_page() must be a string or '__name__'" ) return module + + +def validate_long_callbacks(callback_map): + # Validate that long callback side output & inputs are not circular + # If circular, triggering a long callback would result in a fatal server/computer crash. + all_outputs = set() + input_indexed = {} + for callback in callback_map.values(): + out = coerce_to_list(callback["output"]) + all_outputs.update(out) + for o in out: + input_indexed.setdefault(o, set()) + input_indexed[o].update(coerce_to_list(callback["raw_inputs"])) + + for callback in (x for x in callback_map.values() if x.get("long")): + long_info = callback["long"] + progress = long_info.get("progress", []) + running = long_info.get("running", []) + + long_inputs = coerce_to_list(callback["raw_inputs"]) + outputs = set([x[0] for x in running] + progress) + circular = [ + x + for x in set(k for k, v in input_indexed.items() if v.intersection(outputs)) + if x in long_inputs + ] + + if circular: + raise exceptions.LongCallbackError( + f"Long callback circular error!\n{circular} is used as input for a long callback" + f" but also used as output from an input that is updated with progress or running argument." + ) diff --git a/dash/dash-renderer/src/actions/callbacks.ts b/dash/dash-renderer/src/actions/callbacks.ts index cf88cdee0a..4764bdfb6c 100644 --- a/dash/dash-renderer/src/actions/callbacks.ts +++ b/dash/dash-renderer/src/actions/callbacks.ts @@ -1,12 +1,15 @@ import { concat, flatten, + intersection, keys, map, mergeDeepRight, path, pick, pluck, + values, + toPairs, zip } from 'ramda'; @@ -24,13 +27,18 @@ import { ICallbackPayload, IStoredCallback, IBlockedCallback, - IPrioritizedCallback + IPrioritizedCallback, + LongCallbackInfo, + CallbackResponse, + CallbackResponseData } from '../types/callbacks'; import {isMultiValued, stringifyId, isMultiOutputProp} from './dependencies'; import {urlBase} from './utils'; import {getCSRFHeader} from '.'; import {createAction, Action} from 'redux-actions'; import {addHttpHeaders} from '../actions'; +import {notifyObservers, updateProps} from './index'; +import {CallbackJobPayload} from '../reducers/callbackJobs'; export const addBlockedCallbacks = createAction( CallbackActionType.AddBlocked @@ -83,6 +91,10 @@ export const aggregateCallbacks = createAction< const updateResourceUsage = createAction('UPDATE_RESOURCE_USAGE'); +const addCallbackJob = createAction('ADD_CALLBACK_JOB'); +const removeCallbackJob = createAction('REMOVE_CALLBACK_JOB'); +const setCallbackJobOutdated = createAction('CALLBACK_JOB_OUTDATED'); + function unwrapIfNotMulti( paths: any, idProps: any, @@ -300,30 +312,90 @@ async function handleClientside( return result; } +function sideUpdate(outputs: any, dispatch: any, paths: any) { + toPairs(outputs).forEach(([id, value]) => { + const [componentId, propName] = id.split('.'); + const componentPath = paths.strs[componentId]; + dispatch( + updateProps({ + props: {[propName]: value}, + itempath: componentPath + }) + ); + dispatch( + notifyObservers({id: componentId, props: {[propName]: value}}) + ); + }); +} + function handleServerside( dispatch: any, hooks: any, config: any, - payload: any -): Promise { + payload: any, + paths: any, + long: LongCallbackInfo | undefined, + additionalArgs: [string, string, boolean?][] | undefined, + getState: any, + output: string +): Promise { if (hooks.request_pre) { hooks.request_pre(payload); } const requestTime = Date.now(); const body = JSON.stringify(payload); + let cacheKey: string; + let job: string; + let runningOff: any; + let progressDefault: any; + let moreArgs = additionalArgs; + + const fetchCallback = () => { + const headers = getCSRFHeader() as any; + let url = `${urlBase(config)}_dash-update-component`; + + const addArg = (name: string, value: string) => { + let delim = '?'; + if (url.includes('?')) { + delim = '&'; + } + url = `${url}${delim}${name}=${value}`; + }; + if (cacheKey) { + addArg('cacheKey', cacheKey); + } + if (job) { + addArg('job', job); + } + + if (moreArgs) { + moreArgs.forEach(([key, value]) => addArg(key, value)); + moreArgs = moreArgs.filter(([_, __, single]) => !single); + } + + return fetch( + url, + mergeDeepRight(config.fetch, { + method: 'POST', + headers, + body + }) + ); + }; - return fetch( - `${urlBase(config)}_dash-update-component`, - mergeDeepRight(config.fetch, { - method: 'POST', - headers: getCSRFHeader() as any, - body - }) - ).then( - (res: any) => { + return new Promise((resolve, reject) => { + const handleOutput = (res: any) => { const {status} = res; + if (job) { + const callbackJob = getState().callbackJobs[job]; + if (callbackJob?.outdated) { + dispatch(removeCallbackJob({jobId: job})); + return resolve({}); + } + } + function recordProfile(result: any) { if (config.ui) { // Callback profiling - only relevant if we're showing the debug ui @@ -361,36 +433,90 @@ function handleServerside( } } + const finishLine = (data: CallbackResponseData) => { + const {multi, response} = data; + if (hooks.request_post) { + hooks.request_post(payload, response); + } + + let result; + if (multi) { + result = response as CallbackResponse; + } else { + const {output} = payload; + const id = output.substr(0, output.lastIndexOf('.')); + result = {[id]: (response as CallbackResponse).props}; + } + + recordProfile(result); + resolve(result); + }; + + const completeJob = () => { + if (job) { + dispatch(removeCallbackJob({jobId: job})); + } + if (runningOff) { + sideUpdate(runningOff, dispatch, paths); + } + if (progressDefault) { + sideUpdate(progressDefault, dispatch, paths); + } + }; + if (status === STATUS.OK) { - return res.json().then((data: any) => { - const {multi, response} = data; - if (hooks.request_post) { - hooks.request_post(payload, response); + res.json().then((data: CallbackResponseData) => { + if (!cacheKey && data.cacheKey) { + cacheKey = data.cacheKey; } - let result; - if (multi) { - result = response; - } else { - const {output} = payload; - const id = output.substr(0, output.lastIndexOf('.')); - result = {[id]: response.props}; + if (!job && data.job) { + const jobInfo: CallbackJobPayload = { + jobId: data.job, + cacheKey: data.cacheKey as string, + cancelInputs: data.cancel, + progressDefault: data.progressDefault, + output + }; + dispatch(addCallbackJob(jobInfo)); + job = data.job; } - recordProfile(result); - return result; + if (data.progress) { + sideUpdate(data.progress, dispatch, paths); + } + if (data.running) { + sideUpdate(data.running, dispatch, paths); + } + if (!runningOff && data.runningOff) { + runningOff = data.runningOff; + } + if (!progressDefault && data.progressDefault) { + progressDefault = data.progressDefault; + } + + if (!long || data.response !== undefined) { + completeJob(); + finishLine(data); + } else { + // Poll chain. + setTimeout( + handle, + long.interval !== undefined ? long.interval : 500 + ); + } }); - } - if (status === STATUS.PREVENT_UPDATE) { + } else if (status === STATUS.PREVENT_UPDATE) { + completeJob(); recordProfile({}); - return {}; + resolve({}); + } else { + completeJob(); + reject(res); } - throw res; - }, - () => { - // fetch rejection - this means the request didn't return, - // we don't get here from 400/500 errors, only network - // errors or unresponsive servers. + }; + + const handleError = () => { if (config.ui) { dispatch( updateResourceUsage({ @@ -402,9 +528,14 @@ function handleServerside( }) ); } - throw new Error('Callback failed: the server did not respond.'); - } - ); + reject(new Error('Callback failed: the server did not respond.')); + }; + + const handle = () => { + fetchCallback().then(handleOutput, handleError); + }; + handle(); + }); } function inputsToDict(inputs_list: any) { @@ -443,10 +574,10 @@ export function executeCallback( paths: any, layout: any, {allOutputs}: any, - dispatch: any + dispatch: any, + getState: any ): IExecutingCallback { - const {output, inputs, state, clientside_function} = cb.callback; - + const {output, inputs, state, clientside_function, long} = cb.callback; try { const inVals = fillVals(paths, layout, cb, inputs, 'Input', true); @@ -518,13 +649,51 @@ export function executeCallback( let newHeaders: Record | null = null; let lastError: any; + const additionalArgs: [string, string, boolean?][] = []; + console.log(cb.callback.output, getState().callbackJobs); + values(getState().callbackJobs).forEach( + (job: CallbackJobPayload) => { + if (cb.callback.output === job.output) { + // Terminate the old jobs that are not completed + // set as outdated for the callback promise to + // resolve and remove after. + additionalArgs.push(['oldJob', job.jobId, true]); + dispatch( + setCallbackJobOutdated({jobId: job.jobId}) + ); + } + if (!job.cancelInputs) { + return; + } + const inter = intersection( + job.cancelInputs, + cb.callback.inputs + ); + if (inter.length) { + additionalArgs.push(['cancelJob', job.jobId]); + if (job.progressDefault) { + sideUpdate( + job.progressDefault, + dispatch, + paths + ); + } + } + } + ); + for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) { try { const data = await handleServerside( dispatch, hooks, newConfig, - payload + payload, + paths, + long, + additionalArgs.length ? additionalArgs : undefined, + getState, + cb.callback.output ); if (newHeaders) { diff --git a/dash/dash-renderer/src/observers/prioritizedCallbacks.ts b/dash/dash-renderer/src/observers/prioritizedCallbacks.ts index 7407e24271..bf557913f8 100644 --- a/dash/dash-renderer/src/observers/prioritizedCallbacks.ts +++ b/dash/dash-renderer/src/observers/prioritizedCallbacks.ts @@ -110,7 +110,8 @@ const observer: IStoreObserverDefinition = { paths, layout, getStash(cb, paths), - dispatch + dispatch, + getState ), pickedSyncCallbacks ) @@ -162,7 +163,8 @@ const observer: IStoreObserverDefinition = { paths, layout, cb, - dispatch + dispatch, + getState ); dispatch( diff --git a/dash/dash-renderer/src/reducers/callbackJobs.ts b/dash/dash-renderer/src/reducers/callbackJobs.ts new file mode 100644 index 0000000000..658401b4d8 --- /dev/null +++ b/dash/dash-renderer/src/reducers/callbackJobs.ts @@ -0,0 +1,41 @@ +import {assoc, assocPath, dissoc} from 'ramda'; +import {ICallbackProperty} from '../types/callbacks'; + +type CallbackJobState = {[k: string]: CallbackJobPayload}; + +export type CallbackJobPayload = { + cancelInputs?: ICallbackProperty[]; + cacheKey: string; + jobId: string; + progressDefault?: any; + output?: string; + outdated?: boolean; +}; + +type CallbackJobAction = { + type: 'ADD_CALLBACK_JOB' | 'REMOVE_CALLBACK_JOB' | 'CALLBACK_JOB_OUTDATED'; + payload: CallbackJobPayload; +}; + +const setJob = (job: CallbackJobPayload, state: CallbackJobState) => + assoc(job.jobId, job, state); +const removeJob = (jobId: string, state: CallbackJobState) => + dissoc(jobId, state); +const setOutdated = (jobId: string, state: CallbackJobState) => + assocPath([jobId, 'outdated'], true, state); + +export default function ( + state: CallbackJobState = {}, + action: CallbackJobAction +) { + switch (action.type) { + case 'ADD_CALLBACK_JOB': + return setJob(action.payload, state); + case 'REMOVE_CALLBACK_JOB': + return removeJob(action.payload.jobId, state); + case 'CALLBACK_JOB_OUTDATED': + return setOutdated(action.payload.jobId, state); + default: + return state; + } +} diff --git a/dash/dash-renderer/src/reducers/reducer.js b/dash/dash-renderer/src/reducers/reducer.js index 0e0d844704..97b71b6bce 100644 --- a/dash/dash-renderer/src/reducers/reducer.js +++ b/dash/dash-renderer/src/reducers/reducer.js @@ -17,6 +17,7 @@ import isLoading from './isLoading'; import layout from './layout'; import loadingMap from './loadingMap'; import paths from './paths'; +import callbackJobs from './callbackJobs'; export const apiRequests = [ 'dependenciesRequest', @@ -45,6 +46,8 @@ function mainReducer() { parts[r] = createApiReducer(r); }, apiRequests); + parts.callbackJobs = callbackJobs; + return combineReducers(parts); } diff --git a/dash/dash-renderer/src/store.ts b/dash/dash-renderer/src/store.ts index 7244077460..09bf21be8b 100644 --- a/dash/dash-renderer/src/store.ts +++ b/dash/dash-renderer/src/store.ts @@ -82,7 +82,12 @@ export default class RendererStore { const reduxDTEC = (window as any) .__REDUX_DEVTOOLS_EXTENSION_COMPOSE__; if (reduxDTEC) { - this.createAppStore(reducer, reduxDTEC(applyMiddleware(thunk))); + this.createAppStore( + reducer, + reduxDTEC({actionsDenylist: ['reloadRequest']})( + applyMiddleware(thunk) + ) + ); } else { this.createAppStore(reducer, applyMiddleware(thunk)); } diff --git a/dash/dash-renderer/src/types/callbacks.ts b/dash/dash-renderer/src/types/callbacks.ts index 5e286fca97..0c51e13c85 100644 --- a/dash/dash-renderer/src/types/callbacks.ts +++ b/dash/dash-renderer/src/types/callbacks.ts @@ -11,6 +11,7 @@ export interface ICallbackDefinition { outputs: ICallbackProperty[]; prevent_initial_call: boolean; state: ICallbackProperty[]; + long?: LongCallbackInfo; } export interface ICallbackProperty { @@ -75,7 +76,29 @@ export interface ICallbackPayload { } export type CallbackResult = { - data?: any; + data?: CallbackResponse; error?: Error; payload: ICallbackPayload | null; }; + +export type LongCallbackInfo = { + interval?: number; + progress?: any; + running?: any; +}; + +export type CallbackResponse = { + [k: string]: any; +}; + +export type CallbackResponseData = { + response?: CallbackResponse; + multi?: boolean; + cacheKey?: string; + job?: string; + progressDefault?: CallbackResponse; + progress?: CallbackResponse; + running?: CallbackResponse; + runningOff?: CallbackResponse; + cancel?: ICallbackProperty[]; +}; diff --git a/dash/dash.py b/dash/dash.py index b0c597ec1a..7786f73667 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1,7 +1,9 @@ +import functools import os import sys import collections import importlib +from contextvars import copy_context from importlib.machinery import ModuleSpec import pkgutil import threading @@ -27,9 +29,7 @@ from .fingerprint import build_fingerprint, check_fingerprint from .resources import Scripts, Css from .dependencies import ( - handle_grouped_callback_args, Output, - State, Input, ) from .development.base_component import ComponentRegistry @@ -61,7 +61,7 @@ from . import _watch from . import _get_app -from ._grouping import flatten_grouping, map_grouping, grouping_len, update_args_group +from ._grouping import map_grouping, grouping_len, update_args_group from . import _pages from ._pages import ( @@ -166,11 +166,6 @@ def _get_skip(text, divider=2): return tb[0] + "".join(tb[skip:]) -class _NoUpdate: - # pylint: disable=too-few-public-methods - pass - - # Singleton signal to not update an output, alternative to PreventUpdate no_update = _callback.NoUpdate() # pylint: disable=protected-access @@ -1132,288 +1127,73 @@ def callback(self, *_args, **_kwargs): """ - return _callback.register_callback( - self._callback_list, - self.callback_map, - self.config.prevent_initial_callbacks, + return _callback.callback( *_args, + config_prevent_initial_callbacks=self.config.prevent_initial_callbacks, + callback_list=self._callback_list, + callback_map=self.callback_map, **_kwargs, ) - def long_callback(self, *_args, **_kwargs): # pylint: disable=too-many-statements + def long_callback( + self, + *_args, + manager=None, + interval=None, + running=None, + cancel=None, + progress=None, + progress_default=None, + cache_args_to_ignore=None, + **_kwargs, + ): """ - Normally used as a decorator, `@app.long_callback` is an alternative to - `@app.callback` designed for callbacks that take a long time to run, - without locking up the Dash app or timing out. - - `@long_callback` is designed to support multiple callback managers. - Two long callback managers are currently implemented: - - - A diskcache manager (`DiskcacheLongCallbackManager`) that runs callback - logic in a separate process and stores the results to disk using the - diskcache library. This is the easiest backend to use for local - development. - - A Celery manager (`CeleryLongCallbackManager`) that runs callback logic - in a celery worker and returns results to the Dash app through a Celery - broker like RabbitMQ or Redis. - - The following arguments may include any valid arguments to `@app.callback`. - In addition, `@app.long_callback` supports the following optional - keyword arguments: - - :Keyword Arguments: - :param manager: - A long callback manager instance. Currently an instance of one of - `DiskcacheLongCallbackManager` or `CeleryLongCallbackManager`. - Defaults to the `long_callback_manager` instance provided to the - `dash.Dash constructor`. - :param running: - A list of 3-element tuples. The first element of each tuple should be - an `Output` dependency object referencing a property of a component in - the app layout. The second element is the value that the property - should be set to while the callback is running, and the third element - is the value the property should be set to when the callback completes. - :param cancel: - A list of `Input` dependency objects that reference a property of a - component in the app's layout. When the value of this property changes - while a callback is running, the callback is canceled. - Note that the value of the property is not significant, any change in - value will result in the cancellation of the running job (if any). - :param progress: - An `Output` dependency grouping that references properties of - components in the app's layout. When provided, the decorated function - will be called with an extra argument as the first argument to the - function. This argument, is a function handle that the decorated - function should call in order to provide updates to the app on its - current progress. This function accepts a single argument, which - correspond to the grouping of properties specified in the provided - `Output` dependency grouping - :param progress_default: - A grouping of values that should be assigned to the components - specified by the `progress` argument when the callback is not in - progress. If `progress_default` is not provided, all the dependency - properties specified in `progress` will be set to `None` when the - callback is not running. - :param cache_args_to_ignore: - Arguments to ignore when caching is enabled. If callback is configured - with keyword arguments (Input/State provided in a dict), - this should be a list of argument names as strings. Otherwise, - this should be a list of argument indices as integers. + Deprecated: long callbacks are now supported natively with regular callbacks, + use `long=True` with `dash.callback` or `app.callback` instead. """ - # pylint: disable-next=import-outside-toplevel - from dash._callback_context import callback_context - - # pylint: disable-next=import-outside-toplevel - from dash.exceptions import WildcardInLongCallback - - # Get long callback manager - callback_manager = _kwargs.pop("manager", self._long_callback_manager) - if callback_manager is None: - raise ValueError( - "The @app.long_callback decorator requires a long callback manager\n" - "instance. This may be provided to the app using the \n" - "long_callback_manager argument to the dash.Dash constructor, or\n" - "it may be provided to the @app.long_callback decorator as the \n" - "manager argument" - ) - - # Extract special long_callback kwargs - running = _kwargs.pop("running", ()) - cancel = _kwargs.pop("cancel", ()) - progress = _kwargs.pop("progress", ()) - progress_default = _kwargs.pop("progress_default", None) - interval_time = _kwargs.pop("interval", 1000) - cache_args_to_ignore = _kwargs.pop("cache_args_to_ignore", []) - - # Parse remaining args just like app.callback - ( - output, - flat_inputs, - flat_state, - inputs_state_indices, - prevent_initial_call, - ) = handle_grouped_callback_args(_args, _kwargs) - inputs_and_state = flat_inputs + flat_state - args_deps = map_grouping(lambda i: inputs_and_state[i], inputs_state_indices) - - # Disallow wildcard dependencies - for deps in [output, flat_inputs, flat_state]: - for dep in flatten_grouping(deps): - if dep.has_wildcard(): - raise WildcardInLongCallback( - f""" - @app.long_callback does not support dependencies with - pattern-matching ids - Received: {repr(dep)}\n""" - ) - - # Get unique id for this long_callback definition. This increment is not - # thread safe, but it doesn't need to be because callback definitions - # happen on the main thread before the app starts - self._long_callback_count += 1 - long_callback_id = self._long_callback_count - - # Create Interval and Store for long callback and add them to the app's - # _extra_components list - interval_id = f"_long_callback_interval_{long_callback_id}" - interval_component = dcc.Interval( - id=interval_id, interval=interval_time, disabled=prevent_initial_call - ) - store_id = f"_long_callback_store_{long_callback_id}" - store_component = dcc.Store(id=store_id, data={}) - self._extra_components.extend([interval_component, store_component]) - - # Compute full component plus property name for the cancel dependencies - cancel_prop_ids = tuple( - ".".join([dep.component_id, dep.component_property]) for dep in cancel + return _callback.callback( + *_args, + long=True, + long_manager=manager, + long_interval=interval, + long_progress=progress, + long_progress_default=progress_default, + long_running=running, + long_cancel=cancel, + long_cache_args_to_ignore=cache_args_to_ignore, + callback_map=self.callback_map, + callback_list=self._callback_list, + config_prevent_initial_callbacks=self.config.prevent_initial_callbacks, + **_kwargs, ) - def wrapper(fn): - background_fn = callback_manager.make_job_fn(fn, bool(progress), args_deps) - - def callback(_triggers, user_store_data, user_callback_args): - # Build result cache key from inputs - pending_key = callback_manager.build_cache_key( - fn, user_callback_args, cache_args_to_ignore - ) - current_key = user_store_data.get("current_key", None) - pending_job = user_store_data.get("pending_job", None) - - should_cancel = pending_key == current_key or any( - trigger["prop_id"] in cancel_prop_ids - for trigger in callback_context.triggered - ) - - # Compute grouping of values to set the progress component's to - # when cleared - if progress_default is None: - clear_progress = ( - map_grouping(lambda x: None, progress) if progress else () - ) - else: - clear_progress = progress_default - - if should_cancel: - user_store_data["current_key"] = None - user_store_data["pending_key"] = None - user_store_data["pending_job"] = None - - callback_manager.terminate_job(pending_job) - - return dict( - user_callback_output=map_grouping(lambda x: no_update, output), - interval_disabled=True, - in_progress=[val for (_, _, val) in running], - progress=clear_progress, - user_store_data=user_store_data, - ) - - # Look up progress value if a job is in progress - if pending_job: - progress_value = callback_manager.get_progress(pending_key) - else: - progress_value = None - - if callback_manager.result_ready(pending_key): - result = callback_manager.get_result(pending_key, pending_job) - # Set current key (hash of data stored in client) - # to pending key (hash of data requested by client) - user_store_data["current_key"] = pending_key - - # Disable interval if this value was pulled from cache. - # If this value was the result of a background calculation, don't - # disable yet. If no other calculations are in progress, - # interval will be disabled in should_cancel logic above - # the next time the interval fires. - interval_disabled = pending_job is None - return dict( - user_callback_output=result, - interval_disabled=interval_disabled, - in_progress=[val for (_, _, val) in running], - progress=clear_progress, - user_store_data=user_store_data, - ) - if progress_value: - return dict( - user_callback_output=map_grouping(lambda x: no_update, output), - interval_disabled=False, - in_progress=[val for (_, val, _) in running], - progress=progress_value or {}, - user_store_data=user_store_data, - ) - - # Check if there is a running calculation that can now - # be canceled - old_pending_key = user_store_data.get("pending_key", None) - if ( - old_pending_key - and old_pending_key != pending_key - and callback_manager.job_running(pending_job) - ): - callback_manager.terminate_job(pending_job) - - user_store_data["pending_key"] = pending_key - callback_manager.terminate_unhealthy_job(pending_job) - if not callback_manager.job_running(pending_job): - user_store_data["pending_job"] = callback_manager.call_job_fn( - pending_key, background_fn, user_callback_args - ) - - return dict( - user_callback_output=map_grouping(lambda x: no_update, output), - interval_disabled=False, - in_progress=[val for (_, val, _) in running], - progress=clear_progress, - user_store_data=user_store_data, - ) - - return self.callback( - inputs=dict( - _triggers=dict( - n_intervals=Input(interval_id, "n_intervals"), - cancel=cancel, - ), - user_store_data=State(store_id, "data"), - user_callback_args=args_deps, - ), - output=dict( - user_callback_output=output, - interval_disabled=Output(interval_id, "disabled"), - in_progress=[dep for (dep, _, _) in running], - progress=progress, - user_store_data=Output(store_id, "data"), - ), - prevent_initial_call=prevent_initial_call, - )(callback) - - return wrapper - def dispatch(self): body = flask.request.get_json() - flask.g.inputs_list = inputs = body.get( # pylint: disable=assigning-non-slot + g = AttributeDict({}) + + g.inputs_list = inputs = body.get( # pylint: disable=assigning-non-slot "inputs", [] ) - flask.g.states_list = state = body.get( # pylint: disable=assigning-non-slot + g.states_list = state = body.get( # pylint: disable=assigning-non-slot "state", [] ) output = body["output"] outputs_list = body.get("outputs") or split_callback_id(output) - flask.g.outputs_list = outputs_list # pylint: disable=assigning-non-slot + g.outputs_list = outputs_list # pylint: disable=assigning-non-slot - flask.g.input_values = ( # pylint: disable=assigning-non-slot + g.input_values = ( # pylint: disable=assigning-non-slot input_values ) = inputs_to_dict(inputs) - flask.g.state_values = inputs_to_dict( # pylint: disable=assigning-non-slot - state - ) + g.state_values = inputs_to_dict(state) # pylint: disable=assigning-non-slot + g.long_callback_manager = self._long_callback_manager # pylint: disable=E0237 changed_props = body.get("changedPropIds", []) - flask.g.triggered_inputs = [ # pylint: disable=assigning-non-slot + g.triggered_inputs = [ # pylint: disable=assigning-non-slot {"prop_id": x, "value": input_values.get(x)} for x in changed_props ] response = ( - flask.g.dash_response # pylint: disable=assigning-non-slot + g.dash_response # pylint: disable=assigning-non-slot ) = flask.Response(mimetype="application/json") args = inputs_to_vals(inputs + state) @@ -1428,19 +1208,19 @@ def dispatch(self): inputs_state = convert_to_AttributeDict(inputs_state) # update args_grouping attributes - for g in inputs_state: + for s in inputs_state: # check for pattern matching: list of inputs or state - if isinstance(g, list): - for pattern_match_g in g: + if isinstance(s, list): + for pattern_match_g in s: update_args_group(pattern_match_g, changed_props) - update_args_group(g, changed_props) + update_args_group(s, changed_props) args_grouping = map_grouping( lambda ind: inputs_state[ind], inputs_state_indices ) - flask.g.args_grouping = args_grouping # pylint: disable=assigning-non-slot - flask.g.using_args_grouping = ( # pylint: disable=assigning-non-slot + g.args_grouping = args_grouping # pylint: disable=assigning-non-slot + g.using_args_grouping = ( # pylint: disable=assigning-non-slot not isinstance(inputs_state_indices, int) and ( inputs_state_indices @@ -1458,10 +1238,8 @@ def dispatch(self): outputs_grouping = map_grouping( lambda ind: flat_outputs[ind], outputs_indices ) - flask.g.outputs_grouping = ( # pylint: disable=assigning-non-slot - outputs_grouping - ) - flask.g.using_outputs_grouping = ( # pylint: disable=assigning-non-slot + g.outputs_grouping = outputs_grouping # pylint: disable=assigning-non-slot + g.using_outputs_grouping = ( # pylint: disable=assigning-non-slot not isinstance(outputs_indices, int) and outputs_indices != list(range(grouping_len(outputs_indices))) ) @@ -1469,7 +1247,19 @@ def dispatch(self): except KeyError as missing_callback_function: msg = f"Callback function not found for output '{output}', perhaps you forgot to prepend the '@'?" raise KeyError(msg) from missing_callback_function - response.set_data(func(*args, outputs_list=outputs_list)) + ctx = copy_context() + # noinspection PyArgumentList + response.set_data( + ctx.run( + functools.partial( + func, + *args, + outputs_list=outputs_list, + long_callback_manager=self._long_callback_manager, + callback_context=g, + ) + ) + ) return response def _setup_server(self): @@ -1508,6 +1298,8 @@ def _setup_server(self): self._callback_list.extend(_callback.GLOBAL_CALLBACK_LIST) _callback.GLOBAL_CALLBACK_LIST.clear() + _validate.validate_long_callbacks(self.callback_map) + def _add_assets_resource(self, url_path, file_path): res = {"asset_path": url_path, "filepath": file_path} if self.config.assets_external_path: diff --git a/dash/exceptions.py b/dash/exceptions.py index fd22dfa050..d2fa911a85 100644 --- a/dash/exceptions.py +++ b/dash/exceptions.py @@ -85,3 +85,11 @@ class ProxyError(DashException): class DuplicateCallback(DashException): pass + + +class LongCallbackError(DashException): + pass + + +class MissingLongCallbackManagerError(DashException): + pass diff --git a/dash/long_callback/managers/__init__.py b/dash/long_callback/managers/__init__.py index b6b05e89a8..b7bf175c4f 100644 --- a/dash/long_callback/managers/__init__.py +++ b/dash/long_callback/managers/__init__.py @@ -4,12 +4,29 @@ class BaseLongCallbackManager(ABC): + UNDEFINED = object() + + # Keep a ref to all the ref to register every callback to every manager. + managers = [] + + # Keep every function for late registering. + functions = [] + def __init__(self, cache_by): if cache_by is not None and not isinstance(cache_by, list): cache_by = [cache_by] self.cache_by = cache_by + BaseLongCallbackManager.managers.append(self) + + self.func_registry = {} + + # Register all funcs that were added before instantiation. + # Ensure all celery task are registered. + for fdetails in self.functions: + self.register(*fdetails) + def terminate_job(self, job): raise NotImplementedError @@ -19,10 +36,10 @@ def terminate_unhealthy_job(self, job): def job_running(self, job): raise NotImplementedError - def make_job_fn(self, fn, progress, args_deps): + def make_job_fn(self, fn, progress): raise NotImplementedError - def call_job_fn(self, key, job_fn, args): + def call_job_fn(self, key, job_fn, args, context): raise NotImplementedError def get_progress(self, key): @@ -58,6 +75,31 @@ def build_cache_key(self, fn, args, cache_args_to_ignore): return hashlib.sha1(str(hash_dict).encode("utf-8")).hexdigest() + def register(self, key, fn, progress): + self.func_registry[key] = self.make_job_fn(fn, progress) + + @staticmethod + def register_func(fn, progress): + key = BaseLongCallbackManager.hash_function(fn) + BaseLongCallbackManager.functions.append( + ( + key, + fn, + progress, + ) + ) + + for manager in BaseLongCallbackManager.managers: + manager.register(key, fn, progress) + + return key + @staticmethod def _make_progress_key(key): return key + "-progress" + + @staticmethod + def hash_function(fn): + fn_source = inspect.getsource(fn) + fn_str = fn_source + return hashlib.sha1(fn_str.encode("utf-8")).hexdigest() diff --git a/dash/long_callback/managers/celery_manager.py b/dash/long_callback/managers/celery_manager.py index 863daa6816..ae6f14903e 100644 --- a/dash/long_callback/managers/celery_manager.py +++ b/dash/long_callback/managers/celery_manager.py @@ -1,8 +1,14 @@ import json import inspect import hashlib +import traceback +from contextvars import copy_context from _plotly_utils.utils import PlotlyJSONEncoder + +from dash._callback_context import context_value +from dash._utils import AttributeDict +from dash.exceptions import PreventUpdate from dash.long_callback.managers import BaseLongCallbackManager @@ -44,9 +50,9 @@ def __init__(self, celery_app, cache_by=None, expire=None): if isinstance(celery_app.backend, DisabledBackend): raise ValueError("Celery instance must be configured with a result backend") - super().__init__(cache_by) self.handle = celery_app self.expire = expire + super().__init__(cache_by) def terminate_job(self, job): if job is None: @@ -70,8 +76,8 @@ def job_running(self, job): "PROGRESS", ) - def make_job_fn(self, fn, progress, args_deps): - return _make_job_fn(fn, self.handle, progress, args_deps) + def make_job_fn(self, fn, progress): + return _make_job_fn(fn, self.handle, progress) def get_task(self, job): if job: @@ -82,8 +88,8 @@ def get_task(self, job): def clear_cache_entry(self, key): self.handle.backend.delete(key) - def call_job_fn(self, key, job_fn, args): - task = job_fn.delay(key, self._make_progress_key(key), args) + def call_job_fn(self, key, job_fn, args, context): + task = job_fn.delay(key, self._make_progress_key(key), args, context) return task.task_id def get_progress(self, key): @@ -101,7 +107,7 @@ def get_result(self, key, job): # Get result value result = self.handle.backend.get(key) if result is None: - return None + return self.UNDEFINED result = json.loads(result) @@ -118,7 +124,7 @@ def get_result(self, key, job): return result -def _make_job_fn(fn, celery_app, progress, args_deps): +def _make_job_fn(fn, celery_app, progress): cache = celery_app.backend # Hash function source and module to create a unique (but stable) celery task name @@ -127,18 +133,51 @@ def _make_job_fn(fn, celery_app, progress, args_deps): fn_hash = hashlib.sha1(fn_str.encode("utf-8")).hexdigest() @celery_app.task(name=f"long_callback_{fn_hash}") - def job_fn(result_key, progress_key, user_callback_args, fn=fn): + def job_fn(result_key, progress_key, user_callback_args, context=None): def _set_progress(progress_value): + if not isinstance(progress_value, (list, tuple)): + progress_value = [progress_value] + cache.set(progress_key, json.dumps(progress_value, cls=PlotlyJSONEncoder)) maybe_progress = [_set_progress] if progress else [] - if isinstance(args_deps, dict): - user_callback_output = fn(*maybe_progress, **user_callback_args) - elif isinstance(args_deps, (list, tuple)): - user_callback_output = fn(*maybe_progress, *user_callback_args) - else: - user_callback_output = fn(*maybe_progress, user_callback_args) - cache.set(result_key, json.dumps(user_callback_output, cls=PlotlyJSONEncoder)) + ctx = copy_context() + + def run(): + context_value.set(AttributeDict(**context)) + try: + if isinstance(user_callback_args, dict): + user_callback_output = fn(*maybe_progress, **user_callback_args) + elif isinstance(user_callback_args, (list, tuple)): + user_callback_output = fn(*maybe_progress, *user_callback_args) + else: + user_callback_output = fn(*maybe_progress, user_callback_args) + except PreventUpdate: + # Put NoUpdate dict directly to avoid circular imports. + cache.set( + result_key, + json.dumps( + {"_dash_no_update": "_dash_no_update"}, cls=PlotlyJSONEncoder + ), + ) + except Exception as err: # pylint: disable=broad-except + cache.set( + result_key, + json.dumps( + { + "long_callback_error": { + "msg": str(err), + "tb": traceback.format_exc(), + } + }, + ), + ) + else: + cache.set( + result_key, json.dumps(user_callback_output, cls=PlotlyJSONEncoder) + ) + + ctx.run(run) return job_fn diff --git a/dash/long_callback/managers/diskcache_manager.py b/dash/long_callback/managers/diskcache_manager.py index ec0c92f981..44979c7d09 100644 --- a/dash/long_callback/managers/diskcache_manager.py +++ b/dash/long_callback/managers/diskcache_manager.py @@ -1,4 +1,7 @@ +import traceback + from . import BaseLongCallbackManager +from ...exceptions import PreventUpdate _pending_value = "__$pending__" @@ -44,8 +47,9 @@ def __init__(self, cache=None, cache_by=None, expire=None): ) self.handle = cache - super().__init__(cache_by) + self.lock = diskcache.Lock(self.handle, "long-callback-lock") self.expire = expire + super().__init__(cache_by) def terminate_job(self, job): import psutil # pylint: disable=import-outside-toplevel,import-error @@ -53,6 +57,8 @@ def terminate_job(self, job): if job is None: return + job = int(job) + # Use diskcache transaction so multiple process don't try to kill the # process at the same time with self.handle.transact(): @@ -78,6 +84,8 @@ def terminate_job(self, job): def terminate_unhealthy_job(self, job): import psutil # pylint: disable=import-outside-toplevel,import-error + job = int(job) + if job and psutil.pid_exists(job): if not self.job_running(job): self.terminate_job(job) @@ -88,18 +96,20 @@ def terminate_unhealthy_job(self, job): def job_running(self, job): import psutil # pylint: disable=import-outside-toplevel,import-error + job = int(job) + if job and psutil.pid_exists(job): proc = psutil.Process(job) return proc.status() != psutil.STATUS_ZOMBIE return False - def make_job_fn(self, fn, progress, args_deps): - return _make_job_fn(fn, self.handle, progress, args_deps) + def make_job_fn(self, fn, progress): + return _make_job_fn(fn, self.handle, progress, self.lock) def clear_cache_entry(self, key): self.handle.delete(key) - def call_job_fn(self, key, job_fn, args): + def call_job_fn(self, key, job_fn, args, context): # pylint: disable-next=import-outside-toplevel,no-name-in-module,import-error from multiprocess import Process @@ -117,9 +127,9 @@ def result_ready(self, key): def get_result(self, key, job): # Get result value - result = self.handle.get(key) - if result is None: - return None + result = self.handle.get(key, self.UNDEFINED) + if result is self.UNDEFINED: + return self.UNDEFINED # Clear result if not caching if self.cache_by is None: @@ -130,22 +140,45 @@ def get_result(self, key, job): self.clear_cache_entry(self._make_progress_key(key)) - self.terminate_job(job) + if job: + self.terminate_job(job) return result -def _make_job_fn(fn, cache, progress, args_deps): +def _make_job_fn(fn, cache, progress, lock): def job_fn(result_key, progress_key, user_callback_args): def _set_progress(progress_value): - cache.set(progress_key, progress_value) + if not isinstance(progress_value, (list, tuple)): + progress_value = [progress_value] + + with lock: + cache.set(progress_key, progress_value) maybe_progress = [_set_progress] if progress else [] - if isinstance(args_deps, dict): - user_callback_output = fn(*maybe_progress, **user_callback_args) - elif isinstance(args_deps, (list, tuple)): - user_callback_output = fn(*maybe_progress, *user_callback_args) + + try: + if isinstance(user_callback_args, dict): + user_callback_output = fn(*maybe_progress, **user_callback_args) + elif isinstance(user_callback_args, (list, tuple)): + user_callback_output = fn(*maybe_progress, *user_callback_args) + else: + user_callback_output = fn(*maybe_progress, user_callback_args) + except PreventUpdate: + with lock: + cache.set(result_key, {"_dash_no_update": "_dash_no_update"}) + except Exception as err: # pylint: disable=broad-except + with lock: + cache.set( + result_key, + { + "long_callback_error": { + "msg": str(err), + "tb": traceback.format_exc(), + } + }, + ) else: - user_callback_output = fn(*maybe_progress, user_callback_args) - cache.set(result_key, user_callback_output) + with lock: + cache.set(result_key, user_callback_output) return job_fn diff --git a/requires-install.txt b/requires-install.txt index 7e0ff3ec4f..ca4774ceff 100644 --- a/requires-install.txt +++ b/requires-install.txt @@ -5,3 +5,4 @@ dash_html_components==2.0.0 dash_core_components==2.0.0 dash_table==5.0.0 importlib-metadata==4.8.3;python_version<"3.7" +contextvars==2.4;python_version<"3.7" diff --git a/requires-testing.txt b/requires-testing.txt index ee5aa609c4..c78142892b 100644 --- a/requires-testing.txt +++ b/requires-testing.txt @@ -5,5 +5,5 @@ lxml>=4.6.2 percy>=2.0.2 pytest>=6.0.2 requests[security]>=2.21.0 -selenium>=3.141.0 +selenium>=3.141.0,<=4.2.0 waitress>=1.4.4 diff --git a/tests/integration/long_callback/app_callback_ctx.py b/tests/integration/long_callback/app_callback_ctx.py new file mode 100644 index 0000000000..0bb2ff1edd --- /dev/null +++ b/tests/integration/long_callback/app_callback_ctx.py @@ -0,0 +1,36 @@ +import json + +from dash import Dash, Input, Output, html, callback, ALL, ctx + +from tests.integration.long_callback.utils import get_long_callback_manager + +long_callback_manager = get_long_callback_manager() +handle = long_callback_manager.handle + +app = Dash(__name__, long_callback_manager=long_callback_manager) + +app.layout = html.Div( + [ + html.Button(id={"type": "run-button", "index": 0}, children="Run 1"), + html.Button(id={"type": "run-button", "index": 1}, children="Run 2"), + html.Button(id={"type": "run-button", "index": 2}, children="Run 3"), + html.Div(id="result", children="No results"), + html.Div(id="running"), + ] +) + + +@callback( + Output("result", "children"), + [Input({"type": "run-button", "index": ALL}, "n_clicks")], + long=True, + prevent_initial_call=True, + long_running=[(Output("running", "children"), "on", "off")], +) +def update_output(n_clicks): + triggered = json.loads(ctx.triggered[0]["prop_id"].split(".")[0]) + return json.dumps(dict(triggered=triggered, value=n_clicks[triggered["index"]])) + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/tests/integration/long_callback/app_error.py b/tests/integration/long_callback/app_error.py new file mode 100644 index 0000000000..c0e47abc3e --- /dev/null +++ b/tests/integration/long_callback/app_error.py @@ -0,0 +1,70 @@ +import os +import time + +import dash +from dash import html, no_update +from dash.dependencies import Input, Output +from dash.exceptions import PreventUpdate + +from tests.integration.long_callback.utils import get_long_callback_manager + +long_callback_manager = get_long_callback_manager() +handle = long_callback_manager.handle + +app = dash.Dash(__name__, long_callback_manager=long_callback_manager) +app.enable_dev_tools(debug=True, dev_tools_ui=True) +app.layout = html.Div( + [ + html.Div([html.P(id="output", children=["Button not clicked"])]), + html.Button(id="button", children="Run Job!"), + html.Div(id="output-status"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div(id="output3"), + html.Button("multi-output", id="multi-output"), + ] +) +app.test_lock = lock = long_callback_manager.test_lock + + +@app.long_callback( + output=Output("output", "children"), + inputs=Input("button", "n_clicks"), + running=[ + (Output("button", "disabled"), True, False), + ], + prevent_initial_call=True, +) +def callback(n_clicks): + if os.getenv("LONG_CALLBACK_MANAGER") != "celery": + # Diskmanager needs some time, celery takes too long. + time.sleep(1) + with lock: + if n_clicks == 2: + raise Exception("bad error") + + if n_clicks == 4: + raise PreventUpdate + return f"Clicked {n_clicks} times" + + +@app.long_callback( + output=[Output("output-status", "children")] + + [Output(f"output{i}", "children") for i in range(1, 4)], + inputs=[Input("multi-output", "n_clicks")], + running=[ + (Output("multi-output", "disabled"), True, False), + ], + prevent_initial_call=True, +) +def long_multi(n_clicks): + with lock: + return ( + [f"Updated: {n_clicks}"] + + [i for i in range(1, n_clicks + 1)] + + [no_update for _ in range(n_clicks + 1, 4)] + ) + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/tests/integration/long_callback/app_pattern_matching.py b/tests/integration/long_callback/app_pattern_matching.py new file mode 100644 index 0000000000..aa176dda32 --- /dev/null +++ b/tests/integration/long_callback/app_pattern_matching.py @@ -0,0 +1,32 @@ +from dash import Dash, Input, Output, html, callback, ALL + +from tests.integration.long_callback.utils import get_long_callback_manager + +long_callback_manager = get_long_callback_manager() +handle = long_callback_manager.handle + +app = Dash(__name__, long_callback_manager=long_callback_manager) + +app.layout = html.Div( + [ + html.Button(id={"type": "run-button", "index": 0}, children="Run 1"), + html.Button(id={"type": "run-button", "index": 1}, children="Run 2"), + html.Button(id={"type": "run-button", "index": 2}, children="Run 3"), + html.Div(id="result", children="No results"), + ] +) + + +@callback( + Output("result", "children"), + [Input({"type": "run-button", "index": ALL}, "n_clicks")], + long=True, + prevent_initial_call=True, +) +def update_output(n_clicks): + found = max(x for x in n_clicks if x is not None) + return f"Clicked '{found}'" + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/tests/integration/long_callback/app_short_interval.py b/tests/integration/long_callback/app_short_interval.py new file mode 100644 index 0000000000..249a65e5a6 --- /dev/null +++ b/tests/integration/long_callback/app_short_interval.py @@ -0,0 +1,39 @@ +from dash import Dash, Input, Output, html, callback +import time + +from tests.integration.long_callback.utils import get_long_callback_manager + +long_callback_manager = get_long_callback_manager() +handle = long_callback_manager.handle + +app = Dash(__name__, long_callback_manager=long_callback_manager) + +app.layout = html.Div( + [ + html.Button(id="run-button", children="Run"), + html.Button(id="cancel-button", children="Cancel"), + html.Div(id="status", children="Finished"), + html.Div(id="result", children="No results"), + ] +) + + +@callback( + Output("result", "children"), + [Input("run-button", "n_clicks")], + long=True, + long_progress=Output("status", "children"), + long_progress_default="Finished", + long_cancel=[Input("cancel-button", "n_clicks")], + long_interval=0, + prevent_initial_call=True, +) +def update_output(set_progress, n_clicks): + for i in range(4): + set_progress(f"Progress {i}/4") + time.sleep(1) + return f"Clicked '{n_clicks}'" + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/tests/integration/long_callback/app_side_update.py b/tests/integration/long_callback/app_side_update.py new file mode 100644 index 0000000000..6d0cf2cd0e --- /dev/null +++ b/tests/integration/long_callback/app_side_update.py @@ -0,0 +1,50 @@ +from dash import Dash, Input, Output, html, callback +import time + +from tests.integration.long_callback.utils import get_long_callback_manager + +long_callback_manager = get_long_callback_manager() +handle = long_callback_manager.handle + +app = Dash(__name__, long_callback_manager=long_callback_manager) + +app.layout = html.Div( + [ + html.Button(id="run-button", children="Run"), + html.Button(id="cancel-button", children="Cancel"), + html.Div(id="status", children="Finished"), + html.Div(id="result", children="No results"), + html.Div(id="side-status"), + ] +) + + +@callback( + Output("result", "children"), + [Input("run-button", "n_clicks")], + long=True, + long_progress=Output("status", "children"), + long_progress_default="Finished", + long_cancel=[Input("cancel-button", "n_clicks")], + long_interval=0, + prevent_initial_call=True, +) +def update_output(set_progress, n_clicks): + print("trigger") + for i in range(4): + set_progress(f"Progress {i}/4") + time.sleep(1) + return f"Clicked '{n_clicks}'" + + +@callback( + Output("side-status", "children"), + [Input("status", "children")], + prevent_initial_call=True, +) +def update_side(progress): + return f"Side {progress}" + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/tests/integration/long_callback/test_basic_long_callback.py b/tests/integration/long_callback/test_basic_long_callback.py index bbd7972934..70cc821019 100644 --- a/tests/integration/long_callback/test_basic_long_callback.py +++ b/tests/integration/long_callback/test_basic_long_callback.py @@ -1,3 +1,4 @@ +import json from multiprocessing import Lock import os from contextlib import contextmanager @@ -21,7 +22,7 @@ def kill(proc_pid): if "REDIS_URL" in os.environ: - managers = ["diskcache", "celery"] + managers = ["celery", "diskcache"] else: print("Skipping celery tests because REDIS_URL is not defined") managers = ["diskcache"] @@ -426,3 +427,100 @@ def test_lcbc007_validation_layout(dash_duo, manager): assert not dash_duo.redux_state_is_loading assert dash_duo.get_logs() == [] + + +def test_lcbc008_long_callbacks_error(dash_duo, manager): + with setup_long_callback_app(manager, "app_error") as app: + dash_duo.start_server( + app, + debug=True, + use_reloader=False, + use_debugger=True, + dev_tools_hot_reload=False, + dev_tools_ui=True, + ) + + clicker = dash_duo.find_element("#button") + + def click_n_wait(): + with app.test_lock: + clicker.click() + dash_duo.wait_for_element("#button:disabled") + dash_duo.wait_for_element("#button:not([disabled])") + + clicker.click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1 times") + + click_n_wait() + dash_duo.wait_for_element(".dash-fe-error__title").click() + + dash_duo.driver.switch_to.frame(dash_duo.find_element("iframe")) + assert ( + "dash.exceptions.LongCallbackError: An error occurred inside a long callback:" + in dash_duo.wait_for_element(".errormsg").text + ) + dash_duo.driver.switch_to.default_content() + + click_n_wait() + dash_duo.wait_for_text_to_equal("#output", "Clicked 3 times") + + click_n_wait() + dash_duo.wait_for_text_to_equal("#output", "Clicked 3 times") + click_n_wait() + dash_duo.wait_for_text_to_equal("#output", "Clicked 5 times") + + def make_expect(n): + return [str(x) for x in range(1, n + 1)] + ["" for _ in range(n + 1, 4)] + + multi = dash_duo.wait_for_element("#multi-output") + + for i in range(1, 4): + with app.test_lock: + multi.click() + dash_duo.wait_for_element("#multi-output:disabled") + expect = make_expect(i) + dash_duo.wait_for_text_to_equal("#output-status", f"Updated: {i}") + for j, e in enumerate(expect): + assert dash_duo.find_element(f"#output{j + 1}").text == e + + +def test_lcbc009_short_interval(dash_duo, manager): + with setup_long_callback_app(manager, "app_short_interval") as app: + dash_duo.start_server(app) + dash_duo.find_element("#run-button").click() + dash_duo.wait_for_text_to_equal("#status", "Progress 2/4", 20) + dash_duo.wait_for_text_to_equal("#status", "Finished", 12) + dash_duo.wait_for_text_to_equal("#result", "Clicked '1'") + + time.sleep(2) + # Ensure the progress is still not running + assert dash_duo.find_element("#status").text == "Finished" + + +def test_lcbc010_side_updates(dash_duo, manager): + with setup_long_callback_app(manager, "app_side_update") as app: + dash_duo.start_server(app) + dash_duo.find_element("#run-button").click() + for i in range(1, 4): + dash_duo.wait_for_text_to_equal("#side-status", f"Side Progress {i}/4") + + +def test_lcbc011_long_pattern_matching(dash_duo, manager): + with setup_long_callback_app(manager, "app_pattern_matching") as app: + dash_duo.start_server(app) + for i in range(1, 4): + for _ in range(i): + dash_duo.find_element(f"button:nth-child({i})").click() + + dash_duo.wait_for_text_to_equal("#result", f"Clicked '{i}'") + + +def test_lcbc012_long_callback_ctx(dash_duo, manager): + with setup_long_callback_app(manager, "app_callback_ctx") as app: + dash_duo.start_server(app) + dash_duo.find_element("button:nth-child(1)").click() + dash_duo.wait_for_text_to_equal("#running", "off") + + output = json.loads(dash_duo.find_element("#result").text) + + assert output["triggered"]["index"] == 0 diff --git a/tests/integration/long_callback/utils.py b/tests/integration/long_callback/utils.py index c6882df1fc..262471d0f8 100644 --- a/tests/integration/long_callback/utils.py +++ b/tests/integration/long_callback/utils.py @@ -8,6 +8,7 @@ def get_long_callback_manager(): if os.environ.get("LONG_CALLBACK_MANAGER", None) == "celery": from dash.long_callback import CeleryLongCallbackManager from celery import Celery + import redis celery_app = Celery( __name__, @@ -15,12 +16,15 @@ def get_long_callback_manager(): backend=os.environ.get("CELERY_BACKEND"), ) long_callback_manager = CeleryLongCallbackManager(celery_app) + redis_conn = redis.Redis(host="localhost", port=6379, db=1) + long_callback_manager.test_lock = redis_conn.lock("test-lock") elif os.environ.get("LONG_CALLBACK_MANAGER", None) == "diskcache": from dash.long_callback import DiskcacheLongCallbackManager import diskcache cache = diskcache.Cache(os.environ.get("DISKCACHE_DIR")) long_callback_manager = DiskcacheLongCallbackManager(cache) + long_callback_manager.test_lock = diskcache.Lock(cache, "test-lock") else: raise ValueError( "Invalid long callback manager specified as LONG_CALLBACK_MANAGER " diff --git a/tests/unit/dash/long_callback_validation.py b/tests/unit/dash/long_callback_validation.py deleted file mode 100644 index 7d4542bedc..0000000000 --- a/tests/unit/dash/long_callback_validation.py +++ /dev/null @@ -1,46 +0,0 @@ -import pytest -import mock - -import dash -from dash.exceptions import WildcardInLongCallback -from dash.dependencies import Input, Output, State, ALL, MATCH, ALLSMALLER - - -def test_wildcard_ids_no_allowed_in_long_callback(): - """ - @app.long_callback doesn't support wildcard dependencies yet. This test can - be removed if wildcard support is added to @app.long_callback in the future. - """ - app = dash.Dash(long_callback_manager=mock.Mock()) - - # ALL - with pytest.raises(WildcardInLongCallback): - - @app.long_callback( - Output("output", "children"), - Input({"type": "filter", "index": ALL}, "value"), - ) - def callback(*args, **kwargs): - pass - - # MATCH - with pytest.raises(WildcardInLongCallback): - - @app.long_callback( - Output({"type": "dynamic-output", "index": MATCH}, "children"), - Input({"type": "dynamic-dropdown", "index": MATCH}, "value"), - State({"type": "dynamic-dropdown", "index": MATCH}, "id"), - ) - def callback(*args, **kwargs): - pass - - # ALLSMALLER - with pytest.raises(WildcardInLongCallback): - - @app.long_callback( - Output({"type": "output-ex3", "index": MATCH}, "children"), - Input({"type": "filter-dropdown-ex3", "index": MATCH}, "value"), - Input({"type": "filter-dropdown-ex3", "index": ALLSMALLER}, "value"), - ) - def callback(*args, **kwargs): - pass diff --git a/tests/unit/dash/test_long_callback_validation.py b/tests/unit/dash/test_long_callback_validation.py new file mode 100644 index 0000000000..db0561233c --- /dev/null +++ b/tests/unit/dash/test_long_callback_validation.py @@ -0,0 +1,26 @@ +import pytest + +from dash.exceptions import LongCallbackError +from dash.dependencies import Input, Output +from dash._validate import validate_long_callbacks + + +def test_circular_long_callback_progress(): + callback_map = { + "side": { + "output": [Output("side-progress", "children")], + "raw_inputs": [Input("progress", "children")], + }, + "long": { + "output": [Output("result", "children")], + "raw_inputs": [ + Input("click", "n_clicks"), + Input("side-progress", "children"), + ], + "long": {"progress": [Output("progress", "children")]}, + }, + } + + with pytest.raises(LongCallbackError): + + validate_long_callbacks(callback_map)