@@ -489,7 +489,9 @@ extension Server.Configuration.CORS {
489
489
public struct AllowedOrigins : Hashable , Sendable {
490
490
enum Wrapped : Hashable , Sendable {
491
491
case all
492
+ case originBased
492
493
case only( [ String ] )
494
+ case custom( AnyCustomCORSAllowedOrigin )
493
495
}
494
496
495
497
private( set) var wrapped : Wrapped
@@ -500,10 +502,23 @@ extension Server.Configuration.CORS {
500
502
/// Allow all origin values.
501
503
public static let all = Self ( . all)
502
504
505
+ /// Allow all origin values; similar to `all` but returns the value of the origin header field
506
+ /// in the 'access-control-allow-origin' response header (rather than "*").
507
+ public static let originBased = Self ( . originBased)
508
+
503
509
/// Allow only the given origin values.
504
510
public static func only( _ allowed: [ String ] ) -> Self {
505
511
return Self ( . only( allowed) )
506
512
}
513
+
514
+ /// Provide a custom CORS origin check.
515
+ ///
516
+ /// - Parameter checkOrigin: A closure which is called with the value of the 'origin' header
517
+ /// and returns the value to use in the 'access-control-allow-origin' response header,
518
+ /// or `nil` if the origin is not allowed.
519
+ public static func custom< C: GRPCCustomCORSAllowedOrigin > ( _ custom: C ) -> Self {
520
+ return Self ( . custom( AnyCustomCORSAllowedOrigin ( custom) ) )
521
+ }
507
522
}
508
523
}
509
524
@@ -530,3 +545,46 @@ extension Comparable {
530
545
return min ( max ( self , range. lowerBound) , range. upperBound)
531
546
}
532
547
}
548
+
549
+ public protocol GRPCCustomCORSAllowedOrigin : Sendable , Hashable {
550
+ /// Returns the value to use for the 'access-control-allow-origin' response header for the given
551
+ /// value of the 'origin' request header.
552
+ ///
553
+ /// - Parameter origin: The value of the 'origin' request header field.
554
+ /// - Returns: The value to use for the 'access-control-allow-origin' header field or `nil` if no
555
+ /// CORS related headers should be returned.
556
+ func check( origin: String ) -> String ?
557
+ }
558
+
559
+ extension Server . Configuration . CORS . AllowedOrigins {
560
+ struct AnyCustomCORSAllowedOrigin : GRPCCustomCORSAllowedOrigin {
561
+ private var checkOrigin : @Sendable ( String ) -> String ?
562
+ private let hashInto : @Sendable ( inout Hasher ) -> Void
563
+ #if swift(>=5.7)
564
+ private let isEqualTo : @Sendable ( any GRPCCustomCORSAllowedOrigin ) -> Bool
565
+ #else
566
+ private let isEqualTo : @Sendable ( Any ) -> Bool
567
+ #endif
568
+
569
+ init < W: GRPCCustomCORSAllowedOrigin > ( _ wrap: W ) {
570
+ self . checkOrigin = { wrap. check ( origin: $0) }
571
+ self . hashInto = { wrap. hash ( into: & $0) }
572
+ self . isEqualTo = { wrap == ( $0 as? W ) }
573
+ }
574
+
575
+ func check( origin: String ) -> String ? {
576
+ return self . checkOrigin ( origin)
577
+ }
578
+
579
+ func hash( into hasher: inout Hasher ) {
580
+ self . hashInto ( & hasher)
581
+ }
582
+
583
+ static func == (
584
+ lhs: Server . Configuration . CORS . AllowedOrigins . AnyCustomCORSAllowedOrigin ,
585
+ rhs: Server . Configuration . CORS . AllowedOrigins . AnyCustomCORSAllowedOrigin
586
+ ) -> Bool {
587
+ return lhs. isEqualTo ( rhs)
588
+ }
589
+ }
590
+ }
0 commit comments