@@ -73,6 +73,9 @@ def test_random_vars(self):
73
73
gen = libMHCUDA .minhash_cuda_init (1000 , 128 , devices = 1 , verbosity = 2 )
74
74
rs , ln_cs , betas = libMHCUDA .minhash_cuda_retrieve_vars (gen )
75
75
libMHCUDA .minhash_cuda_fini (gen )
76
+ self .assertEqual (rs .shape , (128 , 1000 ))
77
+ self .assertEqual (ln_cs .shape , (128 , 1000 ))
78
+ self .assertEqual (betas .shape , (128 , 1000 ))
76
79
cs = numpy .exp (ln_cs )
77
80
a , loc , scale = gamma .fit (rs )
78
81
self .assertTrue (1.97 < a < 2.03 )
@@ -120,6 +123,34 @@ def test_slice(self):
120
123
libMHCUDA .minhash_cuda_fini (gen )
121
124
self .assertTrue ((hashes [3200 :4800 ] == hashes2 ).all ())
122
125
126
+ def test_backwards (self ):
127
+ v1 = [1 , 0 , 0 , 0 , 3 , 4 , 5 , 0 , 0 , 0 , 0 , 6 , 7 , 8 , 0 , 0 , 0 , 0 , 0 , 0 , 9 , 10 , 4 ]
128
+ v2 = [2 , 0 , 0 , 0 , 4 , 3 , 8 , 0 , 0 , 0 , 0 , 4 , 7 , 10 , 0 , 0 , 0 , 0 , 0 , 0 , 9 , 0 , 0 ]
129
+ gen = libMHCUDA .minhash_cuda_init (len (v1 ), 128 , devices = 1 , verbosity = 2 )
130
+ rs , ln_cs , betas = libMHCUDA .minhash_cuda_retrieve_vars (gen )
131
+ bgen = WeightedMinHashGenerator .__new__ (WeightedMinHashGenerator )
132
+ bgen .dim = len (v1 )
133
+ bgen .rs = rs
134
+ bgen .ln_cs = ln_cs
135
+ bgen .betas = betas
136
+ bgen .sample_size = 128
137
+ bgen .seed = None
138
+ m = csr_matrix (numpy .array ([v1 , v2 ], dtype = numpy .float32 ))
139
+ hashes = libMHCUDA .minhash_cuda_calc (gen , m )
140
+ libMHCUDA .minhash_cuda_fini (gen )
141
+ self .assertEqual (hashes .shape , (2 , 128 , 2 ))
142
+ true_hashes = numpy .array ([bgen .minhash (v1 ).hashvalues ,
143
+ bgen .minhash (v2 ).hashvalues ], dtype = numpy .uint32 )
144
+ self .assertEqual (true_hashes .shape , (2 , 128 , 2 ))
145
+ try :
146
+ self .assertTrue ((hashes == true_hashes ).all ())
147
+ except AssertionError as e :
148
+ print ("---- TRUE ----" )
149
+ print (true_hashes )
150
+ print ("---- FALSE ----" )
151
+ print (hashes )
152
+ raise e from None
153
+
123
154
124
155
if __name__ == "__main__" :
125
156
unittest .main ()
0 commit comments