Skip to content

[cxx-interop] Implements CxxMutableSpan, created from an UnsafeMutableBufferPointer #75369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ PROTOCOL(CxxSequence)
PROTOCOL(CxxUniqueSet)
PROTOCOL(CxxVector)
PROTOCOL(CxxSpan)
PROTOCOL(CxxMutableSpan)
PROTOCOL(UnsafeCxxInputIterator)
PROTOCOL(UnsafeCxxMutableInputIterator)
PROTOCOL(UnsafeCxxRandomAccessIterator)
Expand Down
1 change: 1 addition & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
case KnownProtocolKind::CxxUniqueSet:
case KnownProtocolKind::CxxVector:
case KnownProtocolKind::CxxSpan:
case KnownProtocolKind::CxxMutableSpan:
case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
Expand Down
27 changes: 15 additions & 12 deletions lib/ClangImporter/ClangDerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1164,23 +1164,22 @@ void swift::conformToCxxSpanIfNeeded(ClangImporter::Implementation &impl,
if (!elementType || !sizeType)
return;

auto constPointerTypeDecl =
lookupNestedClangTypeDecl(clangDecl, "const_pointer");
auto pointerTypeDecl = lookupNestedClangTypeDecl(clangDecl, "pointer");
auto countTypeDecl = lookupNestedClangTypeDecl(clangDecl, "size_type");

if (!constPointerTypeDecl || !countTypeDecl)
if (!pointerTypeDecl || !countTypeDecl)
return;

// create fake variable for constPointer (constructor arg 1)
auto constPointerType = clangCtx.getTypeDeclType(constPointerTypeDecl);
auto fakeConstPointerVarDecl = clang::VarDecl::Create(
// create fake variable for pointer (constructor arg 1)
clang::QualType pointerType = clangCtx.getTypeDeclType(pointerTypeDecl);
auto fakePointerVarDecl = clang::VarDecl::Create(
clangCtx, /*DC*/ clangCtx.getTranslationUnitDecl(),
clang::SourceLocation(), clang::SourceLocation(), /*Id*/ nullptr,
constPointerType, clangCtx.getTrivialTypeSourceInfo(constPointerType),
pointerType, clangCtx.getTrivialTypeSourceInfo(pointerType),
clang::StorageClass::SC_None);

auto fakeConstPointer = new (clangCtx) clang::DeclRefExpr(
clangCtx, fakeConstPointerVarDecl, false, constPointerType,
auto fakePointer = new (clangCtx) clang::DeclRefExpr(
clangCtx, fakePointerVarDecl, false, pointerType,
clang::ExprValueKind::VK_LValue, clang::SourceLocation());

// create fake variable for count (constructor arg 2)
Expand All @@ -1197,8 +1196,7 @@ void swift::conformToCxxSpanIfNeeded(ClangImporter::Implementation &impl,

// Use clangSema.BuildCxxTypeConstructExpr to create a CXXTypeConstructExpr,
// passing constPointer and count
SmallVector<clang::Expr *, 2> constructExprArgs = {fakeConstPointer,
fakeCount};
SmallVector<clang::Expr *, 2> constructExprArgs = {fakePointer, fakeCount};

auto clangDeclTyInfo = clangCtx.getTrivialTypeSourceInfo(
clang::QualType(clangDecl->getTypeForDecl(), 0));
Expand Down Expand Up @@ -1226,5 +1224,10 @@ void swift::conformToCxxSpanIfNeeded(ClangImporter::Implementation &impl,
elementType->getUnderlyingType());
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Size"),
sizeType->getUnderlyingType());
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSpan});

if (pointerType->getPointeeType().isConstQualified()) {
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSpan});
} else {
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxMutableSpan});
}
}
1 change: 1 addition & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6878,6 +6878,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::CxxUniqueSet:
case KnownProtocolKind::CxxVector:
case KnownProtocolKind::CxxSpan:
case KnownProtocolKind::CxxMutableSpan:
case KnownProtocolKind::UnsafeCxxInputIterator:
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
Expand Down
25 changes: 25 additions & 0 deletions stdlib/public/Cxx/CxxSpan.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,29 @@ extension CxxSpan {
"UnsafeBufferPointer should not point to nil")
self.init(unsafeBufferPointer.baseAddress!, Size(unsafeBufferPointer.count))
}

@inlinable
public init(_ unsafeMutableBufferPointer: UnsafeMutableBufferPointer<Element>) {
precondition(unsafeMutableBufferPointer.baseAddress != nil,
"UnsafeMutableBufferPointer should not point to nil")
self.init(unsafeMutableBufferPointer.baseAddress!, Size(unsafeMutableBufferPointer.count))
}
}

public protocol CxxMutableSpan<Element> {
associatedtype Element
associatedtype Size: BinaryInteger

init()
init(_ unsafeMutablePointer : UnsafeMutablePointer<Element>, _ count: Size)
}

extension CxxMutableSpan {
/// Creates a C++ span from a Swift UnsafeMutableBufferPointer
@inlinable
public init(_ unsafeMutableBufferPointer: UnsafeMutableBufferPointer<Element>) {
precondition(unsafeMutableBufferPointer.baseAddress != nil,
"UnsafeMutableBufferPointer should not point to nil")
self.init(unsafeMutableBufferPointer.baseAddress!, Size(unsafeMutableBufferPointer.count))
}
}
39 changes: 33 additions & 6 deletions test/Interop/Cxx/stdlib/Inputs/std-span.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,51 @@
#include <string>
#include <span>

using Span = std::span<const int>;
using SpanOfString = std::span<const std::string>;
using ConstSpan = std::span<const int>;
using Span = std::span<int>;
using ConstSpanOfString = std::span<const std::string>;
using SpanOfString = std::span<std::string>;

static int iarray[]{1, 2, 3};
static std::string sarray[]{"", "ab", "abc"};
static ConstSpan icspan = {iarray};
static Span ispan = {iarray};
static ConstSpanOfString scspan = {sarray};
static SpanOfString sspan = {sarray};

struct SpanBox {
std::span<const int> ispan;
std::span<const std::string> sspan;
ConstSpan icspan;
Span ispan;
ConstSpanOfString scspan;
SpanOfString sspan;
};

inline Span initSpan() {
class CppApi {
public:
ConstSpan getConstSpan();
Span getSpan();
};

ConstSpan CppApi::getConstSpan() {
ConstSpan sp{new int[2], 2};
return sp;
}

Span CppApi::getSpan() {
Span sp{new int[2], 2};
return sp;
}

inline ConstSpan initConstSpan() {
const int a[]{1, 2, 3};
return ConstSpan(a);
}

inline Span initSpan() {
int a[]{1, 2, 3};
return Span(a);
}

inline struct SpanBox getStructSpanBox() { return {iarray, sarray}; }
inline struct SpanBox getStructSpanBox() { return {iarray, iarray, sarray, sarray}; }

#endif // TEST_INTEROP_CXX_STDLIB_INPUTS_STD_SPAN_H
Loading