Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds support for modern xray patching #112

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions fleece/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ def get_logger(name=None, level=None, stream=DEFAULT_STREAM,
return log


def initial_trace_and_context_binds(logger, trace_id, lambda_context):
def get_initial_trace_and_context_binds(trace_id, lambda_context):
"""A helper to set up standard trace_id and lambda_context binds"""
return logger.new(
trace_id=trace_id,
lambda_context={
return {
"trace_id": trace_id,
"lambda_context": {
"function_name": lambda_context.function_name,
"function_version": lambda_context.function_version,
"invoked_function_arn": lambda_context.invoked_function_arn,
Expand All @@ -216,6 +216,13 @@ def initial_trace_and_context_binds(logger, trace_id, lambda_context):
"log_group_name": lambda_context.log_group_name,
"log_stream_name": lambda_context.log_stream_name,
}
}


def initial_trace_and_context_binds(logger, trace_id, lambda_context):
"""A helper to set up standard trace_id and lambda_context binds"""
return logger.new(
**get_initial_trace_and_context_binds(trace_id, lambda_context)
)


Expand Down
26 changes: 26 additions & 0 deletions fleece/xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import time
import uuid

from aws_xray_sdk.core import patch as aws_xray_patch
from aws_xray_sdk.core import patch_all
from aws_xray_sdk.core import xray_recorder

from botocore.exceptions import ClientError
import wrapt

Expand Down Expand Up @@ -496,3 +500,25 @@ def to_safe_annotation_key(key):
def to_safe_annotation_value(value):
"""Xray doesn't like values that are not strings."""
return str(value)


def patch(modules=None): # type: (t.Optional[t.List[str]]) -> None
"""
Patch known external packages, (requests, pynamo, etc) along with the
given set of modules.
"""
if os.environ.get("AWS_XRAY_DAEMON_ADDRESS") is not None:
patch_all()
if modules:
aws_xray_patch(modules)


def log_args(lambda_context): # type: (t.Any) -> t.Mapping
"""
Returns arguments that should be bound to a log (via .new or .bind) when a
lambda handler is first called.
"""
return log.get_initial_trace_and_context_binds(
trace_id=xray_recorder.get_trace_entity().trace_id,
lambda_context=lambda_context
)
70 changes: 70 additions & 0 deletions tests/test_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mock

from fleece import xray
from fleece import testing


class GetTraceIDTestCase(unittest.TestCase):
Expand Down Expand Up @@ -124,3 +125,72 @@ def test_full_segment_document_with_extra_data(self):
self.assertEqual(segment_document['end_time'], end_time)
self.assertEqual(segment_document['name'], 'NAME')
self.assertEqual(segment_document['foo'], 'BAR')


class TestPatch(unittest.TestCase):
def setUp(self):
self.patch = mock.patch('fleece.xray.aws_xray_patch').start()
self.patch_all = mock.patch('fleece.xray.patch_all').start()

def tearDown(self):
self.patch.stop()
self.patch_all.stop()

def test_patch_when_not_in_lambda(self):
xray.patch()
self.patch_all.assert_not_called()
self.patch.assert_not_called()

@mock.patch.dict(
os.environ,
{
'AWS_XRAY_DAEMON_ADDRESS': 'http://localhost', # noqa
}
)
def test_patch_when_in_lambda(self):
xray.patch()
self.patch_all.assert_called_once()
self.patch.assert_not_called()

@mock.patch.dict(
os.environ,
{
'AWS_XRAY_DAEMON_ADDRESS': 'http://localhost', # noqa
}
)
def test_patch_when_in_lambda_2(self):
xray.patch(["fleece.utils"])
self.patch_all.assert_called_once()
self.patch.assert_called_once()


class TestLogArgs(unittest.TestCase):
def setUp(self):
self.maxDiff = None
self.get_trace_entity = mock.patch(
'fleece.xray.xray_recorder.get_trace_entity').start()

def tearDown(self):
self.get_trace_entity.stop()

def test_patch_when_not_in_lambda(self):
class Trace:
def __init__(self):
self.trace_id = "trace-id"

self.get_trace_entity.return_value = Trace()
lambda_context = testing.LambdaContext("FakeFunction")
actual = xray.log_args(lambda_context)
expected = {
"trace_id": 'trace-id',
"lambda_context": {
"function_name": lambda_context.function_name,
"function_version": lambda_context.function_version,
"invoked_function_arn": lambda_context.invoked_function_arn,
"memory_limit_in_mb": lambda_context.memory_limit_in_mb,
"aws_request_id": lambda_context.aws_request_id,
"log_group_name": lambda_context.log_group_name,
"log_stream_name": lambda_context.log_stream_name,
}
}
self.assertEqual(expected, actual)