-
Notifications
You must be signed in to change notification settings - Fork 530
Add API to donate input buffer for dynamo execution #6587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
507ebc1
Add api to buffer donation
JackCaoG ebbafda
add SetBufferDonors
JackCaoG ccbdc53
Add get_buffer_donation, fix a bug where mark+step with devicedata wi…
JackCaoG 92480b4
add python tests, currently they will fail if being run together due …
JackCaoG 5671ee5
add test to testing script
JackCaoG 0b6c3a8
make sure compilation hash tracks buffer donor index
JackCaoG 3561bdf
only enable buffer donor aliasing in dynamo
JackCaoG e1c3a44
Fix a bug where warm up cache might accidentlly execute the graph
JackCaoG 0b60229
Add test for non-dynamo buffer donation
JackCaoG d12446b
add more test
JackCaoG 93af291
remove debugging messages
JackCaoG fda1c41
add comment
JackCaoG 866b2bc
remove debug messages and fix tests
JackCaoG File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.debug.metrics as met | ||
from torch_xla.core.dynamo_bridge import AliasWithBufferDonorContext | ||
|
||
|
||
class TestBufferDonationUtil(unittest.TestCase): | ||
|
||
def test_hash_with_buffer_donor(self): | ||
device = xm.xla_device() | ||
input = torch.randn(5, 5).to(device) | ||
res = torch.cos(input) | ||
hash_no_donor = torch_xla._XLAC._get_graph_hash([res]) | ||
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) | ||
# without the AliasWithBufferDonorContext, buffer donor will be ignored, | ||
# so we still expect the hash to be the same. | ||
hash_with_donor = torch_xla._XLAC._get_graph_hash([res]) | ||
self.assertEqual(hash_no_donor, hash_with_donor) | ||
|
||
with AliasWithBufferDonorContext(True) as context: | ||
hash_with_donor_and_context = torch_xla._XLAC._get_graph_hash([res]) | ||
self.assertNotEqual(hash_no_donor, hash_with_donor_and_context) | ||
|
||
|
||
class TestDynamoBufferDonationAliasing(unittest.TestCase): | ||
|
||
def dummy_inplace_add(self, input): | ||
input += 1 | ||
return | ||
|
||
def dummy_add(self, input): | ||
return input + 1 | ||
|
||
def test_manual_buffer_donation(self): | ||
device = xm.xla_device() | ||
input = torch.randn(5, 5).to(device) | ||
input_cloned = torch.clone(input) | ||
dummy_inplace_add_compiled = torch.compile( | ||
self.dummy_inplace_add, backend='openxla') | ||
|
||
met.clear_all() | ||
# input is a device_data, we should be able to set the buffer donation field. | ||
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) | ||
# make sure buffer donation setting is correctly updated | ||
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input)) | ||
self.assertIn('XlaSetBufferDonation', met.counter_names()) | ||
self.assertEqual(met.counter_value('XlaSetBufferDonation'), 1) | ||
dummy_inplace_add_compiled(input) | ||
torch.allclose(input_cloned.cpu() + 1, input.cpu()) | ||
|
||
def test_manual_buffer_donation_for_non_inplce_op(self): | ||
device = xm.xla_device() | ||
input = torch.randn(5, 5).to(device) | ||
input_cloned = torch.clone(input) | ||
dummy_add_compiled = torch.compile(self.dummy_add, backend='openxla') | ||
|
||
met.clear_all() | ||
# input is a device_data, we should be able to set the buffer donation field. | ||
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) | ||
# make sure buffer donation setting is correctly updated | ||
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input)) | ||
self.assertIn('XlaSetBufferDonation', met.counter_names()) | ||
self.assertEqual(met.counter_value('XlaSetBufferDonation'), 1) | ||
|
||
res = dummy_add_compiled(input) | ||
# check input's buffer has been aliased. | ||
xm.wait_device_ops() | ||
self.assertIn('Data Handle: Deleted', | ||
torch_xla._XLAC._get_xla_tensor_debug_info(input)) | ||
torch.allclose(input_cloned.cpu() + 1, res.cpu()) | ||
|
||
def test_manual_buffer_donation_for_inplce_op_repeat(self): | ||
# use a different function than above dummy add otherwise XLA won't recompile | ||
def dummy_inplace(input): | ||
input += (0.3 * torch.cos(input)) | ||
|
||
device = xm.xla_device() | ||
input = torch.randn(5, 5).to(device) | ||
input_cloned = torch.clone(input) | ||
dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla') | ||
xm.mark_step() | ||
met.clear_all() | ||
# input is a device_data, we should be able to set the buffer donation field. | ||
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) | ||
# make sure buffer donation setting is correctly updated | ||
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input)) | ||
|
||
for _ in range(100): | ||
dummy_inplace_add_compiled(input) | ||
# should_donate_buffer field is attached to the buffer and won't be inherited to | ||
# the output buffer(unless execution is a no-op). However dynamo don't track this | ||
# field so it will keep executing the graph with input buffer being aliased. | ||
self.assertFalse(torch_xla._XLAC._get_buffer_donation(input)) | ||
# there shouldn't be any recompilation even `should_donate_buffer` field changed after | ||
# first execution. This is because Dynamo does not trace this internal field for xla. | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
|
||
def test_buffer_donation_on_non_data_tensor(self): | ||
device = xm.xla_device() | ||
input = torch.randn(5, 5).to(device) | ||
res = input + 1 | ||
|
||
met.clear_all() | ||
# res now points to a `Add` IR, only data's buffer can be aliased | ||
self.assertFalse(torch_xla._XLAC._set_buffer_donation(res, True)) | ||
self.assertFalse(torch_xla._XLAC._get_buffer_donation(res)) | ||
self.assertNotIn('XlaSetBufferDonation', met.counter_names()) | ||
|
||
|
||
class TestNonDynamoBufferDonationAliasing(unittest.TestCase): | ||
|
||
def dummy_fn(self, input): | ||
return torch.cos(torch.sin(input)) | ||
|
||
# Currently let's skip buffer donation api for the non-dynamo use case | ||
def test_buffer_donation_skip_for_non_dynamo(self): | ||
device = xm.xla_device() | ||
input = torch.randn(5, 5).to(device) | ||
xm.mark_step() | ||
met.clear_all() | ||
|
||
# We should be able to set buffer donation for input tensor, but when mark_step | ||
# triggered, the buffer donation should be ignored. | ||
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) | ||
res = self.dummy_fn(input) | ||
xm.mark_step() | ||
# Make sure that input buffer is not aliased and can be used for other compuations. | ||
# Also make sure that buffer_donation will not trigger recompilation in non-dynamo. | ||
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, False)) | ||
res2 = self.dummy_fn(input) | ||
xm.mark_step() | ||
torch.allclose(res.cpu(), res2.cpu()) | ||
self.assertEqual(met.metric_data('CompileTime')[0], 1) | ||
|
||
def test_no_op_mark_step_keep_buffer_donation(self): | ||
device = xm.xla_device() | ||
input = torch.randn(5, 5).to(device) | ||
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True)) | ||
xm.mark_step() | ||
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input)) | ||
xm.mark_step() | ||
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input)) | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.