Skip to content

Commit 799a085

Browse files
authored
Merge pull request #7 from vmarkovtsev/master
Fix casting negative floats to uint32_t
2 parents e0aa275 + 89f2bc3 commit 799a085

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

Diff for: kernel.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ __global__ void weighted_minhash_cuda(
105105
float ln_a = ln_cs[ci] - ln_y - r;
106106
if (ln_a < lnmins[s]) {
107107
lnmins[s] = ln_a;
108-
dtmins[s] = {d, static_cast<uint32_t>(t)};
108+
dtmins[s] = {d, static_cast<uint32_t>(static_cast<int32_t>(t))};
109109
}
110110
}
111111
}

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def is_pure(self):
6969
setup(
7070
name="libMHCUDA",
7171
description="Accelerated Weighted MinHash-ing on GPU",
72-
version="2.0.4",
72+
version="2.0.5",
7373
license="Apache Software License",
7474
author="Vadim Markovtsev",
7575
author_email="[email protected]",

Diff for: test.py

+29
Original file line numberDiff line numberDiff line change
@@ -181,5 +181,34 @@ def test_deferred(self):
181181
print(hashes)
182182
raise e from None
183183

184+
def test_float(self):
185+
v1 = [
186+
0, 1.0497366, 0.8494359, 0.66231006, 0.66231006, 0.8494359,
187+
0, 0.66231006, 0.33652836, 0, 0, 0.5359344,
188+
0.8494359, 0.66231006, 1.0497366, 0.33652836, 0.66231006, 0.8494359,
189+
0.6800841, 0.33652836]
190+
gen = libMHCUDA.minhash_cuda_init(len(v1), 128, devices=1, seed=7, verbosity=2)
191+
vars = libMHCUDA.minhash_cuda_retrieve_vars(gen)
192+
bgen = WeightedMinHashGenerator.__new__(WeightedMinHashGenerator)
193+
bgen.dim = len(v1)
194+
bgen.rs, bgen.ln_cs, bgen.betas = vars
195+
bgen.sample_size = 128
196+
bgen.seed = None
197+
m = csr_matrix(numpy.array(v1, dtype=numpy.float32))
198+
hashes = libMHCUDA.minhash_cuda_calc(gen, m).astype(numpy.int32)
199+
libMHCUDA.minhash_cuda_fini(gen)
200+
self.assertEqual(hashes.shape, (1, 128, 2))
201+
true_hashes = numpy.array([bgen.minhash(v1).hashvalues], dtype=numpy.int32)
202+
self.assertEqual(true_hashes.shape, (1, 128, 2))
203+
try:
204+
self.assertTrue((hashes == true_hashes).all())
205+
except AssertionError as e:
206+
print("---- TRUE ----")
207+
print(true_hashes)
208+
print("---- FALSE ----")
209+
print(hashes)
210+
raise e from None
211+
212+
184213
if __name__ == "__main__":
185214
unittest.main()

0 commit comments

Comments
 (0)