-
Notifications
You must be signed in to change notification settings - Fork 165
Rescalability layer #1455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Rescalability layer #1455
Conversation
Thanks for the work, @daviswer! Some first-level comments:
|
Thanks @scotts , I made changes 2-4 and working on unit tests now. I'll note that |
A preferred format as we can load document chunks without having to ever pull | ||
the entire document or shard file, allowing for graceful handling of large documents. | ||
Non-standard data format, though. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to confirm my understanding of the format of the pyarrow shard files. I am imagining a very large text file made up a thousands of tokens. That file is broken in multiple PyArrow shard files. Each of those PyArrow shard files is made up of multiple RecordBatches, each with a tokens field which is a list of tokens. That means each token is a 'row' (in that sense that RecordBatches are supposed to be a batch of records/rows). Is that right ?
Additionally, why do we not consider having list of tokens as a single row in the recordbatch? What is the value addition of using recordbatches here? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The general assumption is that each pyarrow shard file represents a collection of documents, rather than a single large document getting split over multiple files. But yes, each file is a collection of RecordBatches, each of which contains a single 'row' of text (a single document) in the tokens
field. We use RecordBatches because that's how pyarrow docs suggest reading/writing random-access memory-mapped files, and we put a single document per RecordBatch to minimize overhead of loading individual documents in random order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. Can every document (even large ones) fit in a single RecordBatch ?
The doc string mentions "as we can load document chunks without having to ever pull the entire document", here we are still referring to a single document being loaded in a RecordBatch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes - because pyarrow files are memory-mapped, RecordBatches (and slices) are loaded lazily. Until you request the actual slice, the RecordBatch is just metadata, so it can hold a document of any size without extra overhead
Sharing some points that we discussed over the call.
|
Some thoughts on point 2 @divyanshk : we could definitely separate this out into a file-interface system, plus a separate rescaling-enabled interface between nested iterable Datasets and generic indexable structures. However, we may lose some capabilities in the process. In particular, the current approach is set up to a) handle sharding of indexable structures where the total number of indices is not known in advance (i.e. many shard files containing many documents, with limited access bandwidth) and b) ensure that no more than one file per device is open/active at a time, regardless of number of files/devices. If we abstract away notions of files/items behind a generic indexable interface, it becomes harder to maintain these guarantees. It may be possible to still make that work but I'd have to think through the approach some more. |
ranked_state = {k:dstate.pop(k) for k in keys if "rank" in k} | ||
ranked_keylist = sorted(list(ranked_state.keys())) | ||
compiled_ranked = [ranked_state[k] for k in ranked_keylist] | ||
dstate[ranked_keylist[0][6:]] = compiled_ranked # Drop "rank0." prefix |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need some clarity on the custom re-scaling state. Will sync on this offline.
|
||
def __pop_dstate(state, device_mesh, placements, create_dtensor=False): | ||
""" | ||
Removes worker states from the StatefulDataLoader state dict, and fuses them into a single dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the question I asked on Slack the other day, does it even make sense to load state after rescaling if what is loaded is only partial, for eg we throw loader state here or by disregarding scalars or RNG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes partial loading makes sense - for example we might have a data buffer coupled with an RNG. On rescaling, we would want to preserve and reshard the data buffer, but don't need to hold onto the RNG state(s), since we now have a different number of them
Implements rescaling of checkpoints to different world sizes and numbers of workers. User specifies in advance the number of data partitions, and when saving/loading checkpoints with different total workers, stateful guarantees are maintained: seen data is not revisited until the next epoch.
Based off of the datasets in the corresponding IBM torchtitan PR, but with an adjusted rescaling and iteration mechanism to support greater flexibility and robustness (removes divisibility constraints from worker and shard counts, and guarantees only one open file per physical worker regardless of number of logical shards). Uses StatefulDataLoader and DCP to manage checkpointing from the master process. An epoch completion testing script is included for demo purposes. It is possible that the IBM datasets can be merged into the existing torchdata Nodes structure.
Changes
torchdata/stateful_dataloader/ibm_rescalable.py
examples/ibm_rescaling/rescaling_demo.py