diff --git a/src/main/kotlin/io/libp2p/etc/types/Delegates.kt b/src/main/kotlin/io/libp2p/etc/types/Delegates.kt index e44335659..9c9e3a77a 100644 --- a/src/main/kotlin/io/libp2p/etc/types/Delegates.kt +++ b/src/main/kotlin/io/libp2p/etc/types/Delegates.kt @@ -22,16 +22,20 @@ fun > cappedVar(value: T, lowerBound: T, upperBound: T) = /** * Creates a Double delegate which may drop value to [0.0] when the new value is less than [decayToZero] */ -fun cappedDouble(value: Double, decayToZero: Double = Double.MIN_VALUE): CappedValueDelegate { - return cappedDouble(value, decayToZero) { Double.MAX_VALUE } +fun cappedDouble(value: Double, decayToZero: Double = Double.MIN_VALUE, updateListener: (Double) -> Unit = { }): CappedValueDelegate { + return cappedDouble(value, decayToZero, { Double.MAX_VALUE }, updateListener) } /** * Creates a Double delegate which may cap upper bound (set [upperBound] when the new value is greater) * and may drop value to [0.0] when the new value is less than [decayToZero] */ -fun cappedDouble(value: Double, decayToZero: Double = Double.MIN_VALUE, upperBound: () -> Double) = - CappedValueDelegate(value, { decayToZero }, { 0.0 }, upperBound, upperBound) +fun cappedDouble( + value: Double, + decayToZero: Double = Double.MIN_VALUE, + upperBound: () -> Double, + updateListener: (Double) -> Unit = { } +) = CappedValueDelegate(value, { decayToZero }, { 0.0 }, upperBound, upperBound, updateListener) // thanks to https://stackoverflow.com/a/47948047/9630725 class LazyMutable(val initializer: () -> T, val rejectSetAfterGet: Boolean = false) : ReadWriteProperty { @@ -64,16 +68,25 @@ data class CappedValueDelegate>( private val lowerBound: () -> C, private val lowerBoundVal: () -> C = lowerBound, private val upperBound: () -> C, - private val upperBoundVal: () -> C = upperBound + private val upperBoundVal: () -> C = upperBound, + private val updateListener: (C) -> Unit = { } ) : ReadWriteProperty { override fun getValue(thisRef: Any?, property: KProperty<*>): C { + val oldValue = this.value val v1 = if (value > upperBound()) upperBoundVal() else value value = if (value < lowerBound()) lowerBoundVal() else v1 + if (oldValue != value) { + updateListener(value) + } return value } override fun setValue(thisRef: Any?, property: KProperty<*>, value: C) { + val oldValue = this.value this.value = value + if (oldValue != value) { + updateListener(value) + } } } diff --git a/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt b/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt index 3067b9ee2..e857d8edf 100644 --- a/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt +++ b/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt @@ -2,7 +2,6 @@ package io.libp2p.pubsub.gossip import io.libp2p.core.InternalErrorException import io.libp2p.core.PeerId -import io.libp2p.core.multiformats.Protocol import io.libp2p.core.pubsub.ValidationResult import io.libp2p.etc.types.anyComplete import io.libp2p.etc.types.copy @@ -35,9 +34,6 @@ const val MaxIAskedEntries = 256 const val MaxPeerIHaveEntries = 256 const val MaxIWantRequestsEntries = 10 * 1024 -fun P2PService.PeerHandler.getIP(): String? = - streamHandler.stream.connection.remoteAddress().getStringComponent(Protocol.IP4) - fun P2PService.PeerHandler.isOutbound() = streamHandler.stream.connection.isInitiator fun P2PService.PeerHandler.getPeerProtocol(): PubsubProtocol { @@ -59,6 +55,15 @@ open class GossipRouter @JvmOverloads constructor( subscriptionTopicSubscriptionFilter: TopicSubscriptionFilter = TopicSubscriptionFilter.AllowAllTopicSubscriptionFilter() ) : AbstractRouter(subscriptionTopicSubscriptionFilter, params.maxGossipMessageSize) { + // The idea behind choosing these specific default values for acceptRequestsWhitelist was + // - from one side are pretty small and safe: peer unlikely be able to drop its score to `graylist` + // with 128 messages. But even if so then it's not critical to accept some extra messages before + // blocking - not too much space for DoS here + // - from the other side param values are pretty high to yield good performance gain + val acceptRequestsWhitelistThresholdScore = 0 + val acceptRequestsWhitelistMaxMessages = 128 + val acceptRequestsWhitelistDuration = 1.seconds + val score by lazy { GossipScore(scoreParams, executor, curTimeMillis) } val fanout: MutableMap> = linkedMapOf() val mesh: MutableMap> = linkedMapOf() @@ -78,6 +83,7 @@ open class GossipRouter @JvmOverloads constructor( TimeUnit.MILLISECONDS ) } + private val acceptRequestsWhitelist = mutableMapOf() override val seenMessages: SeenCache> by lazy { TTLSeenCache(SimpleSeenCache(), params.seenTTL, curTimeMillis) @@ -104,6 +110,7 @@ open class GossipRouter @JvmOverloads constructor( score.notifyDisconnected(peer) mesh.values.forEach { it.remove(peer) } fanout.values.forEach { it.remove(peer) } + acceptRequestsWhitelist -= peer collectPeerMessage(peer) // discard them super.onPeerDisconnected(peer) } @@ -173,7 +180,30 @@ open class GossipRouter @JvmOverloads constructor( } override fun acceptRequestsFrom(peer: PeerHandler): Boolean { - return isDirect(peer) || score.score(peer) >= score.params.graylistThreshold + if (isDirect(peer)) { + return true + } + + val curTime = curTimeMillis() + val whitelistEntry = acceptRequestsWhitelist[peer] + if (whitelistEntry != null && + curTime <= whitelistEntry.whitelistedTill && + whitelistEntry.messagesAccepted < acceptRequestsWhitelistMaxMessages + ) { + + acceptRequestsWhitelist[peer] = whitelistEntry.incrementMessageCount() + return true + } + + val peerScore = score.score(peer) + if (peerScore >= acceptRequestsWhitelistThresholdScore) { + acceptRequestsWhitelist[peer] = + AcceptRequestsWhitelistEntry(curTime + acceptRequestsWhitelistDuration.toMillis()) + } else { + acceptRequestsWhitelist -= peer + } + + return peerScore >= score.params.graylistThreshold } override fun validateMessageListLimits(msg: Rpc.RPC): Boolean { @@ -550,4 +580,8 @@ open class GossipRouter @JvmOverloads constructor( ).build() ) } + + data class AcceptRequestsWhitelistEntry(val whitelistedTill: Long, val messagesAccepted: Int = 0) { + fun incrementMessageCount() = AcceptRequestsWhitelistEntry(whitelistedTill, messagesAccepted + 1) + } } diff --git a/src/main/kotlin/io/libp2p/pubsub/gossip/GossipScore.kt b/src/main/kotlin/io/libp2p/pubsub/gossip/GossipScore.kt index 0e546e65a..cb3fd4a11 100644 --- a/src/main/kotlin/io/libp2p/pubsub/gossip/GossipScore.kt +++ b/src/main/kotlin/io/libp2p/pubsub/gossip/GossipScore.kt @@ -1,6 +1,7 @@ package io.libp2p.pubsub.gossip import io.libp2p.core.PeerId +import io.libp2p.core.multiformats.Protocol import io.libp2p.core.pubsub.ValidationResult import io.libp2p.etc.types.cappedDouble import io.libp2p.etc.types.createLRUMap @@ -16,6 +17,9 @@ import kotlin.math.max import kotlin.math.min import kotlin.math.pow +fun P2PService.PeerHandler.getIP(): String? = + streamHandler.stream.connection.remoteAddress().getStringComponent(Protocol.IP4) + class GossipScore( val params: GossipScoreParams = GossipScoreParams(), val executor: ScheduledExecutorService, @@ -25,20 +29,41 @@ class GossipScore( inner class TopicScores(val topic: Topic) { private val params: GossipTopicScoreParams get() = topicParams[topic] + private val recalcMaxDuration = params.timeInMeshQuantum + private var cachedScore: Double = 0.0 + private var cacheValid: Boolean = false + private var prevParams = params + private var prevTime = curTimeMillis() var joinedMeshTimeMillis: Long = 0 + set(value) { + field = value + cacheValid = false + } + var firstMessageDeliveries: Double by cappedDouble( 0.0, this@GossipScore.peerParams.decayToZero, - { params.firstMessageDeliveriesCap } + { params.firstMessageDeliveriesCap }, + { cacheValid = false } ) var meshMessageDeliveries: Double by cappedDouble( 0.0, this@GossipScore.peerParams.decayToZero, - { params.meshMessageDeliveriesCap } + { params.meshMessageDeliveriesCap }, + { cacheValid = false } + ) + var meshFailurePenalty: Double by cappedDouble( + 0.0, + this@GossipScore.peerParams.decayToZero, + { _ -> cacheValid = false } + ) + + var invalidMessages: Double by cappedDouble( + 0.0, + this@GossipScore.peerParams.decayToZero, + { _ -> cacheValid = false } ) - var meshFailurePenalty: Double by cappedDouble(0.0, this@GossipScore.peerParams.decayToZero) - var invalidMessages: Double by cappedDouble(0.0, this@GossipScore.peerParams.decayToZero) fun inMesh() = joinedMeshTimeMillis > 0 @@ -58,19 +83,26 @@ class GossipScore( fun meshMessageDeliveriesDeficitSqr() = meshMessageDeliveriesDeficit().pow(2) fun calcTopicScore(): Double { + val curTime = curTimeMillis() + if (cacheValid && prevParams === params && curTime - prevTime < recalcMaxDuration.toMillis()) { + return cachedScore + } + prevParams = params + prevTime = curTime val p1 = meshTimeNorm() val p2 = firstMessageDeliveries val p3 = meshMessageDeliveriesDeficitSqr() val p3b = meshFailurePenalty val p4 = invalidMessages.pow(2) - val ret = params.topicWeight * ( + cachedScore = params.topicWeight * ( p1 * params.timeInMeshWeight + p2 * params.firstMessageDeliveriesWeight + p3 * params.meshMessageDeliveriesWeight + p3b * params.meshFailurePenaltyWeight + p4 * params.invalidMessageDeliveriesWeight ) - return ret + cacheValid = true + return cachedScore } fun decayScores() { @@ -101,6 +133,7 @@ class GossipScore( private val validationTime: MutableMap = createLRUMap(1024) val peerScores = mutableMapOf() + private val peerIpCache = mutableMapOf() val refreshTask: ScheduledFuture<*> @@ -112,6 +145,8 @@ class GossipScore( private fun getPeerScores(peer: P2PService.PeerHandler) = peerScores.computeIfAbsent(peer.peerId) { PeerScores() } + private fun getPeerIp(peer: P2PService.PeerHandler): String? = peerIpCache[peer.peerId] + private fun getTopicScores(peer: P2PService.PeerHandler, topic: Topic) = getPeerScores(peer).topicScores.computeIfAbsent(topic) { TopicScores(it) } @@ -133,7 +168,7 @@ class GossipScore( ) val appScore = peerParams.appSpecificScore(peer.peerId) * peerParams.appSpecificWeight - val peersInIp: Int = peer.getIP()?.let { thisIp -> + val peersInIp: Int = getPeerIp(peer)?.let { thisIp -> if (peerParams.ipWhitelisted(thisIp)) 0 else peerScores.values.count { thisIp in it.ips } } ?: 0 @@ -164,12 +199,17 @@ class GossipScore( } getPeerScores(peer).disconnectedTimeMillis = curTimeMillis() + peerIpCache -= peer.peerId } fun notifyConnected(peer: P2PService.PeerHandler) { + peer.getIP()?.also { peerIp -> + peerIpCache[peer.peerId] = peerIp + } + getPeerScores(peer).apply { connectedTimeMillis = curTimeMillis() - peer.getIP()?.also { ips += it } + getPeerIp(peer)?.also { ips += it } } } diff --git a/src/test/kotlin/io/libp2p/etc/types/DelegatesTest.kt b/src/test/kotlin/io/libp2p/etc/types/DelegatesTest.kt index ab1551fa6..f1573f8c7 100644 --- a/src/test/kotlin/io/libp2p/etc/types/DelegatesTest.kt +++ b/src/test/kotlin/io/libp2p/etc/types/DelegatesTest.kt @@ -12,9 +12,21 @@ class DelegatesTest { val min = AtomicDouble(5.0) val minVal = AtomicDouble(0.0) - var cappedValueDelegate: Double by CappedValueDelegate(0.0, { min.get() }, { minVal.get() }, { max.get() }, { maxVal.get() }) + val cappedValueDelegateUpdates = mutableListOf() + val cappedDoubleUpdates = mutableListOf() + + var cappedValueDelegate: Double by CappedValueDelegate( + 0.0, + { min.get() }, + { minVal.get() }, + { max.get() }, + { maxVal.get() }, + { cappedValueDelegateUpdates += it } + ) + var cappedInt: Int by cappedVar(10, 5, 20) - var cappedDouble: Double by cappedDouble(0.0, 1.0) { max.get() } + var cappedDouble: Double by cappedDouble(0.0, 1.0, { -> max.get() }, { cappedDoubleUpdates += it }) + var blackhole: Double = 0.0 @Test fun cappedVarTest() { @@ -90,4 +102,44 @@ class DelegatesTest { min.set(7.0) assertThat(cappedValueDelegate).isEqualTo(0.0) } + + @Test + fun `test cappedDouble update callback`() { + cappedDouble = 5.0 + assertThat(cappedDoubleUpdates).containsExactly(5.0) + + cappedDouble = 5.0 + assertThat(cappedDoubleUpdates).containsExactly(5.0) + + cappedDouble = 4.0 + assertThat(cappedDoubleUpdates).containsExactly(5.0, 4.0) + + max.set(3.0) + blackhole = cappedDouble + assertThat(cappedDoubleUpdates).containsExactly(5.0, 4.0, 3.0) + + max.set(5.0) + blackhole = cappedDouble + assertThat(cappedDoubleUpdates).containsExactly(5.0, 4.0, 3.0) + } + + @Test + fun `test cappedValueDelegate update callback`() { + cappedValueDelegate = 8.0 + assertThat(cappedValueDelegateUpdates).containsExactly(8.0) + + cappedValueDelegate = 8.0 + assertThat(cappedValueDelegateUpdates).containsExactly(8.0) + + cappedValueDelegate = 7.0 + assertThat(cappedValueDelegateUpdates).containsExactly(8.0, 7.0) + + max.set(6.0) + blackhole = cappedValueDelegate + assertThat(cappedValueDelegateUpdates).containsExactly(8.0, 7.0, 15.0) + + max.set(7.0) + blackhole = cappedValueDelegate + assertThat(cappedValueDelegateUpdates).containsExactly(8.0, 7.0, 15.0) + } } diff --git a/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt b/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt index 6d2fed860..4e77dbe41 100644 --- a/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt +++ b/src/test/kotlin/io/libp2p/pubsub/gossip/GossipV1_1Tests.kt @@ -250,6 +250,60 @@ class GossipV1_1Tests { test.mockRouter.inboundMessages.clear() } + @Test + fun `test that acceptRequests whitelist is refreshed on timeout`() { + val appScore = AtomicDouble() + val peerScoreParams = GossipPeerScoreParams( + appSpecificScore = { appScore.get() }, + appSpecificWeight = 1.0 + ) + val scoreParams = GossipScoreParams( + peerScoreParams = peerScoreParams, + graylistThreshold = -100.0 + ) + val test = TwoRoutersTest(scoreParams = scoreParams) + + // with this score the peer should be whitelisted for some period + appScore.set(test.gossipRouter.acceptRequestsWhitelistThresholdScore.toDouble()) + + test.mockRouter.subscribe("topic1") + test.gossipRouter.subscribe("topic1") + + // 2 heartbeats - the topic should be GRAFTed + test.fuzz.timeController.addTime(2.seconds) + test.mockRouter.waitForMessage { it.hasControl() && it.control.graftCount > 0 } + test.mockRouter.inboundMessages.clear() + + val msg1 = Rpc.RPC.newBuilder() + .addPublish(newProtoMessage("topic1", 0L, "Hello-1".toByteArray())) + .build() + test.mockRouter.sendToSingle(msg1) + // at this point peer is whitelisted for a period + + appScore.set(-101.0) + + val graftMsg = Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().addGraft( + Rpc.ControlGraft.newBuilder().setTopicID("topic1") + ) + ).build() + for (i in 0..2) { + test.fuzz.timeController.addTime(50.millis) + + // even having the score below gralist threshold the peer should be answered because + // it is still in acceptRequests whitelist + test.mockRouter.sendToSingle(graftMsg) + test.mockRouter.waitForMessage { it.hasControl() && it.control.pruneCount > 0 } + } + + test.fuzz.timeController.addTime(test.gossipRouter.acceptRequestsWhitelistDuration) + // at this point whitelist should be invalidated and score recalculated + + test.mockRouter.sendToSingle(graftMsg) + // the last message should be ignored + assertEquals(0, test.mockRouter.inboundMessages.size) + } + @Test fun testGraftFloodPenalty() { val test = TwoRoutersTest()