@@ -229,6 +229,11 @@ def test_normal_generator():
229229 buckets , probs = gen_buckets_probs_with_ppf (lambda x : ss .norm .ppf (x , mu , sigma ), 5 )
230230 generator_mx = lambda x : mx .nd .random .normal (mu , sigma , shape = x , ctx = ctx , dtype = dtype ).asnumpy ()
231231 verify_generator (generator = generator_mx , buckets = buckets , probs = probs )
232+ generator_mx_same_seed = \
233+ lambda x : np .concatenate (
234+ [mx .nd .random .normal (mu , sigma , shape = x // 10 , ctx = ctx , dtype = dtype ).asnumpy ()
235+ for _ in range (10 )])
236+ verify_generator (generator = generator_mx_same_seed , buckets = buckets , probs = probs )
232237
233238def test_uniform_generator ():
234239 ctx = mx .context .current_context ()
@@ -238,6 +243,11 @@ def test_uniform_generator():
238243 buckets , probs = gen_buckets_probs_with_ppf (lambda x : ss .uniform .ppf (x , loc = low , scale = high - low ), 5 )
239244 generator_mx = lambda x : mx .nd .random .uniform (low , high , shape = x , ctx = ctx , dtype = dtype ).asnumpy ()
240245 verify_generator (generator = generator_mx , buckets = buckets , probs = probs )
246+ generator_mx_same_seed = \
247+ lambda x : np .concatenate (
248+ [mx .nd .random .uniform (low , high , shape = x // 10 , ctx = ctx , dtype = dtype ).asnumpy ()
249+ for _ in range (10 )])
250+ verify_generator (generator = generator_mx_same_seed , buckets = buckets , probs = probs )
241251
242252def test_gamma_generator ():
243253 ctx = mx .context .current_context ()
@@ -247,6 +257,11 @@ def test_gamma_generator():
247257 buckets , probs = gen_buckets_probs_with_ppf (lambda x : ss .gamma .ppf (x , a = kappa , loc = 0 , scale = theta ), 5 )
248258 generator_mx = lambda x : mx .nd .random .gamma (kappa , theta , shape = x , ctx = ctx , dtype = dtype ).asnumpy ()
249259 verify_generator (generator = generator_mx , buckets = buckets , probs = probs )
260+ generator_mx_same_seed = \
261+ lambda x : np .concatenate (
262+ [mx .nd .random .gamma (kappa , theta , shape = x // 10 , ctx = ctx , dtype = dtype ).asnumpy ()
263+ for _ in range (10 )])
264+ verify_generator (generator = generator_mx_same_seed , buckets = buckets , probs = probs )
250265
251266def test_exponential_generator ():
252267 ctx = mx .context .current_context ()
@@ -256,6 +271,11 @@ def test_exponential_generator():
256271 buckets , probs = gen_buckets_probs_with_ppf (lambda x : ss .expon .ppf (x , loc = 0 , scale = scale ), 5 )
257272 generator_mx = lambda x : mx .nd .random .exponential (scale , shape = x , ctx = ctx , dtype = dtype ).asnumpy ()
258273 verify_generator (generator = generator_mx , buckets = buckets , probs = probs )
274+ generator_mx_same_seed = \
275+ lambda x : np .concatenate (
276+ [mx .nd .random .exponential (scale , shape = x // 10 , ctx = ctx , dtype = dtype ).asnumpy ()
277+ for _ in range (10 )])
278+ verify_generator (generator = generator_mx_same_seed , buckets = buckets , probs = probs )
259279
260280def test_poisson_generator ():
261281 ctx = mx .context .current_context ()
@@ -266,6 +286,11 @@ def test_poisson_generator():
266286 probs = [ss .poisson .cdf (bucket [1 ], lam ) - ss .poisson .cdf (bucket [0 ], lam ) for bucket in buckets ]
267287 generator_mx = lambda x : mx .nd .random .poisson (lam , shape = x , ctx = ctx , dtype = dtype ).asnumpy ()
268288 verify_generator (generator = generator_mx , buckets = buckets , probs = probs )
289+ generator_mx_same_seed = \
290+ lambda x : np .concatenate (
291+ [mx .nd .random .poisson (lam , shape = x // 10 , ctx = ctx , dtype = dtype ).asnumpy ()
292+ for _ in range (10 )])
293+ verify_generator (generator = generator_mx_same_seed , buckets = buckets , probs = probs )
269294
270295def test_negative_binomial_generator ():
271296 ctx = mx .context .current_context ()
@@ -279,24 +304,39 @@ def test_negative_binomial_generator():
279304 generator_mx = lambda x : mx .nd .random .negative_binomial (success_num , success_prob ,
280305 shape = x , ctx = ctx , dtype = dtype ).asnumpy ()
281306 verify_generator (generator = generator_mx , buckets = buckets , probs = probs )
307+ generator_mx_same_seed = \
308+ lambda x : np .concatenate (
309+ [mx .nd .random .negative_binomial (success_num , success_prob , shape = x // 10 , ctx = ctx , dtype = dtype ).asnumpy ()
310+ for _ in range (10 )])
311+ verify_generator (generator = generator_mx_same_seed , buckets = buckets , probs = probs )
282312 # Also test the Gamm-Poisson Mixture
283313 print ('Gamm-Poisson Mixture Test:' )
284314 alpha = 1.0 / success_num
285315 mu = (1.0 - success_prob ) / success_prob / alpha
286316 generator_mx = lambda x : mx .nd .random .generalized_negative_binomial (mu , alpha ,
287317 shape = x , ctx = ctx , dtype = dtype ).asnumpy ()
288318 verify_generator (generator = generator_mx , buckets = buckets , probs = probs )
319+ generator_mx_same_seed = \
320+ lambda x : np .concatenate (
321+ [mx .nd .random .generalized_negative_binomial (mu , alpha , shape = x // 10 , ctx = ctx , dtype = dtype ).asnumpy ()
322+ for _ in range (10 )])
323+ verify_generator (generator = generator_mx_same_seed , buckets = buckets , probs = probs )
289324
290325def test_multinomial_generator ():
291326 ctx = mx .context .current_context ()
292327 probs = [0.1 , 0.2 , 0.3 , 0.05 , 0.15 , 0.2 ]
293- sample_num = 1000000
294328 buckets = list (range (6 ))
295329 for dtype in ['float16' , 'float32' , 'float64' ]:
296330 print ("ctx=%s, dtype=%s" % (ctx , dtype ))
297331 generator_mx = lambda x : mx .nd .random .multinomial (data = mx .nd .array (np .array (probs ), ctx = ctx , dtype = dtype ),
298- shape = sample_num ).asnumpy ()
332+ shape = x ).asnumpy ()
299333 verify_generator (generator_mx , buckets , probs )
334+ generator_mx_same_seed = \
335+ lambda x : np .concatenate (
336+ [mx .nd .random .multinomial (data = mx .nd .array (np .array (probs ), ctx = ctx , dtype = dtype ),
337+ shape = x // 10 ).asnumpy ()
338+ for _ in range (10 )])
339+ verify_generator (generator = generator_mx_same_seed , buckets = buckets , probs = probs )
300340
301341
302342if __name__ == '__main__' :
0 commit comments