1
1
import platform
2
+ import time
2
3
3
4
import pytest
4
5
6
+ from aws_xray_sdk .core .sampling .sampling_rule import SamplingRule
7
+ from aws_xray_sdk .core .sampling .rule_cache import RuleCache
8
+ from aws_xray_sdk .core .sampling .sampler import DefaultSampler
5
9
from aws_xray_sdk .version import VERSION
6
10
from .util import get_new_stubbed_recorder
7
11
@@ -38,7 +42,6 @@ def test_default_runtime_context():
38
42
39
43
40
44
def test_subsegment_parenting ():
41
-
42
45
segment = xray_recorder .begin_segment ('name' )
43
46
subsegment = xray_recorder .begin_subsegment ('name' )
44
47
xray_recorder .end_subsegment ('name' )
@@ -97,7 +100,6 @@ def test_put_annotation_metadata():
97
100
98
101
99
102
def test_pass_through_with_missing_context ():
100
-
101
103
xray_recorder = get_new_stubbed_recorder ()
102
104
xray_recorder .configure (sampling = False , context_missing = 'LOG_ERROR' )
103
105
assert not xray_recorder .is_sampled ()
@@ -175,7 +177,6 @@ def test_in_segment_exception():
175
177
assert segment .fault is True
176
178
assert len (segment .cause ['exceptions' ]) == 1
177
179
178
-
179
180
with pytest .raises (Exception ):
180
181
with xray_recorder .in_segment ('name' ) as segment :
181
182
with xray_recorder .in_subsegment ('name' ) as subsegment :
@@ -259,7 +260,6 @@ def test_disabled_get_context_entity():
259
260
assert type (entity ) is DummySegment
260
261
261
262
262
-
263
263
def test_max_stack_trace_zero ():
264
264
xray_recorder .configure (max_trace_back = 1 )
265
265
with pytest .raises (Exception ):
@@ -279,3 +279,41 @@ def test_max_stack_trace_zero():
279
279
280
280
assert len (segment_with_stack .cause ['exceptions' ][0 ].stack ) == 1
281
281
assert len (segment_no_stack .cause ['exceptions' ][0 ].stack ) == 0
282
+
283
+
284
+ # CustomSampler to mimic the DefaultSampler,
285
+ # but without the rule and target polling logic.
286
+ class CustomSampler (DefaultSampler ):
287
+ def start (self ):
288
+ pass
289
+
290
+ def should_trace (self , sampling_req = None ):
291
+ rule_cache = RuleCache ()
292
+ rule_cache .last_updated = int (time .time ())
293
+ sampling_rule_a = SamplingRule (name = 'rule_a' ,
294
+ priority = 2 ,
295
+ rate = 0.5 ,
296
+ reservoir_size = 1 ,
297
+ service = 'app_a' )
298
+ sampling_rule_b = SamplingRule (name = 'rule_b' ,
299
+ priority = 2 ,
300
+ rate = 0.5 ,
301
+ reservoir_size = 1 ,
302
+ service = 'app_b' )
303
+ rule_cache .load_rules ([sampling_rule_a , sampling_rule_b ])
304
+ now = int (time .time ())
305
+ if sampling_req and not sampling_req .get ('service_type' , None ):
306
+ sampling_req ['service_type' ] = self ._origin
307
+ elif sampling_req is None :
308
+ sampling_req = {'service_type' : self ._origin }
309
+ matched_rule = rule_cache .get_matched_rule (sampling_req , now )
310
+ if matched_rule :
311
+ return self ._process_matched_rule (matched_rule , now )
312
+ else :
313
+ return self ._local_sampler .should_trace (sampling_req )
314
+
315
+
316
+ def test_begin_segment_matches_sampling_rule_on_name ():
317
+ xray_recorder .configure (sampling = True , sampler = CustomSampler ())
318
+ segment = xray_recorder .begin_segment ("app_b" )
319
+ assert segment .aws .get ('xray' ).get ('sampling_rule_name' ) == 'rule_b'
0 commit comments