-
Notifications
You must be signed in to change notification settings - Fork 131
/
Copy pathmodel.dart
398 lines (378 loc) · 13.7 KB
/
model.dart
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import 'dart:async';
import 'package:http/http.dart' as http;
import 'api.dart';
import 'client.dart';
import 'content.dart';
import 'function_calling.dart';
const _apiVersion = 'v1beta';
Uri _googleAIBaseUri(RequestOptions? options) => Uri.https(
'generativelanguage.googleapis.com', options?.apiVersion ?? _apiVersion);
enum Task {
generateContent,
streamGenerateContent,
countTokens,
embedContent,
batchEmbedContents;
}
/// Configuration for how a [GenerativeModel] makes requests.
///
/// This allows overriding the API version in use which may be required to use
/// some beta features.
final class RequestOptions {
/// The API version used to make requests.
///
/// By default the version is `v1beta`.
/// See https://ai.google.dev/gemini-api/docs/api-versions for details.
final String? apiVersion;
const RequestOptions({this.apiVersion});
}
/// A multimodel generative model (like Gemini).
///
/// Allows generating content, creating embeddings, and counting the number of
/// tokens in a piece of content.
final class GenerativeModel {
/// The full model code split into a prefix ("models" or "tunedModels") and
/// the model name.
final ({String prefix, String name}) _model;
final List<SafetySetting> _safetySettings;
final GenerationConfig? _generationConfig;
final List<Tool>? _tools;
final ApiClient _client;
final Uri _baseUri;
final Content? _systemInstruction;
final ToolConfig? _toolConfig;
/// Create a [GenerativeModel] backed by the generative model named [model].
///
/// The [model] argument can be a model name (such as
/// `'gemini-1.5-flash-latest'`) or a model code (such as
/// `'models/gemini-1.5-flash-latest'` or `'tunedModels/my-model'`).
/// There is no creation time check for whether the `model` string identifies
/// a known and supported model. If not, attempts to generate content
/// will fail.
///
/// A Google Cloud [apiKey] is required for all requests.
/// See documentation about [API keys][] for more information.
///
/// [API keys]: https://cloud.google.com/docs/authentication/api-keys "Google Cloud API keys"
///
/// The optional [safetySettings] and [generationConfig] can be used to
/// control and guide the generation. See [SafetySetting] and
/// [GenerationConfig] for details.
///
/// Content creation requests are sent to a server through the [httpClient],
/// which can be used to control, for example, the number of allowed
/// concurrent requests.
/// If the `httpClient` is omitted, a new [http.Client] is created for each
/// request.
///
/// Functions that the model may call while generating content can be passed
/// in [tools]. Tool usage by the model can be configured with [toolConfig].
/// Tools and tool configuration can be overridden for individual requests
/// with arguments to [generateContent] or [generateContentStream].
///
/// A [Content.system] can be passed to [systemInstruction] to give
/// high priority instructions to the model.
factory GenerativeModel({
required String model,
required String apiKey,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
List<Tool>? tools,
http.Client? httpClient,
RequestOptions? requestOptions,
Content? systemInstruction,
ToolConfig? toolConfig,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey, httpClient: httpClient),
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri(requestOptions),
tools: tools,
systemInstruction: systemInstruction,
toolConfig: toolConfig,
);
GenerativeModel._withClient({
required ApiClient client,
required String model,
required List<SafetySetting> safetySettings,
required GenerationConfig? generationConfig,
required Uri baseUri,
required List<Tool>? tools,
required Content? systemInstruction,
required ToolConfig? toolConfig,
}) : _model = _normalizeModelName(model),
_baseUri = baseUri,
_safetySettings = safetySettings,
_generationConfig = generationConfig,
_tools = tools,
_systemInstruction = systemInstruction,
_toolConfig = toolConfig,
_client = client;
/// Returns the model code for a user friendly model name.
///
/// If the model name is already a model code (contains a `/`), use the parts
/// directly. Otherwise, return a `models/` model code.
static ({String prefix, String name}) _normalizeModelName(String modelName) {
if (!modelName.contains('/')) return (prefix: 'models', name: modelName);
final parts = modelName.split('/');
return (prefix: parts.first, name: parts.skip(1).join('/'));
}
Uri _taskUri(Task task) => _baseUri.replace(
pathSegments: _baseUri.pathSegments
.followedBy([_model.prefix, '${_model.name}:${task.name}']));
/// Generates content responding to [prompt].
///
/// Sends a "generateContent" API request for the configured model,
/// and waits for the response.
///
/// The [safetySettings], [generationConfig], [tools], and [toolConfig],
/// override the arguments of the same name passed to the
/// [GenerativeModel.new] constructor. Each argument, when non-null,
/// overrides the model level configuration in its entirety.
///
/// Example:
/// ```dart
/// final response = await model.generateContent([Content.text(prompt)]);
/// print(response.text);
/// ```
Future<GenerateContentResponse> generateContent(
Iterable<Content> prompt, {
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
}) =>
makeRequest(
Task.generateContent,
_generateContentRequest(
prompt,
safetySettings: safetySettings,
generationConfig: generationConfig,
tools: tools,
toolConfig: toolConfig,
),
parseGenerateContentResponse);
/// Generates a stream of content responding to [prompt].
///
/// Sends a "streamGenerateContent" API request for the configured model,
/// and waits for the response.
///
/// The [safetySettings], [generationConfig], [tools], and [toolConfig],
/// override the arguments of the same name passed to the
/// [GenerativeModel.new] constructor. Each argument, when non-null,
/// overrides the model level configuration in its entirety.
///
/// Example:
/// ```dart
/// final responses = await model.generateContent([Content.text(prompt)]);
/// await for (final response in responses) {
/// print(response.text);
/// }
/// ```
Stream<GenerateContentResponse> generateContentStream(
Iterable<Content> prompt, {
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
}) {
final response = _client.streamRequest(
_taskUri(Task.streamGenerateContent),
_generateContentRequest(
prompt,
safetySettings: safetySettings,
generationConfig: generationConfig,
tools: tools,
toolConfig: toolConfig,
));
return response.map(parseGenerateContentResponse);
}
/// Counts the total number of tokens in [contents].
///
/// Sends a "countTokens" API request for the configured model,
/// and waits for the response.
///
/// The [safetySettings], [generationConfig], [tools], and [toolConfig],
/// override the arguments of the same name passed to the
/// [GenerativeModel.new] constructor. Each argument, when non-null,
/// overrides the model level configuration in its entirety.
///
/// Example:
/// ```dart
/// final promptContent = [Content.text(prompt)];
/// final totalTokens =
/// (await model.countTokens(promptContent)).totalTokens;
/// if (totalTokens > maxPromptSize) {
/// print('Prompt is too long!');
/// } else {
/// final response = await model.generateContent(promptContent);
/// print(response.text);
/// }
/// ```
Future<CountTokensResponse> countTokens(
Iterable<Content> contents, {
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
}) =>
makeRequest(
Task.countTokens,
{
'generateContentRequest': _generateContentRequest(
contents,
safetySettings: safetySettings,
generationConfig: generationConfig,
tools: tools,
toolConfig: toolConfig,
)
},
parseCountTokensResponse);
/// Creates an embedding (list of float values) representing [content].
///
/// Sends a "embedContent" API request for the configured model,
/// and waits for the response.
///
/// Example:
/// ```dart
/// final promptEmbedding =
/// (await model.embedContent([Content.text(prompt)])).embedding.values;
/// ```
Future<EmbedContentResponse> embedContent(
Content content, {
TaskType? taskType,
String? title,
int? outputDimensionality,
}) =>
makeRequest(
Task.embedContent,
{
'content': content.toJson(),
if (taskType != null) 'taskType': taskType.toJson(),
if (title != null) 'title': title,
if (outputDimensionality != null)
'outputDimensionality': outputDimensionality,
},
parseEmbedContentResponse);
/// Creates embeddings (list of float values) representing each content in
/// [requests].
///
/// Sends a "batchEmbedContents" API request for the configured model.
///
/// Example:
/// ```dart
/// final requests = [
/// EmbedContentRequest(Content.text(first)),
/// EmbedContentRequest(Content.text(second))
/// ];
/// final promptEmbeddings =
/// (await model.embedContent(requests)).embedding.values;
/// ```
Future<BatchEmbedContentsResponse> batchEmbedContents(
Iterable<EmbedContentRequest> requests,
) =>
makeRequest(
Task.batchEmbedContents,
{
'requests': requests
.map((r) =>
r.toJson(defaultModel: '${_model.prefix}/${_model.name}'))
.toList()
},
parseBatchEmbedContentsResponse);
Map<String, Object?> _generateContentRequest(
Iterable<Content> contents, {
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
}) {
safetySettings ??= _safetySettings;
generationConfig ??= _generationConfig;
tools ??= _tools;
toolConfig ??= _toolConfig;
return {
'model': '${_model.prefix}/${_model.name}',
'contents': contents.map((c) => c.toJson()).toList(),
if (safetySettings.isNotEmpty)
'safetySettings': safetySettings.map((s) => s.toJson()).toList(),
if (generationConfig != null)
'generationConfig': generationConfig.toJson(),
if (tools != null) 'tools': tools.map((t) => t.toJson()).toList(),
if (toolConfig != null) 'toolConfig': toolConfig.toJson(),
if (_systemInstruction case final systemInstruction?)
'systemInstruction': systemInstruction.toJson(),
};
}
}
extension VertexExtensions on GenerativeModel {
/// Make a unary request for [task] with JSON encodable [params].
Future<T> makeRequest<T>(Task task, Map<String, Object?> params,
T Function(Map<String, Object?>) parse) =>
_client.makeRequest(_taskUri(task), params).then(parse);
}
/// Creates a model with an overridden [ApiClient] for testing.
///
/// Package private test-only method.
GenerativeModel createModelWithClient({
required String model,
required ApiClient client,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
RequestOptions? requestOptions,
Content? systemInstruction,
List<Tool>? tools,
ToolConfig? toolConfig,
}) =>
GenerativeModel._withClient(
client: client,
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: _googleAIBaseUri(requestOptions),
systemInstruction: systemInstruction,
tools: tools,
toolConfig: toolConfig,
);
/// Creates a model with an overridden base URL to communicate with a different
/// backend.
///
/// Used from a `src/` import in the Vertex AI SDK.
// TODO: https://github.com/google/generative-ai-dart/issues/111 - Changes to
// this API need to be coordinated with the vertex AI SDK.
GenerativeModel createModelWithBaseUri({
required String model,
required String apiKey,
required Uri baseUri,
FutureOr<Map<String, String>> Function()? requestHeaders,
List<SafetySetting> safetySettings = const [],
GenerationConfig? generationConfig,
List<Tool>? tools,
Content? systemInstruction,
ToolConfig? toolConfig,
}) =>
GenerativeModel._withClient(
client: HttpApiClient(apiKey: apiKey, requestHeaders: requestHeaders),
model: model,
safetySettings: safetySettings,
generationConfig: generationConfig,
baseUri: baseUri,
systemInstruction: systemInstruction,
tools: tools,
toolConfig: toolConfig,
);