Skip to content

Commit c580008

Browse files
authored
feat: update user-agent and improve transport selection in queryClient (#92)
1 parent 933ecb5 commit c580008

File tree

10 files changed

+208
-8
lines changed

10 files changed

+208
-8
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ pyinflux3*.egg-info
99
__pycache__
1010
.idea
1111
*.egg-info/
12+
temp/
13+
test-reports/
14+
coverage.xml
15+
.coverage

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## 0.6.0 [unreleased]
44

5+
### Features
6+
7+
1. [#92](https://github.com/InfluxCommunity/influxdb3-python/pull/92): Update `user-agent` header value to `influxdb3-python/{VERSION}` and add it to queries as well.
8+
59
## 0.5.0 [2024-05-17]
610

711
### Features

influxdb_client_3/__init__.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from influxdb_client_3.write_client.client.write_api import WriteApi as _WriteApi, SYNCHRONOUS, ASYNCHRONOUS, \
1111
PointSettings
1212
from influxdb_client_3.write_client.domain.write_precision import WritePrecision
13+
from influxdb_client_3.version import USER_AGENT
1314

1415
try:
1516
import polars as pl
@@ -147,7 +148,19 @@ def __init__(
147148

148149
if query_port_overwrite is not None:
149150
port = query_port_overwrite
150-
self._flight_client = FlightClient(f"grpc+tls://{hostname}:{port}", **self._flight_client_options)
151+
152+
gen_opts = [
153+
("grpc.secondary_user_agent", USER_AGENT)
154+
]
155+
156+
self._flight_client_options["generic_options"] = gen_opts
157+
158+
if scheme == 'https':
159+
connection_string = f"grpc+tls://{hostname}:{port}"
160+
else:
161+
connection_string = f"grpc+tcp://{hostname}:{port}"
162+
163+
self._flight_client = FlightClient(connection_string, **self._flight_client_options)
151164

152165
def write(self, record=None, database=None, **kwargs):
153166
"""

influxdb_client_3/version.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""Version of the Client that is used in User-Agent header."""
2+
3+
VERSION = '0.6.0dev0'
4+
USER_AGENT = f'influxdb3-python/{VERSION}'

influxdb_client_3/write_client/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@
2727
from influxdb_client_3.write_client.domain.write_precision import WritePrecision
2828

2929
from influxdb_client_3.write_client.configuration import Configuration
30-
from influxdb_client_3.write_client.version import VERSION
30+
from influxdb_client_3.version import VERSION
3131
__version__ = VERSION

influxdb_client_3/write_client/_sync/api_client.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def __init__(self, configuration=None, header_name=None, header_value=None,
7676
self.default_headers[header_name] = header_value
7777
self.cookie = cookie
7878
# Set default User-Agent.
79-
from influxdb_client_3.write_client.version import VERSION
80-
self.user_agent = f'influxdb-client-python/{VERSION}'
79+
from influxdb_client_3.version import USER_AGENT
80+
self.user_agent = USER_AGENT
8181

8282
def __del__(self):
8383
"""Dispose pools."""

influxdb_client_3/write_client/version.py

-3
This file was deleted.

tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# needed to resolve some module imports when running pytest

tests/test_api_client.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import unittest
2+
from unittest import mock
3+
4+
from influxdb_client_3.write_client._sync.api_client import ApiClient
5+
from influxdb_client_3.write_client.configuration import Configuration
6+
from influxdb_client_3.write_client.service import WriteService
7+
from influxdb_client_3.version import VERSION
8+
9+
10+
_package = "influxdb3-python"
11+
_sentHeaders = {}
12+
13+
14+
def mock_rest_request(method,
15+
url,
16+
query_params=None,
17+
headers=None,
18+
body=None,
19+
post_params=None,
20+
_preload_content=True,
21+
_request_timeout=None,
22+
**urlopen_kw):
23+
class MockResponse:
24+
def __init__(self, data, status_code):
25+
self.data = data
26+
self.status_code = status_code
27+
28+
def data(self):
29+
return self.data
30+
31+
global _sentHeaders
32+
_sentHeaders = headers
33+
34+
return MockResponse(None, 200)
35+
36+
37+
class ApiClientTests(unittest.TestCase):
38+
39+
def test_default_headers(self):
40+
global _package
41+
conf = Configuration()
42+
client = ApiClient(conf,
43+
header_name="Authorization",
44+
header_value="Bearer TEST_TOKEN")
45+
self.assertIsNotNone(client.default_headers["User-Agent"])
46+
self.assertIsNotNone(client.default_headers["Authorization"])
47+
self.assertEqual(f"{_package}/{VERSION}", client.default_headers["User-Agent"])
48+
self.assertEqual("Bearer TEST_TOKEN", client.default_headers["Authorization"])
49+
50+
@mock.patch("influxdb_client_3.write_client._sync.rest.RESTClientObject.request",
51+
side_effect=mock_rest_request)
52+
def test_call_api(self, mock_post):
53+
global _package
54+
global _sentHeaders
55+
_sentHeaders = {}
56+
57+
conf = Configuration()
58+
client = ApiClient(conf,
59+
header_name="Authorization",
60+
header_value="Bearer TEST_TOKEN")
61+
service = WriteService(client)
62+
service.post_write("TEST_ORG", "TEST_BUCKET", "data,foo=bar val=3.14")
63+
self.assertEqual(4, len(_sentHeaders.keys()))
64+
self.assertIsNotNone(_sentHeaders["Accept"])
65+
self.assertEqual("application/json", _sentHeaders["Accept"])
66+
self.assertIsNotNone(_sentHeaders["Content-Type"])
67+
self.assertEqual("text/plain", _sentHeaders["Content-Type"])
68+
self.assertIsNotNone(_sentHeaders["Authorization"])
69+
self.assertEqual("Bearer TEST_TOKEN", _sentHeaders["Authorization"])
70+
self.assertIsNotNone(_sentHeaders["User-Agent"])
71+
self.assertEqual(f"{_package}/{VERSION}", _sentHeaders["User-Agent"])

tests/test_query.py

+107-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,115 @@
11
import unittest
2+
import struct
23
from unittest.mock import Mock, patch, ANY
34

4-
from pyarrow.flight import Ticket
5+
from pyarrow import (
6+
array,
7+
Table
8+
)
9+
10+
from pyarrow.flight import (
11+
FlightServerBase,
12+
FlightUnauthenticatedError,
13+
GeneratorStream,
14+
ServerMiddleware,
15+
ServerMiddlewareFactory,
16+
ServerAuthHandler,
17+
Ticket
18+
)
519

620
from influxdb_client_3 import InfluxDBClient3
21+
from influxdb_client_3.version import USER_AGENT
22+
23+
24+
def case_insensitive_header_lookup(headers, lkey):
25+
"""Lookup the value of a given key in the given headers.
26+
The lkey is case-insensitive.
27+
"""
28+
for key in headers:
29+
if key.lower() == lkey.lower():
30+
return headers.get(key)
31+
32+
33+
class NoopAuthHandler(ServerAuthHandler):
34+
"""A no-op auth handler - as seen in pyarrow tests"""
35+
36+
def authenticate(self, outgoing, incoming):
37+
"""Do nothing"""
38+
39+
def is_valid(self, token):
40+
"""
41+
Return an empty string
42+
N.B. Returning None causes Type error
43+
:param token:
44+
:return:
45+
"""
46+
return ""
47+
48+
49+
_req_headers = {}
50+
51+
52+
class HeaderCheckServerMiddlewareFactory(ServerMiddlewareFactory):
53+
"""Factory to create HeaderCheckServerMiddleware and check header values"""
54+
def start_call(self, info, headers):
55+
auth_header = case_insensitive_header_lookup(headers, "Authorization")
56+
values = auth_header[0].split(' ')
57+
if values[0] != 'Bearer':
58+
raise FlightUnauthenticatedError("Token required")
59+
global _req_headers
60+
_req_headers = headers
61+
return HeaderCheckServerMiddleware(values[1])
62+
63+
64+
class HeaderCheckServerMiddleware(ServerMiddleware):
65+
"""
66+
Middleware needed to catch request headers via factory
67+
N.B. As found in pyarrow tests
68+
"""
69+
def __init__(self, token):
70+
self.token = token
71+
72+
def sending_headers(self):
73+
return {'authorization': 'Bearer ' + self.token}
74+
75+
76+
class HeaderCheckFlightServer(FlightServerBase):
77+
"""Mock server handle gRPC do_get calls"""
78+
def do_get(self, context, ticket):
79+
"""Return something to avoid needless errors"""
80+
data = [
81+
array([b"Vltava", struct.pack('<i', 105), b"FM"])
82+
]
83+
table = Table.from_arrays(data, names=['a'])
84+
return GeneratorStream(
85+
table.schema,
86+
self.number_batches(table),
87+
options={})
88+
89+
@staticmethod
90+
def number_batches(table):
91+
for idx, batch in enumerate(table.to_batches()):
92+
buf = struct.pack('<i', idx)
93+
yield batch, buf
94+
95+
96+
def test_influx_default_query_headers():
97+
with HeaderCheckFlightServer(
98+
auth_handler=NoopAuthHandler(),
99+
middleware={"check": HeaderCheckServerMiddlewareFactory()}) as server:
100+
global _req_headers
101+
_req_headers = {}
102+
client = InfluxDBClient3(
103+
host=f'http://localhost:{server.port}',
104+
org='test_org',
105+
databse='test_db',
106+
token='TEST_TOKEN'
107+
)
108+
client.query('SELECT * FROM test')
109+
assert len(_req_headers) > 0
110+
assert _req_headers['authorization'][0] == "Bearer TEST_TOKEN"
111+
assert _req_headers['user-agent'][0].find(USER_AGENT) > -1
112+
_req_headers = {}
7113

8114

9115
class QueryTests(unittest.TestCase):

0 commit comments

Comments
 (0)