Skip to content

Adding attach/detach methods as per spec #429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Feb 26, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 22 additions & 28 deletions opentelemetry-api/src/opentelemetry/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,6 @@ def set_value(
return Context(new_values)


def remove_value(
key: str, context: typing.Optional[Context] = None
) -> Context:
"""To remove a value, this method returns a new context with the key
cleared. Note that the removed value still remains present in the old
context.

Args:
key: The key of the entry to remove
context: The context to copy, if None, the current context is used
"""
if context is None:
context = get_current()
new_values = context.copy()
new_values.pop(key, None)
return Context(new_values)


def get_current() -> Context:
"""To access the context associated with program execution,
the RuntimeContext API provides a function which takes no arguments
Expand Down Expand Up @@ -104,16 +86,29 @@ def get_current() -> Context:
return _RUNTIME_CONTEXT.get_current() # type:ignore


def set_current(context: Context) -> Context:
"""To associate a context with program execution, the Context
API provides a function which takes a Context.
def attach(context: Context) -> object:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try making the token type a TypeVar? Just curious, I don't know that it's worth the extra typing boilerplate to do so.

"""Associates a Context with the caller's current execution unit. Returns
a token that can be used to restore the previous Context.

Args:
context: The Context to set as current.
"""
get_current()

return _RUNTIME_CONTEXT.set_current(context) # type:ignore


def detach(token: object) -> None:
"""Resets the Context associated with the caller's current execution unit
to the value it had before attaching a specified Context.

Args:
context: The context to use as current.
token: The Token that was returned by a previous call to attach a Context.
"""
old_context = get_current()
_RUNTIME_CONTEXT.set_current(context) # type:ignore
return old_context
try:
_RUNTIME_CONTEXT.reset(token) # type: ignore
except ValueError:
logger.error("Failed to detach context")


def with_current_context(
Expand All @@ -127,10 +122,9 @@ def call_with_current_context(
*args: "object", **kwargs: "object"
) -> "object":
try:
backup = get_current()
set_current(caller_context)
token = attach(caller_context)
return func(*args, **kwargs)
finally:
set_current(backup)
detach(token)

return call_with_current_context
13 changes: 11 additions & 2 deletions opentelemetry-api/src/opentelemetry/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ class RuntimeContext(ABC):
"""

@abstractmethod
def set_current(self, context: Context) -> None:
""" Sets the current `Context` object.
def set_current(self, context: Context) -> object:
""" Sets the current `Context` object. Returns a
token that can be used to reset to the previous `Context`.

Args:
context: The Context to set.
Expand All @@ -40,5 +41,13 @@ def set_current(self, context: Context) -> None:
def get_current(self) -> Context:
""" Returns the current `Context` object. """

@abstractmethod
def reset(self, token: object) -> None:
""" Resets Context to a previous value

Args:
token: A reference to a previous Context.
"""


__all__ = ["Context", "RuntimeContext"]
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextvars import ContextVar
from contextvars import ContextVar, Token
from sys import version_info

from opentelemetry.context.context import Context, RuntimeContext
Expand All @@ -35,13 +35,19 @@ def __init__(self) -> None:
self._CONTEXT_KEY, default=Context()
)

def set_current(self, context: Context) -> None:
def set_current(self, context: Context) -> object:
"""See `opentelemetry.context.RuntimeContext.set_current`."""
self._current_context.set(context)
return self._current_context.set(context)

def get_current(self) -> Context:
"""See `opentelemetry.context.RuntimeContext.get_current`."""
return self._current_context.get()

def reset(self, token: object) -> None:
"""See `opentelemetry.context.RuntimeContext.reset`."""
if not isinstance(token, Token):
raise ValueError("invalid token")
self._current_context.reset(token)


__all__ = ["ContextVarsRuntimeContext"]
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ class ThreadLocalRuntimeContext(RuntimeContext):
def __init__(self) -> None:
self._current_context = threading.local()

def set_current(self, context: Context) -> None:
def set_current(self, context: Context) -> object:
"""See `opentelemetry.context.RuntimeContext.set_current`."""
current = self.get_current()
setattr(self._current_context, str(id(current)), current)
setattr(self._current_context, self._CONTEXT_KEY, context)
return str(id(current))

def get_current(self) -> Context:
"""See `opentelemetry.context.RuntimeContext.get_current`."""
Expand All @@ -43,5 +46,12 @@ def get_current(self) -> Context:
) # type: Context
return context

def reset(self, token: object) -> None:
"""See `opentelemetry.context.RuntimeContext.reset`."""
if not hasattr(self._current_context, str(token)):
raise ValueError("invalid token")
context = getattr(self._current_context, str(token)) # type: Context
setattr(self._current_context, self._CONTEXT_KEY, context)


__all__ = ["ThreadLocalRuntimeContext"]
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import typing
from contextlib import contextmanager

from opentelemetry.context import get_value, set_current, set_value
from opentelemetry.context import attach, get_value, set_value
from opentelemetry.context.context import Context

PRINTABLE = frozenset(
Expand Down Expand Up @@ -142,4 +142,4 @@ def distributed_context_from_context(
def with_distributed_context(
dctx: DistributedContext, context: typing.Optional[Context] = None
) -> None:
set_current(set_value(_DISTRIBUTED_CONTEXT_KEY, dctx, context=context))
attach(set_value(_DISTRIBUTED_CONTEXT_KEY, dctx, context=context))
11 changes: 5 additions & 6 deletions opentelemetry-api/tests/context/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@


def do_work() -> None:
context.set_current(context.set_value("say", "bar"))
context.attach(context.set_value("say", "bar"))


class TestContext(unittest.TestCase):
def setUp(self):
context.set_current(Context())
context.attach(Context())

def test_context(self):
self.assertIsNone(context.get_value("say"))
Expand Down Expand Up @@ -55,11 +55,10 @@ def test_context_is_immutable(self):
context.get_current()["test"] = "cant-change-immutable"

def test_set_current(self):
context.set_current(context.set_value("a", "yyy"))
context.attach(context.set_value("a", "yyy"))

old_context = context.set_current(context.set_value("a", "zzz"))
self.assertEqual("yyy", context.get_value("a", context=old_context))
token = context.attach(context.set_value("a", "zzz"))
self.assertEqual("zzz", context.get_value("a"))

context.set_current(old_context)
context.detach(token)
self.assertEqual("yyy", context.get_value("a"))
27 changes: 23 additions & 4 deletions opentelemetry-api/tests/context/test_contextvars_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import unittest
from logging import ERROR
from unittest.mock import patch

from opentelemetry import context
Expand All @@ -27,18 +28,19 @@


def do_work() -> None:
context.set_current(context.set_value("say", "bar"))
context.attach(context.set_value("say", "bar"))


class TestContextVarsContext(unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we create a base class of the context tests, which we can extend and set the context constructor? seems like a great way to ensure standard behavior across context implementations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def setUp(self):
self.previous_context = context.get_current()

def tearDown(self):
context.set_current(self.previous_context)
context.attach(self.previous_context)

@patch(
"opentelemetry.context._RUNTIME_CONTEXT", ContextVarsRuntimeContext() # type: ignore
"opentelemetry.context._RUNTIME_CONTEXT", # type: ignore
ContextVarsRuntimeContext(),
)
def test_context(self):
self.assertIsNone(context.get_value("say"))
Expand All @@ -56,7 +58,8 @@ def test_context(self):
self.assertEqual(context.get_value("say", context=third), "bar")

@patch(
"opentelemetry.context._RUNTIME_CONTEXT", ContextVarsRuntimeContext() # type: ignore
"opentelemetry.context._RUNTIME_CONTEXT", # type: ignore
ContextVarsRuntimeContext(),
)
def test_set_value(self):
first = context.set_value("a", "yyy")
Expand All @@ -66,3 +69,19 @@ def test_set_value(self):
self.assertEqual("zzz", context.get_value("a", context=second))
self.assertEqual("---", context.get_value("a", context=third))
self.assertEqual(None, context.get_value("a"))

@patch(
"opentelemetry.context._RUNTIME_CONTEXT", # type: ignore
ContextVarsRuntimeContext(),
)
def test_set_current(self):
context.attach(context.set_value("a", "yyy"))

token = context.attach(context.set_value("a", "zzz"))
self.assertEqual("zzz", context.get_value("a"))

context.detach(token)
self.assertEqual("yyy", context.get_value("a"))

with self.assertLogs(level=ERROR):
context.detach("some garbage")
27 changes: 23 additions & 4 deletions opentelemetry-api/tests/context/test_threadlocal_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,27 @@
# limitations under the License.

import unittest
from logging import ERROR
from unittest.mock import patch

from opentelemetry import context
from opentelemetry.context.threadlocal_context import ThreadLocalRuntimeContext


def do_work() -> None:
context.set_current(context.set_value("say", "bar"))
context.attach(context.set_value("say", "bar"))


class TestThreadLocalContext(unittest.TestCase):
def setUp(self):
self.previous_context = context.get_current()

def tearDown(self):
context.set_current(self.previous_context)
context.attach(self.previous_context)

@patch(
"opentelemetry.context._RUNTIME_CONTEXT", ThreadLocalRuntimeContext() # type: ignore
"opentelemetry.context._RUNTIME_CONTEXT", # type: ignore
ThreadLocalRuntimeContext(),
)
def test_context(self):
self.assertIsNone(context.get_value("say"))
Expand All @@ -49,7 +51,8 @@ def test_context(self):
self.assertEqual(context.get_value("say", context=third), "bar")

@patch(
"opentelemetry.context._RUNTIME_CONTEXT", ThreadLocalRuntimeContext() # type: ignore
"opentelemetry.context._RUNTIME_CONTEXT", # type: ignore
ThreadLocalRuntimeContext(),
)
def test_set_value(self):
first = context.set_value("a", "yyy")
Expand All @@ -59,3 +62,19 @@ def test_set_value(self):
self.assertEqual("zzz", context.get_value("a", context=second))
self.assertEqual("---", context.get_value("a", context=third))
self.assertEqual(None, context.get_value("a"))

@patch(
"opentelemetry.context._RUNTIME_CONTEXT", # type: ignore
ThreadLocalRuntimeContext(),
)
def test_set_current(self):
context.attach(context.set_value("a", "yyy"))

token = context.attach(context.set_value("a", "zzz"))
self.assertEqual("zzz", context.get_value("a"))

context.detach(token)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth checking that we can detach contexts out of order too, maybe in a separate test_detach test:

In [2]: from opentelemetry import context

In [3]: t1 = context.attach(context.set_value('c', 1)); context.get_current()
Out[3]: {'c': 1}

In [4]: t2 = context.attach(context.set_value('c', 2)); context.get_current()
Out[4]: {'c': 2}

In [5]: context.detach(t1); context.get_current()
Out[5]: {}

In [6]: context.detach(t2); context.get_current()
Out[6]: {'c': 1}

I was surprised to see contextvars works this way. One difference between contextvars and threadlocals is that contextvars will raise if you try to reset using the same token twice. I don't know whether it's worth enforcing this behavoir for threadlocals though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test added. Didn't know about the reset behaviour either, good thinng to know!

self.assertEqual("yyy", context.get_value("a"))

with self.assertLogs(level=ERROR):
context.detach("some garbage")
5 changes: 2 additions & 3 deletions opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,14 +541,13 @@ def use_span(
) -> Iterator[trace_api.Span]:
"""See `opentelemetry.trace.Tracer.use_span`."""
try:
context_snapshot = context_api.get_current()
context_api.set_current(
token = context_api.attach(
context_api.set_value(self.source.key, span)
)
try:
yield span
finally:
context_api.set_current(context_snapshot)
context_api.detach(token)

except Exception as error: # pylint: disable=broad-except
if (
Expand Down
12 changes: 5 additions & 7 deletions opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import typing
from enum import Enum

from opentelemetry.context import get_current, set_current, set_value
from opentelemetry.context import attach, detach, get_current, set_value
from opentelemetry.trace import DefaultSpan
from opentelemetry.util import time_ns

Expand Down Expand Up @@ -75,14 +75,13 @@ def on_start(self, span: Span) -> None:
pass

def on_end(self, span: Span) -> None:
backup_context = get_current()
set_current(set_value("suppress_instrumentation", True))
token = attach(set_value("suppress_instrumentation", True))
try:
self.span_exporter.export((span,))
# pylint: disable=broad-except
except Exception:
logger.exception("Exception while exporting Span.")
set_current(backup_context)
detach(token)

def shutdown(self) -> None:
self.span_exporter.shutdown()
Expand Down Expand Up @@ -202,16 +201,15 @@ def export(self) -> None:
else:
self.spans_list[idx] = span
idx += 1
backup_context = get_current()
set_current(set_value("suppress_instrumentation", True))
token = attach(set_value("suppress_instrumentation", True))
try:
# Ignore type b/c the Optional[None]+slicing is too "clever"
# for mypy
self.span_exporter.export(self.spans_list[:idx]) # type: ignore
# pylint: disable=broad-except
except Exception:
logger.exception("Exception while exporting Span batch.")
set_current(backup_context)
detach(token)

if notify_flush:
with self.flush_condition:
Expand Down
Loading