Skip to content

feat: add requestFactory to Configuration #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion Sources/Segment/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
//

import Foundation
#if os(Linux)
import FoundationNetworking
#endif

public typealias AdvertisingIdCallback = () -> String?

Expand All @@ -30,8 +33,10 @@ public class Configuration {
var autoAddSegmentDestination: Bool = true
var apiHost: String = HTTPClient.getDefaultAPIHost()
var cdnHost: String = HTTPClient.getDefaultCDNHost()
var errorHandler: ((Error) -> Void)?
var requestFactory: ((URLRequest) -> URLRequest)? = nil
var errorHandler: ((Error) -> Void)? = nil
}

internal var values: Values

public init(writeKey: String) {
Expand Down Expand Up @@ -97,6 +102,12 @@ public extension Configuration {
return self
}

@discardableResult
func requestFactory(_ value: @escaping (URLRequest) -> URLRequest) -> Configuration {
values.requestFactory = value
return self
}

@discardableResult
func errorHandler(_ value: @escaping (Error) -> Void) -> Configuration {
values.errorHandler = value
Expand Down
33 changes: 20 additions & 13 deletions Sources/Segment/Utilities/HTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ public class HTTPClient {
completion(.failure(HTTPClientErrors.failedToOpenBatch))
return nil
}

var urlRequest = URLRequest(url: uploadURL)
urlRequest.httpMethod = "POST"


let urlRequest = configuredRequest(for: uploadURL, method: "POST")

let dataTask = session.uploadTask(with: urlRequest, fromFile: batch) { [weak self] (data, response, error) in
if let error = error {
self?.analytics?.log(message: "Error uploading request \(error.localizedDescription).")
Expand Down Expand Up @@ -106,9 +105,8 @@ public class HTTPClient {
return
}

var urlRequest = URLRequest(url: settingsURL)
urlRequest.httpMethod = "GET"

let urlRequest = configuredRequest(for: settingsURL, method: "GET")

let dataTask = session.dataTask(with: urlRequest) { [weak self] (data, response, error) in
if let error = error {
self?.analytics?.reportInternalError(AnalyticsError.networkUnknown(error))
Expand Down Expand Up @@ -169,15 +167,24 @@ extension HTTPClient {
return Self.defaultCDNHost
}

internal func configuredRequest(for url: URL, method: String) -> URLRequest {
var request = URLRequest(url: url, cachePolicy: .reloadIgnoringLocalCacheData, timeoutInterval: 60)
request.httpMethod = method
request.addValue("application/json; charset=utf-8", forHTTPHeaderField: "Content-Type")
request.addValue("Basic \(apiKey)", forHTTPHeaderField: "Authorization")
request.addValue("analytics-ios/\(Analytics.version())", forHTTPHeaderField: "User-Agent")
request.addValue("gzip", forHTTPHeaderField: "Accept-Encoding")

if let requestFactory = analytics?.configuration.values.requestFactory {
request = requestFactory(request)
}

return request
}

internal static func configuredSession(for writeKey: String) -> URLSession {
let configuration = URLSessionConfiguration.ephemeral
configuration.allowsCellularAccess = true
configuration.timeoutIntervalForResource = 30
configuration.timeoutIntervalForRequest = 60
configuration.httpMaximumConnectionsPerHost = 2
configuration.httpAdditionalHeaders = ["Content-Type": "application/json; charset=utf-8",
"Authorization": "Basic \(Self.authorizationHeaderForWriteKey(writeKey))",
"User-Agent": "analytics-ios/\(Analytics.version())"]
let session = URLSession(configuration: configuration, delegate: nil, delegateQueue: nil)
return session
}
Expand Down
23 changes: 23 additions & 0 deletions Tests/Segment-Tests/Analytics_Tests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -531,4 +531,27 @@ final class Analytics_Tests: XCTestCase {
XCTAssertEqual(metadata?.bundled, ["Mixpanel"])
XCTAssertEqual(metadata?.unbundled.sorted(), ["Amplitude", "Customer.io", "dest1"])
}

func testRequestFactory() {
let config = Configuration(writeKey: "test").requestFactory { request in
XCTAssertEqual(request.value(forHTTPHeaderField: "Accept-Encoding"), "gzip")
XCTAssertEqual(request.value(forHTTPHeaderField: "Content-Type"), "application/json; charset=utf-8")
XCTAssertEqual(request.value(forHTTPHeaderField: "Authorization"), "Basic test")
XCTAssertTrue(request.value(forHTTPHeaderField: "User-Agent")!.contains("analytics-ios/"))
return request
}.errorHandler { error in
XCTFail("\(error)")
}
let analytics = Analytics(configuration: config)
let outputReader = OutputReaderPlugin()
analytics.add(plugin: outputReader)

waitUntilStarted(analytics: analytics)

analytics.track(name: "something")

analytics.flush()

RunLoop.main.run(until: Date(timeIntervalSinceNow: 5))
}
}