Skip to content

Prevent before callsites targeting constructors in super calls #8549

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static com.github.javaparser.ast.Modifier.Keyword.PUBLIC;
import static datadog.trace.plugin.csi.impl.CallSiteFactory.typeResolver;
import static datadog.trace.plugin.csi.util.CallSiteConstants.ADVICE_TYPE_CLASS;
import static datadog.trace.plugin.csi.util.CallSiteConstants.AUTO_SERVICE_FQDN;
import static datadog.trace.plugin.csi.util.CallSiteConstants.CALL_SITES_CLASS;
import static datadog.trace.plugin.csi.util.CallSiteConstants.CALL_SITES_FQCN;
Expand Down Expand Up @@ -185,20 +186,24 @@ private void addAdviceLambda(
final MethodType pointCut = spec.getPointcut();
final BlockStmt adviceBody = new BlockStmt();
final Expression advice;
final String type;
if (spec.isInvokeDynamic()) {
advice = invokeDynamicAdviceSignature(adviceBody);
} else {
advice = invokeAdviceSignature(adviceBody);
}
if (spec instanceof BeforeSpecification) {
type = "BEFORE";
writeStackOperations(spec, adviceBody);
writeAdviceMethodCall(spec, adviceBody);
writeOriginalMethodCall(spec, adviceBody);
} else if (spec instanceof AfterSpecification) {
type = "AFTER";
writeStackOperations(spec, adviceBody);
writeOriginalMethodCall(spec, adviceBody);
writeAdviceMethodCall(spec, adviceBody);
} else {
type = "AROUND";
writeAdviceMethodCall(spec, adviceBody);
}
body.addStatement(
Expand All @@ -207,6 +212,10 @@ private void addAdviceLambda(
.setName("addAdvice")
.setArguments(
new NodeList<>(
new FieldAccessExpr()
.setScope(
new TypeExpr(new ClassOrInterfaceType().setName(ADVICE_TYPE_CLASS)))
.setName(type),
new StringLiteralExpr(pointCut.getOwner().getInternalName()),
new StringLiteralExpr(pointCut.getMethodName()),
new StringLiteralExpr(pointCut.getMethodType().getDescriptor()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,19 +364,19 @@ private static LambdaExpr findAdviceLambda(
final MethodType pointcut = spec.getPointcut();
for (final MethodCallExpr add : addAdvices) {
final NodeList<Expression> arguments = add.getArguments();
final String owner = arguments.get(0).asStringLiteralExpr().asString();
final String owner = arguments.get(1).asStringLiteralExpr().asString();
if (!owner.equals(pointcut.getOwner().getInternalName())) {
continue;
}
final String method = arguments.get(1).asStringLiteralExpr().asString();
final String method = arguments.get(2).asStringLiteralExpr().asString();
if (!method.equals(pointcut.getMethodName())) {
continue;
}
final String description = arguments.get(2).asStringLiteralExpr().asString();
final String description = arguments.get(3).asStringLiteralExpr().asString();
if (!description.equals(pointcut.getMethodType().getDescriptor())) {
continue;
}
return arguments.get(3).asLambdaExpr();
return arguments.get(4).asLambdaExpr();
}
throw new IllegalArgumentException("Cannot find lambda expression for pointcut " + pointcut);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ private CallSiteConstants() {}

public static final String HAS_ENABLED_PROPERTY_CLASS = CALL_SITES_CLASS + ".HasEnabledProperty";

public static final String ADVICE_TYPE_CLASS = "AdviceType";

public static final String STACK_DUP_MODE_CLASS = "StackDupMode";

public static final String METHOD_HANDLER_CLASS = "MethodHandler";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
interfaces(CallSites)
helpers(BeforeAdvice)
advices(0) {
type("BEFORE")
pointcut('java/security/MessageDigest', 'getInstance', '(Ljava/lang/String;)Ljava/security/MessageDigest;')
statements(
'handler.dupParameters(descriptor, StackDupMode.COPY);',
Expand Down Expand Up @@ -76,6 +77,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
interfaces(CallSites)
helpers(AroundAdvice)
advices(0) {
type("AROUND")
pointcut('java/lang/String', 'replaceAll', '(Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;')
statements(
'handler.advice("datadog/trace/plugin/csi/impl/AdviceGeneratorTest$AroundAdvice", "around", "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)Ljava/lang/String;");'
Expand Down Expand Up @@ -106,6 +108,7 @@ final class AdviceGeneratorTest extends BaseCsiPluginTest {
interfaces(CallSites)
helpers(AfterAdvice)
advices(0) {
type("AFTER")
pointcut('java/lang/String', 'concat', '(Ljava/lang/String;)Ljava/lang/String;')
statements(
'handler.dupInvoke(owner, descriptor, StackDupMode.COPY);',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package datadog.trace.plugin.csi.impl.assertion

class AdviceAssert {
protected String type
protected String owner
protected String method
protected String descriptor
protected Collection<String> statements

void type(String type) {
assert type == this.type
}

void pointcut(String owner, String method, String descriptor) {
assert owner == this.owner
assert method == this.method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,12 @@ class AssertBuilder<C extends CallSiteAssert> {
return getMethodCalls(acceptMethod).findAll {
it.nameAsString == 'addAdvice'
}.collect {
def (owner, method, descriptor) = it.arguments.subList(0, 3)*.asStringLiteralExpr()*.asString()
final handlerLambda = it.arguments[3].asLambdaExpr()
final adviceType = it.arguments.get(0).asFieldAccessExpr().getName()
def (owner, method, descriptor) = it.arguments.subList(1, 4)*.asStringLiteralExpr()*.asString()
final handlerLambda = it.arguments[4].asLambdaExpr()
final advice = handlerLambda.body.asBlockStmt().statements*.toString()
return new AdviceAssert([
type : adviceType,
owner : owner,
method : method,
descriptor: descriptor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ class IastExtensionTest extends BaseCsiPluginTest {
return getMethodCalls(acceptMethod).findAll {
it.nameAsString == 'addAdvice'
}.collect {
def (owner, method, descriptor) = it.arguments.subList(0, 3)*.asStringLiteralExpr()*.asString()
final handlerLambda = it.arguments[3].asLambdaExpr()
def (owner, method, descriptor) = it.arguments.subList(1, 4)*.asStringLiteralExpr()*.asString()
final handlerLambda = it.arguments[4].asLambdaExpr()
final statements = handlerLambda.body.asBlockStmt().statements
final instrumentedStmt = statements.get(0).asIfStmt()
final executedStmt = statements.get(1).asIfStmt()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package datadog.trace.agent.tooling.bytebuddy.csi;

import static datadog.trace.agent.tooling.csi.CallSiteAdvice.AdviceType.AROUND;

import datadog.trace.agent.tooling.csi.CallSites;
import datadog.trace.agent.tooling.csi.InvokeAdvice;
import datadog.trace.agent.tooling.muzzle.ReferenceMatcher;
Expand Down Expand Up @@ -44,6 +46,7 @@ public Iterable<CallSites> get() {
return Collections.singletonList(
(container -> {
container.addAdvice(
AROUND,
"javax/servlet/ServletRequest",
"getParameter",
"(Ljava/lang/String;)Ljava/lang/String;",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import datadog.trace.agent.tooling.bytebuddy.ClassFileLocators;
import datadog.trace.agent.tooling.csi.CallSiteAdvice;
import datadog.trace.agent.tooling.csi.CallSites;
import datadog.trace.agent.tooling.csi.InvokeAdvice;
import datadog.trace.agent.tooling.csi.InvokeDynamicAdvice;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -141,6 +143,11 @@ public CallSiteAdvice findAdvice(
return methodAdvices.get(descriptor);
}

/** Gets the type of advice we are dealing with */
public byte typeOf(final CallSiteAdvice advice) {
return ((TypedAdvice) advice).getType();
}

public String[] getHelpers() {
return helpers;
}
Expand Down Expand Up @@ -176,15 +183,17 @@ public void addHelpers(final String... helperClassNames) {

@Override
public void addAdvice(
final String type,
final byte type,
final String owner,
final String method,
final String descriptor,
final CallSiteAdvice advice) {
final Map<String, Map<String, CallSiteAdvice>> typeAdvices =
advices.computeIfAbsent(type, k -> new HashMap<>());
advices.computeIfAbsent(owner, k -> new HashMap<>());
final Map<String, CallSiteAdvice> methodAdvices =
typeAdvices.computeIfAbsent(method, k -> new HashMap<>());
final CallSiteAdvice oldAdvice = methodAdvices.put(descriptor, advice);
final CallSiteAdvice oldAdvice =
methodAdvices.put(descriptor, TypedAdvice.withType(advice, type));
if (oldAdvice != null) {
throw new UnsupportedOperationException(
String.format(
Expand Down Expand Up @@ -360,4 +369,67 @@ public interface Listener {
void onConstantPool(
@Nonnull TypeDescription type, @Nonnull ConstantPool pool, final byte[] classFile);
}

private interface TypedAdvice {
byte getType();

static CallSiteAdvice withType(final CallSiteAdvice advice, final byte type) {
if (advice instanceof InvokeAdvice) {
return new InvokeWithType((InvokeAdvice) advice, type);
} else {
return new InvokeDynamicWithType((InvokeDynamicAdvice) advice, type);
}
}
}

private static class InvokeWithType implements InvokeAdvice, TypedAdvice {
private final InvokeAdvice advice;
private final byte type;

private InvokeWithType(InvokeAdvice advice, byte type) {
this.advice = advice;
this.type = type;
}

@Override
public byte getType() {
return type;
}

@Override
public void apply(
final MethodHandler handler,
final int opcode,
final String owner,
final String name,
final String descriptor,
final boolean isInterface) {
advice.apply(handler, opcode, owner, name, descriptor, isInterface);
}
}

private static class InvokeDynamicWithType implements InvokeDynamicAdvice, TypedAdvice {
private final InvokeDynamicAdvice advice;
private final byte type;

private InvokeDynamicWithType(final InvokeDynamicAdvice advice, final byte type) {
this.advice = advice;
this.type = type;
}

@Override
public byte getType() {
return type;
}

@Override
public void apply(
final MethodHandler handler,
final String name,
final String descriptor,
final Handle bootstrapMethodHandle,
final Object... bootstrapMethodArguments) {
advice.apply(handler, name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package datadog.trace.agent.tooling.bytebuddy.csi;

import static datadog.trace.agent.tooling.csi.CallSiteAdvice.AdviceType.AFTER;
import static datadog.trace.api.telemetry.LogCollector.SEND_TELEMETRY;
import static net.bytebuddy.jar.asm.ClassWriter.COMPUTE_MAXS;

Expand Down Expand Up @@ -126,7 +127,7 @@ public MethodVisitor visitMethod(

private static class CallSiteMethodVisitor extends MethodVisitor
implements CallSiteAdvice.MethodHandler {
private final Advices advices;
protected final Advices advices;

private CallSiteMethodVisitor(
@Nonnull final Advices advices, @Nonnull final MethodVisitor delegated) {
Expand All @@ -144,12 +145,22 @@ public void visitMethodInsn(

CallSiteAdvice advice = advices.findAdvice(owner, name, descriptor);
if (advice instanceof InvokeAdvice) {
((InvokeAdvice) advice).apply(this, opcode, owner, name, descriptor, isInterface);
invokeAdvice((InvokeAdvice) advice, opcode, owner, name, descriptor, isInterface);
} else {
mv.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
}
}

protected void invokeAdvice(
final InvokeAdvice advice,
final int opcode,
final String owner,
final String name,
final String descriptor,
final boolean isInterface) {
advice.apply(this, opcode, owner, name, descriptor, isInterface);
}

@Override
public void visitInvokeDynamicInsn(
final String name,
Expand All @@ -158,14 +169,27 @@ public void visitInvokeDynamicInsn(
final Object... bootstrapMethodArguments) {
CallSiteAdvice advice = advices.findAdvice(bootstrapMethodHandle);
if (advice instanceof InvokeDynamicAdvice) {
((InvokeDynamicAdvice) advice)
.apply(this, name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);
invokeDynamicAdvice(
(InvokeDynamicAdvice) advice,
name,
descriptor,
bootstrapMethodHandle,
bootstrapMethodArguments);
} else {
mv.visitInvokeDynamicInsn(
name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);
}
}

protected void invokeDynamicAdvice(
final InvokeDynamicAdvice advice,
final String name,
final String descriptor,
final Handle bootstrapMethodHandle,
final Object... bootstrapMethodArguments) {
advice.apply(this, name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);
}

@Override
public void instruction(final int opcode) {
mv.visitInsn(opcode);
Expand Down Expand Up @@ -347,5 +371,22 @@ public void dupParameters(final String methodDescriptor, final StackDupMode mode
super.dupParameters(
methodDescriptor, isSuperCall ? StackDupMode.PREPEND_ARRAY_SUPER_CTOR : mode);
}

@Override
protected void invokeAdvice(
final InvokeAdvice advice,
final int opcode,
final String owner,
final String name,
final String descriptor,
final boolean isInterface) {
if (isSuperCall && advices.typeOf(advice) != AFTER) {
// TODO APPSEC-57009 calls to super are only instrumented by after call sites
// just ignore the advice and keep on
mv.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
} else {
super.invokeAdvice(advice, opcode, owner, name, descriptor, isInterface);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,13 @@ enum StackDupMode {
/** Copies the parameters in an array and appends it */
APPEND_ARRAY
}

abstract class AdviceType {

private AdviceType() {}

public static final byte BEFORE = -1;
public static final byte AROUND = 0;
public static final byte AFTER = 1;
}
}
Loading