Skip to content

Commit 96134f6

Browse files
authored
Merge pull request #158 from luzfcb/fix-155
Create DummySegment if SDK is disabled
2 parents ae3f213 + 9f6197f commit 96134f6

File tree

5 files changed

+135
-19
lines changed

5 files changed

+135
-19
lines changed

aws_xray_sdk/core/context.py

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import os
44

55
from .exceptions.exceptions import SegmentNotFoundException
6+
from .models.dummy_entities import DummySegment
7+
from aws_xray_sdk import global_sdk_config
8+
69

710
log = logging.getLogger(__name__)
811

@@ -88,8 +91,11 @@ def get_trace_entity(self):
8891
"""
8992
Return the current trace entity(segment/subsegment). If there is none,
9093
it behaves based on pre-defined ``context_missing`` strategy.
94+
If the SDK is disabled, returns a DummySegment
9195
"""
9296
if not getattr(self._local, 'entities', None):
97+
if not global_sdk_config.sdk_enabled():
98+
return DummySegment()
9399
return self.handle_context_missing()
94100

95101
return self._local.entities[-1]

aws_xray_sdk/core/recorder.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ def begin_segment(self, name=None, traceid=None,
211211
:param str traceid: trace id of the segment
212212
:param int sampling: 0 means not sampled, 1 means sampled
213213
"""
214+
# Disable the recorder; return a generated dummy segment.
215+
if not global_sdk_config.sdk_enabled():
216+
return DummySegment(global_sdk_config.DISABLED_ENTITY_NAME)
217+
214218
seg_name = name or self.service
215219
if not seg_name:
216220
raise SegmentNameMissingException("Segment name is required.")
@@ -220,12 +224,6 @@ def begin_segment(self, name=None, traceid=None,
220224
# depending on if centralized or local sampling rule takes effect.
221225
decision = True
222226

223-
# To disable the recorder, we set the sampling decision to always be false.
224-
# This way, when segments are generated, they become dummy segments and are ultimately never sent.
225-
# The call to self._sampler.should_trace() is never called either so the poller threads are never started.
226-
if not global_sdk_config.sdk_enabled():
227-
sampling = 0
228-
229227
# we respect the input sampling decision
230228
# regardless of recorder configuration.
231229
if sampling == 0:
@@ -251,8 +249,12 @@ def end_segment(self, end_time=None):
251249
if it is ready to send. Ready means segment and
252250
all its subsegments are closed.
253251
254-
:param float end_time: segment compeletion in unix epoch in seconds.
252+
:param float end_time: segment completion in unix epoch in seconds.
255253
"""
254+
# When the SDK is disabled we return
255+
if not global_sdk_config.sdk_enabled():
256+
return
257+
256258
self.context.end_segment(end_time)
257259
segment = self.current_segment()
258260
if segment and segment.ready_to_send():
@@ -264,6 +266,7 @@ def current_segment(self):
264266
this will make sure the segment returned is the one created by the
265267
same thread.
266268
"""
269+
267270
entity = self.get_trace_entity()
268271
if self._is_subsegment(entity):
269272
return entity.parent_segment
@@ -280,6 +283,11 @@ def begin_subsegment(self, name, namespace='local'):
280283
:param str name: the name of the subsegment.
281284
:param str namespace: currently can only be 'local', 'remote', 'aws'.
282285
"""
286+
# Generating the parent dummy segment is necessary.
287+
# We don't need to store anything in context. Assumption here
288+
# is that we only work with recorder-level APIs.
289+
if not global_sdk_config.sdk_enabled():
290+
return DummySubsegment(DummySegment(global_sdk_config.DISABLED_ENTITY_NAME))
283291

284292
segment = self.current_segment()
285293
if not segment:
@@ -301,6 +309,9 @@ def current_subsegment(self):
301309
this will make sure the subsegment returned is one created
302310
by the same thread.
303311
"""
312+
if not global_sdk_config.sdk_enabled():
313+
return DummySubsegment(DummySegment(global_sdk_config.DISABLED_ENTITY_NAME))
314+
304315
entity = self.get_trace_entity()
305316
if self._is_subsegment(entity):
306317
return entity
@@ -314,6 +325,9 @@ def end_subsegment(self, end_time=None):
314325
315326
:param float end_time: subsegment compeletion in unix epoch in seconds.
316327
"""
328+
if not global_sdk_config.sdk_enabled():
329+
return
330+
317331
if not self.context.end_subsegment(end_time):
318332
return
319333

@@ -333,6 +347,8 @@ def put_annotation(self, key, value):
333347
:param object value: annotation value. Any type other than
334348
string/number/bool will be dropped
335349
"""
350+
if not global_sdk_config.sdk_enabled():
351+
return
336352
entity = self.get_trace_entity()
337353
if entity and entity.sampled:
338354
entity.put_annotation(key, value)
@@ -348,6 +364,8 @@ def put_metadata(self, key, value, namespace='default'):
348364
:param str key: metadata key under specified namespace
349365
:param object value: any object that can be serialized into JSON string
350366
"""
367+
if not global_sdk_config.sdk_enabled():
368+
return
351369
entity = self.get_trace_entity()
352370
if entity and entity.sampled:
353371
entity.put_metadata(key, value, namespace)
@@ -357,6 +375,9 @@ def is_sampled(self):
357375
Check if the current trace entity is sampled or not.
358376
Return `False` if no active entity found.
359377
"""
378+
if not global_sdk_config.sdk_enabled():
379+
# Disabled SDK is never sampled
380+
return False
360381
entity = self.get_trace_entity()
361382
if entity:
362383
return entity.sampled
@@ -404,16 +425,6 @@ def capture(self, name=None):
404425
def record_subsegment(self, wrapped, instance, args, kwargs, name,
405426
namespace, meta_processor):
406427

407-
# In the case when the SDK is disabled, we ensure that a parent segment exists, because this is usually
408-
# handled by the middleware. We generate a dummy segment as the parent segment if one doesn't exist.
409-
# This is to allow potential segment method calls to not throw exceptions in the captured method.
410-
if not global_sdk_config.sdk_enabled():
411-
try:
412-
self.current_segment()
413-
except SegmentNotFoundException:
414-
segment = DummySegment(name)
415-
self.context.put_segment(segment)
416-
417428
subsegment = self.begin_subsegment(name, namespace)
418429

419430
exception = None

aws_xray_sdk/sdk_config.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import logging
3+
from distutils.util import strtobool
34

45
log = logging.getLogger(__name__)
56

@@ -27,13 +28,32 @@ class SDKConfig(object):
2728
to occur.
2829
"""
2930
XRAY_ENABLED_KEY = 'AWS_XRAY_SDK_ENABLED'
30-
__SDK_ENABLED = str(os.getenv(XRAY_ENABLED_KEY, 'true')).lower() != 'false'
31+
DISABLED_ENTITY_NAME = 'dummy'
32+
33+
__SDK_ENABLED = None
34+
35+
@classmethod
36+
def __get_enabled_from_env(cls):
37+
"""
38+
Searches for the environment variable to see if the SDK should be disabled.
39+
If no environment variable is found, it returns True by default.
40+
41+
:return: bool - True if it is enabled, False otherwise.
42+
"""
43+
env_var_str = os.getenv(cls.XRAY_ENABLED_KEY, 'true')
44+
try:
45+
return bool(strtobool(env_var_str))
46+
except ValueError:
47+
log.warning("Invalid literal passed into environment variable `AWS_XRAY_SDK_ENABLED`. Defaulting to True...")
48+
return True # If an invalid parameter is passed in, we return True.
3149

3250
@classmethod
3351
def sdk_enabled(cls):
3452
"""
3553
Returns whether the SDK is enabled or not.
3654
"""
55+
if cls.__SDK_ENABLED is None:
56+
cls.__SDK_ENABLED = cls.__get_enabled_from_env()
3757
return cls.__SDK_ENABLED
3858

3959
@classmethod
@@ -49,7 +69,7 @@ def set_sdk_enabled(cls, value):
4969
"""
5070
# Environment Variables take precedence over hardcoded configurations.
5171
if cls.XRAY_ENABLED_KEY in os.environ:
52-
cls.__SDK_ENABLED = str(os.getenv(cls.XRAY_ENABLED_KEY, 'true')).lower() != 'false'
72+
cls.__SDK_ENABLED = cls.__get_enabled_from_env()
5373
else:
5474
if type(value) == bool:
5575
cls.__SDK_ENABLED = value

tests/test_recorder.py

+59
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,65 @@ def test_disable_is_dummy():
201201
assert type(xray_recorder.current_subsegment()) is DummySubsegment
202202

203203

204+
def test_disabled_empty_context_current_calls():
205+
global_sdk_config.set_sdk_enabled(False)
206+
assert type(xray_recorder.current_segment()) is DummySegment
207+
assert type(xray_recorder.current_subsegment()) is DummySubsegment
208+
209+
210+
def test_disabled_out_of_order_begins():
211+
global_sdk_config.set_sdk_enabled(False)
212+
xray_recorder.begin_subsegment("Test")
213+
xray_recorder.begin_segment("Test")
214+
xray_recorder.begin_subsegment("Test1")
215+
xray_recorder.begin_subsegment("Test2")
216+
assert type(xray_recorder.begin_subsegment("Test3")) is DummySubsegment
217+
assert type(xray_recorder.begin_segment("Test4")) is DummySegment
218+
219+
220+
def test_disabled_put_methods():
221+
global_sdk_config.set_sdk_enabled(False)
222+
xray_recorder.put_annotation("Test", "Value")
223+
xray_recorder.put_metadata("Test", "Value", "Namespace")
224+
225+
226+
# Test for random end segments/subsegments without any entities in context.
227+
# Should not throw any exceptions
228+
def test_disabled_ends():
229+
global_sdk_config.set_sdk_enabled(False)
230+
xray_recorder.end_segment()
231+
xray_recorder.end_subsegment()
232+
xray_recorder.end_segment()
233+
xray_recorder.end_segment()
234+
xray_recorder.end_subsegment()
235+
xray_recorder.end_subsegment()
236+
237+
238+
# Begin subsegment should not fail on its own.
239+
def test_disabled_begin_subsegment():
240+
global_sdk_config.set_sdk_enabled(False)
241+
subsegment_entity = xray_recorder.begin_subsegment("Test")
242+
assert type(subsegment_entity) is DummySubsegment
243+
244+
245+
# When disabled, force sampling should still return dummy entities.
246+
def test_disabled_force_sampling():
247+
global_sdk_config.set_sdk_enabled(False)
248+
xray_recorder.configure(sampling=True)
249+
segment_entity = xray_recorder.begin_segment("Test1")
250+
subsegment_entity = xray_recorder.begin_subsegment("Test2")
251+
assert type(segment_entity) is DummySegment
252+
assert type(subsegment_entity) is DummySubsegment
253+
254+
255+
# When disabled, get_trace_entity should return DummySegment if an entity is not present in the context
256+
def test_disabled_get_context_entity():
257+
global_sdk_config.set_sdk_enabled(False)
258+
entity = xray_recorder.get_trace_entity()
259+
assert type(entity) is DummySegment
260+
261+
262+
204263
def test_max_stack_trace_zero():
205264
xray_recorder.configure(max_trace_back=1)
206265
with pytest.raises(Exception):

tests/test_sdk_config.py

+20
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ def test_env_enable_case():
5656
global_sdk_config.set_sdk_enabled(True)
5757
assert global_sdk_config.sdk_enabled() is True
5858

59+
os.environ[XRAY_ENABLED_KEY] = "1"
60+
global_sdk_config.set_sdk_enabled(True)
61+
assert global_sdk_config.sdk_enabled() is True
62+
63+
os.environ[XRAY_ENABLED_KEY] = "y"
64+
global_sdk_config.set_sdk_enabled(True)
65+
assert global_sdk_config.sdk_enabled() is True
66+
67+
os.environ[XRAY_ENABLED_KEY] = "t"
68+
global_sdk_config.set_sdk_enabled(True)
69+
assert global_sdk_config.sdk_enabled() is True
70+
5971
os.environ[XRAY_ENABLED_KEY] = "False"
6072
global_sdk_config.set_sdk_enabled(True)
6173
assert global_sdk_config.sdk_enabled() is False
@@ -64,6 +76,10 @@ def test_env_enable_case():
6476
global_sdk_config.set_sdk_enabled(True)
6577
assert global_sdk_config.sdk_enabled() is False
6678

79+
os.environ[XRAY_ENABLED_KEY] = "0"
80+
global_sdk_config.set_sdk_enabled(True)
81+
assert global_sdk_config.sdk_enabled() is False
82+
6783

6884
def test_invalid_env_string():
6985
os.environ[XRAY_ENABLED_KEY] = "INVALID"
@@ -78,3 +94,7 @@ def test_invalid_env_string():
7894
os.environ[XRAY_ENABLED_KEY] = "1-.0"
7995
global_sdk_config.set_sdk_enabled(False)
8096
assert global_sdk_config.sdk_enabled() is True
97+
98+
os.environ[XRAY_ENABLED_KEY] = "T RUE"
99+
global_sdk_config.set_sdk_enabled(True)
100+
assert global_sdk_config.sdk_enabled() is True

0 commit comments

Comments
 (0)