Skip to content

Commit 88b315e

Browse files
authored
[SYCL][Graph] Implementation of whole graph update (#13220)
Implementation of spec PR #13253
1 parent d8c0a93 commit 88b315e

18 files changed

+630
-300
lines changed

sycl/doc/design/CommandGraph.md

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ yet been implemented.
234234

235235
### Design Challenges
236236

237-
Graph update faces significant design challenges in SYCL:
237+
#### Explicit Update
238+
239+
Explicit updates of individual nodes faces significant design challenges in SYCL:
238240

239241
* Lambda capture order is explicitly undefined in C++, so the user cannot reason
240242
about the indices of arguments captured by kernel lambdas.
@@ -256,9 +258,18 @@ can be used:
256258
extension](../extensions/proposed/sycl_ext_oneapi_free_function_kernels.asciidoc)
257259
* OpenCL interop kernels created from SPIR-V source at runtime.
258260

259-
A possible future workaround lambda capture issues could be "Whole-Graph Update"
260-
where if we can guarantee that lambda capture order is the same across two
261-
different recordings we can then match parameter order when updating.
261+
A workaround for the lambda capture issues is the "Whole-Graph Update" feature.
262+
Since the lambda capture order is the same across two different recordings, we
263+
can match the parameter order when updating.
264+
265+
#### Whole-Graph Update
266+
267+
The current implementation of the whole-graph update feature relies on the
268+
assumption that both graphs should have a similar topology. Currently, the
269+
implementation only checks that both graphs have an identical number of nodes
270+
and that each node contains the same number of edges. Further investigation
271+
should be done to see if it is possible to add extra checks (e.g. check that the
272+
nodes and edges were added in the same order).
262273

263274
### Scheduler Integration
264275

sycl/source/detail/graph_impl.cpp

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -757,8 +757,9 @@ void exec_graph_impl::createCommandBuffers(
757757
exec_graph_impl::exec_graph_impl(sycl::context Context,
758758
const std::shared_ptr<graph_impl> &GraphImpl,
759759
const property_list &PropList)
760-
: MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(), MContext(Context),
761-
MRequirements(), MExecutionEvents(),
760+
: MSchedule(), MGraphImpl(GraphImpl), MPiSyncPoints(),
761+
MDevice(GraphImpl->getDevice()), MContext(Context), MRequirements(),
762+
MExecutionEvents(),
762763
MIsUpdatable(PropList.has_property<property::graph::updatable>()) {
763764

764765
// If the graph has been marked as updatable then check if the backend
@@ -1155,9 +1156,56 @@ void exec_graph_impl::duplicateNodes() {
11551156
MNodeStorage.insert(MNodeStorage.begin(), NewNodes.begin(), NewNodes.end());
11561157
}
11571158

1159+
void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {
1160+
1161+
if (MDevice != GraphImpl->getDevice()) {
1162+
throw sycl::exception(
1163+
sycl::make_error_code(errc::invalid),
1164+
"Cannot update using a graph created with a different device.");
1165+
}
1166+
if (MContext != GraphImpl->getContext()) {
1167+
throw sycl::exception(
1168+
sycl::make_error_code(errc::invalid),
1169+
"Cannot update using a graph created with a different context.");
1170+
}
1171+
1172+
if (MNodeStorage.size() != GraphImpl->MNodeStorage.size()) {
1173+
throw sycl::exception(sycl::make_error_code(errc::invalid),
1174+
"Cannot update using a graph with a different "
1175+
"topology. Mismatch found in the number of nodes.");
1176+
} else {
1177+
for (uint32_t i = 0; i < MNodeStorage.size(); ++i) {
1178+
if (MNodeStorage[i]->MSuccessors.size() !=
1179+
GraphImpl->MNodeStorage[i]->MSuccessors.size() ||
1180+
MNodeStorage[i]->MPredecessors.size() !=
1181+
GraphImpl->MNodeStorage[i]->MPredecessors.size()) {
1182+
throw sycl::exception(
1183+
sycl::make_error_code(errc::invalid),
1184+
"Cannot update using a graph with a different topology. Mismatch "
1185+
"found in the number of edges.");
1186+
}
1187+
1188+
if (MNodeStorage[i]->MCGType != GraphImpl->MNodeStorage[i]->MCGType) {
1189+
throw sycl::exception(
1190+
sycl::make_error_code(errc::invalid),
1191+
"Cannot update using a graph with mismatched node types. Each pair "
1192+
"of nodes being updated must have the same type");
1193+
}
1194+
}
1195+
}
1196+
1197+
for (uint32_t i = 0; i < MNodeStorage.size(); ++i) {
1198+
MIDCache.insert(
1199+
std::make_pair(GraphImpl->MNodeStorage[i]->MID, MNodeStorage[i]));
1200+
}
1201+
1202+
update(GraphImpl->MNodeStorage);
1203+
}
1204+
11581205
void exec_graph_impl::update(std::shared_ptr<node_impl> Node) {
11591206
this->update(std::vector<std::shared_ptr<node_impl>>{Node});
11601207
}
1208+
11611209
void exec_graph_impl::update(
11621210
const std::vector<std::shared_ptr<node_impl>> Nodes) {
11631211

@@ -1598,9 +1646,7 @@ void executable_command_graph::finalizeImpl() {
15981646

15991647
void executable_command_graph::update(
16001648
const command_graph<graph_state::modifiable> &Graph) {
1601-
(void)Graph;
1602-
throw sycl::exception(sycl::make_error_code(errc::invalid),
1603-
"Method not yet implemented");
1649+
impl->update(sycl::detail::getSyclObjImpl(Graph));
16041650
}
16051651

16061652
void executable_command_graph::update(const node &Node) {

sycl/source/detail/graph_impl.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,10 @@ class exec_graph_impl {
12811281
void createCommandBuffers(sycl::device Device,
12821282
std::shared_ptr<partition> &Partition);
12831283

1284+
/// Query for the device tied to this graph.
1285+
/// @return Device associated with graph.
1286+
sycl::device getDevice() const { return MDevice; }
1287+
12841288
/// Query for the context tied to this graph.
12851289
/// @return Context associated with graph.
12861290
sycl::context getContext() const { return MContext; }
@@ -1320,6 +1324,7 @@ class exec_graph_impl {
13201324
return MRequirements;
13211325
}
13221326

1327+
void update(std::shared_ptr<graph_impl> GraphImpl);
13231328
void update(std::shared_ptr<node_impl> Node);
13241329
void update(const std::vector<std::shared_ptr<node_impl>> Nodes);
13251330

@@ -1408,6 +1413,8 @@ class exec_graph_impl {
14081413
/// Map of nodes in the exec graph to the partition number to which they
14091414
/// belong.
14101415
std::unordered_map<std::shared_ptr<node_impl>, int> MPartitionNodes;
1416+
/// Device associated with this executable graph.
1417+
sycl::device MDevice;
14111418
/// Context associated with this executable graph.
14121419
sycl::context MContext;
14131420
/// List of requirements for enqueueing this command graph, accumulated from

sycl/test-e2e/Graph/Explicit/executable_graph_update_ordering.cpp

Lines changed: 0 additions & 16 deletions
This file was deleted.

sycl/test-e2e/Graph/Inputs/double_buffer.cpp

Lines changed: 0 additions & 104 deletions
This file was deleted.

0 commit comments

Comments
 (0)