Skip to content

Commit bbf9ce3

Browse files
fix(ui): single or collection field rendering
Fixes an issue where fields like control weight on ControlNet nodes and image on IP Adapter nodes didn't render. These are "single or collection" fields. They accept a single input object, or collection. They are supposed to render the UI input for a single object. In a7a71ca a performance optimisation for a hot code-path inadvertently broke this. The determination of which UI component to render for a given field was done using a type guard function for the field's template. Previously, this used a zod schema to parse the template. This is very slow, especially when the template was not the expected type. The optimization changed the type guards to check the field name (aka its type, integer, image, etc) and cardinality directly, without any zod parsing. It's much faster, but subtly changed the behaviour because it was a bit stricter. For some fields, it rejected "single or collection" cardinalities when it should have accepted them. When these fields - like the aforementioned Control Weight and Image - were being rendered, none of the type guards passed and they rendered nothing. The fix here updates the type guard functions to support multiple cardinalities. So now, when we go to render a "single or collection" field, we will render the "single" input component as it should be.
1 parent 7567ee2 commit bbf9ce3

File tree

1 file changed

+34
-11
lines changed
  • invokeai/frontend/web/src/features/nodes/types

1 file changed

+34
-11
lines changed

Diff for: invokeai/frontend/web/src/features/nodes/types/field.ts

+34-11
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,25 @@ const buildInstanceTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
331331
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
332332
};
333333

334+
/**
335+
* Builds a type guard for a specific field input template type.
336+
*
337+
* The output type guards are primarily used for determining which input component to render for fields in the
338+
* <InputFieldRenderer/> component.
339+
*
340+
* @param name The name of the field type.
341+
* @param cardinalities The allowed cardinalities for the field type. If omitted, all cardinalities are allowed.
342+
*
343+
* @returns A type guard for the specified field type.
344+
*/
334345
const buildTemplateTypeGuard =
335-
<T extends FieldInputTemplate>(name: string, cardinality?: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION') =>
346+
<T extends FieldInputTemplate>(name: string, cardinalities?: FieldType['cardinality'][]) =>
336347
(template: FieldInputTemplate): template is T => {
337348
if (template.type.name !== name) {
338349
return false;
339350
}
340-
if (cardinality) {
341-
return template.type.cardinality === cardinality;
351+
if (cardinalities) {
352+
return cardinalities.includes(template.type.cardinality);
342353
}
343354
return true;
344355
};
@@ -366,7 +377,10 @@ export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
366377
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
367378
export type IntegerFieldInputTemplate = z.infer<typeof zIntegerFieldInputTemplate>;
368379
export const isIntegerFieldInputInstance = buildInstanceTypeGuard(zIntegerFieldInputInstance);
369-
export const isIntegerFieldInputTemplate = buildTemplateTypeGuard<IntegerFieldInputTemplate>('IntegerField', 'SINGLE');
380+
export const isIntegerFieldInputTemplate = buildTemplateTypeGuard<IntegerFieldInputTemplate>('IntegerField', [
381+
'SINGLE',
382+
'SINGLE_OR_COLLECTION',
383+
]);
370384
// #endregion
371385

372386
// #region IntegerField Collection
@@ -406,7 +420,7 @@ export type IntegerFieldCollectionInputTemplate = z.infer<typeof zIntegerFieldCo
406420
export const isIntegerFieldCollectionInputInstance = buildInstanceTypeGuard(zIntegerFieldCollectionInputInstance);
407421
export const isIntegerFieldCollectionInputTemplate = buildTemplateTypeGuard<IntegerFieldCollectionInputTemplate>(
408422
'IntegerField',
409-
'COLLECTION'
423+
['COLLECTION']
410424
);
411425
// #endregion
412426

@@ -432,7 +446,10 @@ export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
432446
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
433447
export type FloatFieldInputTemplate = z.infer<typeof zFloatFieldInputTemplate>;
434448
export const isFloatFieldInputInstance = buildInstanceTypeGuard(zFloatFieldInputInstance);
435-
export const isFloatFieldInputTemplate = buildTemplateTypeGuard<FloatFieldInputTemplate>('FloatField', 'SINGLE');
449+
export const isFloatFieldInputTemplate = buildTemplateTypeGuard<FloatFieldInputTemplate>('FloatField', [
450+
'SINGLE',
451+
'SINGLE_OR_COLLECTION',
452+
]);
436453
// #endregion
437454

438455
// #region FloatField Collection
@@ -471,7 +488,7 @@ export type FloatFieldCollectionInputTemplate = z.infer<typeof zFloatFieldCollec
471488
export const isFloatFieldCollectionInputInstance = buildInstanceTypeGuard(zFloatFieldCollectionInputInstance);
472489
export const isFloatFieldCollectionInputTemplate = buildTemplateTypeGuard<FloatFieldCollectionInputTemplate>(
473490
'FloatField',
474-
'COLLECTION'
491+
['COLLECTION']
475492
);
476493
// #endregion
477494

@@ -504,7 +521,10 @@ export type StringFieldValue = z.infer<typeof zStringFieldValue>;
504521
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
505522
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
506523
export const isStringFieldInputInstance = buildInstanceTypeGuard(zStringFieldInputInstance);
507-
export const isStringFieldInputTemplate = buildTemplateTypeGuard<StringFieldInputTemplate>('StringField', 'SINGLE');
524+
export const isStringFieldInputTemplate = buildTemplateTypeGuard<StringFieldInputTemplate>('StringField', [
525+
'SINGLE',
526+
'SINGLE_OR_COLLECTION',
527+
]);
508528
// #endregion
509529

510530
// #region StringField Collection
@@ -550,7 +570,7 @@ export type StringFieldCollectionInputTemplate = z.infer<typeof zStringFieldColl
550570
export const isStringFieldCollectionInputInstance = buildInstanceTypeGuard(zStringFieldCollectionInputInstance);
551571
export const isStringFieldCollectionInputTemplate = buildTemplateTypeGuard<StringFieldCollectionInputTemplate>(
552572
'StringField',
553-
'COLLECTION'
573+
['COLLECTION']
554574
);
555575
// #endregion
556576

@@ -613,7 +633,10 @@ export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
613633
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
614634
export type ImageFieldInputTemplate = z.infer<typeof zImageFieldInputTemplate>;
615635
export const isImageFieldInputInstance = buildInstanceTypeGuard(zImageFieldInputInstance);
616-
export const isImageFieldInputTemplate = buildTemplateTypeGuard<ImageFieldInputTemplate>('ImageField', 'SINGLE');
636+
export const isImageFieldInputTemplate = buildTemplateTypeGuard<ImageFieldInputTemplate>('ImageField', [
637+
'SINGLE',
638+
'SINGLE_OR_COLLECTION',
639+
]);
617640
// #endregion
618641

619642
// #region ImageField Collection
@@ -648,7 +671,7 @@ export type ImageFieldCollectionInputTemplate = z.infer<typeof zImageFieldCollec
648671
export const isImageFieldCollectionInputInstance = buildInstanceTypeGuard(zImageFieldCollectionInputInstance);
649672
export const isImageFieldCollectionInputTemplate = buildTemplateTypeGuard<ImageFieldCollectionInputTemplate>(
650673
'ImageField',
651-
'COLLECTION'
674+
['COLLECTION']
652675
);
653676
// #endregion
654677

0 commit comments

Comments
 (0)