Skip to content

Fixed memory leak on a race between adding/removing from lock-free list #1845

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions kotlinx-coroutines-core/jvm/src/internal/LockFreeLinkedList.kt
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ public actual open class LockFreeLinkedListNode {
final override fun updatedNext(affected: Node, next: Node): Any = next.removed()

final override fun finishOnSuccess(affected: Node, next: Node) {
// Complete removal operation here. It bails out if next node is also removed and it becomes
// Complete removal operation here. It bails out if next node is also removed. It becomes
// responsibility of the next's removes to call correctPrev which would help fix all the links.
next.correctPrev(null)
}
Expand Down Expand Up @@ -531,7 +531,12 @@ public actual open class LockFreeLinkedListNode {
private fun finishAdd(next: Node) {
next._prev.loop { nextPrev ->
if (this.next !== next) return // this or next was removed or another node added, remover/adder fixes up links
if (next._prev.compareAndSet(nextPrev, this)) return
if (next._prev.compareAndSet(nextPrev, this)) {
// This newly added node could have been removed, and the above CAS would have added it physically again.
// Let us double-check for this situation and correct if needed
if (isRemoved) next.correctPrev(null)
return
}
}
}

Expand All @@ -546,15 +551,15 @@ public actual open class LockFreeLinkedListNode {
* * When this node is removed. In this case there is no need to waste time on corrections, because
* remover of this node will ultimately call [correctPrev] on the next node and that will fix all
* the links from this node, too.
* * When [op] descriptor is not `null` and and operation descriptor that is [OpDescriptor.isEarlierThan]
* * When [op] descriptor is not `null` and operation descriptor that is [OpDescriptor.isEarlierThan]
* that current [op] is found while traversing the list. This `null` result will be translated
* by callers to [RETRY_ATOMIC].
*/
private tailrec fun correctPrev(op: OpDescriptor?): Node? {
val oldPrev = _prev.value
var prev: Node = oldPrev
var last: Node? = null // will be set so that last.next === prev
while (true) { // move the the left until first non-removed node
while (true) { // move the left until first non-removed node
val prevNext: Any = prev._next.value
when {
// fast path to find quickly find prev node when everything is properly linked
Expand All @@ -565,7 +570,7 @@ public actual open class LockFreeLinkedListNode {
// Note: retry from scratch on failure to update prev
return correctPrev(op)
}
return prev // return a correct prev
return prev // return the correct prev
}
// slow path when we need to help remove operations
this.isRemoved -> return null // nothing to do, this node was removed, bail out asap to save time
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.internal

import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import java.util.concurrent.*
import kotlin.concurrent.*
import kotlin.test.*

class LockFreeLinkedListAddRemoveStressTest : TestBase() {
private class Node : LockFreeLinkedListNode()

private val nRepeat = 100_000 * stressTestMultiplier
private val list = LockFreeLinkedListHead()
private val barrier = CyclicBarrier(3)
private val done = atomic(false)
private val removed = atomic(0)

@Test
fun testStressAddRemove() {
val threads = ArrayList<Thread>()
threads += testThread("adder") {
val node = Node()
list.addLast(node)
if (node.remove()) removed.incrementAndGet()
}
threads += testThread("remover") {
val node = list.removeFirstOrNull()
if (node != null) removed.incrementAndGet()
}
try {
for (i in 1..nRepeat) {
barrier.await()
barrier.await()
assertEquals(i, removed.value)
list.validate()
}
} finally {
done.value = true
barrier.await()
threads.forEach { it.join() }
}
}

private fun testThread(name: String, op: () -> Unit) = thread(name = name) {
while (true) {
barrier.await()
if (done.value) break
op()
barrier.await()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.internal
Expand All @@ -19,9 +19,9 @@ import kotlin.test.*
class LockFreeLinkedListAtomicLFStressTest {
private val env = LockFreedomTestEnvironment("LockFreeLinkedListAtomicLFStressTest")

data class IntNode(val i: Int) : LockFreeLinkedListNode()
private data class Node(val i: Long) : LockFreeLinkedListNode()

private val TEST_DURATION_SEC = 5 * stressTestMultiplier
private val nSeconds = 5 * stressTestMultiplier

private val nLists = 4
private val nAdderThreads = 4
Expand All @@ -32,7 +32,8 @@ class LockFreeLinkedListAtomicLFStressTest {
private val undone = AtomicLong()
private val missed = AtomicLong()
private val removed = AtomicLong()
val error = AtomicReference<Throwable>()
private val error = AtomicReference<Throwable>()
private val index = AtomicLong()

@Test
fun testStress() {
Expand All @@ -42,23 +43,23 @@ class LockFreeLinkedListAtomicLFStressTest {
when (rnd.nextInt(4)) {
0 -> {
val list = lists[rnd.nextInt(nLists)]
val node = IntNode(threadId)
val node = Node(index.incrementAndGet())
addLastOp(list, node)
randomSpinWaitIntermission()
tryRemoveOp(node)
}
1 -> {
// just to test conditional add
val list = lists[rnd.nextInt(nLists)]
val node = IntNode(threadId)
val node = Node(index.incrementAndGet())
addLastIfTrueOp(list, node)
randomSpinWaitIntermission()
tryRemoveOp(node)
}
2 -> {
// just to test failed conditional add and burn some time
val list = lists[rnd.nextInt(nLists)]
val node = IntNode(threadId)
val node = Node(index.incrementAndGet())
addLastIfFalseOp(list, node)
}
3 -> {
Expand All @@ -68,8 +69,8 @@ class LockFreeLinkedListAtomicLFStressTest {
check(idx1 < idx2) // that is our global order
val list1 = lists[idx1]
val list2 = lists[idx2]
val node1 = IntNode(threadId)
val node2 = IntNode(-threadId - 1)
val node1 = Node(index.incrementAndGet())
val node2 = Node(index.incrementAndGet())
addTwoOp(list1, node1, list2, node2)
randomSpinWaitIntermission()
tryRemoveOp(node1)
Expand All @@ -91,13 +92,13 @@ class LockFreeLinkedListAtomicLFStressTest {
removeTwoOp(list1, list2)
}
}
env.performTest(TEST_DURATION_SEC) {
val _undone = undone.get()
val _missed = missed.get()
val _removed = removed.get()
println(" Adders undone $_undone node additions")
println(" Adders missed $_missed nodes")
println("Remover removed $_removed nodes")
env.performTest(nSeconds) {
val undone = undone.get()
val missed = missed.get()
val removed = removed.get()
println(" Adders undone $undone node additions")
println(" Adders missed $missed nodes")
println("Remover removed $removed nodes")
}
error.get()?.let { throw it }
assertEquals(missed.get(), removed.get())
Expand All @@ -106,19 +107,19 @@ class LockFreeLinkedListAtomicLFStressTest {
lists.forEach { it.validate() }
}

private fun addLastOp(list: LockFreeLinkedListHead, node: IntNode) {
private fun addLastOp(list: LockFreeLinkedListHead, node: Node) {
list.addLast(node)
}

private fun addLastIfTrueOp(list: LockFreeLinkedListHead, node: IntNode) {
assertTrue(list.addLastIf(node, { true }))
private fun addLastIfTrueOp(list: LockFreeLinkedListHead, node: Node) {
assertTrue(list.addLastIf(node) { true })
}

private fun addLastIfFalseOp(list: LockFreeLinkedListHead, node: IntNode) {
assertFalse(list.addLastIf(node, { false }))
private fun addLastIfFalseOp(list: LockFreeLinkedListHead, node: Node) {
assertFalse(list.addLastIf(node) { false })
}

private fun addTwoOp(list1: LockFreeLinkedListHead, node1: IntNode, list2: LockFreeLinkedListHead, node2: IntNode) {
private fun addTwoOp(list1: LockFreeLinkedListHead, node1: Node, list2: LockFreeLinkedListHead, node2: Node) {
val add1 = list1.describeAddLast(node1)
val add2 = list2.describeAddLast(node2)
val op = object : AtomicOp<Any?>() {
Expand All @@ -138,7 +139,7 @@ class LockFreeLinkedListAtomicLFStressTest {
assertTrue(op.perform(null) == null)
}

private fun tryRemoveOp(node: IntNode) {
private fun tryRemoveOp(node: Node) {
if (node.remove())
undone.incrementAndGet()
else
Expand All @@ -165,5 +166,4 @@ class LockFreeLinkedListAtomicLFStressTest {
val success = op.perform(null) == null
if (success) removed.addAndGet(2)
}

}