3
3
using System . Text ;
4
4
using LLama . Abstractions ;
5
5
using LLama . Native ;
6
+ using System . Collections . Generic ;
6
7
7
8
namespace LLama . Extensions ;
8
9
@@ -45,20 +46,13 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
45
46
result . tensor_split = ( float * ) disposer . Add ( @params . TensorSplits . Pin ( ) ) . Pointer ;
46
47
}
47
48
48
- // Add tensor buffer overrides, if any
49
- if ( @params . TensorBufferOverrides . Count > 0 )
49
+ // Add tensor buffer overrides
50
+ unsafe
50
51
{
51
- var bufferOverrideHelper = new LLamaTensorBufferOverrideHelper ( ) ;
52
- disposer . Add ( bufferOverrideHelper ) ;
53
-
54
- foreach ( var tensorOverride in @params . TensorBufferOverrides )
55
- {
56
- bufferOverrideHelper . AddOverride ( tensorOverride . Pattern , tensorOverride . BufferType ) ;
57
- }
58
-
59
- bufferOverrideHelper . ApplyToModelParams ( ref result ) ;
52
+ result . tensor_buft_overrides = ConvertOverrides ( @params . TensorBufferOverrides , disposer ) ;
60
53
}
61
54
55
+ // Add metadata overrides
62
56
if ( @params . MetadataOverrides . Count == 0 )
63
57
{
64
58
unsafe
@@ -106,4 +100,69 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
106
100
107
101
return disposer ;
108
102
}
103
+
104
+ /// <summary>
105
+ /// Get a map from name of device (`ggml_backend_buft_name`) to the device type (`ggml_backend_dev_buffer_type`)
106
+ /// </summary>
107
+ /// <returns>Dictionary mapping buffer type names to their handles</returns>
108
+ private static IReadOnlyDictionary < string , IntPtr > GetAvailableBufferTypes ( )
109
+ {
110
+ var result = new Dictionary < string , IntPtr > ( ) ;
111
+
112
+ var count = NativeApi . ggml_backend_dev_count ( ) ;
113
+ for ( nuint i = 0 ; i < count ; i ++ )
114
+ {
115
+ var dev = NativeApi . ggml_backend_dev_get ( i ) ;
116
+ var buft = NativeApi . ggml_backend_dev_buffer_type ( dev ) ;
117
+
118
+ var name = Marshal . PtrToStringAnsi ( NativeApi . ggml_backend_buft_name ( buft ) ) ;
119
+ if ( string . IsNullOrEmpty ( name ) )
120
+ continue ;
121
+
122
+ result [ name ] = buft ;
123
+ }
124
+
125
+ return result ;
126
+ }
127
+
128
+ private static unsafe LLamaModelTensorBufferOverride * ConvertOverrides ( List < TensorBufferOverride > overrides , GroupDisposable disposer )
129
+ {
130
+ // Early out if there are no overrides
131
+ if ( overrides . Count == 0 )
132
+ return null ;
133
+
134
+ var bufferTypes = GetAvailableBufferTypes ( ) ;
135
+
136
+ var overridesCount = 0 ;
137
+ var overridesArray = new LLamaModelTensorBufferOverride [ overrides . Count + 1 ] ;
138
+
139
+ foreach ( var @override in overrides )
140
+ {
141
+ // Check if we have this buffer type
142
+ if ( ! bufferTypes . TryGetValue ( @override . BufferType , out var bufferType ) )
143
+ continue ;
144
+
145
+ // Create null terminated string and pin this memory so it can be passed to native code
146
+ var patternBytes = Encoding . UTF8 . GetBytes ( @override . Pattern + "\0 " ) ;
147
+ var patternPin = patternBytes . AsMemory ( ) . Pin ( ) ;
148
+ disposer . Add ( patternPin ) ;
149
+
150
+ // Add the item to the overridesArray
151
+ overridesArray [ overridesCount ++ ] = new ( )
152
+ {
153
+ Pattern = ( byte * ) patternPin . Pointer ,
154
+ BufferType = bufferType
155
+ } ;
156
+ }
157
+
158
+ // Early out if there were no valid overrides
159
+ if ( overridesCount == 0 )
160
+ return null ;
161
+
162
+ // Pin it so it can be safely passed across to native code
163
+ var overrideArrayPin = overridesArray . AsMemory ( ) . Pin ( ) ;
164
+ disposer . Add ( overrideArrayPin ) ;
165
+
166
+ return ( LLamaModelTensorBufferOverride * ) overrideArrayPin . Pointer ;
167
+ }
109
168
}
0 commit comments