Skip to content

Commit c3f09df

Browse files
gjcairoglbrntt
andauthored
Allow adding ServerInterceptors to specific services and methods (#2096)
## Motivation We want to allow users to customise the RPCs a registered interceptor should apply to on the server: - Intercept all requests - Intercept requests only meant for specific services - Intercept requests only meant for specific methods ## Modifications This PR adds a new `ServerInterceptorTarget` type that allows users to specify what the target of the interceptor should be. Existing APIs accepting `[any ServerInterceptor]` have been changed to instead take `[ServerInterceptorTarget]`. ## Result Users can have more control over to which requests interceptors are applied. --------- Co-authored-by: George Barnett <[email protected]>
1 parent f963523 commit c3f09df

File tree

9 files changed

+480
-49
lines changed

9 files changed

+480
-49
lines changed

Diff for: Sources/GRPCCore/Call/Server/RPCRouter.swift

+27-5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
/// the router has a handler for a method with ``hasHandler(forMethod:)`` or get a list of all
2323
/// methods with handlers registered by calling ``methods``. You can also remove the handler for a
2424
/// given method by calling ``removeHandler(forMethod:)``.
25+
/// You can also register any interceptors that you want applied to registered handlers via the
26+
/// ``registerInterceptors(pipeline:)`` method.
2527
///
2628
/// In most cases you won't need to interact with the router directly. Instead you should register
2729
/// your services with ``GRPCServer/init(transport:services:interceptors:)`` which will in turn
@@ -82,7 +84,8 @@ public struct RPCRouter: Sendable {
8284
}
8385

8486
@usableFromInline
85-
private(set) var handlers: [MethodDescriptor: RPCHandler]
87+
private(set) var handlers:
88+
[MethodDescriptor: (handler: RPCHandler, interceptors: [any ServerInterceptor])]
8689

8790
/// Creates a new router with no methods registered.
8891
public init() {
@@ -126,12 +129,13 @@ public struct RPCRouter: Sendable {
126129
_ context: ServerContext
127130
) async throws -> StreamingServerResponse<Output>
128131
) {
129-
self.handlers[descriptor] = RPCHandler(
132+
let handler = RPCHandler(
130133
method: descriptor,
131134
deserializer: deserializer,
132135
serializer: serializer,
133136
handler: handler
134137
)
138+
self.handlers[descriptor] = (handler, [])
135139
}
136140

137141
/// Removes any handler registered for the specified method.
@@ -142,6 +146,25 @@ public struct RPCRouter: Sendable {
142146
public mutating func removeHandler(forMethod descriptor: MethodDescriptor) -> Bool {
143147
return self.handlers.removeValue(forKey: descriptor) != nil
144148
}
149+
150+
/// Registers applicable interceptors to all currently-registered handlers.
151+
///
152+
/// - Important: Calling this method will apply the interceptors only to existing handlers. Any handlers registered via
153+
/// ``registerHandler(forMethod:deserializer:serializer:handler:)`` _after_ calling this method will not have
154+
/// any interceptors applied to them. If you want to make sure all registered methods have any applicable interceptors applied,
155+
/// only call this method _after_ you have registered all handlers.
156+
/// - Parameter pipeline: The interceptor pipeline operations to register to all currently-registered handlers. The order of the
157+
/// interceptors matters.
158+
/// - SeeAlso: ``ServerInterceptorPipelineOperation``.
159+
@inlinable
160+
public mutating func registerInterceptors(pipeline: [ServerInterceptorPipelineOperation]) {
161+
for descriptor in self.handlers.keys {
162+
let applicableOperations = pipeline.filter { $0.applies(to: descriptor) }
163+
if !applicableOperations.isEmpty {
164+
self.handlers[descriptor]?.interceptors = applicableOperations.map { $0.interceptor }
165+
}
166+
}
167+
}
145168
}
146169

147170
extension RPCRouter {
@@ -150,10 +173,9 @@ extension RPCRouter {
150173
RPCAsyncSequence<RPCRequestPart, any Error>,
151174
RPCWriter<RPCResponsePart>.Closable
152175
>,
153-
context: ServerContext,
154-
interceptors: [any ServerInterceptor]
176+
context: ServerContext
155177
) async {
156-
if let handler = self.handlers[stream.descriptor] {
178+
if let (handler, interceptors) = self.handlers[stream.descriptor] {
157179
await handler.handle(stream: stream, context: context, interceptors: interceptors)
158180
} else {
159181
// If this throws then the stream must be closed which we can't do anything about, so ignore

Diff for: Sources/GRPCCore/Call/Server/ServerInterceptor.swift

+9-8
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
/// been returned from a service. They are typically used for cross-cutting concerns like filtering
2222
/// requests, validating messages, logging additional data, and tracing.
2323
///
24-
/// Interceptors are registered with the server apply to all RPCs. If you need to modify the
25-
/// behavior of an interceptor on a per-RPC basis then you can use the
26-
/// ``ServerContext/descriptor`` to determine which RPC is being called and
27-
/// conditionalise behavior accordingly.
24+
/// Interceptors can be registered with the server either directly or via ``ServerInterceptorPipelineOperation``s.
25+
/// You may register them for all services registered with a server, for RPCs directed to specific services, or
26+
/// for RPCs directed to specific methods. If you need to modify the behavior of an interceptor on a
27+
/// per-RPC basis in more detail, then you can use the ``ServerContext/descriptor`` to determine
28+
/// which RPC is being called and conditionalise behavior accordingly.
2829
///
2930
/// ## RPC filtering
3031
///
@@ -33,19 +34,19 @@
3334
/// demonstrates this.
3435
///
3536
/// ```swift
36-
/// struct AuthServerInterceptor: Sendable {
37+
/// struct AuthServerInterceptor: ServerInterceptor {
3738
/// let isAuthorized: @Sendable (String, MethodDescriptor) async throws -> Void
3839
///
3940
/// func intercept<Input: Sendable, Output: Sendable>(
4041
/// request: StreamingServerRequest<Input>,
41-
/// context: ServerInterceptorContext,
42+
/// context: ServerContext,
4243
/// next: @Sendable (
4344
/// _ request: StreamingServerRequest<Input>,
44-
/// _ context: ServerInterceptorContext
45+
/// _ context: ServerContext
4546
/// ) async throws -> StreamingServerResponse<Output>
4647
/// ) async throws -> StreamingServerResponse<Output> {
4748
/// // Extract the auth token.
48-
/// guard let token = request.metadata["authorization"] else {
49+
/// guard let token = request.metadata[stringValues: "authorization"].first(where: { _ in true }) else {
4950
/// throw RPCError(code: .unauthenticated, message: "Not authenticated")
5051
/// }
5152
///
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright 2024, gRPC Authors All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
/// A `ServerInterceptorPipelineOperation` describes to which RPCs a server interceptor should be applied.
18+
///
19+
/// You can configure a server interceptor to be applied to:
20+
/// - all RPCs and services;
21+
/// - requests directed only to specific services registered with your server; or
22+
/// - requests directed only to specific methods (of a specific service).
23+
///
24+
/// - SeeAlso: ``ServerInterceptor`` for more information on server interceptors.
25+
public struct ServerInterceptorPipelineOperation: Sendable {
26+
/// The subject of a ``ServerInterceptorPipelineOperation``.
27+
/// The subject of an interceptor can either be all services and methods, only specific services, or only specific methods.
28+
public struct Subject: Sendable {
29+
internal enum Wrapped: Sendable {
30+
case all
31+
case services(Set<ServiceDescriptor>)
32+
case methods(Set<MethodDescriptor>)
33+
}
34+
35+
private let wrapped: Wrapped
36+
37+
/// An operation subject specifying an interceptor that applies to all RPCs across all services will be registered with this server.
38+
public static var all: Self { .init(wrapped: .all) }
39+
40+
/// An operation subject specifying an interceptor that will be applied only to RPCs directed to the specified services.
41+
/// - Parameters:
42+
/// - services: The list of service names for which this interceptor should intercept RPCs.
43+
/// - Returns: A ``ServerInterceptorPipelineOperation``.
44+
public static func services(_ services: Set<ServiceDescriptor>) -> Self {
45+
Self(wrapped: .services(services))
46+
}
47+
48+
/// An operation subject specifying an interceptor that will be applied only to RPCs directed to the specified service methods.
49+
/// - Parameters:
50+
/// - methods: The list of method descriptors for which this interceptor should intercept RPCs.
51+
/// - Returns: A ``ServerInterceptorPipelineOperation``.
52+
public static func methods(_ methods: Set<MethodDescriptor>) -> Self {
53+
Self(wrapped: .methods(methods))
54+
}
55+
56+
@usableFromInline
57+
internal func applies(to descriptor: MethodDescriptor) -> Bool {
58+
switch self.wrapped {
59+
case .all:
60+
return true
61+
62+
case .services(let services):
63+
return services.map({ $0.fullyQualifiedService }).contains(descriptor.service)
64+
65+
case .methods(let methods):
66+
return methods.contains(descriptor)
67+
}
68+
}
69+
}
70+
71+
/// The interceptor specified for this operation.
72+
public let interceptor: any ServerInterceptor
73+
74+
@usableFromInline
75+
internal let subject: Subject
76+
77+
private init(interceptor: any ServerInterceptor, appliesTo: Subject) {
78+
self.interceptor = interceptor
79+
self.subject = appliesTo
80+
}
81+
82+
/// Create an operation, specifying which ``ServerInterceptor`` to apply and to which ``Subject``.
83+
/// - Parameters:
84+
/// - interceptor: The ``ServerInterceptor`` to register with the server.
85+
/// - subject: The ``Subject`` to which the `interceptor` applies.
86+
/// - Returns: A ``ServerInterceptorPipelineOperation``.
87+
public static func apply(_ interceptor: any ServerInterceptor, to subject: Subject) -> Self {
88+
Self(interceptor: interceptor, appliesTo: subject)
89+
}
90+
91+
/// Returns whether this ``ServerInterceptorPipelineOperation`` applies to the given `descriptor`.
92+
/// - Parameter descriptor: A ``MethodDescriptor`` for which to test whether this interceptor applies.
93+
/// - Returns: `true` if this interceptor applies to the given `descriptor`, or `false` otherwise.
94+
@inlinable
95+
internal func applies(to descriptor: MethodDescriptor) -> Bool {
96+
self.subject.applies(to: descriptor)
97+
}
98+
}

Diff for: Sources/GRPCCore/GRPCServer.swift

+26-21
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,6 @@ public final class GRPCServer: Sendable {
7878
/// The services registered which the server is serving.
7979
private let router: RPCRouter
8080

81-
/// A collection of ``ServerInterceptor`` implementations which are applied to all accepted
82-
/// RPCs.
83-
///
84-
/// RPCs are intercepted in the order that interceptors are added. That is, a request received
85-
/// from the client will first be intercepted by the first added interceptor followed by the
86-
/// second, and so on.
87-
private let interceptors: [any ServerInterceptor]
88-
8981
/// The state of the server.
9082
private let state: Mutex<State>
9183

@@ -154,33 +146,46 @@ public final class GRPCServer: Sendable {
154146
services: [any RegistrableRPCService],
155147
interceptors: [any ServerInterceptor] = []
156148
) {
157-
var router = RPCRouter()
158-
for service in services {
159-
service.registerMethods(with: &router)
160-
}
161-
162-
self.init(transport: transport, router: router, interceptors: interceptors)
149+
self.init(
150+
transport: transport,
151+
services: services,
152+
interceptorPipeline: interceptors.map { .apply($0, to: .all) }
153+
)
163154
}
164155

165156
/// Creates a new server with no resources.
166157
///
167158
/// - Parameters:
168159
/// - transport: The transport the server should listen on.
169-
/// - router: A ``RPCRouter`` used by the server to route accepted streams to method handlers.
170-
/// - interceptors: A collection of interceptors providing cross-cutting functionality to each
160+
/// - services: Services offered by the server.
161+
/// - interceptorPipeline: A collection of interceptors providing cross-cutting functionality to each
171162
/// accepted RPC. The order in which interceptors are added reflects the order in which they
172163
/// are called. The first interceptor added will be the first interceptor to intercept each
173164
/// request. The last interceptor added will be the final interceptor to intercept each
174165
/// request before calling the appropriate handler.
175-
public init(
166+
public convenience init(
176167
transport: any ServerTransport,
177-
router: RPCRouter,
178-
interceptors: [any ServerInterceptor] = []
168+
services: [any RegistrableRPCService],
169+
interceptorPipeline: [ServerInterceptorPipelineOperation]
179170
) {
171+
var router = RPCRouter()
172+
for service in services {
173+
service.registerMethods(with: &router)
174+
}
175+
router.registerInterceptors(pipeline: interceptorPipeline)
176+
177+
self.init(transport: transport, router: router)
178+
}
179+
180+
/// Creates a new server with no resources.
181+
///
182+
/// - Parameters:
183+
/// - transport: The transport the server should listen on.
184+
/// - router: A ``RPCRouter`` used by the server to route accepted streams to method handlers.
185+
public init(transport: any ServerTransport, router: RPCRouter) {
180186
self.state = Mutex(.notStarted)
181187
self.transport = transport
182188
self.router = router
183-
self.interceptors = interceptors
184189
}
185190

186191
/// Starts the server and runs until the registered transport has closed.
@@ -206,7 +211,7 @@ public final class GRPCServer: Sendable {
206211

207212
do {
208213
try await transport.listen { stream, context in
209-
await self.router.handle(stream: stream, context: context, interceptors: self.interceptors)
214+
await self.router.handle(stream: stream, context: context)
210215
}
211216
} catch {
212217
throw RuntimeError(

Diff for: Tests/GRPCCoreTests/Call/Server/Internal/ServerRPCExecutorTests.swift

+3-1
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,9 @@ final class ServerRPCExecutorTests: XCTestCase {
333333

334334
func testThrowingInterceptor() async throws {
335335
let harness = ServerRPCExecutorTestHarness(
336-
interceptors: [.throwError(RPCError(code: .unavailable, message: "Unavailable"))]
336+
interceptors: [
337+
.throwError(RPCError(code: .unavailable, message: "Unavailable"))
338+
]
337339
)
338340

339341
try await harness.execute(handler: .echo) { inbound in

0 commit comments

Comments
 (0)