16
16
17
17
#include " mlir/Support/LLVM.h"
18
18
#include " llvm/ADT/DenseMap.h"
19
+ #include " llvm/Support/ManagedStatic.h"
19
20
#include " llvm/Support/Mutex.h"
20
21
21
22
namespace mlir {
@@ -24,80 +25,28 @@ namespace mlir {
24
25
// / cache has very large lock contention.
25
26
template <typename ValueT>
26
27
class ThreadLocalCache {
27
- struct PerInstanceState ;
28
-
29
- // / The "observer" is owned by a thread-local cache instance. It is
30
- // / constructed the first time a `ThreadLocalCache` instance is accessed by a
31
- // / thread, unless `perInstanceState` happens to get re-allocated to the same
32
- // / address as a previous one. This class is destructed the thread in which
33
- // / the `thread_local` cache lives is destroyed.
34
- // /
35
- // / This class is called the "observer" because while values cached in
36
- // / thread-local caches are owned by `PerInstanceState`, a reference is stored
37
- // / via this class in the TLC. With a double pointer, it knows when the
38
- // / referenced value has been destroyed.
39
- struct Observer {
40
- // / This is the double pointer, explicitly allocated because we need to keep
41
- // / the address stable if the TLC map re-allocates. It is owned by the
42
- // / observer and shared with the value owner.
43
- std::shared_ptr<ValueT *> ptr = std::make_shared<ValueT *>(nullptr );
44
- // / Because `Owner` living inside `PerInstanceState` contains a reference to
45
- // / the double pointer, and livkewise this class contains a reference to the
46
- // / value, we need to synchronize destruction of the TLC and the
47
- // / `PerInstanceState` to avoid racing. This weak pointer is acquired during
48
- // / TLC destruction if the `PerInstanceState` hasn't entered its destructor
49
- // / yet, and prevents it from happening.
50
- std::weak_ptr<PerInstanceState> keepalive;
51
- };
52
-
53
- // / This struct owns the cache entries. It contains a reference back to the
54
- // / reference inside the cache so that it can be written to null to indicate
55
- // / that the cache entry is invalidated. It needs to do this because
56
- // / `perInstanceState` could get re-allocated to the same pointer and we don't
57
- // / remove entries from the TLC when it is deallocated. Thus, we have to reset
58
- // / the TLC entries to a starting state in case the `ThreadLocalCache` lives
59
- // / shorter than the threads.
60
- struct Owner {
61
- // / Save a pointer to the reference and write it to the newly created entry.
62
- Owner (Observer &observer)
63
- : value(std::make_unique<ValueT>()), ptrRef(observer.ptr) {
64
- *observer.ptr = value.get ();
65
- }
66
- ~Owner () {
67
- if (std::shared_ptr<ValueT *> ptr = ptrRef.lock ())
68
- *ptr = nullptr ;
69
- }
70
-
71
- Owner (Owner &&) = default ;
72
- Owner &operator =(Owner &&) = default ;
73
-
74
- std::unique_ptr<ValueT> value;
75
- std::weak_ptr<ValueT *> ptrRef;
76
- };
77
-
78
28
// Keep a separate shared_ptr protected state that can be acquired atomically
79
29
// instead of using shared_ptr's for each value. This avoids a problem
80
30
// where the instance shared_ptr is locked() successfully, and then the
81
31
// ThreadLocalCache gets destroyed before remove() can be called successfully.
82
32
struct PerInstanceState {
83
- // / Remove the given value entry. This is called when a thread local cache
84
- // / is destructing but still contains references to values owned by the
85
- // / `PerInstanceState`. Removal is required because it prevents writeback to
86
- // / a pointer that was deallocated.
33
+ // / Remove the given value entry. This is generally called when a thread
34
+ // / local cache is destructing.
87
35
void remove (ValueT *value) {
88
36
// Erase the found value directly, because it is guaranteed to be in the
89
37
// list.
90
38
llvm::sys::SmartScopedLock<true > threadInstanceLock (instanceMutex);
91
- auto it = llvm::find_if (instances, [&](Owner &instance) {
92
- return instance.value .get () == value;
93
- });
39
+ auto it =
40
+ llvm::find_if (instances, [&](std::unique_ptr<ValueT> &instance) {
41
+ return instance.get () == value;
42
+ });
94
43
assert (it != instances.end () && " expected value to exist in cache" );
95
44
instances.erase (it);
96
45
}
97
46
98
47
// / Owning pointers to all of the values that have been constructed for this
99
48
// / object in the static cache.
100
- SmallVector<Owner , 1 > instances;
49
+ SmallVector<std::unique_ptr<ValueT> , 1 > instances;
101
50
102
51
// / A mutex used when a new thread instance has been added to the cache for
103
52
// / this object.
@@ -108,22 +57,22 @@ class ThreadLocalCache {
108
57
// / instance of the non-static cache and a weak reference to an instance of
109
58
// / ValueT. We use a weak reference here so that the object can be destroyed
110
59
// / without needing to lock access to the cache itself.
111
- struct CacheType : public llvm ::SmallDenseMap<PerInstanceState *, Observer> {
60
+ struct CacheType
61
+ : public llvm::SmallDenseMap<PerInstanceState *,
62
+ std::pair<std::weak_ptr<ValueT>, ValueT *>> {
112
63
~CacheType () {
113
- // Remove the values of this cache that haven't already expired. This is
114
- // required because if we don't remove them, they will contain a reference
115
- // back to the data here that is being destroyed.
116
- for (auto &[instance, observer] : *this )
117
- if (std::shared_ptr<PerInstanceState> state = observer.keepalive .lock ())
118
- state->remove (*observer.ptr );
64
+ // Remove the values of this cache that haven't already expired.
65
+ for (auto &it : *this )
66
+ if (std::shared_ptr<ValueT> value = it.second .first .lock ())
67
+ it.first ->remove (value.get ());
119
68
}
120
69
121
70
// / Clear out any unused entries within the map. This method is not
122
71
// / thread-safe, and should only be called by the same thread as the cache.
123
72
void clearExpiredEntries () {
124
73
for (auto it = this ->begin (), e = this ->end (); it != e;) {
125
74
auto curIt = it++;
126
- if (!* curIt->second .ptr )
75
+ if (curIt->second .first . expired () )
127
76
this ->erase (curIt);
128
77
}
129
78
}
@@ -140,23 +89,27 @@ class ThreadLocalCache {
140
89
ValueT &get () {
141
90
// Check for an already existing instance for this thread.
142
91
CacheType &staticCache = getStaticCache ();
143
- Observer &threadInstance = staticCache[perInstanceState.get ()];
144
- if (ValueT *value = *threadInstance.ptr )
92
+ std::pair<std::weak_ptr<ValueT>, ValueT *> &threadInstance =
93
+ staticCache[perInstanceState.get ()];
94
+ if (ValueT *value = threadInstance.second )
145
95
return *value;
146
96
147
97
// Otherwise, create a new instance for this thread.
148
98
{
149
99
llvm::sys::SmartScopedLock<true > threadInstanceLock (
150
100
perInstanceState->instanceMutex );
151
- perInstanceState->instances .emplace_back (threadInstance);
101
+ threadInstance.second =
102
+ perInstanceState->instances .emplace_back (std::make_unique<ValueT>())
103
+ .get ();
152
104
}
153
- threadInstance.keepalive = perInstanceState;
105
+ threadInstance.first =
106
+ std::shared_ptr<ValueT>(perInstanceState, threadInstance.second );
154
107
155
108
// Before returning the new instance, take the chance to clear out any used
156
109
// entries in the static map. The cache is only cleared within the same
157
110
// thread to remove the need to lock the cache itself.
158
111
staticCache.clearExpiredEntries ();
159
- return ** threadInstance.ptr ;
112
+ return *threadInstance.second ;
160
113
}
161
114
ValueT &operator *() { return get (); }
162
115
ValueT *operator ->() { return &get (); }
0 commit comments