Skip to content

Add test for _get_scaling #5515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ricardoV94 opened this issue Feb 23, 2022 · 2 comments · Fixed by #5544
Closed

Add test for _get_scaling #5515

ricardoV94 opened this issue Feb 23, 2022 · 2 comments · Fixed by #5544

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 23, 2022

This function has a somewhat long and complicated sequence of conditional branches. AFAICT there is no direct test, so we should add one that covers all the expected cases to facilitate future maintenance and refactoring.

Bonus points for adding type hints

def _get_scaling(total_size, shape, ndim):
"""
Gets scaling constant for logp
Parameters
----------
total_size: int or list[int]
shape: shape
shape to scale
ndim: int
ndim hint
Returns
-------
scalar
"""
if total_size is None:
coef = floatX(1)
elif isinstance(total_size, int):
if ndim >= 1:
denom = shape[0]
else:
denom = 1
coef = floatX(total_size) / floatX(denom)
elif isinstance(total_size, (list, tuple)):
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
raise TypeError(
"Unrecognized `total_size` type, expected "
"int or list of ints, got %r" % total_size
)
if Ellipsis in total_size:
sep = total_size.index(Ellipsis)
begin = total_size[:sep]
end = total_size[sep + 1 :]
if Ellipsis in end:
raise ValueError(
"Double Ellipsis in `total_size` is restricted, got %r" % total_size
)
else:
begin = total_size
end = []
if (len(begin) + len(end)) > ndim:
raise ValueError(
"Length of `total_size` is too big, "
"number of scalings is bigger that ndim, got %r" % total_size
)
elif (len(begin) + len(end)) == 0:
return floatX(1)
if len(end) > 0:
shp_end = shape[-len(end) :]
else:
shp_end = np.asarray([])
shp_begin = shape[: len(begin)]
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
coefs = begin_coef + end_coef
coef = at.prod(coefs)
else:
raise TypeError(
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
)
return at.as_tensor(floatX(coef))

Related to #4582

@Diwakar-Gupta
Copy link

@ricardoV94 can you share some resources to know what exactly _get_scaling does.

@chritter
Copy link
Contributor

@Icyshaman and I are working on this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants