Skip to content

Commit 6ee8607

Browse files
psychedeliciousskunkworxdark
authored andcommitted
feat(ui): custom field types connection validation
In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic. *Actually, it was `"Unknown"`, but I changed it to custom for clarity. Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields. To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`. This ended up needing a bit of fanagling: - If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property. While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer. (Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.) - Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future. - We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`. Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`. This causes a TS error. This behaviour is somewhat controversial (see microsoft/TypeScript#14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent.
1 parent 8e125b8 commit 6ee8607

File tree

13 files changed

+371
-178
lines changed

13 files changed

+371
-178
lines changed

invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,13 @@ const AddNodePopover = () => {
7373

7474
return some(handles, (handle) => {
7575
const sourceType =
76-
handleFilter == 'source' ? fieldFilter : handle.type;
76+
handleFilter == 'source'
77+
? fieldFilter
78+
: handle.originalType ?? handle.type;
7779
const targetType =
78-
handleFilter == 'target' ? fieldFilter : handle.type;
80+
handleFilter == 'target'
81+
? fieldFilter
82+
: handle.originalType ?? handle.type;
7983

8084
return validateSourceAndTargetTypes(sourceType, targetType);
8185
});

invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ import { useAppSelector } from 'app/store/storeHooks';
44
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
55
import { memo } from 'react';
66
import { ConnectionLineComponentProps, getBezierPath } from 'reactflow';
7-
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
7+
import { getFieldColor } from '../edges/util/getEdgeColor';
88

99
const selector = createSelector(stateSelector, ({ nodes }) => {
10-
const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } =
10+
const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } =
1111
nodes;
1212

1313
const stroke = shouldColorEdges
14-
? getFieldColor(connectionStartFieldType)
14+
? getFieldColor(currentConnectionFieldType)
1515
: colorTokenToCssVar('base.500');
1616

1717
let className = 'react-flow__custom_connection-path';
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
2-
import { FIELD_COLORS } from 'features/nodes/types/constants';
3-
import { FieldType } from 'features/nodes/types/field';
2+
import { FIELDS } from 'features/nodes/types/constants';
3+
import { FieldType } from 'features/nodes/types/types';
44

5-
export const getFieldColor = (fieldType: FieldType | null): string => {
5+
export const getFieldColor = (fieldType: FieldType | string | null): string => {
66
if (!fieldType) {
77
return colorTokenToCssVar('base.500');
88
}
9-
const color = FIELD_COLORS[fieldType.name];
9+
const color = FIELDS[fieldType]?.color;
1010

1111
return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500');
1212
};

invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { createSelector } from '@reduxjs/toolkit';
22
import { stateSelector } from 'app/store/store';
33
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
44
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
5-
import { isInvocationNode } from 'features/nodes/types/invocation';
5+
import { isInvocationNode } from 'features/nodes/types/types';
66
import { getFieldColor } from './getEdgeColor';
77

88
export const makeEdgeSelector = (

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import { Tooltip } from '@chakra-ui/react';
2-
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
32
import {
43
COLLECTION_TYPES,
5-
FIELDS,
64
HANDLE_TOOLTIP_OPEN_DELAY,
75
MODEL_TYPES,
86
POLYMORPHIC_TYPES,
@@ -13,6 +11,7 @@ import {
1311
} from 'features/nodes/types/types';
1412
import { CSSProperties, memo, useMemo } from 'react';
1513
import { Handle, HandleType, Position } from 'reactflow';
14+
import { getFieldColor } from '../../../edges/util/getEdgeColor';
1615

1716
export const handleBaseStyles: CSSProperties = {
1817
position: 'absolute',
@@ -47,14 +46,14 @@ const FieldHandle = (props: FieldHandleProps) => {
4746
isConnectionStartField,
4847
connectionError,
4948
} = props;
50-
const { name, type, originalType } = fieldTemplate;
51-
const { color: typeColor } = FIELDS[type];
49+
const { name } = fieldTemplate;
50+
const type = fieldTemplate.originalType ?? fieldTemplate.type;
5251

5352
const styles: CSSProperties = useMemo(() => {
54-
const isCollectionType = COLLECTION_TYPES.includes(type);
55-
const isPolymorphicType = POLYMORPHIC_TYPES.includes(type);
56-
const isModelType = MODEL_TYPES.includes(type);
57-
const color = colorTokenToCssVar(typeColor);
53+
const isCollectionType = COLLECTION_TYPES.some((t) => t === type);
54+
const isPolymorphicType = POLYMORPHIC_TYPES.some((t) => t === type);
55+
const isModelType = MODEL_TYPES.some((t) => t === type);
56+
const color = getFieldColor(type);
5857
const s: CSSProperties = {
5958
backgroundColor:
6059
isCollectionType || isPolymorphicType
@@ -97,23 +96,14 @@ const FieldHandle = (props: FieldHandleProps) => {
9796
isConnectionInProgress,
9897
isConnectionStartField,
9998
type,
100-
typeColor,
10199
]);
102100

103101
const tooltip = useMemo(() => {
104-
if (isConnectionInProgress && isConnectionStartField) {
105-
return originalType;
106-
}
107102
if (isConnectionInProgress && connectionError) {
108-
return connectionError ?? originalType;
103+
return connectionError;
109104
}
110-
return originalType;
111-
}, [
112-
connectionError,
113-
isConnectionInProgress,
114-
isConnectionStartField,
115-
originalType,
116-
]);
105+
return type;
106+
}, [connectionError, isConnectionInProgress, type]);
117107

118108
return (
119109
<Tooltip

invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ import { stateSelector } from 'app/store/store';
33
import { useAppSelector } from 'app/store/storeHooks';
44
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
55
import { useMemo } from 'react';
6-
import { KIND_MAP } from 'features/nodes/types/constants';
7-
import { isInvocationNode } from 'features/nodes/types/invocation';
6+
import { KIND_MAP } from '../types/constants';
7+
import { isInvocationNode } from '../types/types';
88

99
export const useFieldType = (
1010
nodeId: string,
@@ -21,7 +21,7 @@ export const useFieldType = (
2121
return;
2222
}
2323
const field = node.data[KIND_MAP[kind]][fieldName];
24-
return field?.type;
24+
return field?.originalType ?? field?.type;
2525
},
2626
defaultSelectorOptions
2727
),

0 commit comments

Comments
 (0)