Skip to content

Commit 7783744

Browse files
authored
[data] update tool_helpers version and add unittest (#9093)
1 parent f4cff96 commit 7783744

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ uvicorn
2020
typer
2121
rich
2222
safetensors
23-
tool_helpers==0.1.1 ; platform_system == "Linux"
23+
tool_helpers>=0.1.1 ; platform_system == "Linux"
2424
aistudio-sdk>=0.1.3
2525
jinja2
2626
regex

tests/data/test_blendable_dataset.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import importlib.metadata
16+
import unittest
17+
18+
import numpy as np
19+
from parameterized import parameterized
20+
21+
22+
def build_blending_indices_python(dataset_index, dataset_sample_index, weights, num_datasets, size, verbose):
23+
"""
24+
Given multiple datasets and a weighting array, build samples such that it follows those weights.
25+
26+
Parameters:
27+
- dataset_index: NumPy array to store the dataset index for each sample.
28+
- dataset_sample_index: NumPy array to store the sample index within each dataset.
29+
- weights: NumPy array of weights for each dataset.
30+
- num_datasets: Integer, the number of datasets.
31+
- size: Integer, the total number of samples to generate.
32+
- verbose: Boolean, whether to print verbose output.
33+
"""
34+
if verbose:
35+
print("> building indices for blendable datasets ...")
36+
37+
# Initialize buffer for number of samples used for each dataset.
38+
current_samples = np.zeros(num_datasets, dtype=np.int64)
39+
40+
# For each sample:
41+
for sample_idx in range(size):
42+
# Determine where the max error in sampling is happening.
43+
sample_idx_double = max(sample_idx, 1)
44+
max_error_index = 0
45+
max_error = weights[0] * sample_idx_double - current_samples[0]
46+
for dataset_idx in range(1, num_datasets):
47+
error = weights[dataset_idx] * sample_idx_double - current_samples[dataset_idx]
48+
if error > max_error:
49+
max_error = error
50+
max_error_index = dataset_idx
51+
52+
# Populate the indices.
53+
dataset_index[sample_idx] = max_error_index
54+
dataset_sample_index[sample_idx] = current_samples[max_error_index]
55+
56+
# Update the total samples.
57+
current_samples[max_error_index] += 1
58+
59+
# Print info
60+
if verbose:
61+
print(" > sample ratios:")
62+
for dataset_idx in range(num_datasets):
63+
ratio = current_samples[dataset_idx] / size
64+
print(f" dataset {dataset_idx}, input: {weights[dataset_idx]}, achieved: {ratio}")
65+
66+
67+
def skip_if_version_not_equal(version="0.1.1", package_name="tool_helpers"):
68+
try:
69+
importlib.import_module(package_name)
70+
except ImportError:
71+
return True, f"package<{package_name}> not found, so to skip this test"
72+
package_version = importlib.metadata.version(package_name)
73+
if package_version != version:
74+
return True, f"{package_name} version must be equal to {version}, but got {package_version}!"
75+
return False, f"{package_name} version is ok!"
76+
77+
78+
class TestToolHelpers(unittest.TestCase):
79+
def _test_build_blending_indices(
80+
self, num_datasets=128, size=8192, dataset_index_dtype="uint8", verbose=False, seed=42, assert_true=True
81+
):
82+
if isinstance(dataset_index_dtype, str):
83+
dataset_index_dtype = np.dtype(dataset_index_dtype)
84+
assert dataset_index_dtype in [np.uint8, np.int16], "dataset_index_dtype must be uint8 or int16!"
85+
86+
np.random.seed(seed)
87+
random_numbers = np.random.rand(num_datasets)
88+
random_numbers[0] = 200
89+
weights = random_numbers / random_numbers.sum()
90+
weights = weights.astype(np.float64)
91+
92+
# for ground truth, so we use np.int32
93+
python_dataset_index = np.zeros(size, dtype=np.int32)
94+
python_dataset_sample_index = np.zeros(size, dtype=np.int64)
95+
build_blending_indices_python(
96+
python_dataset_index, python_dataset_sample_index, weights, num_datasets, size, verbose
97+
)
98+
99+
from tool_helpers import helpers
100+
101+
c_dataset_index = np.zeros(size, dtype=dataset_index_dtype)
102+
c_dataset_sample_index = np.zeros(size, dtype=np.int64)
103+
helpers.build_blending_indices(c_dataset_index, c_dataset_sample_index, weights, num_datasets, size, verbose)
104+
105+
assert_func = self.assertTrue if assert_true else self.assertFalse
106+
assert_func(np.all(python_dataset_index == c_dataset_index.astype(python_dataset_index.dtype)))
107+
self.assertTrue(
108+
np.all(python_dataset_sample_index == c_dataset_sample_index.astype(python_dataset_sample_index.dtype))
109+
)
110+
111+
@parameterized.expand(
112+
[
113+
(128, 8192, "uint8", False, 42, True),
114+
(1024, 8192, "uint8", False, 42, False),
115+
(128, 8192, "int16", False, 42, False),
116+
(1024, 8192, "int16", False, 42, False),
117+
]
118+
)
119+
@unittest.skipIf(*skip_if_version_not_equal(version="0.1.1", package_name="tool_helpers"))
120+
def test_build_blending_indices_version_0_1_1(
121+
self, num_datasets=128, size=8192, dataset_index_dtype="uint8", verbose=False, seed=42, assert_true=True
122+
):
123+
self._test_build_blending_indices(num_datasets, size, dataset_index_dtype, verbose, seed, assert_true)
124+
125+
@parameterized.expand(
126+
[
127+
(128, 8192, "uint8", False, 42, True),
128+
(1024, 8192, "uint8", False, 42, False),
129+
(128, 8192, "int16", False, 42, True),
130+
(1024, 8192, "int16", False, 42, True),
131+
]
132+
)
133+
@unittest.skipIf(*skip_if_version_not_equal(version="0.1.2", package_name="tool_helpers"))
134+
def test_build_blending_indices_version_0_1_2(
135+
self, num_datasets=128, size=8192, dataset_index_dtype="uint8", verbose=False, seed=42, assert_true=True
136+
):
137+
self._test_build_blending_indices(num_datasets, size, dataset_index_dtype, verbose, seed, assert_true)

0 commit comments

Comments
 (0)