@@ -378,9 +378,13 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
378
378
}
379
379
}
380
380
381
- static void init_view (struct ggml_allocr * alloc , struct ggml_tensor * view ) {
381
+ static void init_view (struct ggml_allocr * alloc , struct ggml_tensor * view , bool update_backend ) {
382
382
assert (view -> view_src != NULL && view -> view_src -> data != NULL );
383
- view -> backend = view -> view_src -> backend ;
383
+
384
+ if (update_backend ) {
385
+ view -> backend = view -> view_src -> backend ;
386
+ }
387
+
384
388
view -> buffer = view -> view_src -> buffer ;
385
389
view -> data = (char * )view -> view_src -> data + view -> view_offs ;
386
390
@@ -394,7 +398,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
394
398
struct hash_node * ht = alloc -> hash_table ;
395
399
if (node -> data == NULL ) {
396
400
if (ggml_is_view (node )) {
397
- init_view (alloc , node );
401
+ init_view (alloc , node , true );
398
402
} else {
399
403
// see if we can reuse a parent's buffer (inplace)
400
404
if (ggml_op_can_inplace (node -> op )) {
@@ -424,15 +428,15 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
424
428
AT_PRINTF ("reusing view parent %s (%s) for %s\n" , parent -> name , view_src -> name , node -> name );
425
429
node -> view_src = view_src ;
426
430
view_src_hn -> n_views += 1 ;
427
- init_view (alloc , node );
431
+ init_view (alloc , node , false );
428
432
return ;
429
433
}
430
434
}
431
435
else {
432
436
AT_PRINTF ("reusing parent %s for %s\n" , parent -> name , node -> name );
433
437
node -> view_src = parent ;
434
438
p_hn -> n_views += 1 ;
435
- init_view (alloc , node );
439
+ init_view (alloc , node , false );
436
440
return ;
437
441
}
438
442
}
@@ -463,7 +467,7 @@ size_t ggml_allocr_alloc_graph_n(
463
467
hash_get (ht , view_src )-> n_views += 1 ;
464
468
if (node -> buffer == NULL && node -> data != NULL ) {
465
469
// view of a pre-allocated tensor, didn't call init_view() yet
466
- init_view (alloc , node );
470
+ init_view (alloc , node , true );
467
471
}
468
472
}
469
473
@@ -474,7 +478,7 @@ size_t ggml_allocr_alloc_graph_n(
474
478
}
475
479
hash_get (ht , parent )-> n_children += 1 ;
476
480
if (ggml_is_view (parent ) && parent -> buffer == NULL && parent -> data != NULL ) {
477
- init_view (alloc , parent );
481
+ init_view (alloc , parent , true );
478
482
}
479
483
}
480
484
}
0 commit comments