@@ -95,7 +95,7 @@ struct DiagnosticError: Error {
95
95
}
96
96
}
97
97
98
- enum UnsafePointerKind {
98
+ enum Mutability {
99
99
case Immutable
100
100
case Mutable
101
101
}
@@ -125,6 +125,33 @@ func replaceTypeName(_ type: TypeSyntax, _ name: TokenSyntax) -> TypeSyntax {
125
125
return TypeSyntax ( idType. with ( \. name, name) )
126
126
}
127
127
128
+ func getPointerMutability( text: String ) -> Mutability {
129
+ switch text {
130
+ case " UnsafePointer " : return . Immutable
131
+ case " UnsafeMutablePointer " : return . Mutable
132
+ case " UnsafeRawPointer " : return . Immutable
133
+ case " UnsafeMutableRawPointer " : return . Mutable
134
+ default :
135
+ throw DiagnosticError (
136
+ " expected Unsafe[Mutable][Raw]Pointer type for type \( prev) "
137
+ + " - first type token is ' \( text) ' " , node: name)
138
+ }
139
+ }
140
+
141
+ func getSafePointerName( mut: Mutability , generateSpan: Bool , isRaw: Bool ) -> TokenSyntax {
142
+ switch ( mut, generateSpan, isRaw) {
143
+ case ( . Immutable, true , true ) : return " RawSpan "
144
+ case ( . Mutable, true , true ) : return " MutableRawSpan "
145
+ case ( . Immutable, false , true ) : return " UnsafeRawBufferPointer "
146
+ case ( . Mutable, false , true ) : return " UnsafeMutableRawBufferPointer "
147
+
148
+ case ( . Immutable, true , false ) : return " Span "
149
+ case ( . Mutable, true , false ) : return " MutableSpan "
150
+ case ( . Immutable, false , false ) : return " UnsafeBufferPointer "
151
+ case ( . Mutable, false , false ) : return " UnsafeMutableBufferPointer "
152
+ }
153
+ }
154
+
128
155
func transformType( _ prev: TypeSyntax , _ variant: Variant , _ isSizedBy: Bool ) throws -> TypeSyntax {
129
156
if let optType = prev. as ( OptionalTypeSyntax . self) {
130
157
return TypeSyntax (
@@ -135,37 +162,16 @@ func transformType(_ prev: TypeSyntax, _ variant: Variant, _ isSizedBy: Bool) th
135
162
}
136
163
let name = try getTypeName ( prev)
137
164
let text = name. text
138
- let kind : UnsafePointerKind =
139
- switch text {
140
- case " UnsafePointer " : . Immutable
141
- case " UnsafeMutablePointer " : . Mutable
142
- case " UnsafeRawPointer " : . Immutable
143
- case " UnsafeMutableRawPointer " : . Mutable
144
- default :
145
- throw DiagnosticError (
146
- " expected Unsafe[Mutable][Raw]Pointer type for type \( prev) "
147
- + " - first type token is ' \( text) ' " , node: name)
148
- }
165
+ if !isSizedBy && ( text == " UnsafeRawPointer " || text == " UnsafeMutableRawPointer " ) {
166
+ throw DiagnosticError ( " raw pointers only supported for SizedBy " , node: name)
167
+ }
168
+
169
+ let kind : Mutability =
170
+ getPointerMutability ( text: text)
171
+ let token = getSafePointerName ( mut: kind, generateSpan: variant. generateSpan, isRaw: isSizedBy)
149
172
if isSizedBy {
150
- let token : TokenSyntax =
151
- switch ( kind, variant. generateSpan) {
152
- case ( . Immutable, true ) : " RawSpan "
153
- case ( . Mutable, true ) : " MutableRawSpan "
154
- case ( . Immutable, false ) : " UnsafeRawBufferPointer "
155
- case ( . Mutable, false ) : " UnsafeMutableRawBufferPointer "
156
- }
157
173
return TypeSyntax ( IdentifierTypeSyntax ( name: token) )
158
174
}
159
- if text == " UnsafeRawPointer " || text == " UnsafeMutableRawPointer " {
160
- throw DiagnosticError ( " raw pointers only supported for SizedBy " , node: name)
161
- }
162
- let token : TokenSyntax =
163
- switch ( kind, variant. generateSpan) {
164
- case ( . Immutable, true ) : " Span "
165
- case ( . Mutable, true ) : " MutableSpan "
166
- case ( . Immutable, false ) : " UnsafeBufferPointer "
167
- case ( . Mutable, false ) : " UnsafeMutableBufferPointer "
168
- }
169
175
return replaceTypeName ( prev, token)
170
176
}
171
177
@@ -183,13 +189,11 @@ protocol BoundsCheckedThunkBuilder {
183
189
184
190
func getParam( _ signature: FunctionSignatureSyntax , _ paramIndex: Int ) -> FunctionParameterSyntax {
185
191
let params = signature. parameterClause. parameters
186
- let index =
187
- if paramIndex > 0 {
188
- params. index ( params. startIndex, offsetBy: paramIndex)
189
- } else {
190
- params. startIndex
191
- }
192
- return params [ index]
192
+ if paramIndex > 0 {
193
+ return params [ params. index ( params. startIndex, offsetBy: paramIndex) ]
194
+ } else {
195
+ return params [ params. startIndex]
196
+ }
193
197
}
194
198
func getParam( _ funcDecl: FunctionDeclSyntax , _ paramIndex: Int ) -> FunctionParameterSyntax {
195
199
return getParam ( funcDecl. signature, paramIndex)
@@ -342,19 +346,17 @@ struct CountedOrSizedPointerThunkBuilder: PointerBoundsThunkBuilder {
342
346
343
347
func getCount( _ variant: Variant ) -> ExprSyntax {
344
348
let countName = isSizedBy && variant. generateSpan ? " byteCount " : " count "
345
- return if nullable {
346
- ExprSyntax ( " \( name) ?. \( raw: countName) ?? 0 " )
347
- } else {
348
- ExprSyntax ( " \( name) . \( raw: countName) " )
349
+ if nullable {
350
+ return ExprSyntax ( " \( name) ?. \( raw: countName) ?? 0 " )
349
351
}
352
+ return ExprSyntax ( " \( name) . \( raw: countName) " )
350
353
}
351
354
352
355
func getPointerArg( ) -> ExprSyntax {
353
- return if nullable {
354
- ExprSyntax ( " \( name) ?.baseAddress " )
355
- } else {
356
- ExprSyntax ( " \( name) .baseAddress! " )
356
+ if nullable {
357
+ return ExprSyntax ( " \( name) ?.baseAddress " )
357
358
}
359
+ return ExprSyntax ( " \( name) .baseAddress! " )
358
360
}
359
361
360
362
func buildFunctionCall( _ argOverrides: [ Int : ExprSyntax ] , _ variant: Variant ) throws -> ExprSyntax
@@ -541,12 +543,7 @@ public struct PointerBoundsMacro: PeerMacro {
541
543
let endParamIndexArg = try getArgumentByName ( argumentList, " end " )
542
544
let endParamIndex : Int = try getIntLiteralValue ( endParamIndexArg)
543
545
let nonescapingExprArg = getOptionalArgumentByName ( argumentList, " nonescaping " )
544
- let nonescaping =
545
- if nonescapingExprArg != nil {
546
- try getBoolLiteralValue ( nonescapingExprArg!)
547
- } else {
548
- false
549
- }
546
+ let nonescaping = nonescapingExprArg != nil && try getBoolLiteralValue ( nonescapingExprArg!)
550
547
return EndedBy (
551
548
pointerIndex: startParamIndex, endIndex: endParamIndex, nonescaping: nonescaping,
552
549
original: ExprSyntax ( enumConstructorExpr) )
@@ -618,11 +615,10 @@ public struct PointerBoundsMacro: PeerMacro {
618
615
let i = pointerArg. pointerIndex
619
616
if i < 1 || i > paramCount {
620
617
let noteMessage =
621
- if paramCount > 0 {
618
+ paramCount > 0 ?
622
619
" function \( funcDecl. name) has parameter indices 1.. \( paramCount) "
623
- } else {
620
+ :
624
621
" function \( funcDecl. name) has no parameters "
625
- }
626
622
throw DiagnosticError (
627
623
" pointer index out of bounds " , node: pointerArg. original,
628
624
notes: [
@@ -674,13 +670,12 @@ public struct PointerBoundsMacro: PeerMacro {
674
670
} )
675
671
let newSignature = try builder. buildFunctionSignature ( [ : ] , variant)
676
672
let checks =
677
- if variant. skipTrivialCount {
673
+ variant. skipTrivialCount ?
678
674
[ ] as [ CodeBlockItemSyntax ]
679
- } else {
675
+ :
680
676
try builder. buildBoundsChecks ( variant) . map { e in
681
677
CodeBlockItemSyntax ( leadingTrivia: " \n " , item: e)
682
678
}
683
- }
684
679
let call = CodeBlockItemSyntax (
685
680
item: CodeBlockItemSyntax . Item (
686
681
ReturnStmtSyntax (
0 commit comments