@@ -70,34 +70,41 @@ struct ur_mem_handle_t_ : _ur_object {
70
70
// Keeps device of this memory handle
71
71
ur_device_handle_t UrDevice;
72
72
73
+ // Whether this is an image or buffer
74
+ enum mem_type_t { image, buffer };
75
+ mem_type_t mem_type;
76
+
73
77
// Enumerates all possible types of accesses.
74
78
enum access_mode_t { unknown, read_write, read_only, write_only };
75
79
76
80
// Interface of the _ur_mem object
77
81
78
82
// Get the Level Zero handle of the current memory object
79
- virtual ur_result_t getZeHandle (char *&ZeHandle, access_mode_t ,
80
- ur_device_handle_t Device,
81
- const ur_event_handle_t *phWaitEvents,
82
- uint32_t numWaitEvents) = 0 ;
83
+ ur_result_t getZeHandle (char *&ZeHandle, access_mode_t ,
84
+ ur_device_handle_t Device,
85
+ const ur_event_handle_t *phWaitEvents,
86
+ uint32_t numWaitEvents);
83
87
84
88
// Get a pointer to the Level Zero handle of the current memory object
85
- virtual ur_result_t getZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
86
- ur_device_handle_t Device,
87
- const ur_event_handle_t *phWaitEvents,
88
- uint32_t numWaitEvents) = 0 ;
89
+ ur_result_t getZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
90
+ ur_device_handle_t Device,
91
+ const ur_event_handle_t *phWaitEvents,
92
+ uint32_t numWaitEvents);
89
93
90
94
// Method to get type of the derived object (image or buffer)
91
- virtual bool isImage () const = 0;
92
-
93
- virtual ~ur_mem_handle_t_ () = default ;
95
+ bool isImage () const { return mem_type == mem_type_t ::image; }
94
96
95
97
protected:
96
- ur_mem_handle_t_ (ur_context_handle_t Context)
97
- : UrContext{Context}, UrDevice{nullptr } {}
98
+ ur_mem_handle_t_ (mem_type_t type, ur_context_handle_t Context)
99
+ : UrContext{Context}, UrDevice{nullptr }, mem_type(type) {}
98
100
99
- ur_mem_handle_t_ (ur_context_handle_t Context, ur_device_handle_t Device)
100
- : UrContext{Context}, UrDevice(Device) {}
101
+ ur_mem_handle_t_ (mem_type_t type, ur_context_handle_t Context,
102
+ ur_device_handle_t Device)
103
+ : UrContext{Context}, UrDevice(Device), mem_type(type) {}
104
+
105
+ // Since the destructor isn't virtual, callers must destruct it via _ur_buffer
106
+ // or _ur_image
107
+ ~ur_mem_handle_t_ () {};
101
108
};
102
109
103
110
struct _ur_buffer final : ur_mem_handle_t_ {
@@ -110,7 +117,7 @@ struct _ur_buffer final : ur_mem_handle_t_ {
110
117
111
118
// Sub-buffer constructor
112
119
_ur_buffer (_ur_buffer *Parent, size_t Origin, size_t Size )
113
- : ur_mem_handle_t_(Parent->UrContext), Size (Size ),
120
+ : ur_mem_handle_t_(mem_type_t ::buffer, Parent->UrContext), Size (Size ),
114
121
SubBuffer{{Parent, Origin}} {
115
122
// Retain the Parent Buffer due to the Creation of the SubBuffer.
116
123
Parent->RefCount .increment ();
@@ -127,16 +134,15 @@ struct _ur_buffer final : ur_mem_handle_t_ {
127
134
// up-to-date and any data copies needed for that are performed under
128
135
// the hood.
129
136
//
130
- virtual ur_result_t getZeHandle (char *&ZeHandle, access_mode_t ,
131
- ur_device_handle_t Device,
132
- const ur_event_handle_t *phWaitEvents,
133
- uint32_t numWaitEvents) override ;
134
- virtual ur_result_t getZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
135
- ur_device_handle_t Device,
136
- const ur_event_handle_t *phWaitEvents,
137
- uint32_t numWaitEvents) override ;
137
+ ur_result_t getBufferZeHandle (char *&ZeHandle, access_mode_t ,
138
+ ur_device_handle_t Device,
139
+ const ur_event_handle_t *phWaitEvents,
140
+ uint32_t numWaitEvents);
141
+ ur_result_t getBufferZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
142
+ ur_device_handle_t Device,
143
+ const ur_event_handle_t *phWaitEvents,
144
+ uint32_t numWaitEvents);
138
145
139
- bool isImage () const override { return false ; }
140
146
bool isSubBuffer () const { return SubBuffer != std::nullopt; }
141
147
142
148
// Frees all allocations made for the buffer.
@@ -206,35 +212,33 @@ struct _ur_buffer final : ur_mem_handle_t_ {
206
212
struct _ur_image final : ur_mem_handle_t_ {
207
213
// Image constructor
208
214
_ur_image (ur_context_handle_t UrContext, ze_image_handle_t ZeImage)
209
- : ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {}
215
+ : ur_mem_handle_t_(mem_type_t ::image, UrContext), ZeImage{ZeImage} {}
210
216
211
217
_ur_image (ur_context_handle_t UrContext, ze_image_handle_t ZeImage,
212
218
bool OwnZeMemHandle)
213
- : ur_mem_handle_t_(UrContext), ZeImage{ZeImage} {
219
+ : ur_mem_handle_t_(mem_type_t ::image, UrContext), ZeImage{ZeImage} {
214
220
OwnNativeHandle = OwnZeMemHandle;
215
221
}
216
222
217
- virtual ur_result_t getZeHandle (char *&ZeHandle, access_mode_t ,
218
- ur_device_handle_t ,
219
- const ur_event_handle_t *phWaitEvents,
220
- uint32_t numWaitEvents) override {
223
+ ur_result_t getImageZeHandle (char *&ZeHandle, access_mode_t ,
224
+ ur_device_handle_t ,
225
+ const ur_event_handle_t *phWaitEvents,
226
+ uint32_t numWaitEvents) {
221
227
std::ignore = phWaitEvents;
222
228
std::ignore = numWaitEvents;
223
229
ZeHandle = reinterpret_cast <char *>(ZeImage);
224
230
return UR_RESULT_SUCCESS;
225
231
}
226
- virtual ur_result_t getZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
227
- ur_device_handle_t ,
228
- const ur_event_handle_t *phWaitEvents,
229
- uint32_t numWaitEvents) override {
232
+ ur_result_t getImageZeHandlePtr (char **&ZeHandlePtr, access_mode_t ,
233
+ ur_device_handle_t ,
234
+ const ur_event_handle_t *phWaitEvents,
235
+ uint32_t numWaitEvents) {
230
236
std::ignore = phWaitEvents;
231
237
std::ignore = numWaitEvents;
232
238
ZeHandlePtr = reinterpret_cast <char **>(&ZeImage);
233
239
return UR_RESULT_SUCCESS;
234
240
}
235
241
236
- bool isImage () const override { return true ; }
237
-
238
242
// Keep the descriptor of the image
239
243
ZeStruct<ze_image_desc_t > ZeImageDesc;
240
244
0 commit comments