Skip to content

Commit ebc6c28

Browse files
Revert "[mlir] Fix race condition introduced in ThreadLocalCache (#93… (#93290)
…280)" This reverts commit 6977bfb.
1 parent 430729d commit ebc6c28

File tree

1 file changed

+25
-72
lines changed

1 file changed

+25
-72
lines changed

mlir/include/mlir/Support/ThreadLocalCache.h

+25-72
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/Support/LLVM.h"
1818
#include "llvm/ADT/DenseMap.h"
19+
#include "llvm/Support/ManagedStatic.h"
1920
#include "llvm/Support/Mutex.h"
2021

2122
namespace mlir {
@@ -24,80 +25,28 @@ namespace mlir {
2425
/// cache has very large lock contention.
2526
template <typename ValueT>
2627
class ThreadLocalCache {
27-
struct PerInstanceState;
28-
29-
/// The "observer" is owned by a thread-local cache instance. It is
30-
/// constructed the first time a `ThreadLocalCache` instance is accessed by a
31-
/// thread, unless `perInstanceState` happens to get re-allocated to the same
32-
/// address as a previous one. This class is destructed the thread in which
33-
/// the `thread_local` cache lives is destroyed.
34-
///
35-
/// This class is called the "observer" because while values cached in
36-
/// thread-local caches are owned by `PerInstanceState`, a reference is stored
37-
/// via this class in the TLC. With a double pointer, it knows when the
38-
/// referenced value has been destroyed.
39-
struct Observer {
40-
/// This is the double pointer, explicitly allocated because we need to keep
41-
/// the address stable if the TLC map re-allocates. It is owned by the
42-
/// observer and shared with the value owner.
43-
std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(nullptr);
44-
/// Because `Owner` living inside `PerInstanceState` contains a reference to
45-
/// the double pointer, and livkewise this class contains a reference to the
46-
/// value, we need to synchronize destruction of the TLC and the
47-
/// `PerInstanceState` to avoid racing. This weak pointer is acquired during
48-
/// TLC destruction if the `PerInstanceState` hasn't entered its destructor
49-
/// yet, and prevents it from happening.
50-
std::weak_ptr<PerInstanceState> keepalive;
51-
};
52-
53-
/// This struct owns the cache entries. It contains a reference back to the
54-
/// reference inside the cache so that it can be written to null to indicate
55-
/// that the cache entry is invalidated. It needs to do this because
56-
/// `perInstanceState` could get re-allocated to the same pointer and we don't
57-
/// remove entries from the TLC when it is deallocated. Thus, we have to reset
58-
/// the TLC entries to a starting state in case the `ThreadLocalCache` lives
59-
/// shorter than the threads.
60-
struct Owner {
61-
/// Save a pointer to the reference and write it to the newly created entry.
62-
Owner(Observer &observer)
63-
: value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
64-
*observer.ptr = value.get();
65-
}
66-
~Owner() {
67-
if (std::shared_ptr<ValueT *> ptr = ptrRef.lock())
68-
*ptr = nullptr;
69-
}
70-
71-
Owner(Owner &&) = default;
72-
Owner &operator=(Owner &&) = default;
73-
74-
std::unique_ptr<ValueT> value;
75-
std::weak_ptr<ValueT *> ptrRef;
76-
};
77-
7828
// Keep a separate shared_ptr protected state that can be acquired atomically
7929
// instead of using shared_ptr's for each value. This avoids a problem
8030
// where the instance shared_ptr is locked() successfully, and then the
8131
// ThreadLocalCache gets destroyed before remove() can be called successfully.
8232
struct PerInstanceState {
83-
/// Remove the given value entry. This is called when a thread local cache
84-
/// is destructing but still contains references to values owned by the
85-
/// `PerInstanceState`. Removal is required because it prevents writeback to
86-
/// a pointer that was deallocated.
33+
/// Remove the given value entry. This is generally called when a thread
34+
/// local cache is destructing.
8735
void remove(ValueT *value) {
8836
// Erase the found value directly, because it is guaranteed to be in the
8937
// list.
9038
llvm::sys::SmartScopedLock<true> threadInstanceLock(instanceMutex);
91-
auto it = llvm::find_if(instances, [&](Owner &instance) {
92-
return instance.value.get() == value;
93-
});
39+
auto it =
40+
llvm::find_if(instances, [&](std::unique_ptr<ValueT> &instance) {
41+
return instance.get() == value;
42+
});
9443
assert(it != instances.end() && "expected value to exist in cache");
9544
instances.erase(it);
9645
}
9746

9847
/// Owning pointers to all of the values that have been constructed for this
9948
/// object in the static cache.
100-
SmallVector<Owner, 1> instances;
49+
SmallVector<std::unique_ptr<ValueT>, 1> instances;
10150

10251
/// A mutex used when a new thread instance has been added to the cache for
10352
/// this object.
@@ -108,22 +57,22 @@ class ThreadLocalCache {
10857
/// instance of the non-static cache and a weak reference to an instance of
10958
/// ValueT. We use a weak reference here so that the object can be destroyed
11059
/// without needing to lock access to the cache itself.
111-
struct CacheType : public llvm::SmallDenseMap<PerInstanceState *, Observer> {
60+
struct CacheType
61+
: public llvm::SmallDenseMap<PerInstanceState *,
62+
std::pair<std::weak_ptr<ValueT>, ValueT *>> {
11263
~CacheType() {
113-
// Remove the values of this cache that haven't already expired. This is
114-
// required because if we don't remove them, they will contain a reference
115-
// back to the data here that is being destroyed.
116-
for (auto &[instance, observer] : *this)
117-
if (std::shared_ptr<PerInstanceState> state = observer.keepalive.lock())
118-
state->remove(*observer.ptr);
64+
// Remove the values of this cache that haven't already expired.
65+
for (auto &it : *this)
66+
if (std::shared_ptr<ValueT> value = it.second.first.lock())
67+
it.first->remove(value.get());
11968
}
12069

12170
/// Clear out any unused entries within the map. This method is not
12271
/// thread-safe, and should only be called by the same thread as the cache.
12372
void clearExpiredEntries() {
12473
for (auto it = this->begin(), e = this->end(); it != e;) {
12574
auto curIt = it++;
126-
if (!*curIt->second.ptr)
75+
if (curIt->second.first.expired())
12776
this->erase(curIt);
12877
}
12978
}
@@ -140,23 +89,27 @@ class ThreadLocalCache {
14089
ValueT &get() {
14190
// Check for an already existing instance for this thread.
14291
CacheType &staticCache = getStaticCache();
143-
Observer &threadInstance = staticCache[perInstanceState.get()];
144-
if (ValueT *value = *threadInstance.ptr)
92+
std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
93+
staticCache[perInstanceState.get()];
94+
if (ValueT *value = threadInstance.second)
14595
return *value;
14696

14797
// Otherwise, create a new instance for this thread.
14898
{
14999
llvm::sys::SmartScopedLock<true> threadInstanceLock(
150100
perInstanceState->instanceMutex);
151-
perInstanceState->instances.emplace_back(threadInstance);
101+
threadInstance.second =
102+
perInstanceState->instances.emplace_back(std::make_unique<ValueT>())
103+
.get();
152104
}
153-
threadInstance.keepalive = perInstanceState;
105+
threadInstance.first =
106+
std::shared_ptr<ValueT>(perInstanceState, threadInstance.second);
154107

155108
// Before returning the new instance, take the chance to clear out any used
156109
// entries in the static map. The cache is only cleared within the same
157110
// thread to remove the need to lock the cache itself.
158111
staticCache.clearExpiredEntries();
159-
return **threadInstance.ptr;
112+
return *threadInstance.second;
160113
}
161114
ValueT &operator*() { return get(); }
162115
ValueT *operator->() { return &get(); }

0 commit comments

Comments
 (0)