From fd5593883d284c6b74edcd126f70bd2754ff5b9b Mon Sep 17 00:00:00 2001 From: Roman Elizarov Date: Thu, 5 Mar 2020 18:37:58 +0300 Subject: [PATCH] Fixed memory leak on a race between adding/removing from lock-free list * The problem was introduced by #1565. When doing concurrent add+removeFirst the following can happen: - "add" completes, but has not correct prev pointer in next node yet - "removeFirst" removes freshly added element - "add" performs "finishAdd" that adjust prev pointer of the next node and thus removed element is pointed from the list again * A separate LockFreeLinkedListAddRemoveStressTest is added that reproduces this problem. * The old LockFreeLinkedListAtomicLFStressTest is refactored a bit. --- .../jvm/src/internal/LockFreeLinkedList.kt | 15 +++-- .../LockFreeLinkedListAddRemoveStressTest.kt | 56 +++++++++++++++++++ .../LockFreeLinkedListAtomicLFStressTest.kt | 48 ++++++++-------- 3 files changed, 90 insertions(+), 29 deletions(-) create mode 100644 kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAddRemoveStressTest.kt diff --git a/kotlinx-coroutines-core/jvm/src/internal/LockFreeLinkedList.kt b/kotlinx-coroutines-core/jvm/src/internal/LockFreeLinkedList.kt index 26fd169da3..f718df04b5 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/LockFreeLinkedList.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/LockFreeLinkedList.kt @@ -390,7 +390,7 @@ public actual open class LockFreeLinkedListNode { final override fun updatedNext(affected: Node, next: Node): Any = next.removed() final override fun finishOnSuccess(affected: Node, next: Node) { - // Complete removal operation here. It bails out if next node is also removed and it becomes + // Complete removal operation here. It bails out if next node is also removed. It becomes // responsibility of the next's removes to call correctPrev which would help fix all the links. next.correctPrev(null) } @@ -531,7 +531,12 @@ public actual open class LockFreeLinkedListNode { private fun finishAdd(next: Node) { next._prev.loop { nextPrev -> if (this.next !== next) return // this or next was removed or another node added, remover/adder fixes up links - if (next._prev.compareAndSet(nextPrev, this)) return + if (next._prev.compareAndSet(nextPrev, this)) { + // This newly added node could have been removed, and the above CAS would have added it physically again. + // Let us double-check for this situation and correct if needed + if (isRemoved) next.correctPrev(null) + return + } } } @@ -546,7 +551,7 @@ public actual open class LockFreeLinkedListNode { * * When this node is removed. In this case there is no need to waste time on corrections, because * remover of this node will ultimately call [correctPrev] on the next node and that will fix all * the links from this node, too. - * * When [op] descriptor is not `null` and and operation descriptor that is [OpDescriptor.isEarlierThan] + * * When [op] descriptor is not `null` and operation descriptor that is [OpDescriptor.isEarlierThan] * that current [op] is found while traversing the list. This `null` result will be translated * by callers to [RETRY_ATOMIC]. */ @@ -554,7 +559,7 @@ public actual open class LockFreeLinkedListNode { val oldPrev = _prev.value var prev: Node = oldPrev var last: Node? = null // will be set so that last.next === prev - while (true) { // move the the left until first non-removed node + while (true) { // move the left until first non-removed node val prevNext: Any = prev._next.value when { // fast path to find quickly find prev node when everything is properly linked @@ -565,7 +570,7 @@ public actual open class LockFreeLinkedListNode { // Note: retry from scratch on failure to update prev return correctPrev(op) } - return prev // return a correct prev + return prev // return the correct prev } // slow path when we need to help remove operations this.isRemoved -> return null // nothing to do, this node was removed, bail out asap to save time diff --git a/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAddRemoveStressTest.kt b/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAddRemoveStressTest.kt new file mode 100644 index 0000000000..3229e664c1 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAddRemoveStressTest.kt @@ -0,0 +1,56 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.internal + +import kotlinx.atomicfu.* +import kotlinx.coroutines.* +import java.util.concurrent.* +import kotlin.concurrent.* +import kotlin.test.* + +class LockFreeLinkedListAddRemoveStressTest : TestBase() { + private class Node : LockFreeLinkedListNode() + + private val nRepeat = 100_000 * stressTestMultiplier + private val list = LockFreeLinkedListHead() + private val barrier = CyclicBarrier(3) + private val done = atomic(false) + private val removed = atomic(0) + + @Test + fun testStressAddRemove() { + val threads = ArrayList() + threads += testThread("adder") { + val node = Node() + list.addLast(node) + if (node.remove()) removed.incrementAndGet() + } + threads += testThread("remover") { + val node = list.removeFirstOrNull() + if (node != null) removed.incrementAndGet() + } + try { + for (i in 1..nRepeat) { + barrier.await() + barrier.await() + assertEquals(i, removed.value) + list.validate() + } + } finally { + done.value = true + barrier.await() + threads.forEach { it.join() } + } + } + + private fun testThread(name: String, op: () -> Unit) = thread(name = name) { + while (true) { + barrier.await() + if (done.value) break + op() + barrier.await() + } + } +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAtomicLFStressTest.kt b/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAtomicLFStressTest.kt index b967c46a8f..225b848186 100644 --- a/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAtomicLFStressTest.kt +++ b/kotlinx-coroutines-core/jvm/test/internal/LockFreeLinkedListAtomicLFStressTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. */ package kotlinx.coroutines.internal @@ -19,9 +19,9 @@ import kotlin.test.* class LockFreeLinkedListAtomicLFStressTest { private val env = LockFreedomTestEnvironment("LockFreeLinkedListAtomicLFStressTest") - data class IntNode(val i: Int) : LockFreeLinkedListNode() + private data class Node(val i: Long) : LockFreeLinkedListNode() - private val TEST_DURATION_SEC = 5 * stressTestMultiplier + private val nSeconds = 5 * stressTestMultiplier private val nLists = 4 private val nAdderThreads = 4 @@ -32,7 +32,8 @@ class LockFreeLinkedListAtomicLFStressTest { private val undone = AtomicLong() private val missed = AtomicLong() private val removed = AtomicLong() - val error = AtomicReference() + private val error = AtomicReference() + private val index = AtomicLong() @Test fun testStress() { @@ -42,7 +43,7 @@ class LockFreeLinkedListAtomicLFStressTest { when (rnd.nextInt(4)) { 0 -> { val list = lists[rnd.nextInt(nLists)] - val node = IntNode(threadId) + val node = Node(index.incrementAndGet()) addLastOp(list, node) randomSpinWaitIntermission() tryRemoveOp(node) @@ -50,7 +51,7 @@ class LockFreeLinkedListAtomicLFStressTest { 1 -> { // just to test conditional add val list = lists[rnd.nextInt(nLists)] - val node = IntNode(threadId) + val node = Node(index.incrementAndGet()) addLastIfTrueOp(list, node) randomSpinWaitIntermission() tryRemoveOp(node) @@ -58,7 +59,7 @@ class LockFreeLinkedListAtomicLFStressTest { 2 -> { // just to test failed conditional add and burn some time val list = lists[rnd.nextInt(nLists)] - val node = IntNode(threadId) + val node = Node(index.incrementAndGet()) addLastIfFalseOp(list, node) } 3 -> { @@ -68,8 +69,8 @@ class LockFreeLinkedListAtomicLFStressTest { check(idx1 < idx2) // that is our global order val list1 = lists[idx1] val list2 = lists[idx2] - val node1 = IntNode(threadId) - val node2 = IntNode(-threadId - 1) + val node1 = Node(index.incrementAndGet()) + val node2 = Node(index.incrementAndGet()) addTwoOp(list1, node1, list2, node2) randomSpinWaitIntermission() tryRemoveOp(node1) @@ -91,13 +92,13 @@ class LockFreeLinkedListAtomicLFStressTest { removeTwoOp(list1, list2) } } - env.performTest(TEST_DURATION_SEC) { - val _undone = undone.get() - val _missed = missed.get() - val _removed = removed.get() - println(" Adders undone $_undone node additions") - println(" Adders missed $_missed nodes") - println("Remover removed $_removed nodes") + env.performTest(nSeconds) { + val undone = undone.get() + val missed = missed.get() + val removed = removed.get() + println(" Adders undone $undone node additions") + println(" Adders missed $missed nodes") + println("Remover removed $removed nodes") } error.get()?.let { throw it } assertEquals(missed.get(), removed.get()) @@ -106,19 +107,19 @@ class LockFreeLinkedListAtomicLFStressTest { lists.forEach { it.validate() } } - private fun addLastOp(list: LockFreeLinkedListHead, node: IntNode) { + private fun addLastOp(list: LockFreeLinkedListHead, node: Node) { list.addLast(node) } - private fun addLastIfTrueOp(list: LockFreeLinkedListHead, node: IntNode) { - assertTrue(list.addLastIf(node, { true })) + private fun addLastIfTrueOp(list: LockFreeLinkedListHead, node: Node) { + assertTrue(list.addLastIf(node) { true }) } - private fun addLastIfFalseOp(list: LockFreeLinkedListHead, node: IntNode) { - assertFalse(list.addLastIf(node, { false })) + private fun addLastIfFalseOp(list: LockFreeLinkedListHead, node: Node) { + assertFalse(list.addLastIf(node) { false }) } - private fun addTwoOp(list1: LockFreeLinkedListHead, node1: IntNode, list2: LockFreeLinkedListHead, node2: IntNode) { + private fun addTwoOp(list1: LockFreeLinkedListHead, node1: Node, list2: LockFreeLinkedListHead, node2: Node) { val add1 = list1.describeAddLast(node1) val add2 = list2.describeAddLast(node2) val op = object : AtomicOp() { @@ -138,7 +139,7 @@ class LockFreeLinkedListAtomicLFStressTest { assertTrue(op.perform(null) == null) } - private fun tryRemoveOp(node: IntNode) { + private fun tryRemoveOp(node: Node) { if (node.remove()) undone.incrementAndGet() else @@ -165,5 +166,4 @@ class LockFreeLinkedListAtomicLFStressTest { val success = op.perform(null) == null if (success) removed.addAndGet(2) } - } \ No newline at end of file