Skip to content

Commit 844d103

Browse files
authored
Merge pull request #77006 from swiftlang/egorzhdan/cxx-contiguous
[cxx-interop] Add `UnsafeCxxContiguousIterator` & `UnsafeCxxMutableContiguousIterator` protocols
2 parents d1a26d0 + 3a200de commit 844d103

File tree

8 files changed

+294
-63
lines changed

8 files changed

+294
-63
lines changed

include/swift/AST/KnownProtocols.def

+2
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ PROTOCOL(UnsafeCxxInputIterator)
142142
PROTOCOL(UnsafeCxxMutableInputIterator)
143143
PROTOCOL(UnsafeCxxRandomAccessIterator)
144144
PROTOCOL(UnsafeCxxMutableRandomAccessIterator)
145+
PROTOCOL(UnsafeCxxContiguousIterator)
146+
PROTOCOL(UnsafeCxxMutableContiguousIterator)
145147

146148
PROTOCOL(AsyncSequence)
147149
PROTOCOL(AsyncIteratorProtocol)

lib/AST/ASTContext.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1444,6 +1444,8 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
14441444
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
14451445
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
14461446
case KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator:
1447+
case KnownProtocolKind::UnsafeCxxContiguousIterator:
1448+
case KnownProtocolKind::UnsafeCxxMutableContiguousIterator:
14471449
M = getLoadedModule(Id_Cxx);
14481450
break;
14491451
case KnownProtocolKind::Copyable:

lib/ClangImporter/ClangDerivedConformances.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -462,19 +462,28 @@ void swift::conformToCxxIteratorIfNeeded(
462462
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
463463
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
464464
};
465+
auto isContiguousIteratorDecl = [&](const clang::CXXRecordDecl *base) {
466+
return isIteratorCategoryDecl(base, "contiguous_iterator_tag"); // C++20
467+
};
465468

466469
// Traverse all transitive bases of `underlyingDecl` to check if
467470
// it inherits from `std::input_iterator_tag`.
468471
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
469472
bool isRandomAccessIterator =
470473
isRandomAccessIteratorDecl(underlyingCategoryDecl);
474+
bool isContiguousIterator = isContiguousIteratorDecl(underlyingCategoryDecl);
471475
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
472476
if (isInputIteratorDecl(base)) {
473477
isInputIterator = true;
474478
}
475479
if (isRandomAccessIteratorDecl(base)) {
476480
isRandomAccessIterator = true;
477481
isInputIterator = true;
482+
}
483+
if (isContiguousIteratorDecl(base)) {
484+
isContiguousIterator = true;
485+
isRandomAccessIterator = true;
486+
isInputIterator = true;
478487
return false;
479488
}
480489
return true;
@@ -594,6 +603,15 @@ void swift::conformToCxxIteratorIfNeeded(
594603
else
595604
impl.addSynthesizedProtocolAttrs(
596605
decl, {KnownProtocolKind::UnsafeCxxRandomAccessIterator});
606+
607+
if (isContiguousIterator) {
608+
if (pointeeSettable)
609+
impl.addSynthesizedProtocolAttrs(
610+
decl, {KnownProtocolKind::UnsafeCxxMutableContiguousIterator});
611+
else
612+
impl.addSynthesizedProtocolAttrs(
613+
decl, {KnownProtocolKind::UnsafeCxxContiguousIterator});
614+
}
597615
}
598616

599617
void swift::conformToCxxConvertibleToBoolIfNeeded(

lib/IRGen/GenMeta.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -6967,6 +6967,8 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
69676967
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
69686968
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:
69696969
case KnownProtocolKind::UnsafeCxxMutableRandomAccessIterator:
6970+
case KnownProtocolKind::UnsafeCxxContiguousIterator:
6971+
case KnownProtocolKind::UnsafeCxxMutableContiguousIterator:
69706972
case KnownProtocolKind::Executor:
69716973
case KnownProtocolKind::SerialExecutor:
69726974
case KnownProtocolKind::TaskExecutor:

stdlib/public/Cxx/UnsafeCxxIterators.swift

+12
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,15 @@ public protocol UnsafeCxxMutableRandomAccessIterator:
8787
UnsafeCxxRandomAccessIterator, UnsafeCxxMutableInputIterator {}
8888

8989
extension UnsafeMutablePointer: UnsafeCxxMutableRandomAccessIterator {}
90+
91+
/// Bridged C++ iterator that allows traversing elements of a random access
92+
/// collection that are stored in contiguous memory segments.
93+
///
94+
/// Mostly useful for optimizing operations with containers that conform to
95+
/// `CxxRandomAccessCollection` and should not generally be used directly.
96+
///
97+
/// - SeeAlso: https://en.cppreference.com/w/cpp/named_req/ContiguousIterator
98+
public protocol UnsafeCxxContiguousIterator: UnsafeCxxRandomAccessIterator {}
99+
100+
public protocol UnsafeCxxMutableContiguousIterator:
101+
UnsafeCxxContiguousIterator, UnsafeCxxMutableRandomAccessIterator {}

test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h

+223-56
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,229 @@ struct HasTypedefIteratorTag {
286286
}
287287
};
288288

289+
struct MutableRACIterator {
290+
private:
291+
int *value;
292+
293+
public:
294+
struct iterator_category : std::random_access_iterator_tag,
295+
std::output_iterator_tag {};
296+
using value_type = int;
297+
using pointer = int *;
298+
using reference = const int &;
299+
using difference_type = int;
300+
301+
MutableRACIterator(int *value) : value(value) {}
302+
MutableRACIterator(const MutableRACIterator &other) = default;
303+
304+
const int &operator*() const { return *value; }
305+
int &operator*() { return *value; }
306+
307+
MutableRACIterator &operator++() {
308+
value++;
309+
return *this;
310+
}
311+
MutableRACIterator operator++(int) {
312+
auto tmp = MutableRACIterator(value);
313+
value++;
314+
return tmp;
315+
}
316+
317+
void operator+=(difference_type v) { value += v; }
318+
void operator-=(difference_type v) { value -= v; }
319+
MutableRACIterator operator+(difference_type v) const {
320+
return MutableRACIterator(value + v);
321+
}
322+
MutableRACIterator operator-(difference_type v) const {
323+
return MutableRACIterator(value - v);
324+
}
325+
friend MutableRACIterator operator+(difference_type v,
326+
const MutableRACIterator &it) {
327+
return it + v;
328+
}
329+
int operator-(const MutableRACIterator &other) const {
330+
return value - other.value;
331+
}
332+
333+
bool operator<(const MutableRACIterator &other) const {
334+
return value < other.value;
335+
}
336+
337+
bool operator==(const MutableRACIterator &other) const {
338+
return value == other.value;
339+
}
340+
bool operator!=(const MutableRACIterator &other) const {
341+
return value != other.value;
342+
}
343+
};
344+
345+
#if __cplusplus >= 202002L
346+
struct ConstContiguousIterator {
347+
private:
348+
const int *value;
349+
350+
public:
351+
using iterator_category = std::contiguous_iterator_tag;
352+
using value_type = int;
353+
using pointer = int *;
354+
using reference = const int &;
355+
using difference_type = int;
356+
357+
ConstContiguousIterator(const int *value) : value(value) {}
358+
ConstContiguousIterator(const ConstContiguousIterator &other) = default;
359+
360+
const int &operator*() const { return *value; }
361+
362+
ConstContiguousIterator &operator++() {
363+
value++;
364+
return *this;
365+
}
366+
ConstContiguousIterator operator++(int) {
367+
auto tmp = ConstContiguousIterator(value);
368+
value++;
369+
return tmp;
370+
}
371+
372+
void operator+=(difference_type v) { value += v; }
373+
void operator-=(difference_type v) { value -= v; }
374+
ConstContiguousIterator operator+(difference_type v) const {
375+
return ConstContiguousIterator(value + v);
376+
}
377+
ConstContiguousIterator operator-(difference_type v) const {
378+
return ConstContiguousIterator(value - v);
379+
}
380+
friend ConstContiguousIterator operator+(difference_type v,
381+
const ConstContiguousIterator &it) {
382+
return it + v;
383+
}
384+
int operator-(const ConstContiguousIterator &other) const {
385+
return value - other.value;
386+
}
387+
388+
bool operator<(const ConstContiguousIterator &other) const {
389+
return value < other.value;
390+
}
391+
392+
bool operator==(const ConstContiguousIterator &other) const {
393+
return value == other.value;
394+
}
395+
bool operator!=(const ConstContiguousIterator &other) const {
396+
return value != other.value;
397+
}
398+
};
399+
400+
struct HasCustomContiguousIteratorTag {
401+
private:
402+
const int *value;
403+
404+
public:
405+
struct CustomTag : std::contiguous_iterator_tag {};
406+
using iterator_category = CustomTag;
407+
using value_type = int;
408+
using pointer = int *;
409+
using reference = const int &;
410+
using difference_type = int;
411+
412+
HasCustomContiguousIteratorTag(const int *value) : value(value) {}
413+
HasCustomContiguousIteratorTag(const HasCustomContiguousIteratorTag &other) =
414+
default;
415+
416+
const int &operator*() const { return *value; }
417+
418+
HasCustomContiguousIteratorTag &operator++() {
419+
value++;
420+
return *this;
421+
}
422+
HasCustomContiguousIteratorTag operator++(int) {
423+
auto tmp = HasCustomContiguousIteratorTag(value);
424+
value++;
425+
return tmp;
426+
}
427+
428+
void operator+=(difference_type v) { value += v; }
429+
void operator-=(difference_type v) { value -= v; }
430+
HasCustomContiguousIteratorTag operator+(difference_type v) const {
431+
return HasCustomContiguousIteratorTag(value + v);
432+
}
433+
HasCustomContiguousIteratorTag operator-(difference_type v) const {
434+
return HasCustomContiguousIteratorTag(value - v);
435+
}
436+
friend HasCustomContiguousIteratorTag
437+
operator+(difference_type v, const HasCustomContiguousIteratorTag &it) {
438+
return it + v;
439+
}
440+
int operator-(const HasCustomContiguousIteratorTag &other) const {
441+
return value - other.value;
442+
}
443+
444+
bool operator<(const HasCustomContiguousIteratorTag &other) const {
445+
return value < other.value;
446+
}
447+
448+
bool operator==(const HasCustomContiguousIteratorTag &other) const {
449+
return value == other.value;
450+
}
451+
bool operator!=(const HasCustomContiguousIteratorTag &other) const {
452+
return value != other.value;
453+
}
454+
};
455+
456+
struct MutableContiguousIterator {
457+
private:
458+
int *value;
459+
460+
public:
461+
using iterator_category = std::contiguous_iterator_tag;
462+
using value_type = int;
463+
using pointer = int *;
464+
using reference = const int &;
465+
using difference_type = int;
466+
467+
MutableContiguousIterator(int *value) : value(value) {}
468+
MutableContiguousIterator(const MutableContiguousIterator &other) = default;
469+
470+
const int &operator*() const { return *value; }
471+
int &operator*() { return *value; }
472+
473+
MutableContiguousIterator &operator++() {
474+
value++;
475+
return *this;
476+
}
477+
MutableContiguousIterator operator++(int) {
478+
auto tmp = MutableContiguousIterator(value);
479+
value++;
480+
return tmp;
481+
}
482+
483+
void operator+=(difference_type v) { value += v; }
484+
void operator-=(difference_type v) { value -= v; }
485+
MutableContiguousIterator operator+(difference_type v) const {
486+
return MutableContiguousIterator(value + v);
487+
}
488+
MutableContiguousIterator operator-(difference_type v) const {
489+
return MutableContiguousIterator(value - v);
490+
}
491+
friend MutableContiguousIterator
492+
operator+(difference_type v, const MutableContiguousIterator &it) {
493+
return it + v;
494+
}
495+
int operator-(const MutableContiguousIterator &other) const {
496+
return value - other.value;
497+
}
498+
499+
bool operator<(const MutableContiguousIterator &other) const {
500+
return value < other.value;
501+
}
502+
503+
bool operator==(const MutableContiguousIterator &other) const {
504+
return value == other.value;
505+
}
506+
bool operator!=(const MutableContiguousIterator &other) const {
507+
return value != other.value;
508+
}
509+
};
510+
#endif
511+
289512
// MARK: Types that are not actually iterators
290513

291514
struct HasNoIteratorCategory {
@@ -916,62 +1139,6 @@ struct InputOutputConstIterator {
9161139
}
9171140
};
9181141

919-
struct MutableRACIterator {
920-
private:
921-
int *value;
922-
923-
public:
924-
struct iterator_category : std::random_access_iterator_tag,
925-
std::output_iterator_tag {};
926-
using value_type = int;
927-
using pointer = int *;
928-
using reference = const int &;
929-
using difference_type = int;
930-
931-
MutableRACIterator(int *value) : value(value) {}
932-
MutableRACIterator(const MutableRACIterator &other) = default;
933-
934-
const int &operator*() const { return *value; }
935-
int &operator*() { return *value; }
936-
937-
MutableRACIterator &operator++() {
938-
value++;
939-
return *this;
940-
}
941-
MutableRACIterator operator++(int) {
942-
auto tmp = MutableRACIterator(value);
943-
value++;
944-
return tmp;
945-
}
946-
947-
void operator+=(difference_type v) { value += v; }
948-
void operator-=(difference_type v) { value -= v; }
949-
MutableRACIterator operator+(difference_type v) const {
950-
return MutableRACIterator(value + v);
951-
}
952-
MutableRACIterator operator-(difference_type v) const {
953-
return MutableRACIterator(value - v);
954-
}
955-
friend MutableRACIterator operator+(difference_type v,
956-
const MutableRACIterator &it) {
957-
return it + v;
958-
}
959-
int operator-(const MutableRACIterator &other) const {
960-
return value - other.value;
961-
}
962-
963-
bool operator<(const MutableRACIterator &other) const {
964-
return value < other.value;
965-
}
966-
967-
bool operator==(const MutableRACIterator &other) const {
968-
return value == other.value;
969-
}
970-
bool operator!=(const MutableRACIterator &other) const {
971-
return value != other.value;
972-
}
973-
};
974-
9751142
/// clang::StmtIteratorBase
9761143
class ProtectedIteratorBase {
9771144
protected:

0 commit comments

Comments
 (0)