You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
move kv_chunk_size_min concern to callsite, since if caller knows final kv_chunk_size: they can notice when no chunking would happen at all, and use fast-path. note: there's a question of whether that concern belongs *inside* the algorithm. but it'd feel weird for chunked attention to have a no-chunking-at-all branch.
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
143
-
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
144
142
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
145
143
Returns:
146
144
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
0 commit comments