Skip to content

[Vertex AI] Add APIConfig to userInfo dictionary in coders #14592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions FirebaseVertexAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ struct GenerateContentRequest: Sendable {
let toolConfig: ToolConfig?
let systemInstruction: ModelContent?

let apiConfig: APIConfig
let apiMethod: APIMethod
let options: RequestOptions
}
Expand Down Expand Up @@ -73,7 +72,7 @@ extension GenerateContentRequest {
extension GenerateContentRequest: GenerativeAIRequest {
typealias Response = GenerateContentResponse

var url: URL {
func requestURL(apiConfig: APIConfig) -> URL {
let modelURL = "\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model)"
switch apiMethod {
case .generateContent:
Expand Down
4 changes: 2 additions & 2 deletions FirebaseVertexAI/Sources/GenerativeAIRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ import Foundation
protocol GenerativeAIRequest: Sendable, Encodable {
associatedtype Response: Decodable

var url: URL { get }

var options: RequestOptions { get }

func requestURL(apiConfig: APIConfig) -> URL
}

/// Configuration parameters for sending requests to the backend.
Expand Down
22 changes: 14 additions & 8 deletions FirebaseVertexAI/Sources/GenerativeAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,19 @@ struct GenerativeAIService {

private let firebaseInfo: FirebaseInfo

/// Configuration for the backend API used by this model.
private let apiConfig: APIConfig

private let jsonDecoder: JSONDecoder
private let jsonEncoder: JSONEncoder

private let urlSession: URLSession

init(firebaseInfo: FirebaseInfo, urlSession: URLSession) {
init(firebaseInfo: FirebaseInfo, apiConfig: APIConfig, urlSession: URLSession) {
self.firebaseInfo = firebaseInfo
self.apiConfig = apiConfig
jsonDecoder = JSONDecoder(apiConfig: apiConfig)
jsonEncoder = JSONEncoder(apiConfig: apiConfig)
self.urlSession = urlSession
}

Expand Down Expand Up @@ -125,8 +134,6 @@ struct GenerativeAIService {
// Received lines that are not server-sent events (SSE); these are not prefixed with "data:"
var extraLines = ""

let decoder = JSONDecoder()
decoder.keyDecodingStrategy = .convertFromSnakeCase
for try await line in stream.lines {
VertexLog.debug(code: .loadRequestStreamResponseLine, "Stream response: \(line)")

Expand Down Expand Up @@ -167,7 +174,7 @@ struct GenerativeAIService {
// MARK: - Private Helpers

private func urlRequest<T: GenerativeAIRequest>(request: T) async throws -> URLRequest {
var urlRequest = URLRequest(url: request.url)
var urlRequest = URLRequest(url: request.requestURL(apiConfig: apiConfig))
urlRequest.httpMethod = "POST"
urlRequest.setValue(firebaseInfo.apiKey, forHTTPHeaderField: "x-goog-api-key")
urlRequest.setValue(
Expand Down Expand Up @@ -200,8 +207,7 @@ struct GenerativeAIService {
}
}

let encoder = JSONEncoder()
urlRequest.httpBody = try encoder.encode(request)
urlRequest.httpBody = try jsonEncoder.encode(request)
urlRequest.timeoutInterval = request.options.timeout

return urlRequest
Expand Down Expand Up @@ -246,7 +252,7 @@ struct GenerativeAIService {

private func parseError(responseData: Data) -> Error {
do {
let rpcError = try JSONDecoder().decode(BackendError.self, from: responseData)
let rpcError = try jsonDecoder.decode(BackendError.self, from: responseData)
logRPCError(rpcError)
return rpcError
} catch {
Expand All @@ -273,7 +279,7 @@ struct GenerativeAIService {

private func parseResponse<T: Decodable>(_ type: T.Type, from data: Data) throws -> T {
do {
return try JSONDecoder().decode(type, from: data)
return try jsonDecoder.decode(type, from: data)
} catch {
if let json = String(data: data, encoding: .utf8) {
VertexLog.error(code: .loadRequestParseResponseFailedJSON, "JSON response: \(json)")
Expand Down
4 changes: 1 addition & 3 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public final class GenerativeModel: Sendable {
self.apiConfig = apiConfig
generativeAIService = GenerativeAIService(
firebaseInfo: firebaseInfo,
apiConfig: apiConfig,
urlSession: urlSession
)
self.generationConfig = generationConfig
Expand Down Expand Up @@ -137,7 +138,6 @@ public final class GenerativeModel: Sendable {
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
apiConfig: apiConfig,
apiMethod: .generateContent,
options: requestOptions
)
Expand Down Expand Up @@ -197,7 +197,6 @@ public final class GenerativeModel: Sendable {
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
apiConfig: apiConfig,
apiMethod: .streamGenerateContent,
options: requestOptions
)
Expand Down Expand Up @@ -279,7 +278,6 @@ public final class GenerativeModel: Sendable {
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
apiConfig: apiConfig,
apiMethod: .countTokens,
options: requestOptions
)
Expand Down
53 changes: 53 additions & 0 deletions FirebaseVertexAI/Sources/Types/Internal/APIConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

/// Configuration for the generative AI backend API used by this SDK.
struct APIConfig: Sendable, Hashable {
/// The service to use for generative AI.
Expand Down Expand Up @@ -90,3 +92,54 @@ extension APIConfig {
case v1beta
}
}

// MARK: - Coding Utilities

extension CodingUserInfoKey {
static let apiConfig = {
let keyName = "com.google.firebase.VertexAI.APIConfig"
guard let userInfoKey = CodingUserInfoKey(rawValue: keyName) else {
fatalError("The key name '\(keyName)' is not a valid raw value for CodingUserInfoKey.")
}
return userInfoKey
}()
}

extension APIConfig {
static func from(userInfo: [CodingUserInfoKey: Any]) -> APIConfig {
guard let config = userInfo[CodingUserInfoKey.apiConfig] else {
fatalError(
"No value provided for '\(CodingUserInfoKey.apiConfig)' in the coder's userInfo."
)
}
guard let config = config as? APIConfig else {
fatalError("""
The value provided for '\(CodingUserInfoKey.apiConfig)' in the coder's userInfo is not of \
type '\(APIConfig.self)'; found type '\(config)'.
""")
}
return config
}
}

extension Decoder {
var apiConfig: APIConfig { APIConfig.from(userInfo: userInfo) }
}

extension JSONDecoder {
convenience init(apiConfig: APIConfig) {
self.init()
userInfo[CodingUserInfoKey.apiConfig] = apiConfig
}
}

extension Encoder {
var apiConfig: APIConfig { APIConfig.from(userInfo: userInfo) }
}

extension JSONEncoder {
convenience init(apiConfig: APIConfig) {
self.init()
userInfo[CodingUserInfoKey.apiConfig] = apiConfig
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,15 @@ import Foundation
@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
struct ImagenGenerationRequest<ImageType: ImagenImageRepresentable>: Sendable {
let model: String
let apiConfig: APIConfig
let options: RequestOptions
let instances: [ImageGenerationInstance]
let parameters: ImageGenerationParameters

init(model: String,
apiConfig: APIConfig,
options: RequestOptions,
instances: [ImageGenerationInstance],
parameters: ImageGenerationParameters) {
self.model = model
self.apiConfig = apiConfig
self.options = options
self.instances = instances
self.parameters = parameters
Expand All @@ -39,7 +36,7 @@ struct ImagenGenerationRequest<ImageType: ImagenImageRepresentable>: Sendable {
extension ImagenGenerationRequest: GenerativeAIRequest where ImageType: Decodable {
typealias Response = ImagenGenerationResponse<ImageType>

var url: URL {
func requestURL(apiConfig: APIConfig) -> URL {
return URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(model):predict")!
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ extension CountTokensRequest: GenerativeAIRequest {

var options: RequestOptions { generateContentRequest.options }

var apiConfig: APIConfig { generateContentRequest.apiConfig }

var url: URL {
func requestURL(apiConfig: APIConfig) -> URL {
let version = apiConfig.version.rawValue
let endpoint = apiConfig.service.endpoint.rawValue
return URL(string: "\(endpoint)/\(version)/\(generateContentRequest.model):countTokens")!
Expand Down Expand Up @@ -66,7 +64,7 @@ extension CountTokensRequest: Encodable {
}

func encode(to encoder: any Encoder) throws {
switch apiConfig.service {
switch encoder.apiConfig.service {
case .vertexAI:
try encodeForVertexAI(to: encoder)
case .developer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ public final class ImagenModel {
/// The resource name of the model in the backend; has the format "models/model-name".
let modelResourceName: String

/// Configuration for the backend API used by this model.
let apiConfig: APIConfig

/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

Expand All @@ -52,9 +49,9 @@ public final class ImagenModel {
requestOptions: RequestOptions,
urlSession: URLSession = .shared) {
modelResourceName = name
self.apiConfig = apiConfig
generativeAIService = GenerativeAIService(
firebaseInfo: firebaseInfo,
apiConfig: apiConfig,
urlSession: urlSession
)
self.generationConfig = generationConfig
Expand Down Expand Up @@ -129,7 +126,6 @@ public final class ImagenModel {
-> ImagenGenerationResponse<T> where T: Decodable, T: ImagenImageRepresentable {
let request = ImagenGenerationRequest<T>(
model: modelResourceName,
apiConfig: apiConfig,
options: requestOptions,
instances: [ImageGenerationInstance(prompt: prompt)],
parameters: parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ final class ImagenGenerationRequestTests: XCTestCase {
func testInitializeRequest_inlineDataImage() throws {
let request = ImagenGenerationRequest<ImagenInlineImage>(
model: modelName,
apiConfig: apiConfig,
options: requestOptions,
instances: [instance],
parameters: parameters
Expand All @@ -58,7 +57,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
XCTAssertEqual(request.instances, [instance])
XCTAssertEqual(request.parameters, parameters)
XCTAssertEqual(
request.url,
request.requestURL(apiConfig: apiConfig),
URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
)
Expand All @@ -67,7 +66,6 @@ final class ImagenGenerationRequestTests: XCTestCase {
func testInitializeRequest_fileDataImage() throws {
let request = ImagenGenerationRequest<ImagenGCSImage>(
model: modelName,
apiConfig: apiConfig,
options: requestOptions,
instances: [instance],
parameters: parameters
Expand All @@ -78,7 +76,7 @@ final class ImagenGenerationRequestTests: XCTestCase {
XCTAssertEqual(request.instances, [instance])
XCTAssertEqual(request.parameters, parameters)
XCTAssertEqual(
request.url,
request.requestURL(apiConfig: apiConfig),
URL(string:
"\(apiConfig.service.endpoint.rawValue)/\(apiConfig.version.rawValue)/\(modelName):predict")
)
Expand All @@ -89,7 +87,6 @@ final class ImagenGenerationRequestTests: XCTestCase {
func testEncodeRequest_inlineDataImage() throws {
let request = ImagenGenerationRequest<ImagenInlineImage>(
model: modelName,
apiConfig: apiConfig,
options: RequestOptions(),
instances: [instance],
parameters: parameters
Expand Down Expand Up @@ -118,7 +115,6 @@ final class ImagenGenerationRequestTests: XCTestCase {
func testEncodeRequest_fileDataImage() throws {
let request = ImagenGenerationRequest<ImagenGCSImage>(
model: modelName,
apiConfig: apiConfig,
options: RequestOptions(),
instances: [instance],
parameters: parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,16 @@ import XCTest

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
final class CountTokensRequestTests: XCTestCase {
let encoder = JSONEncoder()

let modelResourceName = "models/test-model-name"
let textPart = TextPart("test-prompt")
let vertexAPIConfig = APIConfig(service: .vertexAI, version: .v1beta)
let developerAPIConfig = APIConfig(
service: .developer(endpoint: .firebaseVertexAIProd),
version: .v1beta
let vertexEncoder = CountTokensRequestTests.encoder(
apiConfig: APIConfig(service: .vertexAI, version: .v1beta)
)
let developerEncoder = CountTokensRequestTests.encoder(
apiConfig: APIConfig(service: .developer(endpoint: .firebaseVertexAIProd), version: .v1beta)
)
let requestOptions = RequestOptions()

override func setUp() {
encoder.outputFormatting = .init(
arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
)
}

// MARK: CountTokensRequest Encoding

func testEncodeCountTokensRequest_vertexAI_minimal() throws {
Expand All @@ -48,13 +41,12 @@ final class CountTokensRequestTests: XCTestCase {
tools: nil,
toolConfig: nil,
systemInstruction: nil,
apiConfig: vertexAPIConfig,
apiMethod: .countTokens,
options: requestOptions
)
let request = CountTokensRequest(generateContentRequest: generateContentRequest)

let jsonData = try encoder.encode(request)
let jsonData = try vertexEncoder.encode(request)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
Expand Down Expand Up @@ -82,13 +74,12 @@ final class CountTokensRequestTests: XCTestCase {
tools: nil,
toolConfig: nil,
systemInstruction: nil,
apiConfig: developerAPIConfig,
apiMethod: .countTokens,
options: requestOptions
)
let request = CountTokensRequest(generateContentRequest: generateContentRequest)

let jsonData = try encoder.encode(request)
let jsonData = try developerEncoder.encode(request)

let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
Expand All @@ -108,4 +99,12 @@ final class CountTokensRequestTests: XCTestCase {
}
""")
}

static func encoder(apiConfig: APIConfig) -> JSONEncoder {
let encoder = JSONEncoder(apiConfig: apiConfig)
encoder.outputFormatting = .init(
arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes
)
return encoder
}
}
Loading