Skip to content

Commit c142fe0

Browse files
committed
Send GenerateContentRequest in CountTokensRequest
1 parent edc9de3 commit c142fe0

File tree

3 files changed

+67
-23
lines changed

3 files changed

+67
-23
lines changed

Sources/GoogleAI/CountTokensRequest.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import Foundation
1717
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
1818
struct CountTokensRequest {
1919
let model: String
20-
let contents: [ModelContent]
20+
let generateContentRequest: GenerateContentRequest
2121
let options: RequestOptions
2222
}
2323

@@ -42,7 +42,7 @@ public struct CountTokensResponse {
4242
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
4343
extension CountTokensRequest: Encodable {
4444
enum CodingKeys: CodingKey {
45-
case contents
45+
case generateContentRequest
4646
}
4747
}
4848

Sources/GoogleAI/GenerateContentRequest.swift

+29-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ import Foundation
1616

1717
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
1818
struct GenerateContentRequest {
19-
/// Model name.
19+
// Model name.
2020
let model: String
21+
// If true, the `model` field above is encoded in requests; currently only required when nested in
22+
// a `CountTokensRequest`.
23+
let isModelEncoded: Bool
2124
let contents: [ModelContent]
2225
let generationConfig: GenerationConfig?
2326
let safetySettings: [SafetySetting]?
@@ -31,13 +34,38 @@ struct GenerateContentRequest {
3134
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
3235
extension GenerateContentRequest: Encodable {
3336
enum CodingKeys: String, CodingKey {
37+
case model
3438
case contents
3539
case generationConfig
3640
case safetySettings
3741
case tools
3842
case toolConfig
3943
case systemInstruction
4044
}
45+
46+
func encode(to encoder: any Encoder) throws {
47+
var container = encoder.container(keyedBy: CodingKeys.self)
48+
49+
if isModelEncoded {
50+
try container.encode(model, forKey: .model)
51+
}
52+
try container.encode(contents, forKey: .contents)
53+
if let generationConfig {
54+
try container.encode(generationConfig, forKey: .generationConfig)
55+
}
56+
if let safetySettings {
57+
try container.encode(safetySettings, forKey: .safetySettings)
58+
}
59+
if let tools {
60+
try container.encode(tools, forKey: .tools)
61+
}
62+
if let toolConfig {
63+
try container.encode(toolConfig, forKey: .toolConfig)
64+
}
65+
if let systemInstruction {
66+
try container.encode(systemInstruction, forKey: .systemInstruction)
67+
}
68+
}
4169
}
4270

4371
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)

Sources/GoogleAI/GenerativeModel.swift

+36-20
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,18 @@ public final class GenerativeModel {
175175
-> GenerateContentResponse {
176176
let response: GenerateContentResponse
177177
do {
178-
let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
179-
contents: content(),
180-
generationConfig: generationConfig,
181-
safetySettings: safetySettings,
182-
tools: tools,
183-
toolConfig: toolConfig,
184-
systemInstruction: systemInstruction,
185-
isStreaming: false,
186-
options: requestOptions)
178+
let generateContentRequest = try GenerateContentRequest(
179+
model: modelResourceName,
180+
isModelEncoded: false,
181+
contents: content(),
182+
generationConfig: generationConfig,
183+
safetySettings: safetySettings,
184+
tools: tools,
185+
toolConfig: toolConfig,
186+
systemInstruction: systemInstruction,
187+
isStreaming: false,
188+
options: requestOptions
189+
)
187190
response = try await generativeAIService.loadRequest(request: generateContentRequest)
188191
} catch {
189192
if let imageError = error as? ImageConversionError {
@@ -249,15 +252,18 @@ public final class GenerativeModel {
249252
}
250253
}
251254

252-
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
253-
contents: evaluatedContent,
254-
generationConfig: generationConfig,
255-
safetySettings: safetySettings,
256-
tools: tools,
257-
toolConfig: toolConfig,
258-
systemInstruction: systemInstruction,
259-
isStreaming: true,
260-
options: requestOptions)
255+
let generateContentRequest = GenerateContentRequest(
256+
model: modelResourceName,
257+
isModelEncoded: false,
258+
contents: evaluatedContent,
259+
generationConfig: generationConfig,
260+
safetySettings: safetySettings,
261+
tools: tools,
262+
toolConfig: toolConfig,
263+
systemInstruction: systemInstruction,
264+
isStreaming: true,
265+
options: requestOptions
266+
)
261267

262268
var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
263269
.makeAsyncIterator()
@@ -325,9 +331,19 @@ public final class GenerativeModel {
325331
public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws
326332
-> CountTokensResponse {
327333
do {
328-
let countTokensRequest = try CountTokensRequest(
334+
let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
335+
isModelEncoded: true,
336+
contents: content(),
337+
generationConfig: generationConfig,
338+
safetySettings: safetySettings,
339+
tools: tools,
340+
toolConfig: toolConfig,
341+
systemInstruction: systemInstruction,
342+
isStreaming: false,
343+
options: requestOptions)
344+
let countTokensRequest = CountTokensRequest(
329345
model: modelResourceName,
330-
contents: content(),
346+
generateContentRequest: generateContentRequest,
331347
options: requestOptions
332348
)
333349
return try await generativeAIService.loadRequest(request: countTokensRequest)

0 commit comments

Comments
 (0)