Skip to content

Commit 6755cfe

Browse files
authored
tsan, xds: fix XdsClientWrapperForServerSds data races (#8107)
1 parent 8468b5c commit 6755cfe

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import java.util.concurrent.ScheduledExecutorService;
5555
import java.util.concurrent.ThreadFactory;
5656
import java.util.concurrent.TimeUnit;
57+
import java.util.concurrent.atomic.AtomicReference;
5758
import java.util.logging.Level;
5859
import java.util.logging.Logger;
5960
import javax.annotation.Nullable;
@@ -70,7 +71,7 @@ public final class XdsClientWrapperForServerSds {
7071
private static final TimeServiceResource timeServiceResource =
7172
new TimeServiceResource("GrpcServerXdsClient");
7273

73-
private EnvoyServerProtoData.Listener curListener;
74+
private AtomicReference<EnvoyServerProtoData.Listener> curListener = new AtomicReference<>();
7475
@SuppressWarnings("unused")
7576
@Nullable private XdsClient xdsClient;
7677
private final int port;
@@ -137,14 +138,14 @@ public void start(XdsClient xdsClient, String grpcServerResourceId) {
137138
new XdsClient.LdsResourceWatcher() {
138139
@Override
139140
public void onChanged(XdsClient.LdsUpdate update) {
140-
curListener = update.listener;
141+
curListener.set(update.listener);
141142
reportSuccess();
142143
}
143144

144145
@Override
145146
public void onResourceDoesNotExist(String resourceName) {
146147
logger.log(Level.WARNING, "Resource {0} is unavailable", resourceName);
147-
curListener = null;
148+
curListener.set(null);
148149
reportError(Status.NOT_FOUND.asException(), true);
149150
}
150151

@@ -180,7 +181,8 @@ private static boolean isResourceAbsent(Status status) {
180181
*/
181182
@Nullable
182183
public DownstreamTlsContext getDownstreamTlsContext(Channel channel) {
183-
if (curListener != null && channel != null) {
184+
EnvoyServerProtoData.Listener copyListener = curListener.get();
185+
if (copyListener != null && channel != null) {
184186
SocketAddress localAddress = channel.localAddress();
185187
SocketAddress remoteAddress = channel.remoteAddress();
186188
if (localAddress instanceof InetSocketAddress && remoteAddress instanceof InetSocketAddress) {
@@ -189,7 +191,7 @@ public DownstreamTlsContext getDownstreamTlsContext(Channel channel) {
189191
checkState(
190192
port == localInetAddr.getPort(),
191193
"Channel localAddress port does not match requested listener port");
192-
return getDownstreamTlsContext(localInetAddr, remoteInetAddr);
194+
return getDownstreamTlsContext(localInetAddr, remoteInetAddr, copyListener);
193195
}
194196
}
195197
return null;
@@ -204,9 +206,10 @@ public DownstreamTlsContext getDownstreamTlsContext(Channel channel) {
204206
* @param localInetAddr dest address of the inbound connection
205207
* @param remoteInetAddr source address of the inbound connection
206208
*/
207-
private DownstreamTlsContext getDownstreamTlsContext(
208-
InetSocketAddress localInetAddr, InetSocketAddress remoteInetAddr) {
209-
List<FilterChain> filterChains = curListener.getFilterChains();
209+
private static DownstreamTlsContext getDownstreamTlsContext(
210+
InetSocketAddress localInetAddr, InetSocketAddress remoteInetAddr,
211+
EnvoyServerProtoData.Listener listener) {
212+
List<FilterChain> filterChains = listener.getFilterChains();
210213

211214
filterChains = filterOnDestinationPort(filterChains);
212215
filterChains = filterOnIpAddress(filterChains, localInetAddr.getAddress(), true);
@@ -221,7 +224,7 @@ private DownstreamTlsContext getDownstreamTlsContext(
221224
} else if (filterChains.size() == 1) {
222225
return filterChains.get(0).getDownstreamTlsContext();
223226
}
224-
return curListener.getDefaultFilterChain().getDownstreamTlsContext();
227+
return listener.getDefaultFilterChain().getDownstreamTlsContext();
225228
}
226229

227230
// destination_port present => Always fail match
@@ -255,7 +258,7 @@ private static List<FilterChain> filterOnSourcePort(
255258
return filteredOnMatch.isEmpty() ? filteredOnEmpty : filteredOnMatch;
256259
}
257260

258-
private List<FilterChain> filterOnSourceType(
261+
private static List<FilterChain> filterOnSourceType(
259262
List<FilterChain> filterChains, InetAddress sourceAddress, InetAddress destAddress) {
260263
ArrayList<FilterChain> filtered = new ArrayList<>(filterChains.size());
261264
for (FilterChain filterChain : filterChains) {
@@ -350,7 +353,7 @@ public int hashCode() {
350353
}
351354

352355
// use prefix_ranges (CIDR) and get the most specific matches
353-
private List<FilterChain> filterOnIpAddress(
356+
private static List<FilterChain> filterOnIpAddress(
354357
List<FilterChain> filterChains, InetAddress address, boolean forDestination) {
355358
PriorityQueue<QueueElement> heap = new PriorityQueue<>(10, new QueueElementComparator());
356359

@@ -384,7 +387,8 @@ public void addServerWatcher(ServerWatcher serverWatcher) {
384387
synchronized (serverWatchers) {
385388
serverWatchers.add(serverWatcher);
386389
}
387-
if (curListener != null) {
390+
EnvoyServerProtoData.Listener copyListener = curListener.get();
391+
if (copyListener != null) {
388392
serverWatcher.onListenerUpdate();
389393
}
390394
}

0 commit comments

Comments
 (0)