Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Add string.GetHashCode(ROS<char>) and related APIs #20422

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ internal static partial class Globalization
internal static extern unsafe bool EndsWith(SafeSortHandle sortHandle, string target, int cwTargetLength, string source, int cwSourceLength, CompareOptions options);

[DllImport(Libraries.GlobalizationNative, CharSet = CharSet.Unicode, EntryPoint = "GlobalizationNative_GetSortKey")]
internal static extern unsafe int GetSortKey(SafeSortHandle sortHandle, string str, int strLength, byte* sortKey, int sortKeyLength, CompareOptions options);
internal static extern unsafe int GetSortKey(SafeSortHandle sortHandle, char* str, int strLength, byte* sortKey, int sortKeyLength, CompareOptions options);

[DllImport(Libraries.GlobalizationNative, CharSet = CharSet.Unicode, EntryPoint = "GlobalizationNative_CompareStringOrdinalIgnoreCase")]
internal static extern unsafe int CompareStringOrdinalIgnoreCase(char* lpStr1, int cwStr1Len, char* lpStr2, int cwStr2Len);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -798,14 +798,17 @@ private unsafe SortKey CreateSortKey(string source, CompareOptions options)
}
else
{
int sortKeyLength = Interop.Globalization.GetSortKey(_sortHandle, source, source.Length, null, 0, options);
keyData = new byte[sortKeyLength];

fixed (byte* pSortKey = keyData)
fixed (char* pSource = source)
{
if (Interop.Globalization.GetSortKey(_sortHandle, source, source.Length, pSortKey, sortKeyLength, options) != sortKeyLength)
int sortKeyLength = Interop.Globalization.GetSortKey(_sortHandle, pSource, source.Length, null, 0, options);
keyData = new byte[sortKeyLength];

fixed (byte* pSortKey = keyData)
{
throw new ArgumentException(SR.Arg_ExternalException);
if (Interop.Globalization.GetSortKey(_sortHandle, pSource, source.Length, pSortKey, sortKeyLength, options) != sortKeyLength)
{
throw new ArgumentException(SR.Arg_ExternalException);
}
}
}
}
Expand Down Expand Up @@ -856,42 +859,43 @@ private static unsafe bool IsSortable(char *text, int length)
// ---- PAL layer ends here ----
// -----------------------------

internal unsafe int GetHashCodeOfStringCore(string source, CompareOptions options)
internal unsafe int GetHashCodeOfStringCore(ReadOnlySpan<char> source, CompareOptions options)
{
Debug.Assert(!_invariantMode);

Debug.Assert(source != null);
Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);

if (source.Length == 0)
{
return 0;
}

int sortKeyLength = Interop.Globalization.GetSortKey(_sortHandle, source, source.Length, null, 0, options);
fixed (char* pSource = source)
{
int sortKeyLength = Interop.Globalization.GetSortKey(_sortHandle, pSource, source.Length, null, 0, options);

byte[] borrowedArr = null;
Span<byte> span = sortKeyLength <= 512 ?
stackalloc byte[512] :
(borrowedArr = ArrayPool<byte>.Shared.Rent(sortKeyLength));
byte[] borrowedArr = null;
Span<byte> span = sortKeyLength <= 512 ?
stackalloc byte[512] :
(borrowedArr = ArrayPool<byte>.Shared.Rent(sortKeyLength));

fixed (byte* pSortKey = &MemoryMarshal.GetReference(span))
{
if (Interop.Globalization.GetSortKey(_sortHandle, source, source.Length, pSortKey, sortKeyLength, options) != sortKeyLength)
fixed (byte* pSortKey = &MemoryMarshal.GetReference(span))
{
throw new ArgumentException(SR.Arg_ExternalException);
if (Interop.Globalization.GetSortKey(_sortHandle, pSource, source.Length, pSortKey, sortKeyLength, options) != sortKeyLength)
{
throw new ArgumentException(SR.Arg_ExternalException);
}
}
}

int hash = Marvin.ComputeHash32(span.Slice(0, sortKeyLength), Marvin.DefaultSeed);
int hash = Marvin.ComputeHash32(span.Slice(0, sortKeyLength), Marvin.DefaultSeed);

// Return the borrowed array if necessary.
if (borrowedArr != null)
{
ArrayPool<byte>.Shared.Return(borrowedArr);
}
// Return the borrowed array if necessary.
if (borrowedArr != null)
{
ArrayPool<byte>.Shared.Return(borrowedArr);
}

return hash;
return hash;
}
}

private static CompareOptions GetOrdinalCompareOptions(CompareOptions options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,10 @@ internal static int LastIndexOfOrdinalCore(string source, string value, int star

return FindStringOrdinal(FIND_FROMEND, source, startIndex - count + 1, count, value, value.Length, ignoreCase);
}

private unsafe int GetHashCodeOfStringCore(string source, CompareOptions options)
private unsafe int GetHashCodeOfStringCore(ReadOnlySpan<char> source, CompareOptions options)
{
Debug.Assert(!_invariantMode);

Debug.Assert(source != null);
Debug.Assert((options & (CompareOptions.Ordinal | CompareOptions.OrdinalIgnoreCase)) == 0);

if (source.Length == 0)
Expand All @@ -130,14 +128,19 @@ private unsafe int GetHashCodeOfStringCore(string source, CompareOptions options
{
int sortKeyLength = Interop.Kernel32.LCMapStringEx(_sortHandle != IntPtr.Zero ? null : _sortName,
flags,
pSource, source.Length,
pSource, source.Length /* in chars */,
null, 0,
null, null, _sortHandle);
if (sortKeyLength == 0)
{
throw new ArgumentException(SR.Arg_ExternalException);
}

// Note in calls to LCMapStringEx below, the input buffer is specified in wchars (and wchar count),
// but the output buffer is specified in bytes (and byte count). This is because when generating
// sort keys, LCMapStringEx treats the output buffer as containing opaque binary data.
// See https://docs.microsoft.com/en-us/windows/desktop/api/winnls/nf-winnls-lcmapstringex.

byte[] borrowedArr = null;
Span<byte> span = sortKeyLength <= 512 ?
stackalloc byte[512] :
Expand All @@ -147,7 +150,7 @@ private unsafe int GetHashCodeOfStringCore(string source, CompareOptions options
{
if (Interop.Kernel32.LCMapStringEx(_sortHandle != IntPtr.Zero ? null : _sortName,
flags,
pSource, source.Length,
pSource, source.Length /* in chars */,
pSortKey, sortKeyLength,
null, null, _sortHandle) != sortKeyLength)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1420,43 +1420,71 @@ internal int GetHashCodeOfString(string source, CompareOptions options)
{
throw new ArgumentNullException(nameof(source));
}
if ((options & ValidHashCodeOfStringMaskOffFlags) == 0)
{
// No unsupported flags are set - continue on with the regular logic

if (_invariantMode)
{
return ((options & CompareOptions.IgnoreCase) != 0) ? source.GetHashCodeOrdinalIgnoreCase() : source.GetHashCode();
}

if ((options & ValidHashCodeOfStringMaskOffFlags) != 0)
return GetHashCodeOfStringCore(source, options);
}
else if (options == CompareOptions.Ordinal)
{
throw new ArgumentException(SR.Argument_InvalidFlag, nameof(options));
// We allow Ordinal in isolation
return source.GetHashCode();
}

if (_invariantMode)
else if (options == CompareOptions.OrdinalIgnoreCase)
{
return ((options & CompareOptions.IgnoreCase) != 0) ? source.GetHashCodeOrdinalIgnoreCase() : source.GetHashCode();
// We allow OrdinalIgnoreCase in isolation
return source.GetHashCodeOrdinalIgnoreCase();
}
else
{
// Unsupported combination of flags specified
throw new ArgumentException(SR.Argument_InvalidFlag, nameof(options));
}

return GetHashCodeOfStringCore(source, options);
}

public virtual int GetHashCode(string source, CompareOptions options)
{
if (source == null)
// virtual method delegates to non-virtual method
return GetHashCodeOfString(source, options);
}

public int GetHashCode(ReadOnlySpan<char> source, CompareOptions options)
{
//
// Parameter validation
//
if ((options & ValidHashCodeOfStringMaskOffFlags) == 0)
{
throw new ArgumentNullException(nameof(source));
}
// No unsupported flags are set - continue on with the regular logic

if (options == CompareOptions.Ordinal)
if (_invariantMode)
{
return ((options & CompareOptions.IgnoreCase) != 0) ? string.GetHashCodeOrdinalIgnoreCase(source) : string.GetHashCode(source);
}

return GetHashCodeOfStringCore(source, options);
}
else if (options == CompareOptions.Ordinal)
{
return source.GetHashCode();
// We allow Ordinal in isolation
return string.GetHashCode(source);
}

if (options == CompareOptions.OrdinalIgnoreCase)
else if (options == CompareOptions.OrdinalIgnoreCase)
{
return source.GetHashCodeOrdinalIgnoreCase();
// We allow OrdinalIgnoreCase in isolation
return string.GetHashCodeOrdinalIgnoreCase(source);
}
else
{
// Unsupported combination of flags specified
throw new ArgumentException(SR.Argument_InvalidFlag, nameof(options));
}

//
// GetHashCodeOfString does more parameters validation. basically will throw when
// having Ordinal, OrdinalIgnoreCase and StringSort
//

return GetHashCodeOfString(source, options);
}

////////////////////////////////////////////////////////////////////////
Expand Down
45 changes: 43 additions & 2 deletions src/System.Private.CoreLib/shared/System/String.Comparison.cs
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ public static bool Equals(string a, string b, StringComparison comparisonType)
public override int GetHashCode()
{
ulong seed = Marvin.DefaultSeed;
return Marvin.ComputeHash32(ref Unsafe.As<char, byte>(ref _firstChar), _stringLength * 2, (uint)seed, (uint)(seed >> 32));
return Marvin.ComputeHash32(ref Unsafe.As<char, byte>(ref _firstChar), _stringLength * 2 /* in bytes, not chars */, (uint)seed, (uint)(seed >> 32));
}

// Gets a hash code for this string and this comparison. If strings A and B and comparison C are such
Expand All @@ -759,7 +759,48 @@ public override int GetHashCode()
internal int GetHashCodeOrdinalIgnoreCase()
{
ulong seed = Marvin.DefaultSeed;
return Marvin.ComputeHash32OrdinalIgnoreCase(ref _firstChar, _stringLength, (uint)seed, (uint)(seed >> 32));
return Marvin.ComputeHash32OrdinalIgnoreCase(ref _firstChar, _stringLength /* in chars, not bytes */, (uint)seed, (uint)(seed >> 32));
}

// A span-based equivalent of String.GetHashCode(). Computes an ordinal hash code.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int GetHashCode(ReadOnlySpan<char> value)
{
ulong seed = Marvin.DefaultSeed;
return Marvin.ComputeHash32(ref Unsafe.As<char, byte>(ref MemoryMarshal.GetReference(value)), value.Length * 2 /* in bytes, not chars */, (uint)seed, (uint)(seed >> 32));
}

// A span-based equivalent of String.GetHashCode(StringComparison). Uses the specified comparison type.
public static int GetHashCode(ReadOnlySpan<char> value, StringComparison comparisonType)
{
switch (comparisonType)
{
case StringComparison.CurrentCulture:
case StringComparison.CurrentCultureIgnoreCase:
return CultureInfo.CurrentCulture.CompareInfo.GetHashCode(value, GetCaseCompareOfComparisonCulture(comparisonType));

case StringComparison.InvariantCulture:
case StringComparison.InvariantCultureIgnoreCase:
return CultureInfo.InvariantCulture.CompareInfo.GetHashCode(value, GetCaseCompareOfComparisonCulture(comparisonType));

case StringComparison.Ordinal:
return GetHashCode(value);

case StringComparison.OrdinalIgnoreCase:
return GetHashCodeOrdinalIgnoreCase(value);

default:
ThrowHelper.ThrowArgumentException(ExceptionResource.NotSupported_StringComparison, ExceptionArgument.comparisonType);
Debug.Fail("Should not reach this point.");
return default;
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static int GetHashCodeOrdinalIgnoreCase(ReadOnlySpan<char> value)
{
ulong seed = Marvin.DefaultSeed;
return Marvin.ComputeHash32OrdinalIgnoreCase(ref MemoryMarshal.GetReference(value), value.Length /* in chars, not bytes */, (uint)seed, (uint)(seed >> 32));
}

// Use this if and only if 'Denial of Service' attacks are not a concern (i.e. never used for free-form user input),
Expand Down