forked from grpc/grpc-swift
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathInProcessTransport+Client.swift
357 lines (325 loc) · 13.5 KB
/
InProcessTransport+Client.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
/*
* Copyright 2024, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
public import GRPCCore
private import Synchronization
extension InProcessTransport {
/// An in-process implementation of a `ClientTransport`.
///
/// This is useful when you're interested in testing your application without any actual networking layers
/// involved, as the client and server will communicate directly with each other via in-process streams.
///
/// To use this client, you'll have to provide a `ServerTransport` upon creation, as well
/// as a `ServiceConfig`.
///
/// Once you have a client, you must keep a long-running task executing ``connect()``, which
/// will return only once all streams have been finished and ``beginGracefulShutdown()`` has been called on this client; or
/// when the containing task is cancelled.
///
/// To execute requests using this client, use ``withStream(descriptor:options:_:)``. If this function is
/// called before ``connect()`` is called, then any streams will remain pending and the call will
/// block until ``connect()`` is called or the task is cancelled.
///
/// - SeeAlso: `ClientTransport`
public final class Client: ClientTransport {
private enum State: Sendable {
struct UnconnectedState {
var serverTransport: InProcessTransport.Server
var pendingStreams: [AsyncStream<Void>.Continuation]
init(serverTransport: InProcessTransport.Server) {
self.serverTransport = serverTransport
self.pendingStreams = []
}
}
struct ConnectedState {
var serverTransport: InProcessTransport.Server
var nextStreamID: Int
var openStreams:
[Int: (
RPCStream<Inbound, Outbound>,
RPCStream<
RPCAsyncSequence<RPCRequestPart, any Error>, RPCWriter<RPCResponsePart>.Closable
>
)]
var signalEndContinuation: AsyncStream<Void>.Continuation
init(
fromUnconnected state: UnconnectedState,
signalEndContinuation: AsyncStream<Void>.Continuation
) {
self.serverTransport = state.serverTransport
self.nextStreamID = 0
self.openStreams = [:]
self.signalEndContinuation = signalEndContinuation
}
}
struct ClosedState {
var openStreams:
[Int: (
RPCStream<Inbound, Outbound>,
RPCStream<
RPCAsyncSequence<RPCRequestPart, any Error>, RPCWriter<RPCResponsePart>.Closable
>
)]
var signalEndContinuation: AsyncStream<Void>.Continuation?
init() {
self.openStreams = [:]
self.signalEndContinuation = nil
}
init(fromConnected state: ConnectedState) {
self.openStreams = state.openStreams
self.signalEndContinuation = state.signalEndContinuation
}
}
case unconnected(UnconnectedState)
case connected(ConnectedState)
case closed(ClosedState)
}
public typealias Inbound = RPCAsyncSequence<RPCResponsePart, any Error>
public typealias Outbound = RPCWriter<RPCRequestPart>.Closable
public let retryThrottle: RetryThrottle?
private let methodConfig: MethodConfigs
private let state: Mutex<State>
/// Creates a new in-process client transport.
///
/// - Parameters:
/// - server: The in-process server transport to connect to.
/// - serviceConfig: Service configuration.
package init(
server: InProcessTransport.Server,
serviceConfig: ServiceConfig = ServiceConfig()
) {
self.retryThrottle = serviceConfig.retryThrottling.map { RetryThrottle(policy: $0) }
self.methodConfig = MethodConfigs(serviceConfig: serviceConfig)
self.state = Mutex(.unconnected(.init(serverTransport: server)))
}
/// Establish and maintain a connection to the remote destination.
///
/// Maintains a long-lived connection, or set of connections, to a remote destination.
/// Connections may be added or removed over time as required by the implementation and the
/// demand for streams by the client.
///
/// Implementations of this function will typically create a long-lived task group which
/// maintains connections. The function exits when all open streams have been closed and new connections
/// are no longer required by the caller who signals this by calling ``beginGracefulShutdown()``, or by cancelling the
/// task this function runs in.
public func connect() async throws {
let (stream, continuation) = AsyncStream<Void>.makeStream()
try self.state.withLock { state in
switch state {
case .unconnected(let unconnectedState):
state = .connected(
.init(
fromUnconnected: unconnectedState,
signalEndContinuation: continuation
)
)
for pendingStream in unconnectedState.pendingStreams {
pendingStream.finish()
}
case .connected:
throw RPCError(
code: .failedPrecondition,
message: "Already connected to server."
)
case .closed:
throw RPCError(
code: .failedPrecondition,
message: "Can't connect to server, transport is closed."
)
}
}
for await _ in stream {
// This for-await loop will exit (and thus `connect()` will return)
// only when the task is cancelled, or when the stream's continuation is
// finished - whichever happens first.
// The continuation will be finished when `close()` is called and there
// are no more open streams.
}
// If at this point there are any open streams, it's because Cancellation
// occurred and all open streams must now be closed.
let openStreams = self.state.withLock { state in
switch state {
case .unconnected:
// We have transitioned to connected, and we can't transition back.
fatalError("Invalid state")
case .connected(let connectedState):
state = .closed(.init())
return connectedState.openStreams.values
case .closed(let closedState):
return closedState.openStreams.values
}
}
for (clientStream, serverStream) in openStreams {
await clientStream.outbound.finish(throwing: CancellationError())
await serverStream.outbound.finish(throwing: CancellationError())
}
}
/// Signal to the transport that no new streams may be created.
///
/// Existing streams may run to completion naturally but calling ``withStream(descriptor:options:_:)``
/// will result in an `RPCError` with code `RPCError/Code/failedPrecondition` being thrown.
///
/// If you want to forcefully cancel all active streams then cancel the task running ``connect()``.
public func beginGracefulShutdown() {
let maybeContinuation: AsyncStream<Void>.Continuation? = self.state.withLock { state in
switch state {
case .unconnected:
state = .closed(.init())
return nil
case .connected(let connectedState):
if connectedState.openStreams.count == 0 {
state = .closed(.init())
return connectedState.signalEndContinuation
} else {
state = .closed(.init(fromConnected: connectedState))
return nil
}
case .closed:
return nil
}
}
maybeContinuation?.finish()
}
/// Opens a stream using the transport, and uses it as input into a user-provided closure.
///
/// - Important: The opened stream is closed after the closure is finished.
///
/// This transport implementation throws `RPCError/Code/failedPrecondition` if the transport
/// is closing or has been closed.
///
/// This implementation will queue any streams (and thus block this call) if this function is called before
/// ``connect()``, until a connection is established - at which point all streams will be
/// created.
///
/// - Parameters:
/// - descriptor: A description of the method to open a stream for.
/// - options: Options specific to the stream.
/// - closure: A closure that takes the opened stream as parameter.
/// - Returns: Whatever value was returned from `closure`.
public func withStream<T>(
descriptor: MethodDescriptor,
options: CallOptions,
_ closure: (RPCStream<Inbound, Outbound>) async throws -> T
) async throws -> T {
let request = GRPCAsyncThrowingStream.makeStream(of: RPCRequestPart.self)
let response = GRPCAsyncThrowingStream.makeStream(of: RPCResponsePart.self)
let clientStream = RPCStream(
descriptor: descriptor,
inbound: RPCAsyncSequence(wrapping: response.stream),
outbound: RPCWriter.Closable(wrapping: request.continuation)
)
let serverStream = RPCStream(
descriptor: descriptor,
inbound: RPCAsyncSequence(wrapping: request.stream),
outbound: RPCWriter.Closable(wrapping: response.continuation)
)
let waitForConnectionStream: AsyncStream<Void>? = self.state.withLock { state in
if case .unconnected(var unconnectedState) = state {
let (stream, continuation) = AsyncStream<Void>.makeStream()
unconnectedState.pendingStreams.append(continuation)
state = .unconnected(unconnectedState)
return stream
}
return nil
}
if let waitForConnectionStream {
for await _ in waitForConnectionStream {
// This loop will exit either when the task is cancelled or when the
// client connects and this stream can be opened.
}
try Task.checkCancellation()
}
let acceptStream: Result<Int, RPCError> = self.state.withLock { state in
switch state {
case .unconnected:
// The state cannot be unconnected because if it was, then the above
// for-await loop on `pendingStream` would have not returned.
// The only other option is for the task to have been cancelled,
// and that's why we check for cancellation right after the loop.
fatalError("Invalid state.")
case .connected(var connectedState):
let streamID = connectedState.nextStreamID
do {
try connectedState.serverTransport.acceptStream(serverStream)
connectedState.openStreams[streamID] = (clientStream, serverStream)
connectedState.nextStreamID += 1
state = .connected(connectedState)
return .success(streamID)
} catch let acceptStreamError as RPCError {
return .failure(acceptStreamError)
} catch {
return .failure(RPCError(code: .unknown, message: "Unknown error: \(error)."))
}
case .closed:
let error = RPCError(
code: .failedPrecondition,
message: "The client transport is closed."
)
return .failure(error)
}
}
switch acceptStream {
case .success(let streamID):
let streamHandlingResult: Result<T, any Error>
do {
let result = try await closure(clientStream)
streamHandlingResult = .success(result)
} catch {
streamHandlingResult = .failure(error)
}
await clientStream.outbound.finish()
self.removeStream(id: streamID)
return try streamHandlingResult.get()
case .failure(let error):
await serverStream.outbound.finish(throwing: error)
await clientStream.outbound.finish(throwing: error)
throw error
}
}
private func removeStream(id streamID: Int) {
let maybeEndContinuation = self.state.withLock { state in
switch state {
case .unconnected:
// The state cannot be unconnected at this point, because if we made
// it this far, it's because the transport was connected.
// Once connected, it's impossible to transition back to unconnected,
// so this is an invalid state.
fatalError("Invalid state")
case .connected(var connectedState):
connectedState.openStreams.removeValue(forKey: streamID)
state = .connected(connectedState)
case .closed(var closedState):
closedState.openStreams.removeValue(forKey: streamID)
state = .closed(closedState)
if closedState.openStreams.isEmpty {
// This was the last open stream: signal the closure of the client.
return closedState.signalEndContinuation
}
}
return nil
}
maybeEndContinuation?.finish()
}
/// Returns the execution configuration for a given method.
///
/// - Parameter descriptor: The method to lookup configuration for.
/// - Returns: Execution configuration for the method, if it exists.
public func config(
forMethod descriptor: MethodDescriptor
) -> MethodConfig? {
self.methodConfig[descriptor]
}
}
}