Skip to content

Commit 01d5c36

Browse files
authored
Merge pull request #1185 from martindevans/feat/tensor-override
Removed Tensor Overrides Native Memory Allocations
2 parents daffe73 + dfd72dc commit 01d5c36

File tree

3 files changed

+72
-148
lines changed

3 files changed

+72
-148
lines changed

LLama/Extensions/IModelParamsExtensions.cs

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Text;
44
using LLama.Abstractions;
55
using LLama.Native;
6+
using System.Collections.Generic;
67

78
namespace LLama.Extensions;
89

@@ -45,20 +46,13 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
4546
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
4647
}
4748

48-
// Add tensor buffer overrides, if any
49-
if (@params.TensorBufferOverrides.Count > 0)
49+
// Add tensor buffer overrides
50+
unsafe
5051
{
51-
var bufferOverrideHelper = new LLamaTensorBufferOverrideHelper();
52-
disposer.Add(bufferOverrideHelper);
53-
54-
foreach (var tensorOverride in @params.TensorBufferOverrides)
55-
{
56-
bufferOverrideHelper.AddOverride(tensorOverride.Pattern, tensorOverride.BufferType);
57-
}
58-
59-
bufferOverrideHelper.ApplyToModelParams(ref result);
52+
result.tensor_buft_overrides = ConvertOverrides(@params.TensorBufferOverrides, disposer);
6053
}
6154

55+
// Add metadata overrides
6256
if (@params.MetadataOverrides.Count == 0)
6357
{
6458
unsafe
@@ -106,4 +100,69 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
106100

107101
return disposer;
108102
}
103+
104+
/// <summary>
105+
/// Get a map from name of device (`ggml_backend_buft_name`) to the device type (`ggml_backend_dev_buffer_type`)
106+
/// </summary>
107+
/// <returns>Dictionary mapping buffer type names to their handles</returns>
108+
private static IReadOnlyDictionary<string, IntPtr> GetAvailableBufferTypes()
109+
{
110+
var result = new Dictionary<string, IntPtr>();
111+
112+
var count = NativeApi.ggml_backend_dev_count();
113+
for (nuint i = 0; i < count; i++)
114+
{
115+
var dev = NativeApi.ggml_backend_dev_get(i);
116+
var buft = NativeApi.ggml_backend_dev_buffer_type(dev);
117+
118+
var name = Marshal.PtrToStringAnsi(NativeApi.ggml_backend_buft_name(buft));
119+
if (string.IsNullOrEmpty(name))
120+
continue;
121+
122+
result[name] = buft;
123+
}
124+
125+
return result;
126+
}
127+
128+
private static unsafe LLamaModelTensorBufferOverride* ConvertOverrides(List<TensorBufferOverride> overrides, GroupDisposable disposer)
129+
{
130+
// Early out if there are no overrides
131+
if (overrides.Count == 0)
132+
return null;
133+
134+
var bufferTypes = GetAvailableBufferTypes();
135+
136+
var overridesCount = 0;
137+
var overridesArray = new LLamaModelTensorBufferOverride[overrides.Count + 1];
138+
139+
foreach (var @override in overrides)
140+
{
141+
// Check if we have this buffer type
142+
if (!bufferTypes.TryGetValue(@override.BufferType, out var bufferType))
143+
continue;
144+
145+
// Create null terminated string and pin this memory so it can be passed to native code
146+
var patternBytes = Encoding.UTF8.GetBytes(@override.Pattern + "\0");
147+
var patternPin = patternBytes.AsMemory().Pin();
148+
disposer.Add(patternPin);
149+
150+
// Add the item to the overridesArray
151+
overridesArray[overridesCount++] = new()
152+
{
153+
Pattern = (byte*)patternPin.Pointer,
154+
BufferType = bufferType
155+
};
156+
}
157+
158+
// Early out if there were no valid overrides
159+
if (overridesCount == 0)
160+
return null;
161+
162+
// Pin it so it can be safely passed across to native code
163+
var overrideArrayPin = overridesArray.AsMemory().Pin();
164+
disposer.Add(overrideArrayPin);
165+
166+
return (LLamaModelTensorBufferOverride*)overrideArrayPin.Pointer;
167+
}
109168
}

LLama/Native/LLamaModelTensorBufferOverride.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ namespace LLama.Native
77
/// Original type: llama_model_tensor_buft_override
88
/// </summary>
99
[StructLayout(LayoutKind.Sequential)]
10-
public struct LLamaModelTensorBufferOverride
10+
public unsafe struct LLamaModelTensorBufferOverride
1111
{
1212
/// <summary>
1313
/// Tensor name pattern to match
1414
/// </summary>
15-
public IntPtr Pattern;
15+
public byte* Pattern;
1616

1717
/// <summary>
1818
/// Backend buffer type to use for matching tensors, as obtained via ggml_backend_dev_buffer_type

LLama/Native/LLamaTensorBufferOverrideHelper.cs

Lines changed: 0 additions & 135 deletions
This file was deleted.

0 commit comments

Comments
 (0)