@@ -17,54 +17,110 @@ import NIOCore
17
17
import NIOHTTP1
18
18
19
19
/// Handler that manages the CORS protocol for requests incoming from the browser.
20
- internal class WebCORSHandler {
21
- var requestMethod : HTTPMethod ?
20
+ internal final class WebCORSHandler {
21
+ let configuration : Server . Configuration . CORS
22
+
23
+ private var state : State = . idle
24
+ private enum State : Equatable {
25
+ /// Starting state.
26
+ case idle
27
+ /// CORS preflight request is in progress.
28
+ case processingPreflightRequest
29
+ /// "Real" request is in progress.
30
+ case processingRequest( origin: String ? )
31
+ }
32
+
33
+ init ( configuration: Server . Configuration . CORS ) {
34
+ self . configuration = configuration
35
+ }
22
36
}
23
37
24
38
extension WebCORSHandler : ChannelInboundHandler {
25
39
typealias InboundIn = HTTPServerRequestPart
40
+ typealias InboundOut = HTTPServerRequestPart
26
41
typealias OutboundOut = HTTPServerResponsePart
27
42
28
43
func channelRead( context: ChannelHandlerContext , data: NIOAny ) {
29
- // If the request is OPTIONS, the request is not propagated further.
30
44
switch self . unwrapInboundIn ( data) {
31
- case let . head( requestHead) :
32
- self . requestMethod = requestHead. method
33
- if self . requestMethod == . OPTIONS {
34
- var headers = HTTPHeaders ( )
35
- headers. add ( name: " Access-Control-Allow-Origin " , value: " * " )
36
- headers. add ( name: " Access-Control-Allow-Methods " , value: " POST " )
37
- headers. add (
38
- name: " Access-Control-Allow-Headers " ,
39
- value: " content-type,x-grpc-web,x-user-agent "
40
- )
41
- headers. add ( name: " Access-Control-Max-Age " , value: " 86400 " )
42
- context. write (
43
- self . wrapOutboundOut ( . head( HTTPResponseHead (
44
- version: requestHead. version,
45
- status: . ok,
46
- headers: headers
47
- ) ) ) ,
48
- promise: nil
49
- )
50
- return
45
+ case let . head( head) :
46
+ self . receivedRequestHead ( context: context, head)
47
+
48
+ case let . body( body) :
49
+ self . receivedRequestBody ( context: context, body)
50
+
51
+ case let . end( trailers) :
52
+ self . receivedRequestEnd ( context: context, trailers)
53
+ }
54
+ }
55
+
56
+ private func receivedRequestHead( context: ChannelHandlerContext , _ head: HTTPRequestHead ) {
57
+ if head. method == . OPTIONS,
58
+ head. headers. contains ( . accessControlRequestMethod) ,
59
+ let origin = head. headers. first ( name: " origin " ) {
60
+ // If the request is OPTIONS with a access-control-request-method header it's a CORS
61
+ // preflight request and is not propagated further.
62
+ self . state = . processingPreflightRequest
63
+ self . handlePreflightRequest ( context: context, head: head, origin: origin)
64
+ } else {
65
+ self . state = . processingRequest( origin: head. headers. first ( name: " origin " ) )
66
+ context. fireChannelRead ( self . wrapInboundOut ( . head( head) ) )
67
+ }
68
+ }
69
+
70
+ private func receivedRequestBody( context: ChannelHandlerContext , _ body: ByteBuffer ) {
71
+ // OPTIONS requests do not have a body, but still handle this case to be
72
+ // cautious.
73
+ if self . state == . processingPreflightRequest {
74
+ return
75
+ }
76
+
77
+ context. fireChannelRead ( self . wrapInboundOut ( . body( body) ) )
78
+ }
79
+
80
+ private func receivedRequestEnd( context: ChannelHandlerContext , _ trailers: HTTPHeaders ? ) {
81
+ if self . state == . processingPreflightRequest {
82
+ // End of OPTIONS request; reset state and finish the response.
83
+ self . state = . idle
84
+ context. writeAndFlush ( self . wrapOutboundOut ( . end( nil ) ) , promise: nil )
85
+ } else {
86
+ context. fireChannelRead ( self . wrapInboundOut ( . end( trailers) ) )
87
+ }
88
+ }
89
+
90
+ private func handlePreflightRequest(
91
+ context: ChannelHandlerContext ,
92
+ head: HTTPRequestHead ,
93
+ origin: String
94
+ ) {
95
+ let responseHead : HTTPResponseHead
96
+
97
+ if let allowedOrigin = self . configuration. allowedOrigins. header ( origin) {
98
+ var headers = HTTPHeaders ( )
99
+ headers. reserveCapacity ( 4 + self . configuration. allowedHeaders. count)
100
+ headers. add ( name: . accessControlAllowOrigin, value: allowedOrigin)
101
+ headers. add ( name: . accessControlAllowMethods, value: " POST " )
102
+
103
+ for value in self . configuration. allowedHeaders {
104
+ headers. add ( name: . accessControlAllowHeaders, value: value)
51
105
}
52
- case . body:
53
- if self . requestMethod == . OPTIONS {
54
- // OPTIONS requests do not have a body, but still handle this case to be
55
- // cautious.
56
- return
106
+
107
+ if self . configuration. allowCredentialedRequests {
108
+ headers. add ( name: . accessControlAllowCredentials, value: " true " )
57
109
}
58
110
59
- case . end :
60
- if self . requestMethod == . OPTIONS {
61
- context . writeAndFlush ( self . wrapOutboundOut ( . end ( nil ) ) , promise : nil )
62
- self . requestMethod = nil
63
- return
111
+ if self . configuration . preflightCacheExpiration > 0 {
112
+ headers . add (
113
+ name : . accessControlMaxAge ,
114
+ value : " \( self . configuration . preflightCacheExpiration ) "
115
+ )
64
116
}
117
+ responseHead = HTTPResponseHead ( version: head. version, status: . ok, headers: headers)
118
+ } else {
119
+ // Not allowed; respond with 403. This is okay in a pre-flight request.
120
+ responseHead = HTTPResponseHead ( version: head. version, status: . forbidden)
65
121
}
66
- // The OPTIONS request should be fully handled at this point.
67
- context. fireChannelRead ( data )
122
+
123
+ context. write ( self . wrapOutboundOut ( . head ( responseHead ) ) , promise : nil )
68
124
}
69
125
}
70
126
@@ -74,25 +130,76 @@ extension WebCORSHandler: ChannelOutboundHandler {
74
130
func write( context: ChannelHandlerContext , data: NIOAny , promise: EventLoopPromise < Void > ? ) {
75
131
let responsePart = self . unwrapOutboundIn ( data)
76
132
switch responsePart {
77
- case let . head( responseHead) :
78
- var headers = responseHead. headers
79
- // CORS requires all requests to have an Allow-Origin header.
80
- headers. add ( name: " Access-Control-Allow-Origin " , value: " * " )
81
- //! FIXME: Check whether we can let browsers keep connections alive. It's not possible
82
- // now as the channel has a state that can't be reused since the pipeline is modified to
83
- // inject the gRPC call handler.
84
- headers. add ( name: " Connection " , value: " close " )
85
-
86
- context. write (
87
- self . wrapOutboundOut ( . head( HTTPResponseHead (
88
- version: responseHead. version,
89
- status: responseHead. status,
90
- headers: headers
91
- ) ) ) ,
92
- promise: promise
93
- )
94
- default :
133
+ case var . head( responseHead) :
134
+ switch self . state {
135
+ case let . processingRequest( origin) :
136
+ self . prepareCORSResponseHead ( & responseHead, origin: origin)
137
+ context. write ( self . wrapOutboundOut ( . head( responseHead) ) , promise: promise)
138
+
139
+ case . idle, . processingPreflightRequest:
140
+ assertionFailure ( " Writing response head when no request is in progress " )
141
+ context. close ( promise: nil )
142
+ }
143
+
144
+ case . body:
145
+ context. write ( data, promise: promise)
146
+
147
+ case . end:
148
+ self . state = . idle
95
149
context. write ( data, promise: promise)
96
150
}
97
151
}
152
+
153
+ private func prepareCORSResponseHead( _ head: inout HTTPResponseHead , origin: String ? ) {
154
+ guard let header = origin. flatMap ( { self . configuration. allowedOrigins. header ( $0) } ) else {
155
+ // No origin or the origin is not allowed; don't treat it as a CORS request.
156
+ return
157
+ }
158
+
159
+ head. headers. replaceOrAdd ( name: . accessControlAllowOrigin, value: header)
160
+
161
+ if self . configuration. allowCredentialedRequests {
162
+ head. headers. add ( name: . accessControlAllowCredentials, value: " true " )
163
+ }
164
+
165
+ //! FIXME: Check whether we can let browsers keep connections alive. It's not possible
166
+ // now as the channel has a state that can't be reused since the pipeline is modified to
167
+ // inject the gRPC call handler.
168
+ head. headers. replaceOrAdd ( name: " Connection " , value: " close " )
169
+ }
170
+ }
171
+
172
+ extension HTTPHeaders {
173
+ fileprivate enum CORSHeader : String {
174
+ case accessControlRequestMethod = " access-control-request-method "
175
+ case accessControlRequestHeaders = " access-control-request-headers "
176
+ case accessControlAllowOrigin = " access-control-allow-origin "
177
+ case accessControlAllowMethods = " access-control-allow-methods "
178
+ case accessControlAllowHeaders = " access-control-allow-headers "
179
+ case accessControlAllowCredentials = " access-control-allow-credentials "
180
+ case accessControlMaxAge = " access-control-max-age "
181
+ }
182
+
183
+ fileprivate func contains( _ name: CORSHeader ) -> Bool {
184
+ return self . contains ( name: name. rawValue)
185
+ }
186
+
187
+ fileprivate mutating func add( name: CORSHeader , value: String ) {
188
+ self . add ( name: name. rawValue, value: value)
189
+ }
190
+
191
+ fileprivate mutating func replaceOrAdd( name: CORSHeader , value: String ) {
192
+ self . replaceOrAdd ( name: name. rawValue, value: value)
193
+ }
194
+ }
195
+
196
+ extension Server . Configuration . CORS . AllowedOrigins {
197
+ internal func header( _ origin: String ) -> String ? {
198
+ switch self . wrapped {
199
+ case . all:
200
+ return " * "
201
+ case let . only( allowed) :
202
+ return allowed. contains ( origin) ? origin : nil
203
+ }
204
+ }
98
205
}
0 commit comments