diff --git a/src/trace/patch-http.spec.ts b/src/trace/patch-http.spec.ts index 04330206..ecac7a81 100644 --- a/src/trace/patch-http.spec.ts +++ b/src/trace/patch-http.spec.ts @@ -7,6 +7,7 @@ import { LogLevel, setLogLevel } from "../utils"; import { parentIDHeader, SampleMode, samplingPriorityHeader, traceIDHeader, Source } from "./constants"; import { patchHttp, unpatchHttp } from "./patch-http"; import { TraceContextService } from "./trace-context-service"; +import { URL } from "url"; describe("patchHttp", () => { let traceWrapper = { @@ -131,4 +132,22 @@ describe("patchHttp", () => { expect(headers[parentIDHeader]).toBeUndefined(); expect(headers[samplingPriorityHeader]).toBeUndefined(); }); + it("injects tracing headers when using the new WHATWG URL object", () => { + nock("http://www.example.com").get("/").reply(200, {}); + patchHttp(contextService); + const url = new URL("http://www.example.com"); + const req = http.request(url); + expectHeaders(req); + }); + it("injects tracing headers when using the new WHATWG URL object and callback", (done) => { + nock("http://www.example.com").get("/").reply(200, {}); + patchHttp(contextService); + const url = new URL("http://www.example.com"); + const req = http.request(url, {}, () => { + done(); + }); + req.end(); + + expectHeaders(req); + }); }); diff --git a/src/trace/patch-http.ts b/src/trace/patch-http.ts index c6891c48..9e0db0c2 100644 --- a/src/trace/patch-http.ts +++ b/src/trace/patch-http.ts @@ -35,14 +35,13 @@ export function unpatchHttp() { function patchMethod(mod: typeof http | typeof https, method: "get" | "request", contextService: TraceContextService) { shimmer.wrap(mod, method, (original) => { const fn = (arg1: any, arg2: any, arg3: any) => { - const { options, callback } = normalizeArgs(arg1, arg2, arg3); - const requestOpts = getRequestOptionsWithTraceContext(options, contextService); + [arg1, arg2, arg3] = addTraceContextToArgs(contextService, arg1, arg2, arg3); - if (isIntegrationTest()) { - _logHttpRequest(requestOpts); + if (arg3 === undefined || arg3 === null) { + return original(arg1, arg2); + } else { + return original(arg1, arg2, arg3); } - - return original(requestOpts, callback); }; return fn as any; }); @@ -54,23 +53,37 @@ function unpatchMethod(mod: typeof http | typeof https, method: "get" | "request } /** - * The input into the http.request function has 6 different overloads. This method normalized the inputs - * into a consistent format. + * Finds the RequestOptions in the args and injects context into headers */ -function normalizeArgs( +function addTraceContextToArgs( + contextService: TraceContextService, arg1: string | URL | http.RequestOptions, arg2?: RequestCallback | http.RequestOptions, arg3?: RequestCallback, ) { - let options: http.RequestOptions = typeof arg1 === "string" ? parse(arg1) : { ...arg1 }; - options.headers = options.headers || {}; - let callback = arg3; - if (typeof arg2 === "function") { - callback = arg2; - } else if (typeof arg2 === "object") { - options = { ...options, ...arg2 }; + let requestOpts: http.RequestOptions | undefined; + if (typeof arg1 === "string" || arg1 instanceof URL) { + if (arg2 === undefined || arg2 === null) { + requestOpts = { + method: "GET", + }; + requestOpts = getRequestOptionsWithTraceContext(requestOpts, contextService); + return [arg1, requestOpts, arg3]; + } else if (typeof arg2 === "function") { + requestOpts = { + method: "GET", + }; + requestOpts = getRequestOptionsWithTraceContext(requestOpts, contextService); + return [arg1, requestOpts, arg2]; + } else { + requestOpts = arg2 as http.RequestOptions; + requestOpts = getRequestOptionsWithTraceContext(requestOpts, contextService); + return [arg1, requestOpts, arg3]; + } + } else { + requestOpts = getRequestOptionsWithTraceContext(arg1, contextService); + return [requestOpts, arg2, arg3]; } - return { options, callback }; } function getRequestOptionsWithTraceContext( @@ -86,10 +99,16 @@ function getRequestOptionsWithTraceContext( ...headers, ...traceHeaders, }; - return { + const requestOpts = { ...options, headers, }; + // Logging all http requests during integration tests let's + // us track traffic in our test snapshots + if (isIntegrationTest()) { + _logHttpRequest(requestOpts); + } + return requestOpts; } function isIntegrationTest() {