Skip to content

Commit 3c4a626

Browse files
committed
Using factory method approach instread
1 parent 4c16eda commit 3c4a626

File tree

2 files changed

+188
-52
lines changed

2 files changed

+188
-52
lines changed

opentelemetry-sdk/src/opentelemetry/sdk/trace/sampling.py

+58-36
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@
7373
* parentbased_always_off - Sampler that respects its parent span's sampling decision, but otherwise never samples.
7474
* parentbased_traceidratio - Sampler that respects its parent span's sampling decision, but otherwise samples probabalistically based on rate.
7575
76-
In order to configure a custom sampler via environment variables, create an entry point for the custom sampler class under the entry point group, ``opentelemtry_traces_sampler``. Then, set the ``OTEL_TRACES_SAMPLER`` environment variable to the key name of the entry point. For example, set ``OTEL_TRACES_SAMPLER=custom_sampler_name`` and ``OTEL_TRACES_SAMPLER_ARG=0.5`` after creating the following entry point:
76+
In order to configure a custom sampler via environment variables, create an entry point for the custom sampler factory method under the entry point group, ``opentelemtry_traces_sampler``. Then, set
77+
the ``OTEL_TRACES_SAMPLER`` environment variable to the key name of the entry point. The custom sampler factory method must take a single string argument. This input will come from
78+
``OTEL_TRACES_SAMPLER_ARG``. If ``OTEL_TRACES_SAMPLER_ARG`` is not configured, the input will be an empty string. For example, set ``OTEL_TRACES_SAMPLER=custom_sampler_name`` and
79+
``OTEL_TRACES_SAMPLER_ARG=0.5`` after creating the following entry point:
7780
7881
.. code:: python
7982
@@ -82,15 +85,22 @@
8285
entry_points={
8386
...
8487
"opentelemtry_traces_sampler": [
85-
"custom_sampler_name = path.to.sampler.module:CustomSampler"
88+
"custom_sampler_name = path.to.sampler.factory.method:CustomSamplerFactory.get_sampler"
8689
]
8790
}
8891
)
8992
...
90-
class CustomSampler(TraceIdRatioBased):
91-
...
93+
class CustomRatioSampler(Sampler):
94+
def __init__(rate):
95+
...
96+
...
97+
class CustomSamplerFactory:
98+
get_sampler(sampler_argument_str):
99+
rate = float(sampler_argument_str)
100+
return CustomSampler(rate)
92101
93-
Sampling probability can be set with ``OTEL_TRACES_SAMPLER_ARG`` if the sampler is a ``TraceIdRatioBased`` Sampler, such as ``traceidratio`` and ``parentbased_traceidratio``. When not provided rate will be set to 1.0 (maximum rate possible).
102+
Sampling probability can be set with ``OTEL_TRACES_SAMPLER_ARG`` if the sampler is a ``TraceIdRatioBased`` Sampler, such as ``traceidratio`` and ``parentbased_traceidratio``. When not provided rate
103+
will be set to 1.0 (maximum rate possible).
94104
95105
96106
Prev example but with environment variables. Please make sure to set the env ``OTEL_TRACES_SAMPLER=traceidratio`` and ``OTEL_TRACES_SAMPLER_ARG=0.001``.
@@ -372,60 +382,76 @@ def get_description(self):
372382
"""Sampler that respects its parent span's sampling decision, but otherwise always samples."""
373383

374384

375-
class ParentBasedTraceIdRatio(ParentBased, TraceIdRatioBased):
385+
class ParentBasedTraceIdRatio(ParentBased):
376386
"""
377387
Sampler that respects its parent span's sampling decision, but otherwise
378388
samples probabalistically based on `rate`.
379389
"""
380390

381391
def __init__(self, rate: float):
382-
# TODO: If TraceIdRatioBased inheritance is kept, change this
383392
root = TraceIdRatioBased(rate=rate)
384393
super().__init__(root=root)
385394

386395

387-
_KNOWN_INITIALIZED_SAMPLERS = {
396+
_KNOWN_SAMPLERS = {
388397
"always_on": ALWAYS_ON,
389398
"always_off": ALWAYS_OFF,
390399
"parentbased_always_on": DEFAULT_ON,
391400
"parentbased_always_off": DEFAULT_OFF,
392-
}
393-
394-
_KNOWN_SAMPLER_CLASSES = {
395401
"traceidratio": TraceIdRatioBased,
396402
"parentbased_traceidratio": ParentBasedTraceIdRatio,
397403
}
398404

399405

400406
def _get_from_env_or_default() -> Sampler:
401-
trace_sampler_name = os.getenv(
407+
trace_sampler = os.getenv(
402408
OTEL_TRACES_SAMPLER, "parentbased_always_on"
403409
).lower()
410+
if trace_sampler not in _KNOWN_SAMPLERS:
411+
_logger.warning("Couldn't recognize sampler %s.", trace_sampler)
412+
trace_sampler = "parentbased_always_on"
404413

405-
if trace_sampler_name in _KNOWN_INITIALIZED_SAMPLERS:
406-
return _KNOWN_INITIALIZED_SAMPLERS[trace_sampler_name]
407-
408-
trace_sampler_impl = None
409-
if trace_sampler_name in _KNOWN_SAMPLER_CLASSES:
410-
trace_sampler_impl = _KNOWN_SAMPLER_CLASSES[trace_sampler_name]
411-
else:
412-
try:
413-
trace_sampler_impl = _import_sampler(trace_sampler_name)
414-
except RuntimeError as err:
415-
_logger.warning(
416-
"Unable to recognize sampler %s: %s", trace_sampler_name, err
417-
)
418-
return _KNOWN_INITIALIZED_SAMPLERS["parentbased_always_on"]
419-
420-
if issubclass(trace_sampler_impl, TraceIdRatioBased):
414+
if trace_sampler in ("traceidratio", "parentbased_traceidratio"):
421415
try:
422416
rate = float(os.getenv(OTEL_TRACES_SAMPLER_ARG))
423417
except ValueError:
424418
_logger.warning("Could not convert TRACES_SAMPLER_ARG to float.")
425419
rate = 1.0
426-
return trace_sampler_impl(rate)
420+
return _KNOWN_SAMPLERS[trace_sampler](rate)
427421

428-
return trace_sampler_impl()
422+
return _KNOWN_SAMPLERS[trace_sampler]
423+
424+
425+
def _get_from_env_or_default() -> Sampler:
426+
trace_sampler_name = os.getenv(
427+
OTEL_TRACES_SAMPLER, "parentbased_always_on"
428+
).lower()
429+
430+
if trace_sampler_name in _KNOWN_SAMPLERS:
431+
if trace_sampler_name in ("traceidratio", "parentbased_traceidratio"):
432+
try:
433+
rate = float(os.getenv(OTEL_TRACES_SAMPLER_ARG))
434+
except ValueError:
435+
_logger.warning("Could not convert TRACES_SAMPLER_ARG to float.")
436+
rate = 1.0
437+
return _KNOWN_SAMPLERS[trace_sampler_name](rate)
438+
return _KNOWN_SAMPLERS[trace_sampler_name]
439+
else:
440+
try:
441+
trace_sampler_factory = _import_sampler_factory(trace_sampler_name)
442+
sampler_arg = os.getenv(OTEL_TRACES_SAMPLER_ARG, "")
443+
trace_sampler = trace_sampler_factory(sampler_arg)
444+
_logger.warning("JEREVOSS: trace_sampler: %s" % trace_sampler)
445+
if not issubclass(type(trace_sampler), Sampler):
446+
message = "Output of traces sampler factory, %s, was not a Sampler object." % trace_sampler_factory
447+
_logger.warning(message)
448+
raise ValueError(message)
449+
return trace_sampler
450+
except Exception as exc:
451+
_logger.warning(
452+
"Failed to initialize custom sampler, %s: %s", trace_sampler_name, exc
453+
)
454+
return _KNOWN_SAMPLERS["parentbased_always_on"]
429455

430456

431457
def _get_parent_trace_state(parent_context) -> Optional["TraceState"]:
@@ -435,13 +461,9 @@ def _get_parent_trace_state(parent_context) -> Optional["TraceState"]:
435461
return parent_span_context.trace_state
436462

437463

438-
def _import_sampler(sampler_name: str) -> Sampler:
464+
def _import_sampler_factory(sampler_name: str) -> Sampler:
439465
# pylint: disable=unbalanced-tuple-unpacking
440466
[(sampler_name, sampler_impl)] = _import_config_components(
441467
[sampler_name.strip()], _OTEL_SAMPLER_ENTRY_POINT_GROUP
442468
)
443-
444-
if issubclass(sampler_impl, Sampler):
445-
return sampler_impl
446-
447-
raise RuntimeError(f"{sampler_name} is not an Sampler")
469+
return sampler_impl

opentelemetry-sdk/tests/trace/test_trace.py

+130-16
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,26 @@ def should_sample(
196196
)
197197

198198

199+
class CustomSamplerFactory:
200+
def get_custom_sampler(unused_sampler_arg):
201+
return CustomSampler()
202+
203+
def get_custom_ratio_sampler(sampler_arg):
204+
return CustomRatioSampler(float(sampler_arg))
205+
206+
def empty_get_custom_sampler(sampler_arg):
207+
return
208+
209+
210+
class IterEntryPoint:
211+
def __init__(self, name, class_type):
212+
self.name = name
213+
self.class_type = class_type
214+
215+
def load(self):
216+
return self.class_type
217+
218+
199219
class TestTracerSampling(unittest.TestCase):
200220
def tearDown(self):
201221
reload(trace)
@@ -222,9 +242,7 @@ def test_default_sampler(self):
222242

223243
def test_default_sampler_type(self):
224244
tracer_provider = trace.TracerProvider()
225-
self.assertIsInstance(tracer_provider.sampler, sampling.ParentBased)
226-
# pylint: disable=protected-access
227-
self.assertEqual(tracer_provider.sampler._root, sampling.ALWAYS_ON)
245+
self.verify_default_sampler(tracer_provider)
228246

229247
def test_sampler_no_sampling(self):
230248
tracer_provider = trace.TracerProvider(sampling.ALWAYS_OFF)
@@ -283,43 +301,139 @@ def test_sampler_with_env_non_existent_entry_point(self):
283301
# pylint: disable=protected-access
284302
reload(trace)
285303
tracer_provider = trace.TracerProvider()
286-
self.assertIsInstance(tracer_provider.sampler, sampling.ParentBased)
287-
# pylint: disable=protected-access
288-
self.assertEqual(tracer_provider.sampler._root, sampling.ALWAYS_ON)
304+
self.verify_default_sampler(tracer_provider)
289305

290-
@mock.patch("opentelemetry.sdk.trace.sampling._import_config_components")
291-
@mock.patch.dict("os.environ", {OTEL_TRACES_SAMPLER: "custom_sampler"})
306+
@mock.patch("opentelemetry.sdk.trace.util.iter_entry_points")
307+
@mock.patch.dict("os.environ", {OTEL_TRACES_SAMPLER: "custom_sampler_factory"})
292308
def test_custom_sampler_with_env(
293-
self, mock_sampling_import_config_components
309+
self, mock_iter_entry_points
310+
):
311+
# mock_iter_entry_points.return_value = [
312+
# ("custom_sampler_factory", CustomSamplerFactory.get_custom_sampler)
313+
# ]
314+
mock_iter_entry_points.return_value=[
315+
IterEntryPoint("custom_sampler_factory", CustomSamplerFactory.get_custom_sampler)
316+
]
317+
# pylint: disable=protected-access
318+
reload(trace)
319+
tracer_provider = trace.TracerProvider()
320+
self.assertIsInstance(tracer_provider.sampler, CustomSampler)
321+
322+
@mock.patch("opentelemetry.sdk.trace.util.iter_entry_points")
323+
@mock.patch.dict("os.environ", {OTEL_TRACES_SAMPLER: "custom_sampler_factory"})
324+
def test_custom_sampler_with_env_bad_factory(
325+
self, mock_iter_entry_points
326+
):
327+
mock_iter_entry_points.return_value = [
328+
IterEntryPoint("custom_sampler_factory", CustomSamplerFactory.empty_get_custom_sampler)
329+
]
330+
# pylint: disable=protected-access
331+
reload(trace)
332+
tracer_provider = trace.TracerProvider()
333+
self.verify_default_sampler(tracer_provider)
334+
335+
@mock.patch("opentelemetry.sdk.trace.util.iter_entry_points")
336+
@mock.patch.dict(
337+
"os.environ",
338+
{
339+
OTEL_TRACES_SAMPLER: "custom_sampler_factory",
340+
OTEL_TRACES_SAMPLER_ARG: "0.5",
341+
},
342+
)
343+
def test_custom_sampler_with_env_unused_arg(
344+
self, mock_iter_entry_points
294345
):
295-
mock_sampling_import_config_components.return_value = [
296-
("custom_sampler", CustomSampler)
346+
mock_iter_entry_points.return_value = [
347+
IterEntryPoint("custom_sampler_factory", CustomSamplerFactory.get_custom_sampler)
297348
]
298349
# pylint: disable=protected-access
299350
reload(trace)
300351
tracer_provider = trace.TracerProvider()
301352
self.assertIsInstance(tracer_provider.sampler, CustomSampler)
302353

303-
@mock.patch("opentelemetry.sdk.trace.sampling._import_config_components")
354+
@mock.patch("opentelemetry.sdk.trace.util.iter_entry_points")
304355
@mock.patch.dict(
305356
"os.environ",
306357
{
307-
OTEL_TRACES_SAMPLER: "custom_ratio_sampler",
358+
OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory",
308359
OTEL_TRACES_SAMPLER_ARG: "0.5",
309360
},
310361
)
311362
def test_custom_ratio_sampler_with_env(
312-
self, mock_sampling_import_config_components
363+
self, mock_iter_entry_points
313364
):
314-
mock_sampling_import_config_components.return_value = [
315-
("custom_ratio_sampler", CustomRatioSampler)
365+
mock_iter_entry_points.return_value = [
366+
IterEntryPoint("custom_ratio_sampler_factory", CustomSamplerFactory.get_custom_ratio_sampler)
316367
]
317368
# pylint: disable=protected-access
318369
reload(trace)
319370
tracer_provider = trace.TracerProvider()
320371
self.assertIsInstance(tracer_provider.sampler, CustomRatioSampler)
321372
self.assertEqual(tracer_provider.sampler.ratio, 0.5)
322373

374+
@mock.patch("opentelemetry.sdk.trace.util.iter_entry_points")
375+
@mock.patch.dict(
376+
"os.environ",
377+
{
378+
OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory",
379+
OTEL_TRACES_SAMPLER_ARG: "foobar",
380+
},
381+
)
382+
def test_custom_ratio_sampler_with_env_bad_arg(
383+
self, mock_iter_entry_points
384+
):
385+
mock_iter_entry_points.return_value = [
386+
IterEntryPoint("custom_ratio_sampler_factory", CustomSamplerFactory.get_custom_ratio_sampler)
387+
]
388+
# pylint: disable=protected-access
389+
reload(trace)
390+
tracer_provider = trace.TracerProvider()
391+
self.verify_default_sampler(tracer_provider)
392+
393+
@mock.patch("opentelemetry.sdk.trace.util.iter_entry_points")
394+
@mock.patch.dict(
395+
"os.environ",
396+
{
397+
OTEL_TRACES_SAMPLER: "custom_ratio_sampler_factory",
398+
},
399+
)
400+
def test_custom_ratio_sampler_with_env_no_arg(
401+
self, mock_iter_entry_points
402+
):
403+
mock_iter_entry_points.return_value = [
404+
IterEntryPoint("custom_ratio_sampler_factory", CustomSamplerFactory.get_custom_ratio_sampler)
405+
]
406+
# pylint: disable=protected-access
407+
reload(trace)
408+
tracer_provider = trace.TracerProvider()
409+
self.verify_default_sampler(tracer_provider)
410+
411+
@mock.patch("opentelemetry.sdk.trace.util.iter_entry_points")
412+
@mock.patch.dict(
413+
"os.environ",
414+
{
415+
OTEL_TRACES_SAMPLER: "custom_sampler_factory",
416+
OTEL_TRACES_SAMPLER_ARG: "0.5",
417+
},
418+
)
419+
def test_custom_ratio_sampler_with_env_multiple_entry_points(
420+
self, mock_iter_entry_points
421+
):
422+
mock_iter_entry_points.return_value = [
423+
IterEntryPoint("custom_ratio_sampler_factory", CustomSamplerFactory.get_custom_ratio_sampler),
424+
IterEntryPoint("custom_sampler_factory", CustomSamplerFactory.get_custom_sampler),
425+
IterEntryPoint("custom_z_sampler_factory", CustomSamplerFactory.empty_get_custom_sampler)
426+
]
427+
# pylint: disable=protected-access
428+
reload(trace)
429+
tracer_provider = trace.TracerProvider()
430+
self.assertIsInstance(tracer_provider.sampler, CustomSampler)
431+
432+
def verify_default_sampler(self, tracer_provider):
433+
self.assertIsInstance(tracer_provider.sampler, sampling.ParentBased)
434+
# pylint: disable=protected-access
435+
self.assertEqual(tracer_provider.sampler._root, sampling.ALWAYS_ON)
436+
323437

324438
class TestSpanCreation(unittest.TestCase):
325439
def test_start_span_invalid_spancontext(self):

0 commit comments

Comments
 (0)