@@ -1627,6 +1627,7 @@ def _reduce_batch_minus_min_and_max(x):
1627
1627
values = [3 , 2 , - 1 , 3 ],
1628
1628
dense_shape = [4 , 5 ]),
1629
1629
key = ['a' , 'a' , 'a' , 'b' ],
1630
+ reduce_instance_dims = True ,
1630
1631
expected_key_vocab = [b'a' , b'b' ],
1631
1632
expected_x_minus_min = [1 , - 3 ],
1632
1633
expected_x_max = [3 , 3 ],
@@ -1638,25 +1639,52 @@ def _reduce_batch_minus_min_and_max(x):
1638
1639
testcase_name = 'float' ,
1639
1640
x = [[1 ], [5 ], [2 ], [3 ]],
1640
1641
key = ['a' , 'a' , 'a' , 'b' ],
1642
+ reduce_instance_dims = True ,
1641
1643
expected_key_vocab = [b'a' , b'b' ],
1642
1644
expected_x_minus_min = [- 1 , - 3 ],
1643
1645
expected_x_max = [5 , 3 ],
1644
1646
input_signature = [
1645
1647
tf .TensorSpec ([None , None ], tf .float32 ),
1646
1648
tf .TensorSpec ([None ], tf .string )
1647
1649
]),
1650
+ dict (
1651
+ testcase_name = 'float_elementwise' ,
1652
+ x = [[1 ], [5 ], [2 ], [3 ]],
1653
+ key = ['a' , 'a' , 'a' , 'b' ],
1654
+ reduce_instance_dims = False ,
1655
+ expected_key_vocab = [b'a' , b'b' ],
1656
+ expected_x_minus_min = [[- 1 ], [- 3 ]],
1657
+ expected_x_max = [[5 ], [3 ]],
1658
+ input_signature = [
1659
+ tf .TensorSpec ([None , None ], tf .float32 ),
1660
+ tf .TensorSpec ([None ], tf .string )
1661
+ ]),
1648
1662
dict (
1649
1663
testcase_name = 'float3dims' ,
1650
1664
x = [[[1 , 5 ], [1 , 1 ]], [[5 , 1 ], [5 , 5 ]], [[2 , 2 ], [2 , 5 ]],
1651
1665
[[3 , - 3 ], [3 , 3 ]]],
1652
1666
key = ['a' , 'a' , 'a' , 'b' ],
1667
+ reduce_instance_dims = True ,
1653
1668
expected_key_vocab = [b'a' , b'b' ],
1654
1669
expected_x_minus_min = [- 1 , 3 ],
1655
1670
expected_x_max = [5 , 3 ],
1656
1671
input_signature = [
1657
1672
tf .TensorSpec ([None , None , None ], tf .float32 ),
1658
1673
tf .TensorSpec ([None ], tf .string )
1659
1674
]),
1675
+ dict (
1676
+ testcase_name = 'float3dims_elementwise' ,
1677
+ x = [[[1 , 5 ], [1 , 1 ]], [[5 , 1 ], [5 , 5 ]], [[2 , 2 ], [2 , 5 ]],
1678
+ [[3 , - 3 ], [3 , 3 ]]],
1679
+ key = ['a' , 'a' , 'a' , 'b' ],
1680
+ reduce_instance_dims = False ,
1681
+ expected_key_vocab = [b'a' , b'b' ],
1682
+ expected_x_minus_min = [[[- 1 , - 1 ], [- 1 , - 1 ]], [[- 3 , 3 ], [- 3 , - 3 ]]],
1683
+ expected_x_max = [[[5 , 5 ], [5 , 5 ]], [[3 , - 3 ], [3 , 3 ]]],
1684
+ input_signature = [
1685
+ tf .TensorSpec ([None , None , None ], tf .float32 ),
1686
+ tf .TensorSpec ([None ], tf .string )
1687
+ ]),
1660
1688
dict (
1661
1689
testcase_name = 'ragged' ,
1662
1690
x = tf .compat .v1 .ragged .RaggedTensorValue (
@@ -1673,6 +1701,7 @@ def _reduce_batch_minus_min_and_max(x):
1673
1701
row_splits = np .array ([0 , 2 , 3 , 4 , 5 ])),
1674
1702
row_splits = np .array ([0 , 2 , 3 , 4 ])),
1675
1703
row_splits = np .array ([0 , 2 , 3 ])),
1704
+ reduce_instance_dims = True ,
1676
1705
expected_key_vocab = [b'a' , b'b' ],
1677
1706
expected_x_minus_min = [- 2. , - 3. ],
1678
1707
expected_x_max = [4. , 5. ],
@@ -1682,12 +1711,13 @@ def _reduce_batch_minus_min_and_max(x):
1682
1711
]),
1683
1712
]))
1684
1713
def test_reduce_batch_minus_min_and_max_per_key (
1685
- self , x , key , expected_key_vocab , expected_x_minus_min , expected_x_max ,
1686
- input_signature , function_handler ):
1714
+ self , x , key , reduce_instance_dims , expected_key_vocab ,
1715
+ expected_x_minus_min , expected_x_max , input_signature , function_handler ):
1687
1716
1688
1717
@function_handler (input_signature = input_signature )
1689
1718
def _reduce_batch_minus_min_and_max_per_key (x , key ):
1690
- return tf_utils .reduce_batch_minus_min_and_max_per_key (x , key )
1719
+ return tf_utils .reduce_batch_minus_min_and_max_per_key (
1720
+ x , key , reduce_instance_dims = reduce_instance_dims )
1691
1721
1692
1722
key_vocab , x_minus_min , x_max = _reduce_batch_minus_min_and_max_per_key (
1693
1723
x , key )
0 commit comments