@@ -1812,6 +1812,12 @@ static bool llama_eval_internal(
1812
1812
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
1813
1813
n_threads = N >= 32 && ggml_cpu_has_blas () && !ggml_cpu_has_gpublas () ? 1 : n_threads;
1814
1814
1815
+ struct ggml_tensor * res = gf->nodes [gf->n_nodes - 1 ];
1816
+ struct ggml_tensor * embeddings = gf->nodes [gf->n_nodes - 2 ];
1817
+
1818
+ LLAMA_ASSERT (strcmp (res->name , " result_output" ) == 0 );
1819
+ LLAMA_ASSERT (strcmp (embeddings->name , " result_norm" ) == 0 );
1820
+
1815
1821
#if GGML_USE_MPI
1816
1822
const int64_t n_layer = hparams.n_layer ;
1817
1823
ggml_mpi_graph_compute_pre (lctx.ctx_mpi , gf, n_layer);
@@ -1825,7 +1831,10 @@ static bool llama_eval_internal(
1825
1831
// }
1826
1832
ggml_metal_set_n_cb (lctx.ctx_metal , n_threads);
1827
1833
ggml_metal_graph_compute (lctx.ctx_metal , gf);
1828
- ggml_metal_get_tensor (lctx.ctx_metal , cur);
1834
+ ggml_metal_get_tensor (lctx.ctx_metal , res);
1835
+ if (!lctx.embedding .empty ()) {
1836
+ ggml_metal_get_tensor (lctx.ctx_metal , embeddings);
1837
+ }
1829
1838
} else {
1830
1839
// IMPORTANT:
1831
1840
// Since we don't have efficient Matrix x Matrix Metal multiplication yet, we fallback to vanilla
@@ -1856,12 +1865,6 @@ static bool llama_eval_internal(
1856
1865
// update kv token count
1857
1866
lctx.kv_self .n = n_past + N;
1858
1867
1859
- struct ggml_tensor * res = gf->nodes [gf->n_nodes - 1 ];
1860
- struct ggml_tensor * embeddings = gf->nodes [gf->n_nodes - 2 ];
1861
-
1862
- LLAMA_ASSERT (strcmp (res->name , " result_output" ) == 0 );
1863
- LLAMA_ASSERT (strcmp (embeddings->name , " result_norm" ) == 0 );
1864
-
1865
1868
if (cgraph_fname) {
1866
1869
ggml_graph_export (gf, cgraph_fname);
1867
1870
}
0 commit comments