-
Notifications
You must be signed in to change notification settings - Fork 52
Specify casting rules and accepted input dtypes for reductions better #202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I feel part of this issue is related to the fact that different implementations handle the intermediate type during reduction differently. On the C++ side this is also a known problem and there is an ongoing proposal to fix this ambiguity (and all the associated issues) by requiring the intermediate type to be either the dtype of the initial value (if there's one) or the input iterator's value type (https://wg21.link/P0571). I would suggest to follow the C++ behavior if possible (assuming if the proposal would be accepted) as most libraries have a C++ implementation under the hood. The net effect is likely leaning toward keeping the input dtype. Pinging the proposal author @brycelelbach for awareness (maybe he could comment on potential pitfalls or challenges). On the CuPy side, I see no reason to not follow this behavior except that we want to be NumPy compliant in the main namespace. |
Interesting, thanks @leofang. That proposal goes into a lot of depth about the intermediary type, which matters for implementers but is something we should be agnostic about here imho. The only thing that matters is the output dtype, because that's user-observable behavior. If someone wants to write an implementation where the intermediate type is I'm very unsure about what the output dtype should be. Preserving input dtype sounds nice in theory, but I expect the TF/MXNet behavior to be a foot gun in practice. Take for example Reductions in numpy have an Side note, NumPy does have something weird going on as well, it's happy to use scalar negative values in unsigned integer reductions; unclear to me why: >>> image = np.ones((100, 2), dtype=np.uint8)
>>> np.sum(image, axis=0)
array([100, 100], dtype=uint64)
>>> np.sum(image, axis=0, initial=1)
array([101, 101], dtype=uint64)
>>> np.sum(image, axis=0, initial=-1)
array([99, 99], dtype=uint64)
>>> np.sum(image, axis=0, initial=np.array(-1))
...
TypeError: Cannot cast scalar from dtype('int64') to dtype('uint64') according to the rule 'safe' |
Searching for " // Specialization for which we do the reduction in IntermediateType to
// avoid integer overflow.
#define CASTING_SPECIALIZATION(ScalarType, IntermediateType)
...
CASTING_SPECIALIZATION(uint8, uint64);
CASTING_SPECIALIZATION(uint16, uint64);
<etc> And same for floating point: // Specialization for BF16 Reducer to fix accuracy.
// TODO: All BF16 reducers should have specializations to fix accuracy. |
JAX discussion: jax-ml/jax#3154. That leans towards making a change from its current behavior to preserving input dtype. The discussion there does make sense for, e.g., EDIT: copying the example from that issue to show JAX behavior for all dtypes: In [1]: import jax.numpy as jnp
...: from jax.test_util import dtypes
...: from jax import config; config.update('jax_enable_x64', True)
...: for dtype in dtypes.all:
...: print(dtype.__name__, "->", jnp.zeros(2, dtype).sum().dtype)
...:
bfloat16 -> bfloat16
float16 -> float16
float32 -> float32
float64 -> float64
int8 -> int64
int16 -> int64
int32 -> int64
int64 -> int64
uint8 -> uint64
uint16 -> uint64
uint32 -> uint64
uint64 -> uint64
complex64 -> complex64
complex128 -> complex128
bool_ -> int64 Also, JAX has a global setting to switch default dtype to 64-bit?? |
Personally, I think pushing considerations of overflow to userland is fine. If a user has reason to be concerned about overflow during summation or multiplication, then explicitly casting an array to dtype capable of handling larger values without overflow should be okay. That a user would be forced to explicitly think about desired dtypes is not necessarily a bad thing, imo. As has been discussed elsewhere, requiring explicit casting may incur costs, such as performance overhead due to memory allocation and/or multiple data passes. However, those costs are likely to be incurred regardless. While libraries such as NumPy may not be able to benefit from whole graph optimization, others may be able to combine casts/reductions into a single operation.
I don't agree. Naive summation techniques may overflow, especially if provided a large set of monotonically increasing positive values. However, for summands of mixed sign, various summation techniques are possible which guard against overflow by using correction terms. So I don't think overflow is guaranteed to be rampant. |
Another, potentially left-field, option is to support specifying the output dtype (e.g., The issue here, however, is the specification would be underspecified for mixed-kind input/output dtypes. So one potential constraint could be to require an output dtype be of the same kind and possibly of equal or greater size. |
To be clear, what we are discussing here only applies to a subset of statistical reductions--namely, For For Thus leaving |
The new C++ guidance is to infer the intermediate type from the operator - see P2322. |
Yes indeed. There's yet another set, namely the reductions that return bool dtype (
Agreed. The specification is probably wrong though. It says to return the default floating-point dtype, but it should be preserving dtypes (while accumulation in higher precision as needed): >>> np.std(np.ones(3, dtype=np.float32)).dtype
dtype('float32') |
Not all that left-field. From the
|
You are right. At minimum, should clarify that returning the default floating-point dtype applies when providing integer dtypes. |
For dealing with int8, the main use case in deep learning is probably quantization for inference. Intel has a comprehensive library for such usage: https://github.com/intel/lpot |
I don't have a formed opinion about special casing integers. I like the consistency of not "up-promoting", but A few (NumPy specific) points, most of which just to read and forget :):
|
Just a little context from the PyTorch side:
|
This is also relevant to the |
It seems |
@asmeurer are there plans to update trace sometime soon? Encountering dtype issues in ivy testing. |
I opened a new issue to track this #493. I would just make the change, but I'm not sure if we should also add a |
Reductions were added in PR gh-17, based on discussion in gh-10. There was quite a bit of discussion in calls as well around reductions (e.g., which ones to support, returning 0-D arrays and not scalars, naming) but not about casting rules and accepted input dtypes. It turns out that this is pretty inconsistent between libraries. Here's a script that compares
sum
,std
andprod
:And the result of that:
Conclusions
For
sum(int8)
andprod(int8)
there appear to be two options:The TensorFlow docs do note this as the one inconsistency with NumPy: https://www.tensorflow.org/api_docs/python/tf/math/reduce_sum says "Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to int64 while tensorflow returns the same dtype as the input."
The MXNet docs at https://mxnet.apache.org/versions/master/api/python/docs/api/np/generated/mxnet.np.sum.html#mxnet-np-sum do not clearly say that this is expected, even though those docs do have a list of differences with NumPy (@szha thoughts on this?).
For
std(int8)
there appear to be three options:This is all quite inconsistent, and needs to be considered more carefully for all reductions and dtypes.
The text was updated successfully, but these errors were encountered: