Skip to content

Add Redis instrumentation query sanitization #1572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- `opentelemetry-instrumentation-redis` Add `sanitize_query` config option to allow query sanitization. ([#1572](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1572))
- `opentelemetry-instrumentation-celery` Record exceptions as events on the span.
([#1573](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1573))
- Add metric instrumentation for urllib
Original file line number Diff line number Diff line change
@@ -64,6 +64,8 @@ async def redis_get():
response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request
this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None

sanitize_query (Boolean) - default False, enable the Redis query sanitization

for example:

.. code: python
@@ -139,9 +141,11 @@ def _instrument(
tracer,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
sanitize_query: bool = False,
):
def _traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args)
query = _format_command_args(args, sanitize_query)

if len(args) > 0 and args[0]:
name = args[0]
else:
@@ -169,7 +173,9 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
)

cmds = [
_format_command_args(c.args if hasattr(c, "args") else c[0])
_format_command_args(
c.args if hasattr(c, "args") else c[0], sanitize_query
)
for c in command_stack
]
resource = "\n".join(cmds)
@@ -281,6 +287,7 @@ def _instrument(self, **kwargs):
tracer,
request_hook=kwargs.get("request_hook"),
response_hook=kwargs.get("response_hook"),
sanitize_query=kwargs.get("sanitize_query", False),
)

def _uninstrument(self, **kwargs):
Original file line number Diff line number Diff line change
@@ -48,11 +48,27 @@ def _extract_conn_attributes(conn_kwargs):
return attributes


def _format_command_args(args):
"""Format command arguments and trim them as needed"""
value_max_len = 100
value_too_long_mark = "..."
def _format_command_args(args, sanitize_query):
"""Format and sanitize command arguments, and trim them as needed"""
cmd_max_len = 1000
value_too_long_mark = "..."
if sanitize_query:
# Sanitized query format: "COMMAND ? ?"
args_length = len(args)
if args_length > 0:
out = [str(args[0])] + ["?"] * (args_length - 1)
out_str = " ".join(out)

if len(out_str) > cmd_max_len:
out_str = (
out_str[: cmd_max_len - len(value_too_long_mark)]
+ value_too_long_mark
)
else:
out_str = ""
return out_str

value_max_len = 100
length = 0
out = []
for arg in args:
Original file line number Diff line number Diff line change
@@ -148,6 +148,40 @@ def request_hook(span, conn, args, kwargs):
span = spans[0]
self.assertEqual(span.attributes.get(custom_attribute_name), "GET")

def test_query_sanitizer_enabled(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
redis_client.connection = connection

RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider,
sanitize_query=True,
)

with mock.patch.object(redis_client, "connection"):
redis_client.set("key", "value")

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

span = spans[0]
self.assertEqual(span.attributes.get("db.statement"), "SET ? ?")

def test_query_sanitizer_disabled(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
redis_client.connection = connection

with mock.patch.object(redis_client, "connection"):
redis_client.set("key", "value")

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

span = spans[0]
self.assertEqual(span.attributes.get("db.statement"), "SET key value")

def test_no_op_tracer_provider(self):
RedisInstrumentor().uninstrument()
tracer_provider = trace.NoOpTracerProvider()
Original file line number Diff line number Diff line change
@@ -45,6 +45,27 @@ def _check_span(self, span, name):
)
self.assertEqual(span.attributes[SpanAttributes.NET_PEER_PORT], 6379)

def test_long_command_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)

self.redis_client.mget(*range(2000))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self._check_span(span, "MGET")
self.assertTrue(
span.attributes.get(SpanAttributes.DB_STATEMENT).startswith(
"MGET ? ? ? ?"
)
)
self.assertTrue(
span.attributes.get(SpanAttributes.DB_STATEMENT).endswith("...")
)

def test_long_command(self):
self.redis_client.mget(*range(1000))

@@ -61,6 +82,22 @@ def test_long_command(self):
span.attributes.get(SpanAttributes.DB_STATEMENT).endswith("...")
)

def test_basics_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)

self.assertIsNone(self.redis_client.get("cheese"))
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self._check_span(span, "GET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
)
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)

def test_basics(self):
self.assertIsNone(self.redis_client.get("cheese"))
spans = self.memory_exporter.get_finished_spans()
@@ -72,6 +109,28 @@ def test_basics(self):
)
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)

def test_pipeline_traced_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)

with self.redis_client.pipeline(transaction=False) as pipeline:
pipeline.set("blah", 32)
pipeline.rpush("foo", "éé")
pipeline.hgetall("xxx")
pipeline.execute()

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self._check_span(span, "SET RPUSH HGETALL")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT),
"SET ? ?\nRPUSH ? ?\nHGETALL ?",
)
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)

def test_pipeline_traced(self):
with self.redis_client.pipeline(transaction=False) as pipeline:
pipeline.set("blah", 32)
@@ -89,6 +148,27 @@ def test_pipeline_traced(self):
)
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)

def test_pipeline_immediate_sanitized(self):
RedisInstrumentor().uninstrument()
RedisInstrumentor().instrument(
tracer_provider=self.tracer_provider, sanitize_query=True
)

with self.redis_client.pipeline() as pipeline:
pipeline.set("a", 1)
pipeline.immediate_execute_command("SET", "b", 2)
pipeline.execute()

spans = self.memory_exporter.get_finished_spans()
# expecting two separate spans here, rather than a
# single span for the whole pipeline
self.assertEqual(len(spans), 2)
span = spans[0]
self._check_span(span, "SET")
self.assertEqual(
span.attributes.get(SpanAttributes.DB_STATEMENT), "SET ? ?"
)

def test_pipeline_immediate(self):
with self.redis_client.pipeline() as pipeline:
pipeline.set("a", 1)