Skip to content

Commit 9e10601

Browse files
zhiweij1pytorchmergebot
authored andcommitted
[XPU] Add an implict conversion from XPUStream to sycl::queue* (pytorch#148646)
# Motivation Currently, in Pytorch XPU, `cudaStream_t` is mapped to `sycl::queue&`, so an implicit cast from `XPUStream` to `sycl::queue&` is provided just like `CUDAStream` has an implicit cast to `cudaStream_t`. But on the SYCLomatic side, we migrate `cudaStream_t` to `sycl::queue*` but not `sycl::queue&` (One reason is that `cudaStream_t` is actually a pointer so users can do anything with that integer. Another reason is that the early `sycl::queue` was not impl-ed by a pointer, so copy by value is not desirable.) Without this PR: ``` cudaStream_t a = getCurrentCUDAStream(); cudaStream_t b = getCurrentCUDAStream().stream(); ``` need be migrated to: ``` queue_ptr a = &(sycl::queue&)getCurrentXPUStream(); queue_ptr b = &(getCurrentXPUStream().queue()); ``` With this PR: ``` queue_ptr a = getCurrentXPUStream(); queue_ptr b = &(getCurrentXPUStream().queue()); ``` Pull Request resolved: pytorch#148646 Approved by: https://github.com/guangyey, https://github.com/EikanWang
1 parent c067127 commit 9e10601

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

c10/xpu/XPUStream.h

+5
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ class C10_XPU_API XPUStream {
5959
return queue();
6060
}
6161

62+
/// Implicit conversion to sycl::queue*.
63+
operator sycl::queue*() const {
64+
return &queue();
65+
}
66+
6267
/// Implicit conversion to Stream (a.k.a., forget that the stream is a
6368
/// XPU stream).
6469
operator Stream() const {

c10/xpu/test/impl/XPUStreamTest.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ TEST(XPUStreamTest, ExternalTest) {
223223
ASSERT_TRUE(curStream == myStream);
224224
ASSERT_TRUE(&(curStream.queue()) == stream);
225225

226+
sycl::queue* q_ptr = curStream;
227+
ASSERT_TRUE(q_ptr == stream);
228+
226229
delete stream;
227230
}
228231

0 commit comments

Comments
 (0)