Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 199fabd

Browse files
committed
revise test cases + remove dependency of scipy
1 parent 1b082ac commit 199fabd

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

python/mxnet/test_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
import numpy as np
3535
import numpy.testing as npt
3636
import numpy.random as rnd
37-
import scipy.stats as ss
37+
try:
38+
import scipy.stats as ss
39+
except:
40+
ss = None
3841
try:
3942
import requests
4043
except ImportError:

tests/python/unittest/test_random.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,11 @@ def test_normal_generator():
229229
buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.norm.ppf(x, mu, sigma), 5)
230230
generator_mx = lambda x: mx.nd.random.normal(mu, sigma, shape=x, ctx=ctx, dtype=dtype).asnumpy()
231231
verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
232+
generator_mx_same_seed =\
233+
lambda x: np.concatenate(
234+
[mx.nd.random.normal(mu, sigma, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
235+
for _ in range(10)])
236+
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
232237

233238
def test_uniform_generator():
234239
ctx = mx.context.current_context()
@@ -238,6 +243,11 @@ def test_uniform_generator():
238243
buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=high - low), 5)
239244
generator_mx = lambda x: mx.nd.random.uniform(low, high, shape=x, ctx=ctx, dtype=dtype).asnumpy()
240245
verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
246+
generator_mx_same_seed = \
247+
lambda x: np.concatenate(
248+
[mx.nd.random.uniform(low, high, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
249+
for _ in range(10)])
250+
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
241251

242252
def test_gamma_generator():
243253
ctx = mx.context.current_context()
@@ -247,6 +257,11 @@ def test_gamma_generator():
247257
buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.gamma.ppf(x, a=kappa, loc=0, scale=theta), 5)
248258
generator_mx = lambda x: mx.nd.random.gamma(kappa, theta, shape=x, ctx=ctx, dtype=dtype).asnumpy()
249259
verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
260+
generator_mx_same_seed = \
261+
lambda x: np.concatenate(
262+
[mx.nd.random.gamma(kappa, theta, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
263+
for _ in range(10)])
264+
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
250265

251266
def test_exponential_generator():
252267
ctx = mx.context.current_context()
@@ -256,6 +271,11 @@ def test_exponential_generator():
256271
buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.expon.ppf(x, loc=0, scale=scale), 5)
257272
generator_mx = lambda x: mx.nd.random.exponential(scale, shape=x, ctx=ctx, dtype=dtype).asnumpy()
258273
verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
274+
generator_mx_same_seed = \
275+
lambda x: np.concatenate(
276+
[mx.nd.random.exponential(scale, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
277+
for _ in range(10)])
278+
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
259279

260280
def test_poisson_generator():
261281
ctx = mx.context.current_context()
@@ -266,6 +286,11 @@ def test_poisson_generator():
266286
probs = [ss.poisson.cdf(bucket[1], lam) - ss.poisson.cdf(bucket[0], lam) for bucket in buckets]
267287
generator_mx = lambda x: mx.nd.random.poisson(lam, shape=x, ctx=ctx, dtype=dtype).asnumpy()
268288
verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
289+
generator_mx_same_seed = \
290+
lambda x: np.concatenate(
291+
[mx.nd.random.poisson(lam, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
292+
for _ in range(10)])
293+
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
269294

270295
def test_negative_binomial_generator():
271296
ctx = mx.context.current_context()
@@ -279,24 +304,39 @@ def test_negative_binomial_generator():
279304
generator_mx = lambda x: mx.nd.random.negative_binomial(success_num, success_prob,
280305
shape=x, ctx=ctx, dtype=dtype).asnumpy()
281306
verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
307+
generator_mx_same_seed = \
308+
lambda x: np.concatenate(
309+
[mx.nd.random.negative_binomial(success_num, success_prob, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
310+
for _ in range(10)])
311+
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
282312
# Also test the Gamm-Poisson Mixture
283313
print('Gamm-Poisson Mixture Test:')
284314
alpha = 1.0 / success_num
285315
mu = (1.0 - success_prob) / success_prob / alpha
286316
generator_mx = lambda x: mx.nd.random.generalized_negative_binomial(mu, alpha,
287317
shape=x, ctx=ctx, dtype=dtype).asnumpy()
288318
verify_generator(generator=generator_mx, buckets=buckets, probs=probs)
319+
generator_mx_same_seed = \
320+
lambda x: np.concatenate(
321+
[mx.nd.random.generalized_negative_binomial(mu, alpha, shape=x // 10, ctx=ctx, dtype=dtype).asnumpy()
322+
for _ in range(10)])
323+
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
289324

290325
def test_multinomial_generator():
291326
ctx = mx.context.current_context()
292327
probs = [0.1, 0.2, 0.3, 0.05, 0.15, 0.2]
293-
sample_num = 1000000
294328
buckets = list(range(6))
295329
for dtype in ['float16', 'float32', 'float64']:
296330
print("ctx=%s, dtype=%s" %(ctx, dtype))
297331
generator_mx = lambda x: mx.nd.random.multinomial(data=mx.nd.array(np.array(probs), ctx=ctx, dtype=dtype),
298-
shape=sample_num).asnumpy()
332+
shape=x).asnumpy()
299333
verify_generator(generator_mx, buckets, probs)
334+
generator_mx_same_seed = \
335+
lambda x: np.concatenate(
336+
[mx.nd.random.multinomial(data=mx.nd.array(np.array(probs), ctx=ctx, dtype=dtype),
337+
shape=x // 10).asnumpy()
338+
for _ in range(10)])
339+
verify_generator(generator=generator_mx_same_seed, buckets=buckets, probs=probs)
300340

301341

302342
if __name__ == '__main__':

0 commit comments

Comments
 (0)