Skip to content

Commit c0e6438

Browse files
committed
Include extra_headers for Azure
1 parent 3fe0996 commit c0e6438

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

lib/openai/client.rb

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def translate(parameters: {})
6666
end
6767

6868
def azure?
69-
@api_type == :azure
69+
@api_type&.to_sym == :azure
7070
end
7171
end
7272
end

spec/openai/client/client_spec.rb

+10-3
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
let!(:c0) { OpenAI::Client.new }
1818
let!(:c1) do
1919
OpenAI::Client.new(
20+
api_type: "azure",
2021
access_token: "access_token1",
2122
organization_id: "organization_id1",
2223
request_timeout: 60,
23-
uri_base: "https://oai.hconeai.com/"
24+
uri_base: "https://oai.hconeai.com/",
25+
extra_headers: { "test" => "X-Test" }
2426
)
2527
end
2628
let!(:c2) do
@@ -33,6 +35,7 @@
3335
end
3436

3537
it "does not confuse the clients" do
38+
expect(c0.azure?).to eq(false)
3639
expect(c0.access_token).to eq(ENV.fetch("OPENAI_ACCESS_TOKEN", "dummy-token"))
3740
expect(c0.organization_id).to eq("organization_id0")
3841
expect(c0.request_timeout).to eq(OpenAI::Configuration::DEFAULT_REQUEST_TIMEOUT)
@@ -41,16 +44,19 @@
4144
expect(c0.send(:headers).values).to include(c0.organization_id)
4245
expect(c0.send(:conn).options.timeout).to eq(OpenAI::Configuration::DEFAULT_REQUEST_TIMEOUT)
4346
expect(c0.send(:uri, path: "")).to include(OpenAI::Configuration::DEFAULT_URI_BASE)
47+
expect(c0.send(:headers).values).not_to include("X-Test")
4448

49+
expect(c1.azure?).to eq(true)
4550
expect(c1.access_token).to eq("access_token1")
4651
expect(c1.organization_id).to eq("organization_id1")
4752
expect(c1.request_timeout).to eq(60)
4853
expect(c1.uri_base).to eq("https://oai.hconeai.com/")
49-
expect(c1.send(:headers).values).to include("Bearer #{c1.access_token}")
50-
expect(c1.send(:headers).values).to include(c1.organization_id)
54+
expect(c1.send(:headers).values).to include(c1.access_token)
5155
expect(c1.send(:conn).options.timeout).to eq(60)
5256
expect(c1.send(:uri, path: "")).to include("https://oai.hconeai.com/")
57+
expect(c1.send(:headers).values).to include("X-Test")
5358

59+
expect(c2.azure?).to eq(false)
5460
expect(c2.access_token).to eq("access_token2")
5561
expect(c2.organization_id).to eq("organization_id0") # Fall back to default.
5662
expect(c2.request_timeout).to eq(1)
@@ -59,6 +65,7 @@
5965
expect(c2.send(:headers).values).to include(c2.organization_id)
6066
expect(c2.send(:conn).options.timeout).to eq(1)
6167
expect(c2.send(:uri, path: "")).to include("https://example.com/")
68+
expect(c2.send(:headers).values).not_to include("X-Test")
6269
end
6370

6471
context "hitting other classes" do

0 commit comments

Comments
 (0)