@@ -378,9 +378,13 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
378
378
}
379
379
}
380
380
381
- static void init_view (struct ggml_allocr * alloc , struct ggml_tensor * view ) {
381
+ static void init_view (struct ggml_allocr * alloc , struct ggml_tensor * view , bool update_backend ) {
382
382
assert (view -> view_src != NULL && view -> view_src -> data != NULL );
383
- view -> backend = view -> view_src -> backend ;
383
+
384
+ if (update_backend ) {
385
+ view -> backend = view -> view_src -> backend ;
386
+ }
387
+
384
388
view -> buffer = view -> view_src -> buffer ;
385
389
view -> data = (char * )view -> view_src -> data + view -> view_offs ;
386
390
@@ -394,7 +398,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
394
398
struct hash_node * ht = alloc -> hash_table ;
395
399
if (node -> data == NULL ) {
396
400
if (ggml_is_view (node )) {
397
- init_view (alloc , node );
401
+ init_view (alloc , node , true );
398
402
} else {
399
403
// see if we can reuse a parent's buffer (inplace)
400
404
if (ggml_op_can_inplace (node -> op )) {
@@ -424,15 +428,14 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
424
428
AT_PRINTF ("reusing view parent %s (%s) for %s\n" , parent -> name , view_src -> name , node -> name );
425
429
node -> view_src = view_src ;
426
430
view_src_hn -> n_views += 1 ;
427
- init_view (alloc , node );
431
+ init_view (alloc , node , false );
428
432
return ;
429
433
}
430
- }
431
- else {
434
+ } else {
432
435
AT_PRINTF ("reusing parent %s for %s\n" , parent -> name , node -> name );
433
436
node -> view_src = parent ;
434
437
p_hn -> n_views += 1 ;
435
- init_view (alloc , node );
438
+ init_view (alloc , node , false );
436
439
return ;
437
440
}
438
441
}
@@ -463,7 +466,7 @@ size_t ggml_allocr_alloc_graph_n(
463
466
hash_get (ht , view_src )-> n_views += 1 ;
464
467
if (node -> buffer == NULL && node -> data != NULL ) {
465
468
// view of a pre-allocated tensor, didn't call init_view() yet
466
- init_view (alloc , node );
469
+ init_view (alloc , node , true );
467
470
}
468
471
}
469
472
@@ -474,7 +477,7 @@ size_t ggml_allocr_alloc_graph_n(
474
477
}
475
478
hash_get (ht , parent )-> n_children += 1 ;
476
479
if (ggml_is_view (parent ) && parent -> buffer == NULL && parent -> data != NULL ) {
477
- init_view (alloc , parent );
480
+ init_view (alloc , parent , true );
478
481
}
479
482
}
480
483
}
0 commit comments