Skip to content

Commit 0dd50fc

Browse files
committed
Remove the rest of RawKind usages in Conversions.
1 parent ddc8505 commit 0dd50fc

File tree

2 files changed

+44
-11
lines changed

2 files changed

+44
-11
lines changed

src/Microsoft.ML.Core/Data/DataKind.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,32 @@ public static ulong ToMaxInt(this DataKind kind)
104104
return 0;
105105
}
106106

107+
/// <summary>
108+
/// For integer DataKinds, this returns the maximum legal value. For un-supported kinds,
109+
/// it returns zero.
110+
/// </summary>
111+
public static ulong ToMaxInt(this Type type)
112+
{
113+
if (type == typeof(sbyte))
114+
return (ulong)sbyte.MaxValue;
115+
else if (type == typeof(byte))
116+
return byte.MaxValue;
117+
else if (type == typeof(short))
118+
return (ulong)short.MaxValue;
119+
else if (type == typeof(ushort))
120+
return ushort.MaxValue;
121+
else if (type == typeof(int))
122+
return int.MaxValue;
123+
else if (type == typeof(uint))
124+
return uint.MaxValue;
125+
else if (type == typeof(long))
126+
return long.MaxValue;
127+
else if (type == typeof(ulong))
128+
return ulong.MaxValue;
129+
130+
return 0;
131+
}
132+
107133
/// <summary>
108134
/// For integer DataKinds, this returns the minimum legal value. For un-supported kinds,
109135
/// it returns one.

src/Microsoft.ML.Data/Data/Conversion.cs

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Collections.Generic;
99
using System.Globalization;
1010
using System.Reflection;
11+
using System.Runtime.InteropServices;
1112
using System.Text;
1213
using System.Threading;
1314
using Microsoft.ML.Internal.Utilities;
@@ -394,7 +395,7 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
394395
// Smaller dst means mapping values to NA.
395396
if (keySrc.Count != keyDst.Count)
396397
return false;
397-
if (keySrc.Count == 0 && keySrc.RawKind > keyDst.RawKind)
398+
if (keySrc.Count == 0 && Marshal.SizeOf(keySrc.RawType) > Marshal.SizeOf(keyDst.RawType))
398399
return false;
399400
// REVIEW: Should we allow contiguous to be changed when Count is zero?
400401
if (keySrc.Contiguous != keyDst.Contiguous)
@@ -407,11 +408,11 @@ public bool TryGetStandardConversion(ColumnType typeSrc, ColumnType typeDst,
407408
// does not allow this.
408409
if (!KeyType.IsValidDataType(typeDst.RawType))
409410
return false;
410-
if (keySrc.RawKind > typeDst.RawKind)
411+
if (Marshal.SizeOf(keySrc.RawType) > Marshal.SizeOf(typeDst.RawType))
411412
{
412413
if (keySrc.Count == 0)
413414
return false;
414-
if ((ulong)keySrc.Count > typeDst.RawKind.ToMaxInt())
415+
if ((ulong)keySrc.Count > typeDst.RawType.ToMaxInt())
415416
return false;
416417
}
417418
}
@@ -549,20 +550,19 @@ private TryParseMapper<TDst> GetKeyTryParse<TDst>(KeyType key)
549550
ulong min = key.Min;
550551
ulong max;
551552

552-
ulong count = DataKindExtensions.ToMaxInt(key.RawKind);
553+
ulong count = key.RawType.ToMaxInt();
553554
if (key.Count > 0)
554555
max = min - 1 + (ulong)key.Count;
555556
else if (min == 0)
556557
max = count - 1;
557-
else if (key.RawKind == DataKind.U8)
558+
else if (key.RawType == typeof(ulong))
558559
max = ulong.MaxValue;
559560
else if (min - 1 > ulong.MaxValue - count)
560561
max = ulong.MaxValue;
561562
else
562563
max = min - 1 + count;
563564

564-
bool identity;
565-
var fnConv = GetStandardConversion<U8, TDst>(NumberType.U8, NumberType.FromKind(key.RawKind), out identity);
565+
var fnConv = GetKeyStandardConversion<TDst>();
566566
return
567567
(in TX src, out TDst dst) =>
568568
{
@@ -592,20 +592,19 @@ private ValueMapper<TX, TDst> GetKeyParse<TDst>(KeyType key)
592592
ulong min = key.Min;
593593
ulong max;
594594

595-
ulong count = DataKindExtensions.ToMaxInt(key.RawKind);
595+
ulong count = key.RawType.ToMaxInt();
596596
if (key.Count > 0)
597597
max = min - 1 + (ulong)key.Count;
598598
else if (min == 0)
599599
max = count - 1;
600-
else if (key.RawKind == DataKind.U8)
600+
else if (key.RawType == typeof(U8))
601601
max = ulong.MaxValue;
602602
else if (min - 1 > ulong.MaxValue - count)
603603
max = ulong.MaxValue;
604604
else
605605
max = min - 1 + count;
606606

607-
bool identity;
608-
var fnConv = GetStandardConversion<U8, TDst>(NumberType.U8, NumberType.FromKind(key.RawKind), out identity);
607+
var fnConv = GetKeyStandardConversion<TDst>();
609608
return
610609
(in TX src, ref TDst dst) =>
611610
{
@@ -622,6 +621,14 @@ private ValueMapper<TX, TDst> GetKeyParse<TDst>(KeyType key)
622621
};
623622
}
624623

624+
private ValueMapper<U8, TDst> GetKeyStandardConversion<TDst>()
625+
{
626+
var delegatesKey = (typeof(U8), typeof(TDst));
627+
if (!_delegatesStd.TryGetValue(delegatesKey, out Delegate del))
628+
throw Contracts.Except("No standard conversion from '{0}' to '{1}'", typeof(U8), typeof(TDst));
629+
return (ValueMapper<U8, TDst>)del;
630+
}
631+
625632
private static StringBuilder ClearDst(ref StringBuilder dst)
626633
{
627634
if (dst == null)

0 commit comments

Comments
 (0)