Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit fd9ee13

Browse files
committedJun 13, 2022
Replace long callback interval with request polling handled in renderer.
1 parent 3a207ce commit fd9ee13

File tree

12 files changed

+571
-319
lines changed

12 files changed

+571
-319
lines changed
 

Diff for: ‎dash/_callback.py

+194-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import collections
22
from functools import wraps
33

4+
import flask
5+
46
from .dependencies import (
57
handle_callback_args,
68
handle_grouped_callback_args,
79
Output,
810
)
9-
from .exceptions import PreventUpdate
11+
from .exceptions import PreventUpdate, WildcardInLongCallback, DuplicateCallback
1012

1113
from ._grouping import (
1214
flatten_grouping,
@@ -17,9 +19,11 @@
1719
create_callback_id,
1820
stringify_id,
1921
to_json,
22+
coerce_to_list,
2023
)
2124

2225
from . import _validate
26+
from .long_callback.managers import BaseLongCallbackManager
2327

2428

2529
class NoUpdate:
@@ -30,15 +34,28 @@ def to_plotly_json(self): # pylint: disable=no-self-use
3034

3135
@staticmethod
3236
def is_no_update(obj):
33-
return obj == {"_dash_no_update": "_dash_no_update"}
37+
return isinstance(obj, NoUpdate) or obj == {
38+
"_dash_no_update": "_dash_no_update"
39+
}
3440

3541

3642
GLOBAL_CALLBACK_LIST = []
3743
GLOBAL_CALLBACK_MAP = {}
3844
GLOBAL_INLINE_SCRIPTS = []
3945

4046

41-
def callback(*_args, **_kwargs):
47+
def callback(
48+
*_args,
49+
long=False,
50+
long_interval=1000,
51+
long_progress=None,
52+
long_progress_default=None,
53+
long_running=None,
54+
long_cancel=None,
55+
long_manager=None,
56+
long_cache_args_to_ignore=None,
57+
**_kwargs,
58+
):
4259
"""
4360
Normally used as a decorator, `@dash.callback` provides a server-side
4461
callback relating the values of one or more `Output` items to one or
@@ -56,15 +73,79 @@ def callback(*_args, **_kwargs):
5673
not to fire when its outputs are first added to the page. Defaults to
5774
`False` and unlike `app.callback` is not configurable at the app level.
5875
"""
76+
77+
long_spec = None
78+
79+
if long:
80+
long_spec = {
81+
"interval": long_interval,
82+
}
83+
84+
if long_manager:
85+
long_spec["manager"] = long_manager
86+
87+
if long_progress:
88+
long_spec["progress"] = coerce_to_list(long_progress)
89+
validate_long_inputs(long_spec["progress"])
90+
91+
if long_progress_default:
92+
long_spec["progressDefault"] = coerce_to_list(long_progress_default)
93+
94+
if not len(long_spec["progress"]) == len(long_spec["progressDefault"]):
95+
raise Exception(
96+
"Progress and progress default needs to be of same length"
97+
)
98+
99+
if long_running:
100+
long_spec["running"] = coerce_to_list(long_running)
101+
validate_long_inputs(x[0] for x in long_spec["running"])
102+
103+
if long_cancel:
104+
cancel_inputs = coerce_to_list(long_cancel)
105+
validate_long_inputs(cancel_inputs)
106+
107+
cancels_output = [Output(c.component_id, "id") for c in cancel_inputs]
108+
109+
try:
110+
111+
@callback(cancels_output, cancel_inputs, prevent_initial_call=True)
112+
def cancel_call(*_):
113+
job_ids = flask.request.args.getlist("cancelJob")
114+
manager = long_manager or flask.g.long_callback_manager
115+
if job_ids:
116+
for job_id in job_ids:
117+
manager.terminate_job(int(job_id))
118+
return NoUpdate()
119+
120+
except DuplicateCallback:
121+
pass # Already a callback to cancel, will get the proper jobs from the store.
122+
123+
long_spec["cancel"] = [c.to_dict() for c in cancel_inputs]
124+
125+
if long_cache_args_to_ignore:
126+
long_spec["cache_args_to_ignore"] = long_cache_args_to_ignore
127+
59128
return register_callback(
60129
GLOBAL_CALLBACK_LIST,
61130
GLOBAL_CALLBACK_MAP,
62131
False,
63132
*_args,
64133
**_kwargs,
134+
long=long_spec,
65135
)
66136

67137

138+
def validate_long_inputs(deps):
139+
for dep in deps:
140+
if dep.has_wildcard():
141+
raise WildcardInLongCallback(
142+
f"""
143+
long callbacks does not support dependencies with
144+
pattern-matching ids
145+
Received: {repr(dep)}\n"""
146+
)
147+
148+
68149
def clientside_callback(clientside_function, *args, **kwargs):
69150
return register_clientside_callback(
70151
GLOBAL_CALLBACK_LIST,
@@ -87,6 +168,7 @@ def insert_callback(
87168
state,
88169
inputs_state_indices,
89170
prevent_initial_call,
171+
long=None,
90172
):
91173
if prevent_initial_call is None:
92174
prevent_initial_call = config_prevent_initial_callbacks
@@ -98,19 +180,26 @@ def insert_callback(
98180
"state": [c.to_dict() for c in state],
99181
"clientside_function": None,
100182
"prevent_initial_call": prevent_initial_call,
183+
"long": long
184+
and {
185+
"interval": long["interval"],
186+
},
101187
}
188+
102189
callback_map[callback_id] = {
103190
"inputs": callback_spec["inputs"],
104191
"state": callback_spec["state"],
105192
"outputs_indices": outputs_indices,
106193
"inputs_state_indices": inputs_state_indices,
194+
"long": long,
107195
}
108196
callback_list.append(callback_spec)
109197

110198
return callback_id
111199

112200

113-
def register_callback(
201+
# pylint: disable=R0912
202+
def register_callback( # pylint: disable=R0914
114203
callback_list, callback_map, config_prevent_initial_callbacks, *_args, **_kwargs
115204
):
116205
(
@@ -129,6 +218,8 @@ def register_callback(
129218
insert_output = flatten_grouping(output)
130219
multi = True
131220

221+
long = _kwargs.get("long")
222+
132223
output_indices = make_grouping_by_index(output, list(range(grouping_len(output))))
133224
callback_id = insert_callback(
134225
callback_list,
@@ -140,23 +231,118 @@ def register_callback(
140231
flat_state,
141232
inputs_state_indices,
142233
prevent_initial_call,
234+
long=long,
143235
)
144236

145237
# pylint: disable=too-many-locals
146238
def wrap_func(func):
239+
240+
if long is not None:
241+
long_key = BaseLongCallbackManager.register_func(
242+
func, long.get("progress") is not None
243+
)
244+
147245
@wraps(func)
148246
def add_context(*args, **kwargs):
149247
output_spec = kwargs.pop("outputs_list")
248+
callback_manager = long.get(
249+
"manager", kwargs.pop("long_callback_manager", None)
250+
)
150251
_validate.validate_output_spec(insert_output, output_spec, Output)
151252

152253
func_args, func_kwargs = _validate.validate_and_group_input_args(
153254
args, inputs_state_indices
154255
)
155256

156-
# don't touch the comment on the next line - used by debugger
157-
output_value = func(*func_args, **func_kwargs) # %% callback invoked %%
257+
response = {"multi": True}
258+
259+
if long is not None:
260+
progress_outputs = long.get("progress")
261+
cache_key = flask.request.args.get("cacheKey")
262+
job_id = flask.request.args.get("job")
263+
264+
current_key = callback_manager.build_cache_key(
265+
func,
266+
# Inputs provided as dict is kwargs.
267+
func_args if func_args else func_kwargs,
268+
long.get("cache_args_to_ignore", []),
269+
)
270+
271+
if not cache_key:
272+
cache_key = current_key
273+
274+
job_fn = callback_manager.func_registry.get(long_key)
275+
276+
job = callback_manager.call_job_fn(
277+
cache_key,
278+
job_fn,
279+
args,
280+
)
281+
282+
data = {
283+
"cacheKey": cache_key,
284+
"job": job,
285+
}
286+
287+
running = long.get("running")
288+
289+
if running:
290+
data["running"] = {str(r[0]): r[1] for r in running}
291+
data["runningOff"] = {str(r[0]): r[2] for r in running}
292+
cancel = long.get("cancel")
293+
if cancel:
294+
data["cancel"] = cancel
295+
296+
progress_default = long.get("progressDefault")
297+
if progress_default:
298+
data["progressDefault"] = {
299+
str(o): x
300+
for o, x in zip(progress_outputs, progress_default)
301+
}
302+
return to_json(data)
303+
else:
304+
if progress_outputs:
305+
# Get the progress before the result as it would be erased after the results.
306+
progress = callback_manager.get_progress(cache_key)
307+
if progress:
308+
response["progress"] = {
309+
str(x): progress[i]
310+
for i, x in enumerate(progress_outputs)
311+
}
312+
313+
output_value = callback_manager.get_result(cache_key, job_id)
314+
# Must get job_running after get_result since get_results terminates it.
315+
job_running = callback_manager.job_running(job_id)
316+
if not job_running and output_value is callback_manager.UNDEFINED:
317+
# Job canceled -> no output to close the loop.
318+
output_value = NoUpdate()
319+
320+
elif (
321+
isinstance(output_value, dict)
322+
and "long_callback_error" in output_value
323+
):
324+
error = output_value.get("long_callback_error")
325+
raise Exception(
326+
f"An error occurred inside a long callback: {error['msg']}\n{error['tb']}"
327+
)
328+
329+
if job_running and output_value is not callback_manager.UNDEFINED:
330+
# cached results.
331+
callback_manager.terminate_job(job_id)
332+
333+
if multi and isinstance(output_value, (list, tuple)):
334+
output_value = [
335+
NoUpdate() if NoUpdate.is_no_update(r) else r
336+
for r in output_value
337+
]
338+
339+
if output_value is callback_manager.UNDEFINED:
340+
return to_json(response)
341+
else:
342+
# don't touch the comment on the next line - used by debugger
343+
output_value = func(*func_args, **func_kwargs) # %% callback invoked %%
158344

159-
if isinstance(output_value, NoUpdate):
345+
if NoUpdate.is_no_update(output_value):
160346
raise PreventUpdate
161347

162348
if not multi:
@@ -191,7 +377,7 @@ def add_context(*args, **kwargs):
191377
if not has_update:
192378
raise PreventUpdate
193379

194-
response = {"response": component_ids, "multi": True}
380+
response["response"] = component_ids
195381

196382
try:
197383
jsonResponse = to_json(response)

Diff for: ‎dash/_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,9 @@ def gen_salt(chars):
217217
return "".join(
218218
secrets.choice(string.ascii_letters + string.digits) for _ in range(chars)
219219
)
220+
221+
222+
def coerce_to_list(obj):
223+
if not isinstance(obj, (list, tuple)):
224+
return [obj]
225+
return obj

Diff for: ‎dash/dash-renderer/src/actions/callbacks.ts

+179-41
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import {
22
concat,
33
flatten,
4+
intersection,
45
keys,
56
map,
67
mergeDeepRight,
78
path,
89
pick,
910
pluck,
11+
values,
12+
toPairs,
1013
zip
1114
} from 'ramda';
1215

@@ -24,13 +27,18 @@ import {
2427
ICallbackPayload,
2528
IStoredCallback,
2629
IBlockedCallback,
27-
IPrioritizedCallback
30+
IPrioritizedCallback,
31+
LongCallbackInfo,
32+
CallbackResponse,
33+
CallbackResponseData
2834
} from '../types/callbacks';
2935
import {isMultiValued, stringifyId, isMultiOutputProp} from './dependencies';
3036
import {urlBase} from './utils';
3137
import {getCSRFHeader} from '.';
3238
import {createAction, Action} from 'redux-actions';
3339
import {addHttpHeaders} from '../actions';
40+
import {updateProps} from './index';
41+
import {CallbackJobPayload} from '../reducers/callbackJobs';
3442

3543
export const addBlockedCallbacks = createAction<IBlockedCallback[]>(
3644
CallbackActionType.AddBlocked
@@ -83,6 +91,9 @@ export const aggregateCallbacks = createAction<
8391

8492
const updateResourceUsage = createAction('UPDATE_RESOURCE_USAGE');
8593

94+
const addCallbackJob = createAction('ADD_CALLBACK_JOB');
95+
const removeCallbackJob = createAction('REMOVE_CALLBACK_JOB');
96+
8697
function unwrapIfNotMulti(
8798
paths: any,
8899
idProps: any,
@@ -300,28 +311,73 @@ async function handleClientside(
300311
return result;
301312
}
302313

314+
function sideUpdate(outputs: any, dispatch: any, paths: any) {
315+
toPairs(outputs).forEach(([id, value]) => {
316+
const [componentId, propName] = id.split('.');
317+
const componentPath = paths.strs[componentId];
318+
dispatch(
319+
updateProps({
320+
props: {[propName]: value},
321+
itempath: componentPath
322+
})
323+
);
324+
});
325+
}
326+
303327
function handleServerside(
304328
dispatch: any,
305329
hooks: any,
306330
config: any,
307-
payload: any
308-
): Promise<any> {
331+
payload: any,
332+
paths: any,
333+
long?: LongCallbackInfo,
334+
additionalArgs?: [string, string][]
335+
): Promise<CallbackResponse> {
309336
if (hooks.request_pre) {
310337
hooks.request_pre(payload);
311338
}
312339

313340
const requestTime = Date.now();
314341
const body = JSON.stringify(payload);
342+
let cacheKey: string;
343+
let job: string;
344+
let runningOff: any;
345+
let progressDefault: any;
346+
347+
const fetchCallback = () => {
348+
const headers = getCSRFHeader() as any;
349+
let url = `${urlBase(config)}_dash-update-component`;
350+
351+
const addArg = (name: string, value: string) => {
352+
let delim = '?';
353+
if (url.includes('?')) {
354+
delim = '&';
355+
}
356+
url = `${url}${delim}${name}=${value}`;
357+
};
358+
if (cacheKey) {
359+
addArg('cacheKey', cacheKey);
360+
}
361+
if (job) {
362+
addArg('job', job);
363+
}
364+
365+
if (additionalArgs) {
366+
additionalArgs.forEach(([key, value]) => addArg(key, value));
367+
}
368+
369+
return fetch(
370+
url,
371+
mergeDeepRight(config.fetch, {
372+
method: 'POST',
373+
headers,
374+
body
375+
})
376+
);
377+
};
315378

316-
return fetch(
317-
`${urlBase(config)}_dash-update-component`,
318-
mergeDeepRight(config.fetch, {
319-
method: 'POST',
320-
headers: getCSRFHeader() as any,
321-
body
322-
})
323-
).then(
324-
(res: any) => {
379+
return new Promise((resolve, reject) => {
380+
const handleOutput = (res: any) => {
325381
const {status} = res;
326382

327383
function recordProfile(result: any) {
@@ -361,36 +417,86 @@ function handleServerside(
361417
}
362418
}
363419

420+
const finishLine = (data: CallbackResponseData) => {
421+
const {multi, response} = data;
422+
if (hooks.request_post) {
423+
hooks.request_post(payload, response);
424+
}
425+
426+
let result;
427+
if (multi) {
428+
result = response as CallbackResponse;
429+
} else {
430+
const {output} = payload;
431+
const id = output.substr(0, output.lastIndexOf('.'));
432+
result = {[id]: (response as CallbackResponse).props};
433+
}
434+
435+
recordProfile(result);
436+
resolve(result);
437+
};
438+
439+
const completeJob = () => {
440+
if (job) {
441+
dispatch(removeCallbackJob({jobId: job}));
442+
}
443+
if (runningOff) {
444+
sideUpdate(runningOff, dispatch, paths);
445+
}
446+
if (progressDefault) {
447+
sideUpdate(progressDefault, dispatch, paths);
448+
}
449+
};
450+
364451
if (status === STATUS.OK) {
365-
return res.json().then((data: any) => {
366-
const {multi, response} = data;
367-
if (hooks.request_post) {
368-
hooks.request_post(payload, response);
452+
res.json().then((data: CallbackResponseData) => {
453+
if (!cacheKey && data.cacheKey) {
454+
cacheKey = data.cacheKey;
369455
}
370456

371-
let result;
372-
if (multi) {
373-
result = response;
374-
} else {
375-
const {output} = payload;
376-
const id = output.substr(0, output.lastIndexOf('.'));
377-
result = {[id]: response.props};
457+
if (!job && data.job) {
458+
const jobInfo: CallbackJobPayload = {
459+
jobId: data.job,
460+
cacheKey: data.cacheKey as string,
461+
cancelInputs: data.cancel,
462+
progressDefault: data.progressDefault
463+
};
464+
dispatch(addCallbackJob(jobInfo));
465+
job = data.job;
378466
}
379467

380-
recordProfile(result);
381-
return result;
468+
if (data.progress) {
469+
sideUpdate(data.progress, dispatch, paths);
470+
}
471+
if (data.running) {
472+
sideUpdate(data.running, dispatch, paths);
473+
}
474+
if (!runningOff && data.runningOff) {
475+
runningOff = data.runningOff;
476+
}
477+
if (!progressDefault && data.progressDefault) {
478+
progressDefault = data.progressDefault;
479+
}
480+
481+
if (!long || data.response !== undefined) {
482+
completeJob();
483+
finishLine(data);
484+
} else {
485+
// Poll chain.
486+
setTimeout(handle, long.interval || 500);
487+
}
382488
});
383-
}
384-
if (status === STATUS.PREVENT_UPDATE) {
489+
} else if (status === STATUS.PREVENT_UPDATE) {
490+
completeJob();
385491
recordProfile({});
386-
return {};
492+
resolve({});
493+
} else {
494+
completeJob();
495+
reject(res);
387496
}
388-
throw res;
389-
},
390-
() => {
391-
// fetch rejection - this means the request didn't return,
392-
// we don't get here from 400/500 errors, only network
393-
// errors or unresponsive servers.
497+
};
498+
499+
const handleError = () => {
394500
if (config.ui) {
395501
dispatch(
396502
updateResourceUsage({
@@ -402,9 +508,14 @@ function handleServerside(
402508
})
403509
);
404510
}
405-
throw new Error('Callback failed: the server did not respond.');
406-
}
407-
);
511+
reject(new Error('Callback failed: the server did not respond.'));
512+
};
513+
514+
const handle = () => {
515+
fetchCallback().then(handleOutput, handleError);
516+
};
517+
handle();
518+
});
408519
}
409520

410521
function inputsToDict(inputs_list: any) {
@@ -443,10 +554,10 @@ export function executeCallback(
443554
paths: any,
444555
layout: any,
445556
{allOutputs}: any,
446-
dispatch: any
557+
dispatch: any,
558+
getState: any
447559
): IExecutingCallback {
448-
const {output, inputs, state, clientside_function} = cb.callback;
449-
560+
const {output, inputs, state, clientside_function, long} = cb.callback;
450561
try {
451562
const inVals = fillVals(paths, layout, cb, inputs, 'Input', true);
452563

@@ -518,13 +629,40 @@ export function executeCallback(
518629
let newHeaders: Record<string, string> | null = null;
519630
let lastError: any;
520631

632+
const additionalArgs: [string, string][] = [];
633+
values(getState().callbackJobs).forEach(
634+
(job: CallbackJobPayload) => {
635+
if (!job.cancelInputs) {
636+
return;
637+
}
638+
const inter = intersection(
639+
job.cancelInputs,
640+
cb.callback.inputs
641+
);
642+
if (inter.length) {
643+
additionalArgs.push(['cancelJob', job.jobId]);
644+
if (job.progressDefault) {
645+
console.log(job.progressDefault);
646+
sideUpdate(
647+
job.progressDefault,
648+
dispatch,
649+
paths
650+
);
651+
}
652+
}
653+
}
654+
);
655+
521656
for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) {
522657
try {
523658
const data = await handleServerside(
524659
dispatch,
525660
hooks,
526661
newConfig,
527-
payload
662+
payload,
663+
paths,
664+
long,
665+
additionalArgs.length ? additionalArgs : undefined
528666
);
529667

530668
if (newHeaders) {

Diff for: ‎dash/dash-renderer/src/observers/prioritizedCallbacks.ts

+4-2
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ const observer: IStoreObserverDefinition<IStoreState> = {
110110
paths,
111111
layout,
112112
getStash(cb, paths),
113-
dispatch
113+
dispatch,
114+
getState
114115
),
115116
pickedSyncCallbacks
116117
)
@@ -162,7 +163,8 @@ const observer: IStoreObserverDefinition<IStoreState> = {
162163
paths,
163164
layout,
164165
cb,
165-
dispatch
166+
dispatch,
167+
getState
166168
);
167169

168170
dispatch(

Diff for: ‎dash/dash-renderer/src/reducers/callbackJobs.ts

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import {assoc, dissoc} from 'ramda';
2+
import {ICallbackProperty} from '../types/callbacks';
3+
4+
type CallbackJobState = {[k: string]: CallbackJobPayload};
5+
6+
export type CallbackJobPayload = {
7+
cancelInputs?: ICallbackProperty[];
8+
cacheKey: string;
9+
jobId: string;
10+
progressDefault?: any;
11+
};
12+
13+
type CallbackJobAction = {
14+
type: 'ADD_CALLBACK_JOB' | 'REMOVE_CALLBACK_JOB';
15+
payload: CallbackJobPayload;
16+
};
17+
18+
const setJob = (job: CallbackJobPayload, state: CallbackJobState) =>
19+
assoc(job.jobId, job, state);
20+
const removeJob = (jobId: string, state: CallbackJobState) =>
21+
dissoc(jobId, state);
22+
23+
export default function (
24+
state: CallbackJobState = {},
25+
action: CallbackJobAction
26+
) {
27+
switch (action.type) {
28+
case 'ADD_CALLBACK_JOB':
29+
return setJob(action.payload, state);
30+
case 'REMOVE_CALLBACK_JOB':
31+
return removeJob(action.payload.jobId, state);
32+
default:
33+
return state;
34+
}
35+
}

Diff for: ‎dash/dash-renderer/src/reducers/reducer.js

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import isLoading from './isLoading';
1717
import layout from './layout';
1818
import loadingMap from './loadingMap';
1919
import paths from './paths';
20+
import callbackJobs from './callbackJobs';
2021

2122
export const apiRequests = [
2223
'dependenciesRequest',
@@ -45,6 +46,8 @@ function mainReducer() {
4546
parts[r] = createApiReducer(r);
4647
}, apiRequests);
4748

49+
parts.callbackJobs = callbackJobs;
50+
4851
return combineReducers(parts);
4952
}
5053

Diff for: ‎dash/dash-renderer/src/types/callbacks.ts

+24-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ export interface ICallbackDefinition {
1111
outputs: ICallbackProperty[];
1212
prevent_initial_call: boolean;
1313
state: ICallbackProperty[];
14+
long?: LongCallbackInfo;
1415
}
1516

1617
export interface ICallbackProperty {
@@ -75,7 +76,29 @@ export interface ICallbackPayload {
7576
}
7677

7778
export type CallbackResult = {
78-
data?: any;
79+
data?: CallbackResponse;
7980
error?: Error;
8081
payload: ICallbackPayload | null;
8182
};
83+
84+
export type LongCallbackInfo = {
85+
interval?: number;
86+
progress?: any;
87+
running?: any;
88+
};
89+
90+
export type CallbackResponse = {
91+
[k: string]: any;
92+
};
93+
94+
export type CallbackResponseData = {
95+
response?: CallbackResponse;
96+
multi?: boolean;
97+
cacheKey?: string;
98+
job?: string;
99+
progressDefault?: CallbackResponse;
100+
progress?: CallbackResponse;
101+
running?: CallbackResponse;
102+
runningOff?: CallbackResponse;
103+
cancel?: ICallbackProperty[];
104+
};

Diff for: ‎dash/dash.py

+38-243
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@
2727
from .fingerprint import build_fingerprint, check_fingerprint
2828
from .resources import Scripts, Css
2929
from .dependencies import (
30-
handle_grouped_callback_args,
3130
Output,
32-
State,
3331
Input,
3432
)
3533
from .development.base_component import ComponentRegistry
@@ -61,7 +59,7 @@
6159
from . import _watch
6260
from . import _get_app
6361

64-
from ._grouping import flatten_grouping, map_grouping, grouping_len, update_args_group
62+
from ._grouping import map_grouping, grouping_len, update_args_group
6563

6664
from . import _pages
6765
from ._pages import (
@@ -1118,6 +1116,7 @@ def clientside_callback(self, clientside_function, *args, **kwargs):
11181116
**kwargs,
11191117
)
11201118

1119+
# pylint: disable=R0201
11211120
def callback(self, *_args, **_kwargs):
11221121
"""
11231122
Normally used as a decorator, `@app.callback` provides a server-side
@@ -1132,16 +1131,27 @@ def callback(self, *_args, **_kwargs):
11321131
11331132
11341133
"""
1135-
return _callback.register_callback(
1136-
self._callback_list,
1137-
self.callback_map,
1138-
self.config.prevent_initial_callbacks,
1134+
return _callback.callback(
11391135
*_args,
11401136
**_kwargs,
11411137
)
11421138

1143-
def long_callback(self, *_args, **_kwargs): # pylint: disable=too-many-statements
1139+
# pylint: disable=R0201
1140+
def long_callback(
1141+
self,
1142+
*_args,
1143+
manager=None,
1144+
interval=None,
1145+
running=None,
1146+
cancel=None,
1147+
progress=None,
1148+
progress_default=None,
1149+
cache_args_to_ignore=None,
1150+
**_kwargs,
1151+
): # pylint: disable=too-many-statements
11441152
"""
1153+
Deprecation Notice, use long
1154+
11451155
Normally used as a decorator, `@app.long_callback` is an alternative to
11461156
`@app.callback` designed for callbacks that take a long time to run,
11471157
without locking up the Dash app or timing out.
@@ -1200,242 +1210,19 @@ def long_callback(self, *_args, **_kwargs): # pylint: disable=too-many-statemen
12001210
this should be a list of argument names as strings. Otherwise,
12011211
this should be a list of argument indices as integers.
12021212
"""
1203-
# pylint: disable-next=import-outside-toplevel
1204-
from dash._callback_context import callback_context
1205-
1206-
# pylint: disable-next=import-outside-toplevel
1207-
from dash.exceptions import WildcardInLongCallback
1208-
1209-
# Get long callback manager
1210-
callback_manager = _kwargs.pop("manager", self._long_callback_manager)
1211-
if callback_manager is None:
1212-
raise ValueError(
1213-
"The @app.long_callback decorator requires a long callback manager\n"
1214-
"instance. This may be provided to the app using the \n"
1215-
"long_callback_manager argument to the dash.Dash constructor, or\n"
1216-
"it may be provided to the @app.long_callback decorator as the \n"
1217-
"manager argument"
1218-
)
1219-
1220-
# Extract special long_callback kwargs
1221-
running = _kwargs.pop("running", ())
1222-
cancel = _kwargs.pop("cancel", ())
1223-
progress = _kwargs.pop("progress", ())
1224-
progress_default = _kwargs.pop("progress_default", None)
1225-
interval_time = _kwargs.pop("interval", 1000)
1226-
cache_args_to_ignore = _kwargs.pop("cache_args_to_ignore", [])
1227-
1228-
# Parse remaining args just like app.callback
1229-
(
1230-
output,
1231-
flat_inputs,
1232-
flat_state,
1233-
inputs_state_indices,
1234-
prevent_initial_call,
1235-
) = handle_grouped_callback_args(_args, _kwargs)
1236-
inputs_and_state = flat_inputs + flat_state
1237-
args_deps = map_grouping(lambda i: inputs_and_state[i], inputs_state_indices)
1238-
multi_output = isinstance(output, (list, tuple)) and len(output) > 1
1239-
1240-
# Disallow wildcard dependencies
1241-
for deps in [output, flat_inputs, flat_state]:
1242-
for dep in flatten_grouping(deps):
1243-
if dep.has_wildcard():
1244-
raise WildcardInLongCallback(
1245-
f"""
1246-
@app.long_callback does not support dependencies with
1247-
pattern-matching ids
1248-
Received: {repr(dep)}\n"""
1249-
)
1250-
1251-
# Get unique id for this long_callback definition. This increment is not
1252-
# thread safe, but it doesn't need to be because callback definitions
1253-
# happen on the main thread before the app starts
1254-
self._long_callback_count += 1
1255-
long_callback_id = self._long_callback_count
1256-
1257-
# Create Interval and Store for long callback and add them to the app's
1258-
# _extra_components list
1259-
interval_id = f"_long_callback_interval_{long_callback_id}"
1260-
interval_component = dcc.Interval(
1261-
id=interval_id, interval=interval_time, disabled=prevent_initial_call
1262-
)
1263-
store_id = f"_long_callback_store_{long_callback_id}"
1264-
store_component = dcc.Store(id=store_id, data={})
1265-
error_id = f"_long_callback_error_{long_callback_id}"
1266-
error_store_component = dcc.Store(id=error_id)
1267-
error_dummy = f"_long_callback_error_dummy_{long_callback_id}"
1268-
self._extra_components.extend(
1269-
[
1270-
interval_component,
1271-
store_component,
1272-
error_store_component,
1273-
dcc.Store(id=error_dummy),
1274-
]
1275-
)
1276-
1277-
# Compute full component plus property name for the cancel dependencies
1278-
cancel_prop_ids = tuple(
1279-
".".join([dep.component_id, dep.component_property]) for dep in cancel
1213+
return _callback.callback(
1214+
*_args,
1215+
long=True,
1216+
long_manager=manager,
1217+
long_interval=interval,
1218+
long_progress=progress,
1219+
long_progress_default=progress_default,
1220+
long_running=running,
1221+
long_cancel=cancel,
1222+
long_cache_args_to_ignore=cache_args_to_ignore,
1223+
**_kwargs,
12801224
)
12811225

1282-
def wrapper(fn):
1283-
background_fn = callback_manager.make_job_fn(fn, bool(progress), args_deps)
1284-
1285-
def callback(_triggers, user_store_data, user_callback_args):
1286-
# Build result cache key from inputs
1287-
pending_key = callback_manager.build_cache_key(
1288-
fn, user_callback_args, cache_args_to_ignore
1289-
)
1290-
current_key = user_store_data.get("current_key", None)
1291-
pending_job = user_store_data.get("pending_job", None)
1292-
1293-
should_cancel = pending_key == current_key or any(
1294-
trigger["prop_id"] in cancel_prop_ids
1295-
for trigger in callback_context.triggered
1296-
)
1297-
1298-
# Compute grouping of values to set the progress component's to
1299-
# when cleared
1300-
if progress_default is None:
1301-
clear_progress = (
1302-
map_grouping(lambda x: None, progress) if progress else ()
1303-
)
1304-
else:
1305-
clear_progress = progress_default
1306-
1307-
if should_cancel:
1308-
user_store_data["current_key"] = None
1309-
user_store_data["pending_key"] = None
1310-
user_store_data["pending_job"] = None
1311-
1312-
callback_manager.terminate_job(pending_job)
1313-
1314-
return dict(
1315-
user_callback_output=map_grouping(lambda x: no_update, output),
1316-
interval_disabled=True,
1317-
in_progress=[val for (_, _, val) in running],
1318-
progress=clear_progress,
1319-
user_store_data=user_store_data,
1320-
error=no_update,
1321-
)
1322-
1323-
# Look up progress value if a job is in progress
1324-
if pending_job:
1325-
progress_value = callback_manager.get_progress(pending_key)
1326-
else:
1327-
progress_value = None
1328-
1329-
if callback_manager.result_ready(pending_key):
1330-
result = callback_manager.get_result(pending_key, pending_job)
1331-
# Set current key (hash of data stored in client)
1332-
# to pending key (hash of data requested by client)
1333-
user_store_data["current_key"] = pending_key
1334-
1335-
if isinstance(result, dict) and result.get("long_callback_error"):
1336-
error = result.get("long_callback_error")
1337-
print(
1338-
result["long_callback_error"]["tb"],
1339-
file=sys.stderr,
1340-
)
1341-
return dict(
1342-
error=f"An error occurred inside a long callback: {error['msg']}\n"
1343-
+ error["tb"],
1344-
user_callback_output=no_update,
1345-
in_progress=[val for (_, _, val) in running],
1346-
interval_disabled=pending_job is None,
1347-
progress=clear_progress,
1348-
user_store_data=user_store_data,
1349-
)
1350-
1351-
if _callback.NoUpdate.is_no_update(result):
1352-
result = no_update
1353-
1354-
if multi_output and isinstance(result, (list, tuple)):
1355-
result = [
1356-
no_update if _callback.NoUpdate.is_no_update(r) else r
1357-
for r in result
1358-
]
1359-
1360-
# Disable interval if this value was pulled from cache.
1361-
# If this value was the result of a background calculation, don't
1362-
# disable yet. If no other calculations are in progress,
1363-
# interval will be disabled in should_cancel logic above
1364-
# the next time the interval fires.
1365-
interval_disabled = pending_job is None
1366-
return dict(
1367-
user_callback_output=result,
1368-
interval_disabled=interval_disabled,
1369-
in_progress=[val for (_, _, val) in running],
1370-
progress=clear_progress,
1371-
user_store_data=user_store_data,
1372-
error=no_update,
1373-
)
1374-
if progress_value:
1375-
return dict(
1376-
user_callback_output=map_grouping(lambda x: no_update, output),
1377-
interval_disabled=False,
1378-
in_progress=[val for (_, val, _) in running],
1379-
progress=progress_value or {},
1380-
user_store_data=user_store_data,
1381-
error=no_update,
1382-
)
1383-
1384-
# Check if there is a running calculation that can now
1385-
# be canceled
1386-
old_pending_key = user_store_data.get("pending_key", None)
1387-
if (
1388-
old_pending_key
1389-
and old_pending_key != pending_key
1390-
and callback_manager.job_running(pending_job)
1391-
):
1392-
callback_manager.terminate_job(pending_job)
1393-
1394-
user_store_data["pending_key"] = pending_key
1395-
callback_manager.terminate_unhealthy_job(pending_job)
1396-
if not callback_manager.job_running(pending_job):
1397-
user_store_data["pending_job"] = callback_manager.call_job_fn(
1398-
pending_key, background_fn, user_callback_args
1399-
)
1400-
1401-
return dict(
1402-
user_callback_output=map_grouping(lambda x: no_update, output),
1403-
interval_disabled=False,
1404-
in_progress=[val for (_, val, _) in running],
1405-
progress=clear_progress,
1406-
user_store_data=user_store_data,
1407-
error=no_update,
1408-
)
1409-
1410-
self.clientside_callback(
1411-
"function (error) {throw new Error(error)}",
1412-
Output(error_dummy, "data"),
1413-
[Input(error_id, "data")],
1414-
prevent_initial_call=True,
1415-
)
1416-
1417-
return self.callback(
1418-
inputs=dict(
1419-
_triggers=dict(
1420-
n_intervals=Input(interval_id, "n_intervals"),
1421-
cancel=cancel,
1422-
),
1423-
user_store_data=State(store_id, "data"),
1424-
user_callback_args=args_deps,
1425-
),
1426-
output=dict(
1427-
user_callback_output=output,
1428-
interval_disabled=Output(interval_id, "disabled"),
1429-
in_progress=[dep for (dep, _, _) in running],
1430-
progress=progress,
1431-
user_store_data=Output(store_id, "data"),
1432-
error=Output(error_id, "data"),
1433-
),
1434-
prevent_initial_call=prevent_initial_call,
1435-
)(callback)
1436-
1437-
return wrapper
1438-
14391226
def dispatch(self):
14401227
body = flask.request.get_json()
14411228

@@ -1455,6 +1242,7 @@ def dispatch(self):
14551242
flask.g.state_values = inputs_to_dict( # pylint: disable=assigning-non-slot
14561243
state
14571244
)
1245+
flask.g.long_callback_manager = self._long_callback_manager # pylint: disable=E0237
14581246
changed_props = body.get("changedPropIds", [])
14591247
flask.g.triggered_inputs = [ # pylint: disable=assigning-non-slot
14601248
{"prop_id": x, "value": input_values.get(x)} for x in changed_props
@@ -1517,7 +1305,14 @@ def dispatch(self):
15171305
except KeyError as missing_callback_function:
15181306
msg = f"Callback function not found for output '{output}', perhaps you forgot to prepend the '@'?"
15191307
raise KeyError(msg) from missing_callback_function
1520-
response.set_data(func(*args, outputs_list=outputs_list))
1308+
# noinspection PyArgumentList
1309+
response.set_data(
1310+
func(
1311+
*args,
1312+
outputs_list=outputs_list,
1313+
long_callback_manager=self._long_callback_manager,
1314+
)
1315+
)
15211316
return response
15221317

15231318
def _setup_server(self):

Diff for: ‎dash/long_callback/managers/__init__.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,29 @@
44

55

66
class BaseLongCallbackManager(ABC):
7+
UNDEFINED = object()
8+
9+
# Keep a ref to all the ref to register every callback to every manager.
10+
managers = []
11+
12+
# Keep every function for late registering.
13+
functions = []
14+
715
def __init__(self, cache_by):
816
if cache_by is not None and not isinstance(cache_by, list):
917
cache_by = [cache_by]
1018

1119
self.cache_by = cache_by
1220

21+
BaseLongCallbackManager.managers.append(self)
22+
23+
self.func_registry = {}
24+
25+
# Register all funcs that were added before instantiation.
26+
# Ensure all celery task are registered.
27+
for fdetails in self.functions:
28+
self.register(*fdetails)
29+
1330
def terminate_job(self, job):
1431
raise NotImplementedError
1532

@@ -19,7 +36,7 @@ def terminate_unhealthy_job(self, job):
1936
def job_running(self, job):
2037
raise NotImplementedError
2138

22-
def make_job_fn(self, fn, progress, args_deps):
39+
def make_job_fn(self, fn, progress):
2340
raise NotImplementedError
2441

2542
def call_job_fn(self, key, job_fn, args):
@@ -58,6 +75,31 @@ def build_cache_key(self, fn, args, cache_args_to_ignore):
5875

5976
return hashlib.sha1(str(hash_dict).encode("utf-8")).hexdigest()
6077

78+
def register(self, key, fn, progress):
79+
self.func_registry[key] = self.make_job_fn(fn, progress)
80+
81+
@staticmethod
82+
def register_func(fn, progress):
83+
key = BaseLongCallbackManager.hash_function(fn)
84+
BaseLongCallbackManager.functions.append(
85+
(
86+
key,
87+
fn,
88+
progress,
89+
)
90+
)
91+
92+
for manager in BaseLongCallbackManager.managers:
93+
manager.register(key, fn, progress)
94+
95+
return key
96+
6197
@staticmethod
6298
def _make_progress_key(key):
6399
return key + "-progress"
100+
101+
@staticmethod
102+
def hash_function(fn):
103+
fn_source = inspect.getsource(fn)
104+
fn_str = fn_source
105+
return hashlib.sha1(fn_str.encode("utf-8")).hexdigest()

Diff for: ‎dash/long_callback/managers/celery_manager.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from _plotly_utils.utils import PlotlyJSONEncoder
77

8-
from dash._callback import NoUpdate
98
from dash.exceptions import PreventUpdate
109
from dash.long_callback.managers import BaseLongCallbackManager
1110

@@ -74,8 +73,8 @@ def job_running(self, job):
7473
"PROGRESS",
7574
)
7675

77-
def make_job_fn(self, fn, progress, args_deps):
78-
return _make_job_fn(fn, self.handle, progress, args_deps)
76+
def make_job_fn(self, fn, progress):
77+
return _make_job_fn(fn, self.handle, progress)
7978

8079
def get_task(self, job):
8180
if job:
@@ -105,7 +104,7 @@ def get_result(self, key, job):
105104
# Get result value
106105
result = self.handle.backend.get(key)
107106
if result is None:
108-
return None
107+
return self.UNDEFINED
109108

110109
result = json.loads(result)
111110

@@ -122,7 +121,7 @@ def get_result(self, key, job):
122121
return result
123122

124123

125-
def _make_job_fn(fn, celery_app, progress, args_deps):
124+
def _make_job_fn(fn, celery_app, progress):
126125
cache = celery_app.backend
127126

128127
# Hash function source and module to create a unique (but stable) celery task name
@@ -133,19 +132,28 @@ def _make_job_fn(fn, celery_app, progress, args_deps):
133132
@celery_app.task(name=f"long_callback_{fn_hash}")
134133
def job_fn(result_key, progress_key, user_callback_args, fn=fn):
135134
def _set_progress(progress_value):
135+
if not isinstance(progress_value, (list, tuple)):
136+
progress_value = [progress_value]
137+
136138
cache.set(progress_key, json.dumps(progress_value, cls=PlotlyJSONEncoder))
137139

138140
maybe_progress = [_set_progress] if progress else []
139141

140142
try:
141-
if isinstance(args_deps, dict):
143+
if isinstance(user_callback_args, dict):
142144
user_callback_output = fn(*maybe_progress, **user_callback_args)
143-
elif isinstance(args_deps, (list, tuple)):
145+
elif isinstance(user_callback_args, (list, tuple)):
144146
user_callback_output = fn(*maybe_progress, *user_callback_args)
145147
else:
146148
user_callback_output = fn(*maybe_progress, user_callback_args)
147149
except PreventUpdate:
148-
cache.set(result_key, json.dumps(NoUpdate(), cls=PlotlyJSONEncoder))
150+
# Put NoUpdate dict directly to avoid circular imports.
151+
cache.set(
152+
result_key,
153+
json.dumps(
154+
{"_dash_no_update": "_dash_no_update"}, cls=PlotlyJSONEncoder
155+
),
156+
)
149157
except Exception as err: # pylint: disable=broad-except
150158
cache.set(
151159
result_key,

Diff for: ‎dash/long_callback/managers/diskcache_manager.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import traceback
22

33
from . import BaseLongCallbackManager
4-
from ..._callback import NoUpdate
54
from ...exceptions import PreventUpdate
65

76
_pending_value = "__$pending__"
@@ -48,16 +47,18 @@ def __init__(self, cache=None, cache_by=None, expire=None):
4847
)
4948
self.handle = cache
5049

51-
super().__init__(cache_by)
52-
self.expire = expire
5350
self.lock = diskcache.Lock(self.handle, "long-callback-lock")
51+
self.expire = expire
52+
super().__init__(cache_by)
5453

5554
def terminate_job(self, job):
5655
import psutil # pylint: disable=import-outside-toplevel,import-error
5756

5857
if job is None:
5958
return
6059

60+
job = int(job)
61+
6162
# Use diskcache transaction so multiple process don't try to kill the
6263
# process at the same time
6364
with self.handle.transact():
@@ -83,6 +84,8 @@ def terminate_job(self, job):
8384
def terminate_unhealthy_job(self, job):
8485
import psutil # pylint: disable=import-outside-toplevel,import-error
8586

87+
job = int(job)
88+
8689
if job and psutil.pid_exists(job):
8790
if not self.job_running(job):
8891
self.terminate_job(job)
@@ -93,13 +96,15 @@ def terminate_unhealthy_job(self, job):
9396
def job_running(self, job):
9497
import psutil # pylint: disable=import-outside-toplevel,import-error
9598

99+
job = int(job)
100+
96101
if job and psutil.pid_exists(job):
97102
proc = psutil.Process(job)
98103
return proc.status() != psutil.STATUS_ZOMBIE
99104
return False
100105

101-
def make_job_fn(self, fn, progress, args_deps):
102-
return _make_job_fn(fn, self.handle, progress, args_deps, self.lock)
106+
def make_job_fn(self, fn, progress):
107+
return _make_job_fn(fn, self.handle, progress, self.lock)
103108

104109
def clear_cache_entry(self, key):
105110
self.handle.delete(key)
@@ -122,9 +127,9 @@ def result_ready(self, key):
122127

123128
def get_result(self, key, job):
124129
# Get result value
125-
result = self.handle.get(key)
126-
if result is None:
127-
return None
130+
result = self.handle.get(key, self.UNDEFINED)
131+
if result is self.UNDEFINED:
132+
return self.UNDEFINED
128133

129134
# Clear result if not caching
130135
if self.cache_by is None:
@@ -135,28 +140,32 @@ def get_result(self, key, job):
135140

136141
self.clear_cache_entry(self._make_progress_key(key))
137142

138-
self.terminate_job(job)
143+
if job:
144+
self.terminate_job(job)
139145
return result
140146

141147

142-
def _make_job_fn(fn, cache, progress, args_deps, lock):
148+
def _make_job_fn(fn, cache, progress, lock):
143149
def job_fn(result_key, progress_key, user_callback_args):
144150
def _set_progress(progress_value):
151+
if not isinstance(progress_value, (list, tuple)):
152+
progress_value = [progress_value]
153+
145154
with lock:
146155
cache.set(progress_key, progress_value)
147156

148157
maybe_progress = [_set_progress] if progress else []
149158

150159
try:
151-
if isinstance(args_deps, dict):
160+
if isinstance(user_callback_args, dict):
152161
user_callback_output = fn(*maybe_progress, **user_callback_args)
153-
elif isinstance(args_deps, (list, tuple)):
162+
elif isinstance(user_callback_args, (list, tuple)):
154163
user_callback_output = fn(*maybe_progress, *user_callback_args)
155164
else:
156165
user_callback_output = fn(*maybe_progress, user_callback_args)
157166
except PreventUpdate:
158167
with lock:
159-
cache.set(result_key, NoUpdate())
168+
cache.set(result_key, {"_dash_no_update": "_dash_no_update"})
160169
except Exception as err: # pylint: disable=broad-except
161170
with lock:
162171
cache.set(

Diff for: ‎tests/integration/long_callback/test_basic_long_callback.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -451,9 +451,14 @@ def click_n_wait():
451451
dash_duo.wait_for_text_to_equal("#output", "Clicked 1 times")
452452

453453
click_n_wait()
454-
dash_duo.wait_for_contains_text(
455-
".dash-fe-error__title", "An error occurred inside a long callback:"
454+
dash_duo.wait_for_element(".dash-fe-error__title").click()
455+
456+
dash_duo.driver.switch_to.frame(dash_duo.find_element("iframe"))
457+
assert (
458+
"Exception: An error occurred inside a long callback:"
459+
in dash_duo.wait_for_element(".errormsg").text
456460
)
461+
dash_duo.driver.switch_to.default_content()
457462

458463
click_n_wait()
459464
dash_duo.wait_for_text_to_equal("#output", "Clicked 3 times")

0 commit comments

Comments
 (0)
Please sign in to comment.