Skip to content

Commit 8d60ea3

Browse files
tf-transform-teamtfx-copybara
tf-transform-team
authored andcommitted
Adding internal reduce_instance_dims=False support to tf_utils.reduce_batch_minus_min_and_max_per_key
PiperOrigin-RevId: 454047295
1 parent e812bef commit 8d60ea3

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

tensorflow_transform/tf_utils.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1540,7 +1540,9 @@ def reduce_batch_minus_min_and_max(
15401540

15411541
def reduce_batch_minus_min_and_max_per_key(
15421542
x: common_types.TensorType,
1543-
key: common_types.TensorType) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
1543+
key: common_types.TensorType,
1544+
reduce_instance_dims: bool = True
1545+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
15441546
"""Computes the -min and max of a tensor x.
15451547
15461548
Args:
@@ -1552,6 +1554,10 @@ def reduce_batch_minus_min_and_max_per_key(
15521554
everything except values,
15531555
3. The axis=1 index of each element of sparse x matches its index of
15541556
dense key.
1557+
reduce_instance_dims: A bool indicating whether this should collapse the
1558+
batch and instance dimensions to arrive at a single scalar output, or only
1559+
collapse the batch dimension and outputs a vector of the same shape as the
1560+
input.
15551561
Returns:
15561562
A 3-tuple containing the `Tensor`s (key_vocab, min_per_key, max_per_key).
15571563
"""
@@ -1561,10 +1567,16 @@ def reduce_batch_minus_min_and_max_per_key(
15611567
elif x.dtype == tf.uint32 or x.dtype == tf.uint64:
15621568
raise TypeError('Tensor type %r is not supported' % x.dtype)
15631569

1570+
if not reduce_instance_dims and isinstance(
1571+
x, (tf.SparseTensor, tf.RaggedTensor)):
1572+
raise NotImplementedError(
1573+
'Elementwise reduction of composite tensors is not supported'
1574+
)
1575+
15641576
x, key = _validate_and_get_dense_value_key_inputs(x, key)
15651577

15661578
def get_batch_max_per_key(tensor, key_uniques): # pylint: disable=missing-docstring
1567-
if tensor.get_shape().ndims < 2:
1579+
if not reduce_instance_dims or tensor.get_shape().ndims < 2:
15681580
row_maxes = tensor
15691581
else:
15701582
row_maxes = tf.reduce_max(

tensorflow_transform/tf_utils_test.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,7 @@ def _reduce_batch_minus_min_and_max(x):
16271627
values=[3, 2, -1, 3],
16281628
dense_shape=[4, 5]),
16291629
key=['a', 'a', 'a', 'b'],
1630+
reduce_instance_dims=True,
16301631
expected_key_vocab=[b'a', b'b'],
16311632
expected_x_minus_min=[1, -3],
16321633
expected_x_max=[3, 3],
@@ -1638,25 +1639,52 @@ def _reduce_batch_minus_min_and_max(x):
16381639
testcase_name='float',
16391640
x=[[1], [5], [2], [3]],
16401641
key=['a', 'a', 'a', 'b'],
1642+
reduce_instance_dims=True,
16411643
expected_key_vocab=[b'a', b'b'],
16421644
expected_x_minus_min=[-1, -3],
16431645
expected_x_max=[5, 3],
16441646
input_signature=[
16451647
tf.TensorSpec([None, None], tf.float32),
16461648
tf.TensorSpec([None], tf.string)
16471649
]),
1650+
dict(
1651+
testcase_name='float_elementwise',
1652+
x=[[1], [5], [2], [3]],
1653+
key=['a', 'a', 'a', 'b'],
1654+
reduce_instance_dims=False,
1655+
expected_key_vocab=[b'a', b'b'],
1656+
expected_x_minus_min=[[-1], [-3]],
1657+
expected_x_max=[[5], [3]],
1658+
input_signature=[
1659+
tf.TensorSpec([None, None], tf.float32),
1660+
tf.TensorSpec([None], tf.string)
1661+
]),
16481662
dict(
16491663
testcase_name='float3dims',
16501664
x=[[[1, 5], [1, 1]], [[5, 1], [5, 5]], [[2, 2], [2, 5]],
16511665
[[3, -3], [3, 3]]],
16521666
key=['a', 'a', 'a', 'b'],
1667+
reduce_instance_dims=True,
16531668
expected_key_vocab=[b'a', b'b'],
16541669
expected_x_minus_min=[-1, 3],
16551670
expected_x_max=[5, 3],
16561671
input_signature=[
16571672
tf.TensorSpec([None, None, None], tf.float32),
16581673
tf.TensorSpec([None], tf.string)
16591674
]),
1675+
dict(
1676+
testcase_name='float3dims_elementwise',
1677+
x=[[[1, 5], [1, 1]], [[5, 1], [5, 5]], [[2, 2], [2, 5]],
1678+
[[3, -3], [3, 3]]],
1679+
key=['a', 'a', 'a', 'b'],
1680+
reduce_instance_dims=False,
1681+
expected_key_vocab=[b'a', b'b'],
1682+
expected_x_minus_min=[[[-1, -1], [-1, -1]], [[-3, 3], [-3, -3]]],
1683+
expected_x_max=[[[5, 5], [5, 5]], [[3, -3], [3, 3]]],
1684+
input_signature=[
1685+
tf.TensorSpec([None, None, None], tf.float32),
1686+
tf.TensorSpec([None], tf.string)
1687+
]),
16601688
dict(
16611689
testcase_name='ragged',
16621690
x=tf.compat.v1.ragged.RaggedTensorValue(
@@ -1673,6 +1701,7 @@ def _reduce_batch_minus_min_and_max(x):
16731701
row_splits=np.array([0, 2, 3, 4, 5])),
16741702
row_splits=np.array([0, 2, 3, 4])),
16751703
row_splits=np.array([0, 2, 3])),
1704+
reduce_instance_dims=True,
16761705
expected_key_vocab=[b'a', b'b'],
16771706
expected_x_minus_min=[-2., -3.],
16781707
expected_x_max=[4., 5.],
@@ -1682,12 +1711,13 @@ def _reduce_batch_minus_min_and_max(x):
16821711
]),
16831712
]))
16841713
def test_reduce_batch_minus_min_and_max_per_key(
1685-
self, x, key, expected_key_vocab, expected_x_minus_min, expected_x_max,
1686-
input_signature, function_handler):
1714+
self, x, key, reduce_instance_dims, expected_key_vocab,
1715+
expected_x_minus_min, expected_x_max, input_signature, function_handler):
16871716

16881717
@function_handler(input_signature=input_signature)
16891718
def _reduce_batch_minus_min_and_max_per_key(x, key):
1690-
return tf_utils.reduce_batch_minus_min_and_max_per_key(x, key)
1719+
return tf_utils.reduce_batch_minus_min_and_max_per_key(
1720+
x, key, reduce_instance_dims=reduce_instance_dims)
16911721

16921722
key_vocab, x_minus_min, x_max = _reduce_batch_minus_min_and_max_per_key(
16931723
x, key)

0 commit comments

Comments
 (0)