You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As a sub-item for enabling 2D sharding with minibatch=True, we want to create a protoype of SPMD that uses local batching. This will hopefully be the beginning of an implementation that might substitute the current workaround.
The text was updated successfully, but these errors were encountered:
We plan on creating this prototype based on the existing mark_sharding. We need an API such that each host TransferShardsToDevice for the devices they are associated with. We can consider the existing pipeline for mark_sharding:
I am still working on this prototype; I wanted to give a follow-up in the progress:
I have done an attempt of prototyping by just trying to create CreateGlobalTensorData -> CreateGlobalShardedData -> TransferShardsToDevice, and creating something like load_local_shards_. I however ran into a series of issues related object creation.
I have increased the scope of this prototype for the entire flow cited in #8842 (comment).
I have ran into a couple type coversion issues with xla::Shape, but I currently believe there is a path forward by leveraging xla::ShapeUtil::MakeShape. This will hopefully unblock us to run some tests.
As a sub-item for enabling 2D sharding with minibatch=True, we want to create a protoype of SPMD that uses local batching. This will hopefully be the beginning of an implementation that might substitute the current workaround.
The text was updated successfully, but these errors were encountered: