diff --git a/src/Microsoft.ML.Api/ApiUtils.cs b/src/Microsoft.ML.Api/ApiUtils.cs index 96e821f16e..6061d7474c 100644 --- a/src/Microsoft.ML.Api/ApiUtils.cs +++ b/src/Microsoft.ML.Api/ApiUtils.cs @@ -19,8 +19,7 @@ private static OpCode GetAssignmentOpCode(Type t) { // REVIEW: This should be a Dictionary based solution. // DvTypes, strings, arrays, all nullable types, VBuffers and UInt128. - if (t == typeof(DvInt8) || t == typeof(DvInt4) || t == typeof(DvInt2) || t == typeof(DvInt1) || - t == typeof(DvBool) || t == typeof(DvText) || t == typeof(string) || t.IsArray || + if (t == typeof(DvText) || t == typeof(DvBool) || t == typeof(string) || t.IsArray || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) || (t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) || t == typeof(DvDateTime) || t == typeof(DvDateTimeZone) || t == typeof(DvTimeSpan) || t == typeof(UInt128)) diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index 6962080a7e..d493dcc3e9 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -131,46 +131,6 @@ private Delegate CreateGetter(int index) Ch.Assert(colType.ItemType.IsText); return CreateConvertingArrayGetterDelegate(index, x => x == null ? DvText.NA : new DvText(x)); } - else if (outputType.GetElementType() == typeof(int)) - { - Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingArrayGetterDelegate(index, x => x); - } - else if (outputType.GetElementType() == typeof(int?)) - { - Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt4.NA); - } - else if (outputType.GetElementType() == typeof(long)) - { - Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingArrayGetterDelegate(index, x => x); - } - else if (outputType.GetElementType() == typeof(long?)) - { - Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt8.NA); - } - else if (outputType.GetElementType() == typeof(short)) - { - Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingArrayGetterDelegate(index, x => x); - } - else if (outputType.GetElementType() == typeof(short?)) - { - Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt2.NA); - } - else if (outputType.GetElementType() == typeof(sbyte)) - { - Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingArrayGetterDelegate(index, x => x); - } - else if (outputType.GetElementType() == typeof(sbyte?)) - { - Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingArrayGetterDelegate(index, x => x ?? DvInt1.NA); - } else if (outputType.GetElementType() == typeof(bool)) { Ch.Assert(colType.ItemType.IsBool); @@ -220,54 +180,6 @@ private Delegate CreateGetter(int index) Ch.Assert(colType.IsBool); return CreateConvertingGetterDelegate(index, x => x ?? DvBool.NA); } - else if (outputType == typeof(int)) - { - // int -> DvInt4 - Ch.Assert(colType == NumberType.I4); - return CreateConvertingGetterDelegate(index, x => x); - } - else if (outputType == typeof(int?)) - { - // int? -> DvInt4 - Ch.Assert(colType == NumberType.I4); - return CreateConvertingGetterDelegate(index, x => x ?? DvInt4.NA); - } - else if (outputType == typeof(short)) - { - // short -> DvInt2 - Ch.Assert(colType == NumberType.I2); - return CreateConvertingGetterDelegate(index, x => x); - } - else if (outputType == typeof(short?)) - { - // short? -> DvInt2 - Ch.Assert(colType == NumberType.I2); - return CreateConvertingGetterDelegate(index, x => x ?? DvInt2.NA); - } - else if (outputType == typeof(long)) - { - // long -> DvInt8 - Ch.Assert(colType == NumberType.I8); - return CreateConvertingGetterDelegate(index, x => x); - } - else if (outputType == typeof(long?)) - { - // long? -> DvInt8 - Ch.Assert(colType == NumberType.I8); - return CreateConvertingGetterDelegate(index, x => x ?? DvInt8.NA); - } - else if (outputType == typeof(sbyte)) - { - // sbyte -> DvInt1 - Ch.Assert(colType == NumberType.I1); - return CreateConvertingGetterDelegate(index, x => x); - } - else if (outputType == typeof(sbyte?)) - { - // sbyte? -> DvInt1 - Ch.Assert(colType == NumberType.I1); - return CreateConvertingGetterDelegate(index, x => x ?? DvInt1.NA); - } // T -> T if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable<>)) Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType)); diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index eda5b6656b..5aacaa76a5 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -292,46 +292,6 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit Ch.Assert(colType.ItemType.IsBool); return CreateConvertingVBufferSetter(input, index, poke, peek, x => (bool?)x); } - else if (fieldType.GetElementType() == typeof(int)) - { - Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int)x); - } - else if (fieldType.GetElementType() == typeof(int?)) - { - Ch.Assert(colType.ItemType == NumberType.I4); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (int?)x); - } - else if (fieldType.GetElementType() == typeof(short)) - { - Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (short)x); - } - else if (fieldType.GetElementType() == typeof(short?)) - { - Ch.Assert(colType.ItemType == NumberType.I2); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (short?)x); - } - else if (fieldType.GetElementType() == typeof(long)) - { - Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (long)x); - } - else if (fieldType.GetElementType() == typeof(long?)) - { - Ch.Assert(colType.ItemType == NumberType.I8); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (long?)x); - } - else if (fieldType.GetElementType() == typeof(sbyte)) - { - Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (sbyte)x); - } - else if (fieldType.GetElementType() == typeof(sbyte?)) - { - Ch.Assert(colType.ItemType == NumberType.I1); - return CreateConvertingVBufferSetter(input, index, poke, peek, x => (sbyte?)x); - } // VBuffer -> T[] if (fieldType.GetElementType().IsGenericType && fieldType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable<>)) @@ -372,54 +332,7 @@ private Action GenerateSetter(IRow input, int index, InternalSchemaDefinit Ch.Assert(peek == null); return CreateConvertingActionSetter(input, index, poke, x => (bool?)x); } - else if (fieldType == typeof(int)) - { - Ch.Assert(colType == NumberType.I4); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (int)x); - } - else if (fieldType == typeof(int?)) - { - Ch.Assert(colType == NumberType.I4); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (int?)x); - } - else if (fieldType == typeof(short)) - { - Ch.Assert(colType == NumberType.I2); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (short)x); - } - else if (fieldType == typeof(short?)) - { - Ch.Assert(colType == NumberType.I2); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (short?)x); - } - else if (fieldType == typeof(long)) - { - Ch.Assert(colType == NumberType.I8); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (long)x); - } - else if (fieldType == typeof(long?)) - { - Ch.Assert(colType == NumberType.I8); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (long?)x); - } - else if (fieldType == typeof(sbyte)) - { - Ch.Assert(colType == NumberType.I1); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (sbyte)x); - } - else if (fieldType == typeof(sbyte?)) - { - Ch.Assert(colType == NumberType.I1); - Ch.Assert(peek == null); - return CreateConvertingActionSetter(input, index, poke, x => (sbyte?)x); - } + // T -> T if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Nullable<>)) Ch.Assert(colType.RawType == Nullable.GetUnderlyingType(fieldType)); diff --git a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs index 91ec2e638e..40ee1c3527 100644 --- a/src/Microsoft.ML.Core/CommandLine/CmdParser.cs +++ b/src/Microsoft.ML.Core/CommandLine/CmdParser.cs @@ -517,23 +517,23 @@ public static int GetConsoleWindowWidth() private struct Coord { - internal Int16 X; - internal Int16 Y; + internal short X; + internal short Y; } private struct SmallRect { - internal Int16 Left; - internal Int16 Top; - internal Int16 Right; - internal Int16 Bottom; + internal short Left; + internal short Top; + internal short Right; + internal short Bottom; } private struct ConsoleScreenBufferInfo { internal Coord DwSize; internal Coord DwCursorPosition; - internal Int16 WAttributes; + internal short WAttributes; internal SmallRect SrWindow; internal Coord DwMaximumWindowSize; } diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index 0249745691..841f96e0a5 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -55,7 +55,7 @@ public enum DataKind : byte public static class DataKindExtensions { public const DataKind KindMin = DataKind.I1; - public const DataKind KindLim = DataKind.UG + 1; + public const DataKind KindLim = DataKind.U16 + 1; public const int KindCount = KindLim - KindMin; /// @@ -141,19 +141,19 @@ public static Type ToType(this DataKind kind) switch (kind) { case DataKind.I1: - return typeof(DvInt1); + return typeof(sbyte); case DataKind.U1: return typeof(byte); case DataKind.I2: - return typeof(DvInt2); + return typeof(short); case DataKind.U2: return typeof(ushort); case DataKind.I4: - return typeof(DvInt4); + return typeof(int); case DataKind.U4: return typeof(uint); case DataKind.I8: - return typeof(DvInt8); + return typeof(long); case DataKind.U8: return typeof(ulong); case DataKind.R4: @@ -185,29 +185,29 @@ public static bool TryGetDataKind(this Type type, out DataKind kind) Contracts.CheckValueOrNull(type); // REVIEW: Make this more efficient. Should we have a global dictionary? - if (type == typeof(DvInt1) || type == typeof(sbyte) || type == typeof(sbyte?)) + if (type == typeof(sbyte)) kind = DataKind.I1; - else if (type == typeof(byte) || type == typeof(byte?)) + else if (type == typeof(byte)) kind = DataKind.U1; - else if (type == typeof(DvInt2)|| type== typeof(short) || type == typeof(short?)) + else if (type == typeof(short)) kind = DataKind.I2; - else if (type == typeof(ushort)|| type == typeof(ushort?)) + else if (type == typeof(ushort)) kind = DataKind.U2; - else if (type == typeof(DvInt4) || type == typeof(int)|| type == typeof(int?)) + else if (type == typeof(int)) kind = DataKind.I4; - else if (type == typeof(uint)|| type == typeof(uint?)) + else if (type == typeof(uint)) kind = DataKind.U4; - else if (type == typeof(DvInt8) || type==typeof(long)|| type == typeof(long?)) + else if (type == typeof(long)) kind = DataKind.I8; - else if (type == typeof(ulong)|| type == typeof(ulong?)) + else if (type == typeof(ulong)) kind = DataKind.U8; - else if (type == typeof(Single)|| type == typeof(Single?)) + else if (type == typeof(Single)) kind = DataKind.R4; - else if (type == typeof(Double)|| type == typeof(Double?)) + else if (type == typeof(Double)) kind = DataKind.R8; else if (type == typeof(DvText)) kind = DataKind.TX; - else if (type == typeof(DvBool) || type == typeof(bool) || type == typeof(bool?)) + else if (type == typeof(DvBool) || type == typeof(bool)) kind = DataKind.BL; else if (type == typeof(DvTimeSpan)) kind = DataKind.TS; diff --git a/src/Microsoft.ML.Core/Data/DateTime.cs b/src/Microsoft.ML.Core/Data/DateTime.cs index d11be2a494..1e90a80f81 100644 --- a/src/Microsoft.ML.Core/Data/DateTime.cs +++ b/src/Microsoft.ML.Core/Data/DateTime.cs @@ -18,7 +18,7 @@ namespace Microsoft.ML.Runtime.Data public struct DvDateTime : IEquatable, IComparable { public const long MaxTicks = 3155378975999999999; - private readonly DvInt8 _ticks; + private readonly long _ticks; /// /// This ctor initializes _ticks to the value of sdt.Ticks, and ignores its DateTimeKind value. @@ -32,10 +32,10 @@ public DvDateTime(SysDateTime sdt) /// /// This ctor accepts any value for ticks, but produces an NA if ticks is out of the legal range. /// - public DvDateTime(DvInt8 ticks) + public DvDateTime(long ticks) { - if ((ulong)ticks.RawValue > MaxTicks) - _ticks = DvInt8.NA; + if ((ulong)ticks > MaxTicks) + _ticks = long.MinValue; else _ticks = ticks; AssertValid(); @@ -44,10 +44,10 @@ public DvDateTime(DvInt8 ticks) [Conditional("DEBUG")] internal void AssertValid() { - Contracts.Assert((ulong)_ticks.RawValue <= MaxTicks || _ticks.IsNA); + Contracts.Assert((ulong)_ticks <= MaxTicks || _ticks == long.MinValue); } - public DvInt8 Ticks + public long Ticks { get { @@ -81,14 +81,11 @@ public bool IsNA get { AssertValid(); - return (ulong)_ticks.RawValue > MaxTicks; + return (ulong)_ticks > MaxTicks; } } - public static DvDateTime NA - { - get { return new DvDateTime(DvInt8.NA); } - } + public static DvDateTime NA => new DvDateTime(long.MinValue); public static explicit operator SysDateTime?(DvDateTime dvDt) { @@ -124,12 +121,12 @@ internal SysDateTime GetSysDateTime() { AssertValid(); Contracts.Assert(!IsNA); - return new SysDateTime(_ticks.RawValue); + return new SysDateTime(_ticks); } public bool Equals(DvDateTime other) { - return _ticks.RawValue == other._ticks.RawValue; + return _ticks == other._ticks; } public override bool Equals(object obj) @@ -139,9 +136,9 @@ public override bool Equals(object obj) public int CompareTo(DvDateTime other) { - if (_ticks.RawValue == other._ticks.RawValue) + if (_ticks == other._ticks) return 0; - return _ticks.RawValue < other._ticks.RawValue ? -1 : 1; + return _ticks < other._ticks ? -1 : 1; } public override int GetHashCode() @@ -162,11 +159,11 @@ public struct DvDateTimeZone : IEquatable, IComparable /// The number of clock ticks in the date time portion /// The time zone offset in minutes - public DvDateTimeZone(DvInt8 ticks, DvInt2 offset) + public DvDateTimeZone(long ticks, short offset) { var dt = new DvDateTime(ticks); - if (dt.IsNA || offset.IsNA || MinMinutesOffset > offset.RawValue || offset.RawValue > MaxMinutesOffset) + if (dt.IsNA || offset == short.MinValue || MinMinutesOffset > offset || offset > MaxMinutesOffset) { _dateTime = DvDateTime.NA; - _offset = DvInt2.NA; + _offset = short.MinValue; } else { @@ -203,7 +200,7 @@ public DvDateTimeZone(SysDateTimeOffset dto) Contracts.Assert(success); _dateTime = ValidateDate(new DvDateTime(dto.DateTime), ref _offset); Contracts.Assert(!_dateTime.IsNA); - Contracts.Assert(!_offset.IsNA); + Contracts.Assert(_offset != short.MinValue); AssertValid(); } @@ -217,7 +214,7 @@ public DvDateTimeZone(DvDateTime dt, DvTimeSpan offset) if (dt.IsNA || offset.IsNA || !TryValidateOffset(offset.Ticks, out _offset)) { _dateTime = DvDateTime.NA; - _offset = DvInt2.NA; + _offset = short.MinValue; } else _dateTime = ValidateDate(dt, ref _offset); @@ -233,20 +230,19 @@ public DvDateTimeZone(DvDateTime dt, DvTimeSpan offset) /// The offset. This value is assumed to be validated as a legal offset: /// a value in whole minutes, between -14 and 14 hours. /// The UTC DvDateTime representing the input clock time minus the offset - private static DvDateTime ValidateDate(DvDateTime dateTime, ref DvInt2 offset) + private static DvDateTime ValidateDate(DvDateTime dateTime, ref short offset) { Contracts.Assert(!dateTime.IsNA); - Contracts.Assert(!offset.IsNA); // Validate that both the UTC and clock times are legal. - Contracts.Assert(MinMinutesOffset <= offset.RawValue && offset.RawValue <= MaxMinutesOffset); - var offsetTicks = offset.RawValue * TicksPerMinute; + Contracts.Assert(MinMinutesOffset <= offset && offset <= MaxMinutesOffset); + var offsetTicks = offset * TicksPerMinute; // This operation cannot overflow because offset should have already been validated to be within // 14 hours and the DateTime instance is more than that distance from the boundaries of Int64. - long utcTicks = dateTime.Ticks.RawValue - offsetTicks; + long utcTicks = dateTime.Ticks - offsetTicks; var dvdt = new DvDateTime(utcTicks); if (dvdt.IsNA) - offset = DvInt2.NA; + offset = short.MinValue; return dvdt; } @@ -257,36 +253,37 @@ private static DvDateTime ValidateDate(DvDateTime dateTime, ref DvInt2 offset) /// /// /// - private static bool TryValidateOffset(DvInt8 offsetTicks, out DvInt2 offset) + private static bool TryValidateOffset(long offsetTicks, out short offset) { - if (offsetTicks.IsNA || offsetTicks.RawValue % TicksPerMinute != 0) + if (offsetTicks == long.MinValue || offsetTicks % TicksPerMinute != 0) { - offset = DvInt2.NA; + offset = short.MinValue; return false; } - long mins = offsetTicks.RawValue / TicksPerMinute; + long mins = offsetTicks / TicksPerMinute; short res = (short)mins; if (res != mins || res > MaxMinutesOffset || res < MinMinutesOffset) { - offset = DvInt2.NA; + offset = short.MinValue; return false; } offset = res; - Contracts.Assert(!offset.IsNA); + Contracts.Assert(offset != short.MinValue); return true; } [Conditional("DEBUG")] private void AssertValid() { + _dateTime.AssertValid(); _dateTime.AssertValid(); if (_dateTime.IsNA) - Contracts.Assert(_offset.IsNA); + Contracts.Assert(_offset == short.MinValue); else { - Contracts.Assert(MinMinutesOffset <= _offset.RawValue && _offset.RawValue <= MaxMinutesOffset); - Contracts.Assert((ulong)(_dateTime.Ticks.RawValue + _offset.RawValue * TicksPerMinute) + Contracts.Assert(MinMinutesOffset <= _offset && _offset <= MaxMinutesOffset); + Contracts.Assert((ulong)(_dateTime.Ticks + _offset * TicksPerMinute) <= (ulong)DvDateTime.MaxTicks); } } @@ -298,7 +295,7 @@ public DvDateTime ClockDateTime AssertValid(); if (_dateTime.IsNA) return DvDateTime.NA; - var res = new DvDateTime(_dateTime.Ticks.RawValue + _offset.RawValue * TicksPerMinute); + var res = new DvDateTime(_dateTime.Ticks + _offset * TicksPerMinute); Contracts.Assert(!res.IsNA); return res; } @@ -326,16 +323,16 @@ public DvTimeSpan Offset get { AssertValid(); - if (_offset.IsNA) + if (_offset == short.MinValue) return DvTimeSpan.NA; - return new DvTimeSpan(_offset.RawValue * TicksPerMinute); + return new DvTimeSpan(_offset * TicksPerMinute); } } /// /// Gets the offset in minutes. /// - public DvInt2 OffsetMinutes + public short OffsetMinutes { get { @@ -392,7 +389,7 @@ public bool IsNA // and _offset = 0. public static DvDateTimeZone NA { - get { return new DvDateTimeZone(DvDateTime.NA, DvInt2.NA); } + get { return new DvDateTimeZone(DvDateTime.NA, short.MinValue); } } public static explicit operator SysDateTimeOffset?(DvDateTimeZone dvDto) @@ -427,7 +424,7 @@ private DateTimeOffset GetSysDateTimeOffset() { AssertValid(); Contracts.Assert(!IsNA); - return new SysDateTimeOffset(ClockDateTime.GetSysDateTime(), new TimeSpan(0, _offset.RawValue, 0)); + return new SysDateTimeOffset(ClockDateTime.GetSysDateTime(), new TimeSpan(0, _offset, 0)); } /// @@ -436,7 +433,7 @@ private DateTimeOffset GetSysDateTimeOffset() /// public bool Equals(DvDateTimeZone other) { - return _offset.RawValue == other._offset.RawValue && _dateTime.Equals(other._dateTime); + return _offset == other._offset && _dateTime.Equals(other._dateTime); } public override bool Equals(object obj) @@ -456,9 +453,9 @@ public int CompareTo(DvDateTimeZone other) int res = _dateTime.CompareTo(other._dateTime); if (res != 0) return res; - if (_offset.RawValue == other._offset.RawValue) + if (_offset == other._offset) return 0; - return _offset.RawValue < other._offset.RawValue ? -1 : 1; + return _offset < other._offset ? -1 : 1; } public override int GetHashCode() @@ -472,11 +469,11 @@ public override int GetHashCode() /// public struct DvTimeSpan : IEquatable, IComparable { - private readonly DvInt8 _ticks; + private readonly long _ticks; - public DvInt8 Ticks { get { return _ticks; } } + public long Ticks { get { return _ticks; } } - public DvTimeSpan(DvInt8 ticks) + public DvTimeSpan(long ticks) { _ticks = ticks; } @@ -488,24 +485,24 @@ public DvTimeSpan(SysTimeSpan sts) public DvTimeSpan(SysTimeSpan? sts) { - _ticks = sts != null ? sts.GetValueOrDefault().Ticks : DvInt8.NA; + _ticks = sts != null ? sts.GetValueOrDefault().Ticks : long.MinValue; } public bool IsNA { - get { return _ticks.IsNA; } + get { return _ticks == long.MinValue; } } public static DvTimeSpan NA { - get { return new DvTimeSpan(DvInt8.NA); } + get { return new DvTimeSpan(long.MinValue); } } public static explicit operator SysTimeSpan?(DvTimeSpan ts) { if (ts.IsNA) return null; - return new SysTimeSpan(ts._ticks.RawValue); + return new SysTimeSpan(ts._ticks); } public static implicit operator DvTimeSpan(SysTimeSpan sts) @@ -522,12 +519,12 @@ public override string ToString() { if (IsNA) return ""; - return new SysTimeSpan(_ticks.RawValue).ToString("c"); + return new SysTimeSpan(_ticks).ToString("c"); } public bool Equals(DvTimeSpan other) { - return _ticks.RawValue == other._ticks.RawValue; + return _ticks == other._ticks; } public override bool Equals(object obj) @@ -537,9 +534,9 @@ public override bool Equals(object obj) public int CompareTo(DvTimeSpan other) { - if (_ticks.RawValue == other._ticks.RawValue) + if (_ticks == other._ticks) return 0; - return _ticks.RawValue < other._ticks.RawValue ? -1 : 1; + return _ticks < other._ticks ? -1 : 1; } public override int GetHashCode() diff --git a/src/Microsoft.ML.Core/Data/DvInt1.cs b/src/Microsoft.ML.Core/Data/DvInt1.cs deleted file mode 100644 index ced2a4688d..0000000000 --- a/src/Microsoft.ML.Core/Data/DvInt1.cs +++ /dev/null @@ -1,264 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.CompilerServices; - -namespace Microsoft.ML.Runtime.Data -{ - using BL = DvBool; - using I2 = DvInt2; - using I4 = DvInt4; - using I8 = DvInt8; - using IX = DvInt1; - using R4 = Single; - using R8 = Double; - using RawI8 = Int64; - using RawIX = SByte; - - public struct DvInt1 : IEquatable, IComparable - { - public const RawIX RawNA = RawIX.MinValue; - - // Ideally this would be readonly. However, note that this struct has no - // ctor, but instead only has conversion operators. The implicit conversion - // operator from RawIX to DvIX performs better than an equivalent ctor, - // and the conversion operator must assign the _value field. - private RawIX _value; - - /// - /// Property to return the raw value. - /// - public RawIX RawValue - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value; } - } - - /// - /// Static method to return the raw value. This is more convenient than the - /// property in code-generation scenarios. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static RawIX GetRawBits(IX a) - { - return a._value; - } - - public static IX NA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return RawNA; } - } - - public bool IsNA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value == RawNA; } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX value) - { - IX res; - res._value = value; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX? value) - { - IX res; - res._value = value ?? RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX(IX value) - { - if (value._value == RawNA) - throw Contracts.ExceptValue(nameof(value), "NA cast to sbyte"); - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX?(IX value) - { - if (value._value == RawNA) - return null; - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(BL a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I2 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I4 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I8 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R4 a) - { - return (IX)(R8)a; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R4(IX a) - { - if (a._value == RawNA) - return R4.NaN; - return (R4)a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R8 a) - { - const R8 lim = -(R8)RawIX.MinValue; - if (-lim < a && a < lim) - { - RawIX n = (RawIX)a; -#if DEBUG - Contracts.Assert(!a.IsNA()); - Contracts.Assert(n != RawNA); - RawI8 nn = (RawI8)a; - Contracts.Assert(nn == n); - if (a >= 0) - Contracts.Assert(a - 1 < n & n <= a); - else - Contracts.Assert(a <= n & n < a + 1); -#endif - return n; - } - - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R8(IX a) - { - if (a._value == RawNA) - return R8.NaN; - return (R8)a._value; - } - - public override int GetHashCode() - { - return _value.GetHashCode(); - } - - public override bool Equals(object obj) - { - if (obj is IX) - return _value == ((IX)obj)._value; - return false; - } - - public bool Equals(IX other) - { - return _value == other._value; - } - - public int CompareTo(IX other) - { - if (_value == other._value) - return 0; - return _value < other._value ? -1 : 1; - } - - public override string ToString() - { - if (_value == RawNA) - return "NA"; - return _value.ToString(); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator ==(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av == bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator !=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av != bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av < bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av <= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av >= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av > bv ? BL.True : BL.False; - return BL.NA; - } - } -} diff --git a/src/Microsoft.ML.Core/Data/DvInt2.cs b/src/Microsoft.ML.Core/Data/DvInt2.cs deleted file mode 100644 index 33599f6468..0000000000 --- a/src/Microsoft.ML.Core/Data/DvInt2.cs +++ /dev/null @@ -1,263 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.CompilerServices; - -namespace Microsoft.ML.Runtime.Data -{ - using BL = DvBool; - using I1 = DvInt1; - using I4 = DvInt4; - using I8 = DvInt8; - using IX = DvInt2; - using R4 = Single; - using R8 = Double; - using RawI8 = Int64; - using RawIX = Int16; - - public struct DvInt2 : IEquatable, IComparable - { - public const RawIX RawNA = RawIX.MinValue; - - // Ideally this would be readonly. However, note that this struct has no - // ctor, but instead only has conversion operators. The implicit conversion - // operator from RawIX to DvIX performs better than an equivalent ctor, - // and the conversion operator must assign the _value field. - private RawIX _value; - - /// - /// Property to return the raw value. - /// - public RawIX RawValue - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value; } - } - - /// - /// Static method to return the raw value. This is more convenient than the - /// property in code-generation scenarios. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static RawIX GetRawBits(IX a) - { - return a._value; - } - - public static IX NA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return RawNA; } - } - - public bool IsNA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value == RawNA; } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX value) - { - IX res; - res._value = value; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX? value) - { - IX res; - res._value = value ?? RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX(IX value) - { - if (value._value == RawNA) - throw Contracts.ExceptValue(nameof(value), "NA cast to short"); - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX?(IX value) - { - if (value._value == RawNA) - return null; - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(BL a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I1 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I4 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I8 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R4 a) - { - return (IX)(R8)a; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R4(IX a) - { - if (a._value == RawNA) - return R4.NaN; - return (R4)a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R8 a) - { - const R8 lim = -(R8)RawIX.MinValue; - if (-lim < a && a < lim) - { - RawIX n = (RawIX)a; -#if DEBUG - Contracts.Assert(!a.IsNA()); - Contracts.Assert(n != RawNA); - RawI8 nn = (RawI8)a; - Contracts.Assert(nn == n); - if (a >= 0) - Contracts.Assert(a - 1 < n & n <= a); - else - Contracts.Assert(a <= n & n < a + 1); -#endif - return n; - } - - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R8(IX a) - { - if (a._value == RawNA) - return R8.NaN; - return (R8)a._value; - } - - public override int GetHashCode() - { - return _value.GetHashCode(); - } - - public override bool Equals(object obj) - { - if (obj is IX) - return _value == ((IX)obj)._value; - return false; - } - - public bool Equals(IX other) - { - return _value == other._value; - } - - public int CompareTo(IX other) - { - if (_value == other._value) - return 0; - return _value < other._value ? -1 : 1; - } - - public override string ToString() - { - if (_value == RawNA) - return "NA"; - return _value.ToString(); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator ==(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av == bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator !=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av != bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av < bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av <= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av >= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av > bv ? BL.True : BL.False; - return BL.NA; - } - } -} diff --git a/src/Microsoft.ML.Core/Data/DvInt4.cs b/src/Microsoft.ML.Core/Data/DvInt4.cs deleted file mode 100644 index 23c7e89242..0000000000 --- a/src/Microsoft.ML.Core/Data/DvInt4.cs +++ /dev/null @@ -1,456 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.CompilerServices; - -namespace Microsoft.ML.Runtime.Data -{ - using BL = DvBool; - using I1 = DvInt1; - using I2 = DvInt2; - using I8 = DvInt8; - using IX = DvInt4; - using R4 = Single; - using R8 = Double; - using RawI8 = Int64; - using RawIX = Int32; - - public struct DvInt4 : IEquatable, IComparable - { - public const RawIX RawNA = RawIX.MinValue; - - // Ideally this would be readonly. However, note that this struct has no - // ctor, but instead only has conversion operators. The implicit conversion - // operator from RawIX to DvIX performs better than an equivalent ctor, - // and the conversion operator must assign the _value field. - private RawIX _value; - - /// - /// Property to return the raw value. - /// - public RawIX RawValue - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value; } - } - - /// - /// Static method to return the raw value. This is more convenient than the - /// property in code-generation scenarios. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static RawIX GetRawBits(IX a) - { - return a._value; - } - - public static IX NA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return RawNA; } - } - - public bool IsNA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value == RawNA; } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX value) - { - IX res; - res._value = value; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX? value) - { - IX res; - res._value = value ?? RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX(IX value) - { - if (value._value == RawNA) - throw Contracts.ExceptValue(nameof(value), "NA cast to int"); - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX?(IX value) - { - if (value._value == RawNA) - return null; - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(BL a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I1 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I2 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(I8 a) - { - RawIX res = (RawIX)a.RawValue; - if (res != a.RawValue) - return RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R4 a) - { - return (IX)(R8)a; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R4(IX a) - { - if (a._value == RawNA) - return R4.NaN; - return (R4)a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R8 a) - { - const R8 lim = -(R8)RawIX.MinValue; - if (-lim < a && a < lim) - { - RawIX n = (RawIX)a; -#if DEBUG - Contracts.Assert(!a.IsNA()); - Contracts.Assert(n != RawNA); - RawI8 nn = (RawI8)a; - Contracts.Assert(nn == n); - if (a >= 0) - Contracts.Assert(a - 1 < n & n <= a); - else - Contracts.Assert(a <= n & n < a + 1); -#endif - return n; - } - - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R8(IX a) - { - if (a._value == RawNA) - return R8.NaN; - return (R8)a._value; - } - - public override int GetHashCode() - { - return _value.GetHashCode(); - } - - public override bool Equals(object obj) - { - if (obj is IX) - return _value == ((IX)obj)._value; - return false; - } - - public bool Equals(IX other) - { - return _value == other._value; - } - - public int CompareTo(IX other) - { - if (_value == other._value) - return 0; - return _value < other._value ? -1 : 1; - } - - public override string ToString() - { - if (_value == RawNA) - return "NA"; - return _value.ToString(); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator ==(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av == bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator !=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av != bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av < bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av <= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av >= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av > bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator -(IX a) - { - return -a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator +(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - { - var res = av + bv; - // Overflow happens iff the sign of the result is different than both source values. - if ((av ^ res) >= 0) - return res; - if ((bv ^ res) >= 0) - return res; - } - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator -(IX a, IX b) - { - var av = a._value; - var bv = -b._value; - if (av != RawNA && bv != RawNA) - { - var res = av + bv; - // Overflow happens iff the sign of the result is different than both source values. - if ((av ^ res) >= 0) - return res; - if ((bv ^ res) >= 0) - return res; - } - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator *(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - { - RawI8 res = (RawI8)av * bv; - if (-RawIX.MaxValue <= res && res <= RawIX.MaxValue) - return (RawIX)res; - } - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator /(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA && bv != 0) - return av / bv; - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator %(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA && bv != 0) - return av % bv; - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Abs(IX a) - { - // Can't use Math.Abs since it throws on the RawNA value. - return a._value >= 0 ? a._value : -a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Sign(IX a) - { - var val = a._value; - var neg = -val; - // This works for NA since -RawNA == RawNA. - return val > neg ? +1 : val < neg ? -1 : val; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Min(IX a, IX b) - { - var v1 = a._value; - var v2 = b._value; - // This works for NA since RawNA == RawIX.MinValue. - return v1 <= v2 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public IX Min(IX b) - { - var v1 = _value; - var v2 = b._value; - // This works for NA since RawNA == RawIX.MinValue. - return v1 <= v2 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Max(IX a, IX b) - { - var v1 = a._value; - var v2 = b._value; - // This works for NA since RawNA - 1 == RawIX.MaxValue. - return v1 - 1 >= v2 - 1 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public IX Max(IX b) - { - var v1 = _value; - var v2 = b._value; - // This works for NA since RawNA - 1 == RawIX.MaxValue. - return v1 - 1 >= v2 - 1 ? v1 : v2; - } - - /// - /// Raise a to the b power. Special cases: - /// * 1^NA => 1 - /// * NA^0 => 1 - /// - public static IX Pow(IX a, IX b) - { - var av = a.RawValue; - var bv = b.RawValue; - - if (av == 1) - return 1; - switch (bv) - { - case 0: - return 1; - case 1: - return av; - case 2: - return a * a; - case RawNA: - return RawNA; - } - if (av == -1) - return (bv & 1) == 0 ? 1 : -1; - if (bv < 0) - return RawNA; - if (av == RawNA) - return RawNA; - - // Since the abs of the base is at least two, the exponent must be less than 31. - if (bv >= 31) - return RawNA; - - bool neg = false; - if (av < 0) - { - av = -av; - neg = (bv & 1) != 0; - } - Contracts.Assert(av >= 2); - - // Since the exponent is at least three, the base must be <= 1290. - Contracts.Assert(bv >= 3); - if (av > 1290) - return RawNA; - - // REVIEW: Should we use a checked context and exception catching like I8 does? - ulong u = (ulong)(uint)av; - ulong result = 1; - for (; ; ) - { - if ((bv & 1) != 0 && (result *= u) > RawIX.MaxValue) - return RawNA; - bv >>= 1; - if (bv == 0) - break; - if ((u *= u) > RawIX.MaxValue) - return RawNA; - } - Contracts.Assert(result <= RawIX.MaxValue); - - var res = (RawIX)result; - if (neg) - res = -res; - return res; - } - } -} diff --git a/src/Microsoft.ML.Core/Data/DvInt8.cs b/src/Microsoft.ML.Core/Data/DvInt8.cs deleted file mode 100644 index 3212e21fa6..0000000000 --- a/src/Microsoft.ML.Core/Data/DvInt8.cs +++ /dev/null @@ -1,511 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Runtime.CompilerServices; - -namespace Microsoft.ML.Runtime.Data -{ - using BL = DvBool; - using I1 = DvInt1; - using I2 = DvInt2; - using I4 = DvInt4; - using IX = DvInt8; - using R4 = Single; - using R8 = Double; - using RawIX = Int64; - - public struct DvInt8 : IEquatable, IComparable - { - public const RawIX RawNA = RawIX.MinValue; - - // Ideally this would be readonly. However, note that this struct has no - // ctor, but instead only has conversion operators. The implicit conversion - // operator from RawIX to DvIX performs better than an equivalent ctor, - // and the conversion operator must assign the _value field. - private RawIX _value; - - /// - /// Property to return the raw value. - /// - public RawIX RawValue - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value; } - } - - /// - /// Static method to return the raw value. This is more convenient than the - /// property in code-generation scenarios. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static RawIX GetRawBits(IX a) - { - return a._value; - } - - public static IX NA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return RawNA; } - } - - public bool IsNA - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get { return _value == RawNA; } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX value) - { - IX res; - res._value = value; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(RawIX? value) - { - IX res; - res._value = value ?? RawNA; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX(IX value) - { - if (value._value == RawNA) - throw Contracts.ExceptValue(nameof(value), "NA cast to long"); - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator RawIX?(IX value) - { - if (value._value == RawNA) - return null; - return value._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(BL a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I1 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I2 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static implicit operator IX(I4 a) - { - if (a.IsNA) - return RawNA; - return (RawIX)a.RawValue; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R4 a) - { - return (IX)(R8)a; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R4(IX a) - { - if (a._value == RawNA) - return R4.NaN; - return (R4)a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator IX(R8 a) - { - const R8 lim = -(R8)RawIX.MinValue; - if (-lim < a && a < lim) - { - RawIX n = (RawIX)a; -#if DEBUG - Contracts.Assert(!a.IsNA()); - Contracts.Assert(n != RawNA); - // Note that an R8 cannot represent long.MaxValue exactly so y + 1.0 below might be the same as y. - R8 x = a; - R8 y = n; - if (a < 0) - { - x = -x; - y = -y; - } - Contracts.Assert(y <= x); - Contracts.Assert(x < y + 1.0 | y + 1.0 == y & x == y); -#endif - return n; - } - - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static explicit operator R8(IX a) - { - if (a._value == RawNA) - return R8.NaN; - return (R8)a._value; - } - - public override int GetHashCode() - { - return _value.GetHashCode(); - } - - public override bool Equals(object obj) - { - if (obj is IX) - return _value == ((IX)obj)._value; - return false; - } - - public bool Equals(IX other) - { - return _value == other._value; - } - - public int CompareTo(IX other) - { - if (_value == other._value) - return 0; - return _value < other._value ? -1 : 1; - } - - public override string ToString() - { - if (_value == RawNA) - return "NA"; - return _value.ToString(); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator ==(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av == bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator !=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av != bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av < bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator <=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av <= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >=(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av >= bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static BL operator >(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - return av > bv ? BL.True : BL.False; - return BL.NA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator -(IX a) - { - return -a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator +(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA) - { - var res = av + bv; - // Overflow happens iff the sign of the result is different than both source values. - if ((av ^ res) >= 0) - return res; - if ((bv ^ res) >= 0) - return res; - } - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator -(IX a, IX b) - { - var av = a._value; - var bv = -b._value; - if (av != RawNA && bv != RawNA) - { - var res = av + bv; - // Overflow happens iff the sign of the result is different than both source values. - if ((av ^ res) >= 0) - return res; - if ((bv ^ res) >= 0) - return res; - } - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator *(IX a, IX b) - { - var av = a._value; - var bv = b._value; - bool neg = (av ^ bv) < 0; - if (av < 0) - { - if (av == RawNA) - return RawNA; - av = -av; - } - if (bv < 0) - { - if (bv == RawNA) - return RawNA; - bv = -bv; - } - - // Deal with the low 32 bits. - ulong lo1 = (ulong)av & 0x00000000FFFFFFFF; - ulong lo2 = (ulong)bv & 0x00000000FFFFFFFF; - RawIX res = (RawIX)(lo1 * lo2); - if (res < 0) - return RawNA; - - // Get the high 32 bits, including cross terms. - ulong hi1 = (ulong)av >> 32; - ulong hi2 = (ulong)bv >> 32; - if (hi1 != 0) - { - // If both high words are non-zero, overflow is guaranteed. - if (hi2 != 0) - return RawNA; - // Compute the cross term. - ulong tmp = hi1 * lo2; - if ((tmp & 0xFFFFFFFF80000000) != 0) - return RawNA; - res += (long)(tmp << 32); - if (res < 0) - return RawNA; - } - else if (hi2 != 0) - { - // Compute the cross term. - ulong tmp = hi2 * lo1; - if ((tmp & 0xFFFFFFFF80000000) != 0) - return RawNA; - res += (long)(tmp << 32); - if (res < 0) - return RawNA; - } - - // Adjust the sign. - if (neg) - res = -res; - return res; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator /(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA && bv != 0) - return av / bv; - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX operator %(IX a, IX b) - { - var av = a._value; - var bv = b._value; - if (av != RawNA && bv != RawNA && bv != 0) - return av % bv; - return RawNA; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Abs(IX a) - { - // Can't use Math.Abs since it throws on the RawNA value. - return a._value >= 0 ? a._value : -a._value; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Sign(IX a) - { - var val = a._value; - var neg = -val; - // This works for NA since -RawNA == RawNA. - return val > neg ? +1 : val < neg ? -1 : val; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Min(IX a, IX b) - { - var v1 = a._value; - var v2 = b._value; - // This works for NA since RawNA == RawIX.MinValue. - return v1 <= v2 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public IX Min(IX b) - { - var v1 = _value; - var v2 = b._value; - // This works for NA since RawNA == RawIX.MinValue. - return v1 <= v2 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static IX Max(IX a, IX b) - { - var v1 = a._value; - var v2 = b._value; - // This works for NA since RawNA - 1 == RawIX.MaxValue. - return v1 - 1 >= v2 - 1 ? v1 : v2; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public IX Max(IX b) - { - var v1 = _value; - var v2 = b._value; - // This works for NA since RawNA - 1 == RawIX.MaxValue. - return v1 - 1 >= v2 - 1 ? v1 : v2; - } - - /// - /// Raise a to the b power. Special cases: - /// * 1^NA => 1 - /// * NA^0 => 1 - /// - public static IX Pow(IX a, IX b) - { - var av = a.RawValue; - var bv = b.RawValue; - - if (av == 1) - return 1; - switch (bv) - { - case 0: - return 1; - case 1: - return av; - case 2: - return a * a; - case RawNA: - return RawNA; - } - if (av == -1) - return (bv & 1) == 0 ? 1 : -1; - if (bv < 0) - return RawNA; - if (av == RawNA) - return RawNA; - - // Since the abs of the base is at least two, the exponent must be less than 63. - if (bv >= 63) - return RawNA; - - bool neg = false; - if (av < 0) - { - av = -av; - neg = (bv & 1) != 0; - } - Contracts.Assert(av >= 2); - - // Since the exponent is at least three, the base must be < 2^21. - Contracts.Assert(bv >= 3); - if (av >= (1L << 21)) - return RawNA; - - long res = 1; - long x = av; - // REVIEW: Is the catch too slow in the overflow case? - try - { - checked - { - for (; ; ) - { - if ((bv & 1) != 0) - res *= x; - bv >>= 1; - if (bv == 0) - break; - x *= x; - } - } - } - catch (OverflowException) - { - return RawNA; - } - Contracts.Assert(res > 0); - - if (neg) - res = -res; - return res; - } - } -} diff --git a/src/Microsoft.ML.Core/Data/MetadataUtils.cs b/src/Microsoft.ML.Core/Data/MetadataUtils.cs index a0de4e51ca..fbb3c54e86 100644 --- a/src/Microsoft.ML.Core/Data/MetadataUtils.cs +++ b/src/Microsoft.ML.Core/Data/MetadataUtils.cs @@ -420,9 +420,9 @@ public static bool TryGetCategoricalFeatureIndices(ISchema schema, int colIndex, return isValid; var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex); - if (type?.RawType == typeof(VBuffer)) + if (type?.RawType == typeof(VBuffer)) { - VBuffer catIndices = default(VBuffer); + VBuffer catIndices = default(VBuffer); schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices); VBufferUtils.Densify(ref catIndices); int columnSlotsCount = schema.GetColumnType(colIndex).AsVector.VectorSizeCore; @@ -432,19 +432,19 @@ public static bool TryGetCategoricalFeatureIndices(ISchema schema, int colIndex, isValid = true; for (int i = 0; i < catIndices.Values.Length; i += 2) { - if (catIndices.Values[i].RawValue > catIndices.Values[i + 1].RawValue || - catIndices.Values[i].RawValue <= previousEndIndex || - catIndices.Values[i].RawValue >= columnSlotsCount || - catIndices.Values[i + 1].RawValue >= columnSlotsCount) + if (catIndices.Values[i] > catIndices.Values[i + 1] || + catIndices.Values[i] <= previousEndIndex || + catIndices.Values[i] >= columnSlotsCount || + catIndices.Values[i + 1] >= columnSlotsCount) { isValid = false; break; } - previousEndIndex = catIndices.Values[i + 1].RawValue; + previousEndIndex = catIndices.Values[i + 1]; } if (isValid) - categoricalFeatures = catIndices.Values.Select(val => val.RawValue).ToArray(); + categoricalFeatures = catIndices.Values.Select(val => val).ToArray(); } } diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs index 41c794e17f..171d73ff65 100644 --- a/src/Microsoft.ML.Core/Utilities/Stream.cs +++ b/src/Microsoft.ML.Core/Utilities/Stream.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Float = System.Single; - using System; using System.Collections; using System.Collections.Generic; @@ -178,7 +176,7 @@ public static void WriteBytesNoCount(this BinaryWriter writer, byte[] values, in /// /// Writes a length prefixed array of Floats. /// - public static void WriteFloatArray(this BinaryWriter writer, Float[] values) + public static void WriteFloatArray(this BinaryWriter writer, float[] values) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -197,7 +195,7 @@ public static void WriteFloatArray(this BinaryWriter writer, Float[] values) /// /// Writes a length prefixed array of Floats. /// - public static void WriteFloatArray(this BinaryWriter writer, Float[] values, int count) + public static void WriteFloatArray(this BinaryWriter writer, float[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -211,7 +209,7 @@ public static void WriteFloatArray(this BinaryWriter writer, Float[] values, int /// /// Writes a specified number of floats starting at the specified index from an array. /// - public static void WriteFloatArray(this BinaryWriter writer, Float[] values, int start, int count) + public static void WriteFloatArray(this BinaryWriter writer, float[] values, int start, int count) { Contracts.AssertValue(writer); Contracts.AssertValue(values); @@ -225,7 +223,7 @@ public static void WriteFloatArray(this BinaryWriter writer, Float[] values, int /// /// Writes a length prefixed array of Floats. /// - public static void WriteFloatArray(this BinaryWriter writer, IEnumerable values, int count) + public static void WriteFloatArray(this BinaryWriter writer, IEnumerable values, int count) { Contracts.AssertValue(writer); Contracts.AssertValue(values); @@ -244,7 +242,7 @@ public static void WriteFloatArray(this BinaryWriter writer, IEnumerable /// /// Writes an array of Floats without the length prefix. /// - public static void WriteFloatsNoCount(this BinaryWriter writer, Float[] values, int count) + public static void WriteFloatsNoCount(this BinaryWriter writer, float[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -257,7 +255,7 @@ public static void WriteFloatsNoCount(this BinaryWriter writer, Float[] values, /// /// Writes a length prefixed array of singles. /// - public static void WriteSingleArray(this BinaryWriter writer, Single[] values) + public static void WriteSingleArray(this BinaryWriter writer, float[] values) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -276,7 +274,7 @@ public static void WriteSingleArray(this BinaryWriter writer, Single[] values) /// /// Writes a length prefixed array of singles. /// - public static void WriteSingleArray(this BinaryWriter writer, Single[] values, int count) + public static void WriteSingleArray(this BinaryWriter writer, float[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -290,7 +288,7 @@ public static void WriteSingleArray(this BinaryWriter writer, Single[] values, i /// /// Writes an array of singles without the length prefix. /// - public static void WriteSinglesNoCount(this BinaryWriter writer, Single[] values, int count) + public static void WriteSinglesNoCount(this BinaryWriter writer, float[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -303,7 +301,7 @@ public static void WriteSinglesNoCount(this BinaryWriter writer, Single[] values /// /// Writes a length prefixed array of doubles. /// - public static void WriteDoubleArray(this BinaryWriter writer, Double[] values) + public static void WriteDoubleArray(this BinaryWriter writer, double[] values) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -315,14 +313,14 @@ public static void WriteDoubleArray(this BinaryWriter writer, Double[] values) } writer.Write(values.Length); - foreach (Double val in values) + foreach (double val in values) writer.Write(val); } /// /// Writes a length prefixed array of doubles. /// - public static void WriteDoubleArray(this BinaryWriter writer, Double[] values, int count) + public static void WriteDoubleArray(this BinaryWriter writer, double[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -336,7 +334,7 @@ public static void WriteDoubleArray(this BinaryWriter writer, Double[] values, i /// /// Writes an array of doubles without the length prefix. /// - public static void WriteDoublesNoCount(this BinaryWriter writer, Double[] values, int count) + public static void WriteDoublesNoCount(this BinaryWriter writer, double[] values, int count) { Contracts.AssertValue(writer); Contracts.AssertValueOrNull(values); @@ -427,7 +425,7 @@ public static void WriteBitArray(this BinaryWriter writer, BitArray arr) } } - public static long WriteSByteStream(this BinaryWriter writer, IEnumerable e) + public static long WriteSByteStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -438,7 +436,7 @@ public static long WriteSByteStream(this BinaryWriter writer, IEnumerable return c; } - public static long WriteByteStream(this BinaryWriter writer, IEnumerable e) + public static long WriteByteStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -449,7 +447,7 @@ public static long WriteByteStream(this BinaryWriter writer, IEnumerable e return c; } - public static long WriteIntStream(this BinaryWriter writer, IEnumerable e) + public static long WriteIntStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -460,7 +458,7 @@ public static long WriteIntStream(this BinaryWriter writer, IEnumerable e return c; } - public static long WriteUIntStream(this BinaryWriter writer, IEnumerable e) + public static long WriteUIntStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -471,7 +469,7 @@ public static long WriteUIntStream(this BinaryWriter writer, IEnumerable return c; } - public static long WriteShortStream(this BinaryWriter writer, IEnumerable e) + public static long WriteShortStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -482,7 +480,7 @@ public static long WriteShortStream(this BinaryWriter writer, IEnumerable return c; } - public static long WriteUShortStream(this BinaryWriter writer, IEnumerable e) + public static long WriteUShortStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -493,7 +491,7 @@ public static long WriteUShortStream(this BinaryWriter writer, IEnumerable e) + public static long WriteLongStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -504,7 +502,7 @@ public static long WriteLongStream(this BinaryWriter writer, IEnumerable return c; } - public static long WriteULongStream(this BinaryWriter writer, IEnumerable e) + public static long WriteULongStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -515,7 +513,7 @@ public static long WriteULongStream(this BinaryWriter writer, IEnumerable e) + public static long WriteSingleStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -526,7 +524,7 @@ public static long WriteSingleStream(this BinaryWriter writer, IEnumerable e) + public static long WriteDoubleStream(this BinaryWriter writer, IEnumerable e) { long c = 0; foreach (var v in e) @@ -606,12 +604,12 @@ public static bool ReadBoolByte(this BinaryReader reader) return b != 0; } - public static Float ReadFloat(this BinaryReader reader) + public static float ReadFloat(this BinaryReader reader) { return reader.ReadSingle(); } - public static Float[] ReadFloatArray(this BinaryReader reader) + public static float[] ReadFloatArray(this BinaryReader reader) { Contracts.AssertValue(reader); @@ -620,16 +618,16 @@ public static Float[] ReadFloatArray(this BinaryReader reader) return ReadFloatArray(reader, size); } - public static Float[] ReadFloatArray(this BinaryReader reader, int size) + public static float[] ReadFloatArray(this BinaryReader reader, int size) { Contracts.AssertValue(reader); Contracts.Assert(size >= 0); if (size == 0) return null; - var values = new Float[size]; + var values = new float[size]; - long bufferSizeInBytes = (long)size * sizeof(Float); + long bufferSizeInBytes = (long)size * sizeof(float); if (bufferSizeInBytes < _bulkReadThresholdInBytes) { for (int i = 0; i < size; i++) @@ -649,14 +647,14 @@ public static Float[] ReadFloatArray(this BinaryReader reader, int size) return values; } - public static void ReadFloatArray(this BinaryReader reader, Float[] array, int start, int count) + public static void ReadFloatArray(this BinaryReader reader, float[] array, int start, int count) { Contracts.AssertValue(reader); Contracts.AssertValue(array); Contracts.Assert(0 <= start && start < array.Length); Contracts.Assert(0 < count && count <= array.Length - start); - long bufferReadLengthInBytes = (long)count * sizeof(Float); + long bufferReadLengthInBytes = (long)count * sizeof(float); if (bufferReadLengthInBytes < _bulkReadThresholdInBytes) { for (int i = 0; i < count; i++) @@ -668,15 +666,15 @@ public static void ReadFloatArray(this BinaryReader reader, Float[] array, int s { fixed (void* dst = array) { - long bufferBeginOffsetInBytes = (long)start * sizeof(Float); - long bufferSizeInBytes = ((long)array.Length - start) * sizeof(Float); + long bufferBeginOffsetInBytes = (long)start * sizeof(float); + long bufferSizeInBytes = ((long)array.Length - start) * sizeof(float); ReadBytes(reader, (byte*)dst + bufferBeginOffsetInBytes, bufferSizeInBytes, bufferReadLengthInBytes); } } } } - public static Single[] ReadSingleArray(this BinaryReader reader) + public static float[] ReadSingleArray(this BinaryReader reader) { Contracts.AssertValue(reader); int size = reader.ReadInt32(); @@ -684,15 +682,15 @@ public static Single[] ReadSingleArray(this BinaryReader reader) return ReadSingleArray(reader, size); } - public static Single[] ReadSingleArray(this BinaryReader reader, int size) + public static float[] ReadSingleArray(this BinaryReader reader, int size) { Contracts.AssertValue(reader); Contracts.Assert(size >= 0); if (size == 0) return null; - var values = new Single[size]; + var values = new float[size]; - long bufferSizeInBytes = (long)size * sizeof(Single); + long bufferSizeInBytes = (long)size * sizeof(float); if (bufferSizeInBytes < _bulkReadThresholdInBytes) { for (int i = 0; i < size; i++) @@ -712,7 +710,7 @@ public static Single[] ReadSingleArray(this BinaryReader reader, int size) return values; } - public static Double[] ReadDoubleArray(this BinaryReader reader) + public static double[] ReadDoubleArray(this BinaryReader reader) { Contracts.AssertValue(reader); @@ -721,15 +719,15 @@ public static Double[] ReadDoubleArray(this BinaryReader reader) return ReadDoubleArray(reader, size); } - public static Double[] ReadDoubleArray(this BinaryReader reader, int size) + public static double[] ReadDoubleArray(this BinaryReader reader, int size) { Contracts.AssertValue(reader); Contracts.Assert(size >= 0); if (size == 0) return null; - var values = new Double[size]; + var values = new double[size]; - long bufferSizeInBytes = (long)size * sizeof(Double); + long bufferSizeInBytes = (long)size * sizeof(double); if (bufferSizeInBytes < _bulkReadThresholdInBytes) { for (int i = 0; i < size; i++) diff --git a/src/Microsoft.ML.Data/Data/Conversion.cs b/src/Microsoft.ML.Data/Data/Conversion.cs index 0a9833064a..e39a0242e6 100644 --- a/src/Microsoft.ML.Data/Data/Conversion.cs +++ b/src/Microsoft.ML.Data/Data/Conversion.cs @@ -17,16 +17,12 @@ namespace Microsoft.ML.Runtime.Data.Conversion using BL = DvBool; using DT = DvDateTime; using DZ = DvDateTimeZone; - using I1 = DvInt1; - using I2 = DvInt2; - using I4 = DvInt4; - using I8 = DvInt8; using R4 = Single; using R8 = Double; - using RawI1 = SByte; - using RawI2 = Int16; - using RawI4 = Int32; - using RawI8 = Int64; + using I1 = SByte; + using I2 = Int16; + using I4 = Int32; + using I8 = Int64; using SB = StringBuilder; using TS = DvTimeSpan; using TX = DvText; @@ -244,10 +240,6 @@ private Conversions() AddStd(Convert); AddAux(Convert); - AddIsNA(IsNA); - AddIsNA(IsNA); - AddIsNA(IsNA); - AddIsNA(IsNA); AddIsNA(IsNA); AddIsNA(IsNA); AddIsNA(IsNA); @@ -256,10 +248,6 @@ private Conversions() AddIsNA
(IsNA); AddIsNA(IsNA); - AddGetNA(GetNA); - AddGetNA(GetNA); - AddGetNA(GetNA); - AddGetNA(GetNA); AddGetNA(GetNA); AddGetNA(GetNA); AddGetNA(GetNA); @@ -268,10 +256,6 @@ private Conversions() AddGetNA
(GetNA); AddGetNA(GetNA); - AddHasNA(HasNA); - AddHasNA(HasNA); - AddHasNA(HasNA); - AddHasNA(HasNA); AddHasNA(HasNA); AddHasNA(HasNA); AddHasNA(HasNA); @@ -846,10 +830,6 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) // The IsNA methods are for efficient delegates (instance instead of static). #region IsNA - private bool IsNA(ref I1 src) => src.IsNA; - private bool IsNA(ref I2 src) => src.IsNA; - private bool IsNA(ref I4 src) => src.IsNA; - private bool IsNA(ref I8 src) => src.IsNA; private bool IsNA(ref R4 src) => src.IsNA(); private bool IsNA(ref R8 src) => src.IsNA(); private bool IsNA(ref BL src) => src.IsNA; @@ -860,10 +840,6 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion IsNA #region HasNA - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } - private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA()) return true; } return false; } private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA()) return true; } return false; } private bool HasNA(ref VBuffer src) { for (int i = 0; i < src.Count; i++) { if (src.Values[i].IsNA) return true; } return false; } @@ -874,10 +850,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion HasNA #region IsDefault - private bool IsDefault(ref I1 src) => src.RawValue == 0; - private bool IsDefault(ref I2 src) => src.RawValue == 0; - private bool IsDefault(ref I4 src) => src.RawValue == 0; - private bool IsDefault(ref I8 src) => src.RawValue == 0; + private bool IsDefault(ref I1 src) => src == default(I1); + private bool IsDefault(ref I2 src) => src == default(I2); + private bool IsDefault(ref I4 src) => src == default(I4); + private bool IsDefault(ref I8 src) => src == default(I8); private bool IsDefault(ref R4 src) => src == 0; private bool IsDefault(ref R8 src) => src == 0; private bool IsDefault(ref TX src) => src.IsEmpty; @@ -900,10 +876,6 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion HasZero #region GetNA - private void GetNA(ref I1 value) => value = I1.NA; - private void GetNA(ref I2 value) => value = I2.NA; - private void GetNA(ref I4 value) => value = I4.NA; - private void GetNA(ref I8 value) => value = I8.NA; private void GetNA(ref R4 value) => value = R4.NaN; private void GetNA(ref R8 value) => value = R8.NaN; private void GetNA(ref BL value) => value = BL.NA; @@ -1022,10 +994,10 @@ public ValueGetter GetNAOrDefaultGetter(ColumnType type) #endregion ToR8 #region ToStringBuilder - public void Convert(ref I1 src, ref SB dst) { ClearDst(ref dst); if (!src.IsNA) dst.Append(src.RawValue); } - public void Convert(ref I2 src, ref SB dst) { ClearDst(ref dst); if (!src.IsNA) dst.Append(src.RawValue); } - public void Convert(ref I4 src, ref SB dst) { ClearDst(ref dst); if (!src.IsNA) dst.Append(src.RawValue); } - public void Convert(ref I8 src, ref SB dst) { ClearDst(ref dst); if (!src.IsNA) dst.Append(src.RawValue); } + public void Convert(ref I1 src, ref SB dst) { ClearDst(ref dst); dst.Append(src); } + public void Convert(ref I2 src, ref SB dst) { ClearDst(ref dst); dst.Append(src); } + public void Convert(ref I4 src, ref SB dst) { ClearDst(ref dst); dst.Append(src); } + public void Convert(ref I8 src, ref SB dst) { ClearDst(ref dst); dst.Append(src); } public void Convert(ref U1 src, ref SB dst) => ClearDst(ref dst).Append(src); public void Convert(ref U2 src, ref SB dst) => ClearDst(ref dst).Append(src); public void Convert(ref U4 src, ref SB dst) => ClearDst(ref dst).Append(src); @@ -1063,6 +1035,7 @@ public void Convert(ref BL src, ref SB dst) ///
public bool TryParse(ref TX src, out U1 dst) { + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref src, out res) || res > U1.MaxValue) { @@ -1078,6 +1051,7 @@ public bool TryParse(ref TX src, out U1 dst) /// public bool TryParse(ref TX src, out U2 dst) { + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref src, out res) || res > U2.MaxValue) { @@ -1093,6 +1067,7 @@ public bool TryParse(ref TX src, out U2 dst) /// public bool TryParse(ref TX src, out U4 dst) { + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref src, out res) || res > U4.MaxValue) { @@ -1108,12 +1083,7 @@ public bool TryParse(ref TX src, out U4 dst) /// public bool TryParse(ref TX src, out U8 dst) { - if (src.IsNA) - { - dst = 0; - return false; - } - + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to unsigned integer type."); int ichMin; int ichLim; string text = src.GetRawUnderlyingBufferInfo(out ichMin, out ichLim); @@ -1226,11 +1196,13 @@ private bool IsStdMissing(ref TX src) /// Utility to assist in parsing key-type values. The min and max values define /// the legal input value bounds. The output dst value is "normalized" so min is /// mapped to 1, max is mapped to 1 + (max - min). - /// Missing values are mapped to zero with a true return. + /// Exception is thrown for missing values. /// Unparsable or out of range values are mapped to zero with a false return. /// public bool TryParseKey(ref TX src, U8 min, U8 max, out U8 dst) { + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to unsigned integer type."); + Contracts.Check(!IsStdMissing(ref src), "Missing text value cannot be converted to unsigned integer type."); Contracts.Assert(min <= max); // This simply ensures we don't have min == 0 and max == U8.MaxValue. This is illegal since @@ -1255,7 +1227,7 @@ public bool TryParseKey(ref TX src, U8 min, U8 max, out U8 dst) { dst = 0; // Return true only for standard forms for NA. - return IsStdMissing(ref src); + return false; } if (min > uu || uu > max) @@ -1301,57 +1273,65 @@ private bool TryParseCore(string text, int ich, int lim, out ulong dst) /// /// This produces zero for empty. It returns false if the text is not parsable or overflows. - /// On failure, it sets dst to the NA value. + /// On failure, it sets dst to the default value. /// public bool TryParse(ref TX src, out I1 dst) { - long res; - bool f = TryParseSigned(RawI1.MaxValue, ref src, out res); - Contracts.Assert(f || res == I1.RawNA); - Contracts.Assert((RawI1)res == res); - dst = (RawI1)res; - return f; + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + + dst = default; + TryParseSigned(I1.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to sbyte."); + Contracts.Check((I1)res == res, "Overflow or underflow occured while converting value in text to sbyte."); + dst = (I1)res; + return true; } /// /// This produces zero for empty. It returns false if the text is not parsable or overflows. - /// On failure, it sets dst to the NA value. + /// On failure, it sets dst to the default value. /// public bool TryParse(ref TX src, out I2 dst) { - long res; - bool f = TryParseSigned(RawI2.MaxValue, ref src, out res); - Contracts.Assert(f || res == I2.RawNA); - Contracts.Assert((RawI2)res == res); - dst = (RawI2)res; - return f; + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + + dst = default; + TryParseSigned(I2.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to short."); + Contracts.Check((I2)res == res, "Overflow or underflow occured while converting value in text to short."); + dst = (I2)res; + return true; } /// /// This produces zero for empty. It returns false if the text is not parsable or overflows. - /// On failure, it sets dst to the NA value. + /// On failure, it sets dst to the defualt value. /// public bool TryParse(ref TX src, out I4 dst) { - long res; - bool f = TryParseSigned(RawI4.MaxValue, ref src, out res); - Contracts.Assert(f || res == I4.RawNA); - Contracts.Assert((RawI4)res == res); - dst = (RawI4)res; - return f; + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + + dst = default; + TryParseSigned(I4.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to int32."); + Contracts.Check((I4)res == res, "Overflow or underflow occured while converting value in text to int."); + dst = (I4)res; + return true; } /// /// This produces zero for empty. It returns false if the text is not parsable or overflows. - /// On failure, it sets dst to the NA value. + /// On failure, it sets dst to the default value. /// public bool TryParse(ref TX src, out I8 dst) { - long res; - bool f = TryParseSigned(RawI8.MaxValue, ref src, out res); - Contracts.Assert(f || res == I8.RawNA); - dst = res; - return f; + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + + dst = default; + TryParseSigned(I8.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to long."); + dst = (I8)res; + return true; } /// @@ -1389,61 +1369,57 @@ private bool TryParseNonNegative(string text, int ich, int lim, out long result) /// /// This produces zero for empty. It returns false if the text is not parsable as a signed integer - /// or the result overflows. The min legal value is -max. The NA value is -max - 1. + /// or the result overflows. The min legal value is -max. The NA value null. /// When it returns false, result is set to the NA value. The result can be NA on true return, /// since some representations of NA are not considered parse failure. /// - private bool TryParseSigned(long max, ref TX span, out long result) + private void TryParseSigned(long max, ref TX span, out long? result) { Contracts.Assert(max > 0); Contracts.Assert((max & (max + 1)) == 0); if (!span.HasChars) { - if (span.IsNA) - result = -max - 1; - else - result = 0; - return true; + result = default(long); + return; } int ichMin; int ichLim; string text = span.GetRawUnderlyingBufferInfo(out ichMin, out ichLim); - - long val; + ulong val; if (span[0] == '-') { if (span.Length == 1 || - !TryParseNonNegative(text, ichMin + 1, ichLim, out val) || - val > max) + !TryParseCore(text, ichMin + 1, ichLim, out val) || + (val > ((ulong)max + 1))) { - result = -max - 1; - return false; + result = null; + return; } Contracts.Assert(val >= 0); result = -(long)val; - Contracts.Assert(long.MinValue < result && result <= 0); - return true; + Contracts.Assert(long.MinValue <= result && result <= 0); + return; } - if (!TryParseNonNegative(text, ichMin, ichLim, out val)) + long sVal; + if (!TryParseNonNegative(text, ichMin, ichLim, out sVal)) { - // Check for acceptable NA forms: ? NaN NA and N/A. - result = -max - 1; - return IsStdMissing(ref span); + result = null; + return; } - Contracts.Assert(val >= 0); - if (val > max) + Contracts.Assert(sVal >= 0); + if (sVal > max) { - result = -max - 1; - return false; + result = null; + return; } - result = (long)val; + result = (long)sVal; Contracts.Assert(0 <= result && result <= long.MaxValue); - return true; + return; } /// @@ -1530,42 +1506,40 @@ public bool TryParse(ref TX src, out DZ dst) return IsStdMissing(ref src); } - // These map unparsable and overflow values to "NA", which is the value Ix.MinValue. Note that this NA - // value is the "evil" value - the non-zero value, x, such that x == -x. Note also, that for I4, this - // matches R's representation of NA. + // These throw an exception for unparsable and overflow values. private I1 ParseI1(ref TX src) { - long res; - bool f = TryParseSigned(RawI1.MaxValue, ref src, out res); - Contracts.Assert(f || res == I1.RawNA); - Contracts.Assert((RawI1)res == res); - return (RawI1)res; + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + TryParseSigned(I1.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to sbyte."); + Contracts.Check((I1)res == res, "Overflow or underflow occured while converting value in text to sbyte."); + return (I1)res; } private I2 ParseI2(ref TX src) { - long res; - bool f = TryParseSigned(RawI2.MaxValue, ref src, out res); - Contracts.Assert(f || res == I2.RawNA); - Contracts.Assert((RawI2)res == res); - return (RawI2)res; + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + TryParseSigned(I2.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to short."); + Contracts.Check((I2)res == res, "Overflow or underflow occured while converting value in text to short."); + return (I2)res; } private I4 ParseI4(ref TX src) { - long res; - bool f = TryParseSigned(RawI4.MaxValue, ref src, out res); - Contracts.Assert(f || res == I4.RawNA); - Contracts.Assert((RawI4)res == res); - return (RawI4)res; + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + TryParseSigned(I4.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to int."); + Contracts.Check((I4)res == res, "Overflow or underflow occured while converting value in text to int."); + return (I4)res; } private I8 ParseI8(ref TX src) { - long res; - bool f = TryParseSigned(RawI8.MaxValue, ref src, out res); - Contracts.Assert(f || res == I8.RawNA); - return res; + Contracts.Check(!src.IsNA, "Missing text value cannot be converted to integer type."); + TryParseSigned(I8.MaxValue, ref src, out long? res); + Contracts.Check(res.HasValue, "Value could not be parsed from text to long."); + return res.Value; } // These map unparsable and overflow values to zero. The unsigned integer types do not have an NA value. @@ -1573,6 +1547,7 @@ private I8 ParseI8(ref TX src) // unsigned integer types. private U1 ParseU1(ref TX span) { + Contracts.Check(!span.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref span, out res)) return 0; @@ -1583,6 +1558,7 @@ private U1 ParseU1(ref TX span) private U2 ParseU2(ref TX span) { + Contracts.Check(!span.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref span, out res)) return 0; @@ -1593,6 +1569,7 @@ private U2 ParseU2(ref TX span) private U4 ParseU4(ref TX span) { + Contracts.Check(!span.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref span, out res)) return 0; @@ -1603,6 +1580,7 @@ private U4 ParseU4(ref TX span) private U8 ParseU8(ref TX span) { + Contracts.Check(!span.IsNA, "Missing text value cannot be converted to unsigned integer type."); ulong res; if (!TryParse(ref span, out res)) return 0; diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs index 7bc0a8d2ad..582212738a 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/BinaryLoader.cs @@ -729,7 +729,13 @@ public void GetMetadata(string kind, int col, ref TValue value) /// /// Upper inclusive bound of versions this reader can read. /// - private const ulong ReaderVersion = MissingTextVersion; + private const ulong ReaderVersion = StandardDataTypesVersion; + + /// + /// The first version that removes DvTypes and uses .NET standard + /// data types. + /// + private const ulong StandardDataTypesVersion = 0x0001000100010006; /// /// The first version of the format that accomodated DvText.NA. diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs index d04adaf099..5793a0b129 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/CodecFactory.cs @@ -44,13 +44,13 @@ public CodecFactory(IHostEnvironment env, MemoryStreamPool memPool = null) _loadNameToCodecCreator = new Dictionary(); _simpleCodecTypeMap = new Dictionary(); // Register the current codecs. - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); - RegisterSimpleCodec(new UnsafeTypeCodec(this)); + RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); RegisterSimpleCodec(new UnsafeTypeCodec(this)); diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs index f840773872..6029027f36 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Codecs.cs @@ -626,8 +626,8 @@ public Writer(DateTimeCodec codec, Stream stream) public override void Write(ref DvDateTime value) { - var ticks = value.Ticks.RawValue; - Contracts.Assert(ticks == DvInt8.RawNA || (ulong)ticks <= DvDateTime.MaxTicks); + var ticks = value.Ticks; + Contracts.Assert(ticks == long.MinValue || (ulong)ticks <= DvDateTime.MaxTicks); Writer.Write(ticks); _numWritten++; } @@ -639,7 +639,7 @@ public override void Commit() public override long GetCommitLengthEstimate() { - return _numWritten * sizeof(Int64); + return _numWritten * sizeof(long); } } @@ -658,7 +658,7 @@ public override void MoveNext() { Contracts.Assert(_remaining > 0, "already consumed all values"); var value = Reader.ReadInt64(); - Contracts.CheckDecode(value == DvInt8.RawNA || (ulong)value <= DvDateTime.MaxTicks); + Contracts.CheckDecode(value == long.MinValue || (ulong)value <= DvDateTime.MaxTicks); _value = new DvDateTime(value); _remaining--; } @@ -711,19 +711,19 @@ public override void Write(ref DvDateTimeZone value) var ticks = value.ClockDateTime.Ticks; var offset = value.OffsetMinutes; - _ticks.Add(ticks.RawValue); - if (ticks.IsNA) + _ticks.Add(ticks); + if (ticks == long.MinValue) { - Contracts.Assert(offset.IsNA); + Contracts.Assert(offset == short.MinValue); _offsets.Add(0); } else { Contracts.Assert( - offset.RawValue >= DvDateTimeZone.MinMinutesOffset && - offset.RawValue <= DvDateTimeZone.MaxMinutesOffset); - Contracts.Assert(0 <= ticks.RawValue && ticks.RawValue <= DvDateTime.MaxTicks); - _offsets.Add(offset.RawValue); + offset >= DvDateTimeZone.MinMinutesOffset && + offset <= DvDateTimeZone.MaxMinutesOffset); + Contracts.Assert(0 <= ticks && ticks <= DvDateTime.MaxTicks); + _offsets.Add(offset); } } @@ -740,7 +740,7 @@ public override void Commit() public override long GetCommitLengthEstimate() { - return (long)_offsets.Count * (sizeof(Int64) + sizeof(Int16)); + return (long)_offsets.Count * (sizeof(long) + sizeof(short)); } } @@ -773,7 +773,7 @@ public Reader(DateTimeZoneCodec codec, Stream stream, int items) for (int i = 0; i < _entries; i++) { _ticks[i] = Reader.ReadInt64(); - Contracts.CheckDecode(_ticks[i] == DvInt8.RawNA || (ulong)_ticks[i] <= DvDateTime.MaxTicks); + Contracts.CheckDecode(_ticks[i] == long.MinValue || (ulong)_ticks[i] <= DvDateTime.MaxTicks); } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs index 36186cf7af..b552ab6523 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/Header.cs @@ -34,8 +34,9 @@ public struct Header //public const ulong WriterVersion = 0x0001000100010002; // Codec changes. //public const ulong WriterVersion = 0x0001000100010003; // Slot names. //public const ulong WriterVersion = 0x0001000100010004; // Column metadata. - public const ulong WriterVersion = 0x0001000100010005; // "NA" DvText support. - public const ulong CanBeReadByVersion = 0x0001000100010005; + //public const ulong WriterVersion = 0x0001000100010005; // "NA" DvText support. + public const ulong WriterVersion = 0x0001000100010006; // Replace DvTypes with .NET Standard data types. + public const ulong CanBeReadByVersion = 0x0001000100010006; internal static string VersionToString(ulong v) { diff --git a/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs b/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs index 026228d6be..7a17b84ac1 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Binary/UnsafeTypeOps.cs @@ -32,17 +32,13 @@ internal static class UnsafeTypeOpsFactory static UnsafeTypeOpsFactory() { _type2ops = new Dictionary(); - _type2ops[typeof(SByte)] = new SByteUnsafeTypeOps(); - _type2ops[typeof(DvInt1)] = new DvI1UnsafeTypeOps(); + _type2ops[typeof(sbyte)] = new SByteUnsafeTypeOps(); _type2ops[typeof(Byte)] = new ByteUnsafeTypeOps(); - _type2ops[typeof(Int16)] = new Int16UnsafeTypeOps(); - _type2ops[typeof(DvInt2)] = new DvI2UnsafeTypeOps(); + _type2ops[typeof(short)] = new Int16UnsafeTypeOps(); _type2ops[typeof(UInt16)] = new UInt16UnsafeTypeOps(); - _type2ops[typeof(Int32)] = new Int32UnsafeTypeOps(); - _type2ops[typeof(DvInt4)] = new DvI4UnsafeTypeOps(); + _type2ops[typeof(int)] = new Int32UnsafeTypeOps(); _type2ops[typeof(UInt32)] = new UInt32UnsafeTypeOps(); - _type2ops[typeof(Int64)] = new Int64UnsafeTypeOps(); - _type2ops[typeof(DvInt8)] = new DvI8UnsafeTypeOps(); + _type2ops[typeof(long)] = new Int64UnsafeTypeOps(); _type2ops[typeof(UInt64)] = new UInt64UnsafeTypeOps(); _type2ops[typeof(Single)] = new SingleUnsafeTypeOps(); _type2ops[typeof(Double)] = new DoubleUnsafeTypeOps(); @@ -55,29 +51,16 @@ public static UnsafeTypeOps Get() return (UnsafeTypeOps)_type2ops[typeof(T)]; } - private sealed class SByteUnsafeTypeOps : UnsafeTypeOps + private sealed class SByteUnsafeTypeOps : UnsafeTypeOps { - public override int Size { get { return sizeof(SByte); } } - public override unsafe void Apply(SByte[] array, Action func) + public override int Size { get { return sizeof(sbyte); } } + public override unsafe void Apply(sbyte[] array, Action func) { - fixed (SByte* pArray = array) + fixed (sbyte* pArray = array) func(new IntPtr(pArray)); } - public override void Write(SByte a, BinaryWriter writer) { writer.Write(a); } - public override SByte Read(BinaryReader reader) { return reader.ReadSByte(); } - } - - private sealed class DvI1UnsafeTypeOps : UnsafeTypeOps - { - public override int Size { get { return sizeof(SByte); } } - public override unsafe void Apply(DvInt1[] array, Action func) - { - fixed (DvInt1* pArray = array) - func(new IntPtr(pArray)); - } - - public override void Write(DvInt1 a, BinaryWriter writer) { writer.Write(a.RawValue); } - public override DvInt1 Read(BinaryReader reader) { return reader.ReadSByte(); } + public override void Write(sbyte a, BinaryWriter writer) { writer.Write(a); } + public override sbyte Read(BinaryReader reader) { return reader.ReadSByte(); } } private sealed class ByteUnsafeTypeOps : UnsafeTypeOps @@ -92,29 +75,16 @@ public override unsafe void Apply(Byte[] array, Action func) public override Byte Read(BinaryReader reader) { return reader.ReadByte(); } } - private sealed class Int16UnsafeTypeOps : UnsafeTypeOps + private sealed class Int16UnsafeTypeOps : UnsafeTypeOps { - public override int Size { get { return sizeof(Int16); } } - public override unsafe void Apply(Int16[] array, Action func) + public override int Size { get { return sizeof(short); } } + public override unsafe void Apply(short[] array, Action func) { - fixed (Int16* pArray = array) + fixed (short* pArray = array) func(new IntPtr(pArray)); } - public override void Write(Int16 a, BinaryWriter writer) { writer.Write(a); } - public override Int16 Read(BinaryReader reader) { return reader.ReadInt16(); } - } - - private sealed class DvI2UnsafeTypeOps : UnsafeTypeOps - { - public override int Size { get { return sizeof(Int16); } } - public override unsafe void Apply(DvInt2[] array, Action func) - { - fixed (DvInt2* pArray = array) - func(new IntPtr(pArray)); - } - - public override void Write(DvInt2 a, BinaryWriter writer) { writer.Write(a.RawValue); } - public override DvInt2 Read(BinaryReader reader) { return reader.ReadInt16(); } + public override void Write(short a, BinaryWriter writer) { writer.Write(a); } + public override short Read(BinaryReader reader) { return reader.ReadInt16(); } } private sealed class UInt16UnsafeTypeOps : UnsafeTypeOps @@ -129,29 +99,16 @@ public override unsafe void Apply(UInt16[] array, Action func) public override UInt16 Read(BinaryReader reader) { return reader.ReadUInt16(); } } - private sealed class Int32UnsafeTypeOps : UnsafeTypeOps + private sealed class Int32UnsafeTypeOps : UnsafeTypeOps { - public override int Size { get { return sizeof(Int32); } } - public override unsafe void Apply(Int32[] array, Action func) + public override int Size { get { return sizeof(int); } } + public override unsafe void Apply(int[] array, Action func) { - fixed (Int32* pArray = array) + fixed (int* pArray = array) func(new IntPtr(pArray)); } - public override void Write(Int32 a, BinaryWriter writer) { writer.Write(a); } - public override Int32 Read(BinaryReader reader) { return reader.ReadInt32(); } - } - - private sealed class DvI4UnsafeTypeOps : UnsafeTypeOps - { - public override int Size { get { return sizeof(Int32); } } - public override unsafe void Apply(DvInt4[] array, Action func) - { - fixed (DvInt4* pArray = array) - func(new IntPtr(pArray)); - } - - public override void Write(DvInt4 a, BinaryWriter writer) { writer.Write(a.RawValue); } - public override DvInt4 Read(BinaryReader reader) { return reader.ReadInt32(); } + public override void Write(int a, BinaryWriter writer) { writer.Write(a); } + public override int Read(BinaryReader reader) { return reader.ReadInt32(); } } private sealed class UInt32UnsafeTypeOps : UnsafeTypeOps @@ -166,29 +123,16 @@ public override unsafe void Apply(UInt32[] array, Action func) public override UInt32 Read(BinaryReader reader) { return reader.ReadUInt32(); } } - private sealed class Int64UnsafeTypeOps : UnsafeTypeOps + private sealed class Int64UnsafeTypeOps : UnsafeTypeOps { - public override int Size { get { return sizeof(Int64); } } - public override unsafe void Apply(Int64[] array, Action func) + public override int Size { get { return sizeof(long); } } + public override unsafe void Apply(long[] array, Action func) { - fixed (Int64* pArray = array) + fixed (long* pArray = array) func(new IntPtr(pArray)); } - public override void Write(Int64 a, BinaryWriter writer) { writer.Write(a); } - public override Int64 Read(BinaryReader reader) { return reader.ReadInt64(); } - } - - private sealed class DvI8UnsafeTypeOps : UnsafeTypeOps - { - public override int Size { get { return sizeof(Int64); } } - public override unsafe void Apply(DvInt8[] array, Action func) - { - fixed (DvInt8* pArray = array) - func(new IntPtr(pArray)); - } - - public override void Write(DvInt8 a, BinaryWriter writer) { writer.Write(a.RawValue); } - public override DvInt8 Read(BinaryReader reader) { return reader.ReadInt64(); } + public override void Write(long a, BinaryWriter writer) { writer.Write(a); } + public override long Read(BinaryReader reader) { return reader.ReadInt64(); } } private sealed class UInt64UnsafeTypeOps : UnsafeTypeOps @@ -229,14 +173,14 @@ public override unsafe void Apply(Double[] array, Action func) private sealed class DvTimeSpanUnsafeTypeOps : UnsafeTypeOps { - public override int Size { get { return sizeof(Int64); } } + public override int Size { get { return sizeof(long); } } public override unsafe void Apply(DvTimeSpan[] array, Action func) { fixed (DvTimeSpan* pArray = array) func(new IntPtr(pArray)); } - public override void Write(DvTimeSpan a, BinaryWriter writer) { writer.Write(a.Ticks.RawValue); } + public override void Write(DvTimeSpan a, BinaryWriter writer) { writer.Write(a.Ticks); } public override DvTimeSpan Read(BinaryReader reader) { return new DvTimeSpan(reader.ReadInt64()); } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs index c0d0f25b17..57c1322062 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs @@ -992,16 +992,24 @@ public int GatherFields(DvText lineSpan, string path = null, long line = 0) } var spanT = Fields.Spans[Fields.Count - 1]; - // Note that Convert produces NA if the text is unparsable. - DvInt4 csrc = default(DvInt4); - Conversion.Conversions.Instance.Convert(ref spanT, ref csrc); - csrcSparse = csrc.RawValue; - if (csrcSparse <= 0) + // Note that Convert throws exception the text is unparsable. + int csrc = default; + try + { + Conversions.Instance.Convert(ref spanT, ref csrc); + } + catch + { + Contracts.Assert(csrc == default); + } + + if (csrc <= 0) { _stats.LogBadFmt(ref scan, "Bad dimensionality or ambiguous sparse item. Use sparse=- for non-sparse file, and/or quote the value."); break; } + csrcSparse = csrc; srcLimFixed = Fields.Indices[--Fields.Count]; if (csrcSparse >= SrcLim - srcLimFixed) csrcSparse = SrcLim - srcLimFixed - 1; diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index 8e4f3be56c..d5f599fab1 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -136,7 +136,7 @@ protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, A var thresholdAtK = new List(); var thresholdAtP = new List(); var thresholdAtNumAnomalies = new List(); - var numAnoms = new List(); + var numAnoms = new List(); var scores = new List(); var labels = new List(); @@ -678,11 +678,11 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary col == numAnomIndex || (hasStrat && col == stratCol))) { - var numAnomGetter = cursor.GetGetter(numAnomIndex); + var numAnomGetter = cursor.GetGetter(numAnomIndex); ValueGetter stratGetter = null; if (hasStrat) { diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index 7a1c25665c..3efa662ef9 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -986,9 +986,9 @@ protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMa var labelType = perInst.Schema.GetColumnType(labelCol); if (labelType.IsKey && (!perInst.Schema.HasKeyNames(labelCol, labelType.KeyCount) || labelType.RawKind != DataKind.U4)) { - perInst = LambdaColumnMapper.Create(Host, "ConvertToLong", perInst, schema.Label.Name, - schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.I8, - (ref uint src, ref DvInt8 dst) => dst = src == 0 ? DvInt8.NA : src - 1 + (long)labelType.AsKey.Min); + perInst = LambdaColumnMapper.Create(Host, "ConvertToDouble", perInst, schema.Label.Name, + schema.Label.Name, perInst.Schema.GetColumnType(labelCol), NumberType.R8, + (ref uint src, ref double dst) => dst = src == 0 ? double.NaN : src - 1 + (double)labelType.AsKey.Min); } var perInstSchema = perInst.Schema; diff --git a/src/Microsoft.ML.Data/Training/TrainerUtils.cs b/src/Microsoft.ML.Data/Training/TrainerUtils.cs index 33d3d1490d..b86daaace4 100644 --- a/src/Microsoft.ML.Data/Training/TrainerUtils.cs +++ b/src/Microsoft.ML.Data/Training/TrainerUtils.cs @@ -593,7 +593,7 @@ public void Signal(CursOpt opt) } /// - /// This supports Weight (Float), Group (ulong), and Id (DvInt8) columns. + /// This supports Weight (Float), Group (ulong), and Id (UInt128) columns. /// public class StandardScalarCursor : TrainingCursorBase { diff --git a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs index b2024cc18c..e47d1780f4 100644 --- a/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/ConcatTransform.cs @@ -404,7 +404,7 @@ protected override void GetMetadataCore(string kind, int iinfo, ref TVal if (_typesCategoricals[iinfo] == null) throw MetadataUtils.ExceptGetMetadata(); - MetadataUtils.Marshal, TValue>(GetCategoricalSlotRanges, iinfo, ref value); + MetadataUtils.Marshal, TValue>(GetCategoricalSlotRanges, iinfo, ref value); break; case MetadataUtils.Kinds.IsNormalized: if (!_isNormalized[iinfo]) @@ -417,9 +417,9 @@ protected override void GetMetadataCore(string kind, int iinfo, ref TVal } } - private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) + private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) { - List allValues = new List(); + List allValues = new List(); int slotCount = 0; for (int i = 0; i < Infos[iiinfo].SrcIndices.Length; i++) { @@ -440,7 +440,7 @@ private void GetCategoricalSlotRanges(int iiinfo, ref VBuffer dst) Contracts.Assert(allValues.Count > 0); - dst = new VBuffer(allValues.Count, allValues.ToArray()); + dst = new VBuffer(allValues.Count, allValues.ToArray()); } private void IsNormalized(int iinfo, ref DvBool dst) diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index 230cfbe680..ef9791f7d3 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -393,14 +393,14 @@ private void ComputeType(ISchema input, int[] slotsMin, int[] slotsMax, int iinf { if (MetadataUtils.TryGetCategoricalFeatureIndices(Source.Schema, Infos[iinfo].Source, out categoricalRanges)) { - VBuffer dst = default(VBuffer); + VBuffer dst = default(VBuffer); GetCategoricalSlotRangesCore(iinfo, slotDropper.SlotsMin, slotDropper.SlotsMax, categoricalRanges, ref dst); // REVIEW: cache dst as opposed to caculating it again. if (dst.Length > 0) { Contracts.Assert(dst.Length % 2 == 0); - bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, + bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.GetCategoricalType(dst.Length / 2), GetCategoricalSlotRanges); } } @@ -443,7 +443,7 @@ private void GetSlotNames(int iinfo, ref VBuffer dst) infoEx.SlotDropper.DropSlots(ref names, ref dst); } - private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) + private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) { if (_exes[iinfo].CategoricalRanges != null) { @@ -452,7 +452,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) } } - private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slotsMax, int[] catRanges, ref VBuffer dst) + private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slotsMax, int[] catRanges, ref VBuffer dst) { Host.Assert(0 <= iinfo && iinfo < Infos.Length); Host.Assert(slotsMax != null && slotsMin != null); @@ -467,9 +467,9 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots int previousDropSlotsIndex = 0; int droppedSlotsCount = 0; bool combine = false; - DvInt4 min = -1; - DvInt4 max = -1; - List newCategoricalSlotRanges = new List(); + int min = -1; + int max = -1; + List newCategoricalSlotRanges = new List(); // Six possible ways a drop slot range interacts with categorical slots range. // @@ -498,7 +498,7 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots } else { - Contracts.Assert(min.RawValue == -1 && max.RawValue == -1); + Contracts.Assert(min == -1 && max == -1); min = ranges[rangesIndex] - droppedSlotsCount; max = ranges[rangesIndex + 1] - droppedSlotsCount; } @@ -515,14 +515,14 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots rangesIndex += 2; if (combine) { - Contracts.Assert(min.RawValue >= 0 && min.RawValue <= max.RawValue); + Contracts.Assert(min >= 0 && min <= max); newCategoricalSlotRanges.Add(min); newCategoricalSlotRanges.Add(max); min = max = -1; combine = false; } - Contracts.Assert(min.RawValue == -1 && max.RawValue == -1); + Contracts.Assert(min == -1 && max == -1); } else if (slotsMin[dropSlotsIndex] > ranges[rangesIndex] && @@ -535,7 +535,7 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots } else { - Contracts.Assert(min.RawValue == -1 && max.RawValue == -1); + Contracts.Assert(min == -1 && max == -1); min = ranges[rangesIndex] - droppedSlotsCount; max = slotsMin[dropSlotsIndex] - 1 - droppedSlotsCount; @@ -576,28 +576,28 @@ private void GetCategoricalSlotRangesCore(int iinfo, int[] slotsMin, int[] slots min = max = -1; } - Contracts.Assert(min.RawValue == -1 && max.RawValue == -1); + Contracts.Assert(min == -1 && max == -1); for (int i = rangesIndex; i < ranges.Length; i++) newCategoricalSlotRanges.Add(ranges[i] - droppedSlotsCount); Contracts.Assert(newCategoricalSlotRanges.Count % 2 == 0); - Contracts.Assert(newCategoricalSlotRanges.TrueForAll(x => x.RawValue >= 0)); + Contracts.Assert(newCategoricalSlotRanges.TrueForAll(x => x >= 0)); Contracts.Assert(0 <= droppedSlotsCount && droppedSlotsCount <= slotsMax[slotsMax.Length - 1] + 1); if (newCategoricalSlotRanges.Count > 0) - dst = new VBuffer(newCategoricalSlotRanges.Count, newCategoricalSlotRanges.ToArray()); + dst = new VBuffer(newCategoricalSlotRanges.Count, newCategoricalSlotRanges.ToArray()); } private void CombineRanges( - DvInt4 minRange1, DvInt4 maxRange1, DvInt4 minRange2, DvInt4 maxRange2, - out DvInt4 newRangeMin, out DvInt4 newRangeMax) + int minRange1, int maxRange1, int minRange2, int maxRange2, + out int newRangeMin, out int newRangeMax) { - Contracts.Assert(minRange2.RawValue >= 0 && maxRange2.RawValue >= 0); - Contracts.Assert(minRange2.RawValue <= maxRange2.RawValue); - Contracts.Assert(minRange1.RawValue >= 0 && maxRange1.RawValue >= 0); - Contracts.Assert(minRange1.RawValue <= maxRange1.RawValue); - Contracts.Assert(maxRange1.RawValue + 1 == minRange2.RawValue); + Contracts.Assert(minRange2 >= 0 && maxRange2 >= 0); + Contracts.Assert(minRange2 <= maxRange2); + Contracts.Assert(minRange1 >= 0 && maxRange1 >= 0); + Contracts.Assert(minRange1 <= maxRange1); + Contracts.Assert(maxRange1 + 1 == minRange2); newRangeMin = minRange1; newRangeMax = maxRange2; diff --git a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs index cacd681141..3584531848 100644 --- a/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs @@ -430,9 +430,9 @@ public ValueGetter GetGetter(int col) return fn; } - private ValueGetter MakeGetter() + private ValueGetter MakeGetter() { - return (ref DvInt8 value) => + return (ref long value) => { Ch.Check(IsGood); value = Input.Position; diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index 0f4b616a49..6671d0b033 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -319,7 +319,7 @@ private static void ComputeType(KeyToVectorTransform trans, ISchema input, int i if (!bag && info.TypeSrc.ValueCount > 0) { - bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, + bldr.AddGetter>(MetadataUtils.Kinds.CategoricalSlotRanges, MetadataUtils.GetCategoricalType(info.TypeSrc.ValueCount), trans.GetCategoricalSlotRanges); } @@ -334,7 +334,7 @@ protected override ColumnType GetColumnTypeCore(int iinfo) return _types[iinfo]; } - private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) + private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) { Host.Assert(0 <= iinfo && iinfo < Infos.Length); @@ -342,7 +342,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) Host.Assert(info.TypeSrc.ValueCount > 0); - DvInt4[] ranges = new DvInt4[info.TypeSrc.ValueCount * 2]; + int[] ranges = new int[info.TypeSrc.ValueCount * 2]; int size = info.TypeSrc.ItemType.KeyCount; ranges[0] = 0; @@ -353,7 +353,7 @@ private void GetCategoricalSlotRanges(int iinfo, ref VBuffer dst) ranges[i + 1] = ranges[i] + size - 1; } - dst = new VBuffer(ranges.Length, ranges); + dst = new VBuffer(ranges.Length, ranges); } // Used for slot names when appropriate. diff --git a/src/Microsoft.ML.FastTree/Dataset/Dataset.cs b/src/Microsoft.ML.FastTree/Dataset/Dataset.cs index f31ce73a94..b1b24bd4a1 100644 --- a/src/Microsoft.ML.FastTree/Dataset/Dataset.cs +++ b/src/Microsoft.ML.FastTree/Dataset/Dataset.cs @@ -609,7 +609,7 @@ public int[][] GetAssignments(double[] fraction, int randomSeed, out int[][] ass for (int i = 0; i < numParts; ++i) { cumulative += fraction[i]; - thresh[i] = (int)(cumulative * Int32.MaxValue); + thresh[i] = (int)(cumulative * int.MaxValue); if (fraction[i] == 0.0) thresh[i]--; } diff --git a/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs b/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs index 43304b22f7..f94701856a 100644 --- a/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs +++ b/src/Microsoft.ML.FastTree/Utils/PseudorandomFunction.cs @@ -19,7 +19,7 @@ public sealed class PseudorandomFunction public PseudorandomFunction(Random rand) { - _data = _periodics.Select(x => Enumerable.Range(0, x).Select(y => rand.Next(-1, Int32.MaxValue) + 1).ToArray()).ToArray(); + _data = _periodics.Select(x => Enumerable.Range(0, x).Select(y => rand.Next(-1, int.MaxValue) + 1).ToArray()).ToArray(); } public int Apply(ulong seed) diff --git a/src/Microsoft.ML.Parquet/ParquetLoader.cs b/src/Microsoft.ML.Parquet/ParquetLoader.cs index 503debae65..8600238b76 100644 --- a/src/Microsoft.ML.Parquet/ParquetLoader.cs +++ b/src/Microsoft.ML.Parquet/ParquetLoader.cs @@ -499,21 +499,21 @@ private Delegate CreateGetterDelegate(int col) case DataType.Byte: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.SignedByte: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedByte: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Short: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedShort: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int16: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.UnsignedInt16: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int32: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int64: - return CreateGetterDelegateCore(col, _parquetConversions.Conv); + return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.Int96: return CreateGetterDelegateCore(col, _parquetConversions.Conv); case DataType.ByteArray: @@ -678,17 +678,17 @@ public ParquetConversions(IChannel channel) public void Conv(ref byte[] src, ref VBuffer dst) => dst = src != null ? new VBuffer(src.Length, src) : new VBuffer(0, new byte[0]); - public void Conv(ref sbyte? src, ref DvInt1 dst) => dst = src ?? DvInt1.NA; + public void Conv(ref sbyte? src, ref sbyte dst) => dst = (sbyte)src; public void Conv(ref byte src, ref byte dst) => dst = src; - public void Conv(ref short? src, ref DvInt2 dst) => dst = src ?? DvInt2.NA; + public void Conv(ref short? src, ref short dst) => dst = (short)src; public void Conv(ref ushort src, ref ushort dst) => dst = src; - public void Conv(ref int? src, ref DvInt4 dst) => dst = src ?? DvInt4.NA; + public void Conv(ref int? src, ref int dst) => dst = (int)src; - public void Conv(ref long? src, ref DvInt8 dst) => dst = src ?? DvInt8.NA; + public void Conv(ref long? src, ref long dst) => dst = (long)src; public void Conv(ref float? src, ref Single dst) => dst = src ?? Single.NaN; diff --git a/src/Microsoft.ML.PipelineInference/RecipeInference.cs b/src/Microsoft.ML.PipelineInference/RecipeInference.cs index 827bc813d7..5aa71da682 100644 --- a/src/Microsoft.ML.PipelineInference/RecipeInference.cs +++ b/src/Microsoft.ML.PipelineInference/RecipeInference.cs @@ -202,7 +202,7 @@ protected override IEnumerable ApplyCore(Type predictorType, TransformInference.SuggestedTransform[] transforms) { yield return - new SuggestedRecipe(ToString(), transforms, new SuggestedRecipe.SuggestedLearner[0], Int32.MinValue + 1); + new SuggestedRecipe(ToString(), transforms, new SuggestedRecipe.SuggestedLearner[0], int.MinValue + 1); } public override string ToString() => "Default transforms"; @@ -251,7 +251,7 @@ protected override IEnumerable ApplyCore(Type predictorType, } yield return - new SuggestedRecipe(ToString(), transforms, new[] { learner }, Int32.MaxValue); + new SuggestedRecipe(ToString(), transforms, new[] { learner }, int.MaxValue); } public override string ToString() => "Text classification optimized for speed and accuracy"; diff --git a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs index 91874291b0..7b987d4d71 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/ModelStatistics.cs @@ -414,7 +414,7 @@ public void AddStatsColumns(List list, LinearBinaryPredictor parent, Ro _env.AssertValueOrNull(parent); _env.AssertValue(schema); - DvInt8 count = _trainingExampleCount; + long count = _trainingExampleCount; list.Add(RowColumnUtils.GetColumn("Count of training examples", NumberType.I8, ref count)); var dev = _deviance; list.Add(RowColumnUtils.GetColumn("Residual Deviance", NumberType.R4, ref dev)); diff --git a/src/Microsoft.ML.Sweeper/Parameters.cs b/src/Microsoft.ML.Sweeper/Parameters.cs index dd46374732..6f78bcf521 100644 --- a/src/Microsoft.ML.Sweeper/Parameters.cs +++ b/src/Microsoft.ML.Sweeper/Parameters.cs @@ -588,7 +588,7 @@ public bool TryParseParameter(string paramValue, Type paramType, string paramNam } if (option.StartsWith("steps")) { - numSteps = Int32.Parse(option.Substring(option.IndexOf(':') + 1)); + numSteps = int.Parse(option.Substring(option.IndexOf(':') + 1)); optionsSpecified[1] = true; } if (option.StartsWith("inc")) @@ -613,9 +613,9 @@ public bool TryParseParameter(string paramValue, Type paramType, string paramNam if (paramType == typeof(UInt16) || paramType == typeof(UInt32) || paramType == typeof(UInt64) - || paramType == typeof(Int16) - || paramType == typeof(Int32) - || paramType == typeof(Int64)) + || paramType == typeof(short) + || paramType == typeof(int) + || paramType == typeof(long)) { long min; long max; diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index 0af833a046..58fc29b67a 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -407,7 +407,7 @@ private void GetLabels(Transposer trans, ColumnType labelType, int labelCol) // Note: NAs have their own separate bin. if (labelType == NumberType.I4) { - var tmp = default(VBuffer); + var tmp = default(VBuffer); trans.GetSingleSlotValue(labelCol, ref tmp); BinInts(ref tmp, ref labels, _numBins, out min, out lim); _numLabels = lim - min; @@ -486,7 +486,7 @@ private Single[] ComputeMutualInformation(Transposer trans, int col) if (type.ItemType == NumberType.I4) { return ComputeMutualInformation(trans, col, - (ref VBuffer src, ref VBuffer dst, out int min, out int lim) => + (ref VBuffer src, ref VBuffer dst, out int min, out int lim) => { BinInts(ref src, ref dst, _numBins, out min, out lim); }); @@ -674,29 +674,20 @@ private static ValueMapper, VBuffer> BinKeys(ColumnType colTy } /// - /// Maps from DvInt4 to ints. NaNs (and only NaNs) are mapped to the first bin. + /// Maps Ints. /// - private void BinInts(ref VBuffer input, ref VBuffer output, + private void BinInts(ref VBuffer input, ref VBuffer output, int numBins, out int min, out int lim) { Contracts.Assert(_singles.Count == 0); - if (input.Values != null) - { - for (int i = 0; i < input.Count; i++) - { - var val = input.Values[i]; - if (!val.IsNA) - _singles.Add((Single)val); - } - } var bounds = _binFinder.FindBins(numBins, _singles, input.Length - input.Count); min = -1 - bounds.FindIndexSorted(0); lim = min + bounds.Length + 1; int offset = min; - ValueMapper mapper = - (ref DvInt4 src, ref int dst) => - dst = src.IsNA ? offset : offset + 1 + bounds.FindIndexSorted((Single)src); + ValueMapper mapper = + (ref int src, ref int dst) => + dst = offset + 1 + bounds.FindIndexSorted((Single)src); mapper.MapVector(ref input, ref output); _singles.Clear(); } diff --git a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs index 2340f9b413..40e9fc6d61 100644 --- a/src/Microsoft.ML.Transforms/NAReplaceUtils.cs +++ b/src/Microsoft.ML.Transforms/NAReplaceUtils.cs @@ -22,22 +22,10 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.RawKind) { - case DataKind.I1: - return new I1.MeanAggregatorOne(ch, cursor, col); - case DataKind.I2: - return new I2.MeanAggregatorOne(ch, cursor, col); - case DataKind.I4: - return new I4.MeanAggregatorOne(ch, cursor, col); - case DataKind.I8: - return new Long.MeanAggregatorOne(ch, type, cursor, col); case DataKind.R4: return new R4.MeanAggregatorOne(ch, cursor, col); case DataKind.R8: return new R8.MeanAggregatorOne(ch, cursor, col); - case DataKind.TS: - return new Long.MeanAggregatorOne(ch, type, cursor, col); - case DataKind.DT: - return new Long.MeanAggregatorOne(ch, type, cursor, col); default: break; } @@ -46,22 +34,10 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.RawKind) { - case DataKind.I1: - return new I1.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I2: - return new I2.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I4: - return new I4.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I8: - return new Long.MinMaxAggregatorOne(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R4: return new R4.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: return new R8.MinMaxAggregatorOne(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.TS: - return new Long.MinMaxAggregatorOne(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.DT: - return new Long.MinMaxAggregatorOne(ch, type, cursor, col, kind == ReplacementKind.Max); default: break; } @@ -78,22 +54,10 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.I1: - return new I1.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.I2: - return new I2.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.I4: - return new I4.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.I8: - return new Long.MeanAggregatorBySlot(ch, type, cursor, col); case DataKind.R4: return new R4.MeanAggregatorBySlot(ch, type, cursor, col); case DataKind.R8: return new R8.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.TS: - return new Long.MeanAggregatorBySlot(ch, type, cursor, col); - case DataKind.DT: - return new Long.MeanAggregatorBySlot(ch, type, cursor, col); default: break; } @@ -102,22 +66,10 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.I1: - return new I1.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.I2: - return new I2.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.I4: - return new I4.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.I8: - return new Long.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R4: return new R4.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: return new R8.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.TS: - return new Long.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.DT: - return new Long.MinMaxAggregatorBySlot(ch, type, cursor, col, kind == ReplacementKind.Max); default: break; } @@ -130,22 +82,10 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.I1: - return new I1.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.I2: - return new I2.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.I4: - return new I4.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.I8: - return new Long.MeanAggregatorAcrossSlots(ch, type, cursor, col); case DataKind.R4: return new R4.MeanAggregatorAcrossSlots(ch, cursor, col); case DataKind.R8: return new R8.MeanAggregatorAcrossSlots(ch, cursor, col); - case DataKind.TS: - return new Long.MeanAggregatorAcrossSlots(ch, type, cursor, col); - case DataKind.DT: - return new Long.MeanAggregatorAcrossSlots(ch, type, cursor, col); default: break; } @@ -154,22 +94,10 @@ private static StatAggregator CreateStatAggregator(IChannel ch, ColumnType type, { switch (type.ItemType.RawKind) { - case DataKind.I1: - return new I1.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I2: - return new I2.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I4: - return new I4.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.I8: - return new Long.MinMaxAggregatorAcrossSlots(ch, type, cursor, col, kind == ReplacementKind.Max); case DataKind.R4: return new R4.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); case DataKind.R8: return new R8.MinMaxAggregatorAcrossSlots(ch, cursor, col, kind == ReplacementKind.Max); - case DataKind.TS: - return new Long.MinMaxAggregatorAcrossSlots(ch, type, cursor, col, kind == ReplacementKind.Max); - case DataKind.DT: - return new Long.MinMaxAggregatorAcrossSlots(ch, type, cursor, col, kind == ReplacementKind.Max); default: break; } @@ -503,17 +431,17 @@ private void AssertValid(long valMax) Contracts.Assert(_cna >= 0); } - public void Update(long val, long valMax) + public void Update(long? val, long valMax) { AssertValid(valMax); - Contracts.Assert(-valMax - 1 <= val && val <= valMax); + Contracts.Assert(!val.HasValue || -valMax <= val && val <= valMax); - if (val >= 0) + if (!val.HasValue) + _cna++; + else if (val >= 0) IntUtils.Add(ref _sumHi, ref _sumLo, (ulong)val); - else if (val >= -valMax) - IntUtils.Sub(ref _sumHi, ref _sumLo, (ulong)(-val)); else - _cna++; + IntUtils.Sub(ref _sumHi, ref _sumLo, (ulong)(-val)); AssertValid(valMax); } @@ -928,800 +856,5 @@ public override object GetStat() } } } - - private static class I1 - { - // Utilizes MeanStatInt for the mean aggregators of all IX types, TS, and DT. - - private const long MaxVal = sbyte.MaxValue; - - public sealed class MeanAggregatorOne : StatAggregator - { - public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessRow(ref DvInt1 val) - { - Stat.Update(val.RawValue, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt1)(sbyte)val; - } - } - - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots - { - public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessValue(ref DvInt1 val) - { - Stat.Update(val.RawValue, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt1)(sbyte)val; - } - } - - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot - { - public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, type, cursor, col) - { - } - - protected override void ProcessValue(ref DvInt1 val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(val.RawValue, MaxVal); - } - - public override object GetStat() - { - DvInt1[] stat = new DvInt1[Stat.Length]; - for (int slot = 0; slot < stat.Length; slot++) - { - long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - stat[slot] = (DvInt1)(sbyte)val; - } - return stat; - } - } - - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne - { - public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (sbyte)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref DvInt1 val) - { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt1.RawNA) - Stat = raw; - } - - protected override void ProcessValueMax(ref DvInt1 val) - { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; - } - - public override object GetStat() - { - return (DvInt1)Stat; - } - } - - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots - { - public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (sbyte)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref DvInt1 val) - { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt1.RawNA) - Stat = raw; - } - - protected override void ProcessValueMax(ref DvInt1 val) - { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; - } - - public override object GetStat() - { - // If sparsity occurred, fold in a zero. - if (ValueCount > (ulong)ValuesProcessed) - { - var def = default(DvInt1); - ProcValueDelegate(ref def); - } - return (DvInt1)Stat; - } - } - - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot - { - public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, type, cursor, col, returnMax) - { - sbyte bound = (sbyte)(ReturnMax ? -MaxVal : MaxVal); - for (int i = 0; i < Stat.Length; i++) - Stat[i] = bound; - } - - protected override void ProcessValueMin(ref DvInt1 val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw < Stat[slot] && raw != DvInt1.RawNA) - Stat[slot] = raw; - } - - protected override void ProcessValueMax(ref DvInt1 val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw > Stat[slot]) - Stat[slot] = raw; - } - - public override object GetStat() - { - DvInt1[] stat = new DvInt1[Stat.Length]; - // Account for defaults resulting from sparsity. - for (int slot = 0; slot < Stat.Length; slot++) - { - if (GetValuesProcessed(slot) < RowCount) - { - var def = default(DvInt1); - ProcValueDelegate(ref def, slot); - } - stat[slot] = (DvInt1)Stat[slot]; - } - return stat; - } - } - } - - private static class I2 - { - private const long MaxVal = short.MaxValue; - - public sealed class MeanAggregatorOne : StatAggregator - { - public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessRow(ref DvInt2 val) - { - Stat.Update(val.RawValue, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt2)(short)val; - } - } - - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots - { - public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessValue(ref DvInt2 val) - { - Stat.Update(val.RawValue, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt2)(short)val; - } - } - - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot - { - public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, type, cursor, col) - { - } - - protected override void ProcessValue(ref DvInt2 val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(val.RawValue, MaxVal); - } - - public override object GetStat() - { - DvInt2[] stat = new DvInt2[Stat.Length]; - for (int slot = 0; slot < stat.Length; slot++) - { - long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - stat[slot] = (DvInt2)(short)val; - } - return stat; - } - } - - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne - { - public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (short)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref DvInt2 val) - { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt2.RawNA) - Stat = raw; - } - - protected override void ProcessValueMax(ref DvInt2 val) - { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; - } - - public override object GetStat() - { - return (DvInt2)Stat; - } - } - - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots - { - public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (short)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref DvInt2 val) - { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt2.RawNA) - Stat = raw; - } - - protected override void ProcessValueMax(ref DvInt2 val) - { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; - } - - public override object GetStat() - { - // If sparsity occurred, fold in a zero. - if (ValueCount > (ulong)ValuesProcessed) - { - var def = default(DvInt2); - ProcValueDelegate(ref def); - } - return (DvInt2)Stat; - } - } - - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot - { - public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, type, cursor, col, returnMax) - { - short bound = (short)(ReturnMax ? -MaxVal : MaxVal); - for (int i = 0; i < Stat.Length; i++) - Stat[i] = bound; - } - - protected override void ProcessValueMin(ref DvInt2 val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw < Stat[slot] && raw != DvInt2.RawNA) - Stat[slot] = raw; - } - - protected override void ProcessValueMax(ref DvInt2 val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw > Stat[slot]) - Stat[slot] = raw; - } - - public override object GetStat() - { - DvInt2[] stat = new DvInt2[Stat.Length]; - // Account for defaults resulting from sparsity. - for (int slot = 0; slot < Stat.Length; slot++) - { - if (GetValuesProcessed(slot) < RowCount) - { - var def = default(DvInt2); - ProcValueDelegate(ref def, slot); - } - stat[slot] = (DvInt2)Stat[slot]; - } - return stat; - } - } - } - - private static class I4 - { - private const long MaxVal = int.MaxValue; - - public sealed class MeanAggregatorOne : StatAggregator - { - public MeanAggregatorOne(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessRow(ref DvInt4 val) - { - Stat.Update(val.RawValue, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt4)(int)val; - } - } - - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots - { - public MeanAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - } - - protected override void ProcessValue(ref DvInt4 val) - { - Stat.Update(val.RawValue, MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - return (DvInt4)(int)val; - } - } - - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot - { - public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, type, cursor, col) - { - } - - protected override void ProcessValue(ref DvInt4 val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(val.RawValue, MaxVal); - } - - public override object GetStat() - { - DvInt4[] stat = new DvInt4[Stat.Length]; - for (int slot = 0; slot < stat.Length; slot++) - { - long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - Ch.Assert(-MaxVal - 1 <= val && val <= MaxVal); - stat[slot] = (DvInt4)(int)val; - } - return stat; - } - } - - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne - { - public MinMaxAggregatorOne(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (int)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref DvInt4 val) - { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt4.RawNA) - Stat = raw; - } - - protected override void ProcessValueMax(ref DvInt4 val) - { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; - } - - public override object GetStat() - { - return (DvInt4)Stat; - } - } - - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots - { - public MinMaxAggregatorAcrossSlots(IChannel ch, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = (int)(ReturnMax ? -MaxVal : MaxVal); - } - - protected override void ProcessValueMin(ref DvInt4 val) - { - var raw = val.RawValue; - if (raw < Stat && raw != DvInt4.RawNA) - Stat = raw; - } - - protected override void ProcessValueMax(ref DvInt4 val) - { - var raw = val.RawValue; - if (raw > Stat) - Stat = raw; - } - - public override object GetStat() - { - // If sparsity occurred, fold in a zero. - if (ValueCount > (ulong)ValuesProcessed) - { - var def = default(DvInt4); - ProcValueDelegate(ref def); - } - return (DvInt4)Stat; - } - } - - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot - { - public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, type, cursor, col, returnMax) - { - int bound = (int)(ReturnMax ? -MaxVal : MaxVal); - for (int i = 0; i < Stat.Length; i++) - Stat[i] = bound; - } - - protected override void ProcessValueMin(ref DvInt4 val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw < Stat[slot] && raw != DvInt4.RawNA) - Stat[slot] = raw; - } - - protected override void ProcessValueMax(ref DvInt4 val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = val.RawValue; - if (raw > Stat[slot]) - Stat[slot] = raw; - } - - public override object GetStat() - { - DvInt4[] stat = new DvInt4[Stat.Length]; - // Account for defaults resulting from sparsity. - for (int slot = 0; slot < Stat.Length; slot++) - { - if (GetValuesProcessed(slot) < RowCount) - { - var def = default(DvInt4); - ProcValueDelegate(ref def, slot); - } - stat[slot] = (DvInt4)Stat[slot]; - } - return stat; - } - } - } - - private static class Long - { - private const long MaxVal = long.MaxValue; - - public sealed class MeanAggregatorOne : StatAggregator - { - // Converts between TItem and long. - private Converter _converter; - - public MeanAggregatorOne(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - _converter = CreateConverter(type); - } - - protected override void ProcessRow(ref TItem val) - { - Stat.Update(_converter.ToLong(val), MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, RowCount, MaxVal); - return _converter.FromLong(val); - } - } - - public sealed class MeanAggregatorAcrossSlots : StatAggregatorAcrossSlots - { - private Converter _converter; - - public MeanAggregatorAcrossSlots(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, cursor, col) - { - _converter = CreateConverter(type); - } - - protected override void ProcessValue(ref TItem val) - { - Stat.Update(_converter.ToLong(val), MaxVal); - } - - public override object GetStat() - { - long val = Stat.GetCurrentValue(Ch, ValueCount, MaxVal); - return _converter.FromLong(val); - } - } - - public sealed class MeanAggregatorBySlot : StatAggregatorBySlot - { - private Converter _converter; - - public MeanAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col) - : base(ch, type, cursor, col) - { - _converter = CreateConverter(type); - } - - protected override void ProcessValue(ref TItem val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - Stat[slot].Update(_converter.ToLong(val), MaxVal); - } - - public override object GetStat() - { - TItem[] stat = new TItem[Stat.Length]; - for (int slot = 0; slot < stat.Length; slot++) - { - long val = Stat[slot].GetCurrentValue(Ch, RowCount, MaxVal); - stat[slot] = _converter.FromLong(val); - } - return stat; - } - } - - public sealed class MinMaxAggregatorOne : MinMaxAggregatorOne - { - private Converter _converter; - - public MinMaxAggregatorOne(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = ReturnMax ? -MaxVal : MaxVal; - _converter = CreateConverter(type); - } - - protected override void ProcessValueMin(ref TItem val) - { - var raw = _converter.ToLong(val); - if (raw < Stat && -MaxVal <= raw) - Stat = raw; - } - - protected override void ProcessValueMax(ref TItem val) - { - var raw = _converter.ToLong(val); - if (raw > Stat) - Stat = raw; - } - - public override object GetStat() - { - return _converter.FromLong(Stat); - } - } - - public sealed class MinMaxAggregatorAcrossSlots : MinMaxAggregatorAcrossSlots - { - private Converter _converter; - - public MinMaxAggregatorAcrossSlots(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, cursor, col, returnMax) - { - Stat = ReturnMax ? -MaxVal : MaxVal; - _converter = CreateConverter(type); - } - - protected override void ProcessValueMin(ref TItem val) - { - var raw = _converter.ToLong(val); - if (raw < Stat && -MaxVal <= raw) - Stat = raw; - } - - protected override void ProcessValueMax(ref TItem val) - { - var raw = _converter.ToLong(val); - if (raw > Stat) - Stat = raw; - } - - public override object GetStat() - { - // If sparsity occurred, fold in a zero. - if (ValueCount > (ulong)ValuesProcessed) - { - TItem def = default(TItem); - ProcValueDelegate(ref def); - } - return _converter.FromLong(Stat); - } - } - - public sealed class MinMaxAggregatorBySlot : MinMaxAggregatorBySlot - { - private Converter _converter; - - public MinMaxAggregatorBySlot(IChannel ch, ColumnType type, IRowCursor cursor, int col, bool returnMax) - : base(ch, type, cursor, col, returnMax) - { - long bound = ReturnMax ? -MaxVal : MaxVal; - for (int i = 0; i < Stat.Length; i++) - Stat[i] = bound; - - _converter = CreateConverter(type); - } - - protected override void ProcessValueMin(ref TItem val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = _converter.ToLong(val); - if (raw < Stat[slot] && -MaxVal <= raw) - Stat[slot] = raw; - } - - protected override void ProcessValueMax(ref TItem val, int slot) - { - Ch.Assert(0 <= slot && slot < Stat.Length); - var raw = _converter.ToLong(val); - if (raw > Stat[slot]) - Stat[slot] = raw; - } - - public override object GetStat() - { - TItem[] stat = new TItem[Stat.Length]; - // Account for defaults resulting from sparsity. - for (int slot = 0; slot < Stat.Length; slot++) - { - if (GetValuesProcessed(slot) < RowCount) - { - var def = default(TItem); - ProcValueDelegate(ref def, slot); - } - stat[slot] = _converter.FromLong(Stat[slot]); - } - return stat; - } - } - - private static Converter CreateConverter(ColumnType type) - { - Contracts.AssertValue(type); - Contracts.Assert(typeof(TItem) == type.ItemType.RawType); - Converter converter; - if (type.ItemType.IsTimeSpan) - converter = new TSConverter(); - else if (type.ItemType.IsDateTime) - converter = new DTConverter(); - else - { - Contracts.Assert(type.ItemType.RawKind == DataKind.I8); - converter = new I8Converter(); - } - return (Converter)converter; - } - - /// - /// The base class for conversions from types to long. - /// - private abstract class Converter - { - } - - private abstract class Converter : Converter - { - public abstract long ToLong(T val); - public abstract T FromLong(long val); - } - - private sealed class I8Converter : Converter - { - public override long ToLong(DvInt8 val) - { - return val.RawValue; - } - - public override DvInt8 FromLong(long val) - { - Contracts.Assert(DvInt8.RawNA != val); - return (DvInt8)val; - } - } - - private sealed class TSConverter : Converter - { - public override long ToLong(DvTimeSpan val) - { - return val.Ticks.RawValue; - } - - public override DvTimeSpan FromLong(long val) - { - Contracts.Assert(DvInt8.RawNA != val); - return new DvTimeSpan(val); - } - } - - private sealed class DTConverter : Converter - { - public override long ToLong(DvDateTime val) - { - return val.Ticks.RawValue; - } - - public override DvDateTime FromLong(long val) - { - Contracts.Assert(0 <= val && val <= DvDateTime.MaxTicks); - return new DvDateTime(val); - } - } - } } } \ No newline at end of file diff --git a/src/Microsoft.ML/Data/TextLoader.cs b/src/Microsoft.ML/Data/TextLoader.cs index 330412185e..f7eec8ac52 100644 --- a/src/Microsoft.ML/Data/TextLoader.cs +++ b/src/Microsoft.ML/Data/TextLoader.cs @@ -160,19 +160,19 @@ private static bool TryGetDataKind(Type type, out DataKind kind) Contracts.AssertValue(type); // REVIEW: Make this more efficient. Should we have a global dictionary? - if (type == typeof(DvInt1) || type == typeof(sbyte)) + if (type == typeof(sbyte)) kind = DataKind.I1; else if (type == typeof(byte) || type == typeof(char)) kind = DataKind.U1; - else if (type == typeof(DvInt2) || type == typeof(short)) + else if (type == typeof(short)) kind = DataKind.I2; else if (type == typeof(ushort)) kind = DataKind.U2; - else if (type == typeof(DvInt4) || type == typeof(int)) + else if ( type == typeof(int)) kind = DataKind.I4; else if (type == typeof(uint)) kind = DataKind.U4; - else if (type == typeof(DvInt8) || type == typeof(long)) + else if (type == typeof(long)) kind = DataKind.I8; else if (type == typeof(ulong)) kind = DataKind.U8; diff --git a/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt b/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt index e7d128e400..8815a6b6ee 100644 --- a/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt +++ b/test/BaselineOutput/SingleDebug/Command/Datatypes-datatypes.txt @@ -14,6 +14,6 @@ bl i1 i2 i4 i8 ts dto dt tx 0 127 32767 2147483647 9223372036854775807 "2.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" foo 1 -127 -32767 -2147483647 -9223372036854775807 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz - "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" + -128 -32768 -2147483648 -9223372036854775808 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" 9 0:0 - + -128 -32768 -2147483648 -9223372036854775808 diff --git a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt deleted file mode 100644 index c7049cd12a..0000000000 --- a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Data.txt +++ /dev/null @@ -1,9 +0,0 @@ -#@ TextLoader{ -#@ header+ -#@ sep=tab -#@ col=foo:I4:0 -#@ col=bar:I4:1 -#@ } -foo bar -1 2 -1 diff --git a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt deleted file mode 100644 index 8fa619c171..0000000000 --- a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetNull-Schema.txt +++ /dev/null @@ -1,4 +0,0 @@ ----- ParquetLoader ---- -2 columns: - foo: I4 - bar: I4 diff --git a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt index af1e19e1cc..85a3d35b4b 100644 --- a/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt +++ b/test/BaselineOutput/SingleDebug/SavePipe/TestParquetPrimitiveDataTypes-Data.txt @@ -11,5 +11,5 @@ #@ col=string:TX:7 #@ } sbyte short int long bool DateTimeOffset Interval string - 1 "2018-09-01T19:53:18.2910000+00:00" "31.00:00:00.0010000" "" +-128 -32768 -2147483648 -9223372036854775808 1 "2018-09-01T19:53:18.2910000+00:00" "31.00:00:00.0010000" "" 127 32767 2147483647 9223372036854775807 0 "2018-09-01T19:53:18.3110000+00:00" "31.00:00:00.0010000" """""" diff --git a/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt b/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt index e7d128e400..8815a6b6ee 100644 --- a/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt +++ b/test/BaselineOutput/SingleRelease/Command/Datatypes-datatypes.txt @@ -14,6 +14,6 @@ bl i1 i2 i4 i8 ts dto dt tx 0 127 32767 2147483647 9223372036854775807 "2.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" foo 1 -127 -32767 -2147483647 -9223372036854775807 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" xyz - "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" + -128 -32768 -2147483648 -9223372036854775808 "7.00:00:00" "2008-11-30T00:00:00.0000000+00:00" "2013-08-05T00:00:00.0000000" 9 0:0 - + -128 -32768 -2147483648 -9223372036854775808 diff --git a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt deleted file mode 100644 index c7049cd12a..0000000000 --- a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Data.txt +++ /dev/null @@ -1,9 +0,0 @@ -#@ TextLoader{ -#@ header+ -#@ sep=tab -#@ col=foo:I4:0 -#@ col=bar:I4:1 -#@ } -foo bar -1 2 -1 diff --git a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt deleted file mode 100644 index 8fa619c171..0000000000 --- a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetNull-Schema.txt +++ /dev/null @@ -1,4 +0,0 @@ ----- ParquetLoader ---- -2 columns: - foo: I4 - bar: I4 diff --git a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt index af1e19e1cc..85a3d35b4b 100644 --- a/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt +++ b/test/BaselineOutput/SingleRelease/SavePipe/TestParquetPrimitiveDataTypes-Data.txt @@ -11,5 +11,5 @@ #@ col=string:TX:7 #@ } sbyte short int long bool DateTimeOffset Interval string - 1 "2018-09-01T19:53:18.2910000+00:00" "31.00:00:00.0010000" "" +-128 -32768 -2147483648 -9223372036854775808 1 "2018-09-01T19:53:18.2910000+00:00" "31.00:00:00.0010000" "" 127 32767 2147483647 9223372036854775807 0 "2018-09-01T19:53:18.3110000+00:00" "31.00:00:00.0010000" """""" diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs index 35859783ad..3fd6789acb 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/CoreBaseTestClass.cs @@ -153,19 +153,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.RawKind) { case DataKind.I1: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U1: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I2: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U2: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I4: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U4: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I8: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U8: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.R4: @@ -196,19 +196,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.ItemType.RawKind) { case DataKind.I1: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U1: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I2: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U2: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I4: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U4: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I8: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U8: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.R4: diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs new file mode 100644 index 0000000000..94a0459a35 --- /dev/null +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DataTypes.cs @@ -0,0 +1,325 @@ +using System; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.Conversion; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Runtime.RunTests +{ + public class DataTypesTest : TestDataViewBase + { + public DataTypesTest(ITestOutputHelper helper) + : base(helper) + { + } + + private readonly static Conversions _conv = Conversions.Instance; + + [Fact] + public void TXToSByte() + { + var mapper = GetMapper(); + + Assert.NotNull(mapper); + + //1. sbyte.MinValue in text to sbyte. + sbyte minValue = sbyte.MinValue; + sbyte maxValue = sbyte.MaxValue; + DvText src = new DvText(minValue.ToString()); + sbyte dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, minValue); + + //2. sbyte.MaxValue in text to sbyte. + src = new DvText(maxValue.ToString()); + dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, maxValue); + + //3. ERROR condition: sbyte.MinValue - 1 in text to sbyte. + src = new DvText((sbyte.MinValue - 1).ToString()); + dst = 0; + bool error = false; + try + { + mapper(ref src, ref dst); + } + catch(Exception ex) + { + Assert.Equal("Value could not be parsed from text to sbyte.", ex.Message); + error = true; + } + + Assert.True(error); + + //4. ERROR condition: sbyte.MaxValue + 1 in text to sbyte. + src = new DvText((sbyte.MaxValue + 1).ToString()); + dst = 0; + error = false; + try + { + mapper(ref src, ref dst); + } + catch(Exception ex) + { + Assert.Equal("Value could not be parsed from text to sbyte.", ex.Message); + error = true; + } + + Assert.True(error); + + //5. Empty string in text to sbyte. + src = default; + dst = -1; + mapper(ref src, ref dst); + Assert.Equal(default, dst); + + //6. Missing value in text to sbyte. + src = DvText.NA; + dst = -1; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Missing text value cannot be converted to integer type.", ex.Message); + error = true; + } + + Assert.True(error); + } + + [Fact] + public void TXToShort() + { + var mapper = GetMapper(); + + Assert.NotNull(mapper); + + //1. short.MinValue in text to short. + short minValue = short.MinValue; + short maxValue = short.MaxValue; + DvText src = new DvText(minValue.ToString()); + short dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, minValue); + + //2. short.MaxValue in text to short. + src = new DvText(maxValue.ToString()); + dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, maxValue); + + //3. ERROR condition: short.MinValue - 1 in text to short. + src = new DvText((minValue - 1).ToString()); + dst = 0; + bool error = false; + try + { + mapper(ref src, ref dst); + } + catch(Exception ex) + { + Assert.Equal("Value could not be parsed from text to short.", ex.Message); + error = true; + } + + Assert.True(error); + + //4. ERROR condition: short.MaxValue + 1 in text to short. + src = new DvText((maxValue + 1).ToString()); + dst = 0; + error = false; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Value could not be parsed from text to short.", ex.Message); + error = true; + } + + Assert.True(error); + + //5. Empty value in text to short. + src = default; + dst = -1; + mapper(ref src, ref dst); + Assert.Equal(default, dst); + + //6. Missing string in text to sbyte. + src = DvText.NA; + dst = -1; + error = false; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Missing text value cannot be converted to integer type.", ex.Message); + error = true; + } + + Assert.True(error); + } + + [Fact] + public void TXToInt() + { + var mapper = GetMapper(); + + Assert.NotNull(mapper); + + //1. int.MinValue in text to int. + int minValue = int.MinValue; + int maxValue = int.MaxValue; + DvText src = new DvText(minValue.ToString()); + int dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, minValue); + + //2. int.MaxValue in text to int. + src = new DvText(maxValue.ToString()); + dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, maxValue); + + //3. ERROR condition: int.MinValue - 1 in text to int. + src = new DvText(((long)minValue - 1).ToString()); + dst = 0; + bool error = false; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Value could not be parsed from text to int.", ex.Message); + error = true; + } + + Assert.True(error); + + //4. ERROR condition: int.MaxValue + 1 in text to int. + src = new DvText(((long)maxValue + 1).ToString()); + dst = 0; + error = false; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Value could not be parsed from text to int.", ex.Message); + error = true; + } + + Assert.True(error); + + //5. Empty value in text to int. + src = default; + dst = -1; + mapper(ref src, ref dst); + Assert.Equal(default, dst); + + //6. Missing string in text to sbyte. + src = DvText.NA; + dst = -1; + error = false; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Missing text value cannot be converted to integer type.", ex.Message); + error = true; + } + + Assert.True(error); + } + + [Fact] + public void TXToLong() + { + var mapper = GetMapper(); + + Assert.NotNull(mapper); + + //1. long.MinValue in text to long. + var minValue = long.MinValue; + var maxValue = long.MaxValue; + DvText src = new DvText(minValue.ToString()); + var dst = default(long); + mapper(ref src, ref dst); + Assert.Equal(dst, minValue); + + //2. long.MaxValue in text to long. + src = new DvText(maxValue.ToString()); + dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, maxValue); + + //3. long.MinValue - 1 in text to long. + src = new DvText(((long)minValue - 1).ToString()); + dst = 0; + mapper(ref src, ref dst); + Assert.Equal(dst, (long)minValue - 1); + + //4. ERROR condition: long.MaxValue + 1 in text to long. + src = new DvText(((ulong)maxValue + 1).ToString()); + dst = 0; + bool error = false; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Value could not be parsed from text to long.", ex.Message); + error = true; + } + + Assert.True(error); + + //5. Empty value in text to long. + src = default; + dst = -1; + mapper(ref src, ref dst); + Assert.Equal(default, dst); + + //6. Missing string in text to sbyte. + error = false; + src = DvText.NA; + dst = -1; + try + { + mapper(ref src, ref dst); + } + catch (Exception ex) + { + Assert.Equal("Missing text value cannot be converted to integer type.", ex.Message); + error = true; + } + + Assert.True(error); + } + + public ValueMapper GetMapper() + { + Assert.True(typeof(TSrc).TryGetDataKind(out DataKind srcDataKind)); + Assert.True(typeof(TDst).TryGetDataKind(out DataKind dstDataKind)); + + return Conversions.Instance.GetStandardConversion( + TextType.Instance, NumberType.FromKind(dstDataKind), out bool identity); + } + } +} + + diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs b/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs index a3f5d8231b..5df434e7ae 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/DvTypes.cs @@ -10,60 +10,6 @@ namespace Microsoft.ML.Runtime.RunTests { public sealed class DvTypeTests { - [Fact] - public void TestComparableDvInt4() - { - const int count = 100; - - var rand = RandomUtils.Create(42); - var values = new DvInt4[2 * count]; - for (int i = 0; i < count; i++) - { - var v = values[i] = rand.Next(); - values[values.Length - i - 1] = v; - } - - // Assign two NA's at random. - int iv1 = rand.Next(values.Length); - int iv2 = rand.Next(values.Length - 1); - if (iv2 >= iv1) - iv2++; - values[iv1] = DvInt4.NA; - values[iv2] = DvInt4.NA; - Array.Sort(values); - - Assert.True(values[0].IsNA); - Assert.True(values[1].IsNA); - Assert.True(!values[2].IsNA); - - Assert.True((values[0] == values[1]).IsNA); - Assert.True((values[0] != values[1]).IsNA); - Assert.True((values[0] <= values[1]).IsNA); - Assert.True(values[0].Equals(values[1])); - Assert.True(values[0].CompareTo(values[1]) == 0); - - Assert.True((values[1] == values[2]).IsNA); - Assert.True((values[1] != values[2]).IsNA); - Assert.True((values[1] <= values[2]).IsNA); - Assert.True(!values[1].Equals(values[2])); - Assert.True(values[1].CompareTo(values[2]) < 0); - - for (int i = 3; i < values.Length; i++) - { - DvBool eq = values[i - 1] == values[i]; - DvBool ne = values[i - 1] != values[i]; - DvBool le = values[i - 1] <= values[i]; - bool feq = values[i - 1].Equals(values[i]); - int cmp = values[i - 1].CompareTo(values[i]); - Assert.True(!eq.IsNA); - Assert.True(!ne.IsNA); - Assert.True(eq.IsTrue == ne.IsFalse); - Assert.True(le.IsTrue); - Assert.True(feq == eq.IsTrue); - Assert.True(cmp <= 0); - Assert.True(feq == (cmp == 0)); - } - } [Fact] public void TestComparableDvText() diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index ded58f50bb..a237d924a7 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -66,7 +66,7 @@ private IDataView GetBreastCancerDataviewWithTextColumns() { new TextLoader.Column("Label", type: null, 0), new TextLoader.Column("F1", DataKind.Text, 1), - new TextLoader.Column("F2", DataKind.I4, 2), + new TextLoader.Column("F2", DataKind.R4, 2), new TextLoader.Column("Rest", type: null, new [] { new TextLoader.Range(3, 9) }) } }, @@ -1968,7 +1968,6 @@ public void EntryPointConvert() { "Transforms.ColumnTypeConverter", "Transforms.ColumnTypeConverter", - "Transforms.ColumnTypeConverter", }, new[] { @@ -1984,7 +1983,7 @@ public void EntryPointConvert() { 'Name': 'Feat', 'Source': 'FT', - 'Type': 'I1' + 'Type': 'R4' }, { 'Name': 'Key1', @@ -1994,18 +1993,11 @@ public void EntryPointConvert() ]", @"'Column': [ { - 'Name': 'Ints', + 'Name': 'Doubles', 'Source': 'Feat' } ], - 'Type': 'I4'", - @"'Column': [ - { - 'Name': 'Floats', - 'Source': 'Ints' - } - ], - 'Type': 'Num'", + 'Type': 'R8'", }); } diff --git a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs index b5b53677ab..4ce87f070b 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs @@ -149,19 +149,19 @@ public void TransposerTest() ArrayDataViewBuilder builder = new ArrayDataViewBuilder(Env); // A is to check the splitting of a sparse-ish column. - var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (DvInt4)rgen.Next(), 50, 5, 10, 15); - dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. + var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (int)rgen.Next(), 50, 5, 10, 15); + dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. builder.AddColumn("A", NumberType.I4, dataA); // B is to check the splitting of a dense-ish column. builder.AddColumn("B", NumberType.R8, GenerateHelper(rowCount, 0.8, rgen, rgen.NextDouble, 50, 0, 25, 49)); // C is to just have some column we do nothing with. - builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (DvInt2)1, 30, 3, 10, 24)); + builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (short)1, 30, 3, 10, 24)); // D is to check some column we don't have to split because it's sufficiently small. builder.AddColumn("D", NumberType.R8, GenerateHelper(rowCount, 0.1, rgen, rgen.NextDouble, 3, 1)); // E is to check a sparse scalar column. builder.AddColumn("E", NumberType.U4, GenerateHelper(rowCount, 0.1, rgen, () => (uint)rgen.Next(int.MinValue, int.MaxValue))); // F is to check a dense-ish scalar column. - builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (DvInt4)rgen.Next())); + builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => rgen.Next())); IDataView view = builder.GetDataView(); @@ -182,11 +182,11 @@ public void TransposerTest() } // Check the contents Assert.Null(trans.TransposeSchema.GetSlotType(2)); // C check to see that it's not transposable. - TransposeCheckHelper(view, 0, trans); // A check. + TransposeCheckHelper(view, 0, trans); // A check. TransposeCheckHelper(view, 1, trans); // B check. TransposeCheckHelper(view, 3, trans); // D check. TransposeCheckHelper(view, 4, trans); // E check. - TransposeCheckHelper(view, 5, trans); // F check. + TransposeCheckHelper(view, 5, trans); // F check. } // Force save. Recheck columns that would have previously been passthrough columns. @@ -201,7 +201,7 @@ public void TransposerTest() Assert.Null(trans.TransposeSchema.GetSlotType(2)); TransposeCheckHelper(view, 3, trans); // D check. TransposeCheckHelper(view, 4, trans); // E check. - TransposeCheckHelper(view, 5, trans); // F check. + TransposeCheckHelper(view, 5, trans); // F check. } } @@ -214,19 +214,19 @@ public void TransposerSaverLoaderTest() ArrayDataViewBuilder builder = new ArrayDataViewBuilder(Env); // A is to check the splitting of a sparse-ish column. - var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (DvInt4)rgen.Next(), 50, 5, 10, 15); - dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. + var dataA = GenerateHelper(rowCount, 0.1, rgen, () => (int)rgen.Next(), 50, 5, 10, 15); + dataA[rowCount / 2] = new VBuffer(50, 0, null, null); // Coverage for the null vbuffer case. builder.AddColumn("A", NumberType.I4, dataA); // B is to check the splitting of a dense-ish column. builder.AddColumn("B", NumberType.R8, GenerateHelper(rowCount, 0.8, rgen, rgen.NextDouble, 50, 0, 25, 49)); // C is to just have some column we do nothing with. - builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (DvInt2)1, 30, 3, 10, 24)); + builder.AddColumn("C", NumberType.I2, GenerateHelper(rowCount, 0.1, rgen, () => (short)1, 30, 3, 10, 24)); // D is to check some column we don't have to split because it's sufficiently small. builder.AddColumn("D", NumberType.R8, GenerateHelper(rowCount, 0.1, rgen, rgen.NextDouble, 3, 1)); // E is to check a sparse scalar column. builder.AddColumn("E", NumberType.U4, GenerateHelper(rowCount, 0.1, rgen, () => (uint)rgen.Next(int.MinValue, int.MaxValue))); // F is to check a dense-ish scalar column. - builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (DvInt4)rgen.Next())); + builder.AddColumn("F", NumberType.I4, GenerateHelper(rowCount, 0.8, rgen, () => (int)rgen.Next())); IDataView view = builder.GetDataView(); @@ -241,12 +241,12 @@ public void TransposerSaverLoaderTest() // First check whether this as an IDataView yields the same values. CheckSameValues(view, loader); - TransposeCheckHelper(view, 0, loader); // A + TransposeCheckHelper(view, 0, loader); // A TransposeCheckHelper(view, 1, loader); // B - TransposeCheckHelper(view, 2, loader); // C + TransposeCheckHelper(view, 2, loader); // C TransposeCheckHelper(view, 3, loader); // D TransposeCheckHelper(view, 4, loader); // E - TransposeCheckHelper(view, 5, loader); // F + TransposeCheckHelper(view, 5, loader); // F Done(); } diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index 603d9dcd2f..e29407cb40 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -165,7 +165,7 @@ public void AssertStaticKeys() var col1 = RowColumnUtils.GetColumn("stay", new KeyType(DataKind.U4, 0, 3), ref value1, RowColumnUtils.GetRow(counted, meta1)); // Next the case where those values are ints. - var metaValues2 = new VBuffer(3, new DvInt4[] { 1, 2, 3, 4 }); + var metaValues2 = new VBuffer(3, new int[] { 1, 2, 3, 4 }); var meta2 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(NumberType.I4, 4), ref metaValues2); var value2 = new VBuffer(2, 0, null, null); var col2 = RowColumnUtils.GetColumn("awhile", new VectorType(new KeyType(DataKind.U1, 2, 4), 2), ref value2, RowColumnUtils.GetRow(counted, meta2)); diff --git a/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs b/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs index f5be433b3e..ace98c93f9 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/Parquet.cs @@ -33,7 +33,19 @@ public void TestParquetPrimitiveDataTypes() public void TestParquetNull() { string pathData = GetDataPath(@"Parquet", "test-null.parquet"); - TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }, forceDense: true); + bool exception = false; + try + { + TestCore(pathData, false, new[] { "loader=Parquet{bigIntDates=+}" }, forceDense: true); + } + catch (Exception ex) + { + Assert.Equal("Nullable object must have a value.", ex.Message); + exception = true; + } + + Assert.True(exception, "Test failed because control reached here without an expected exception for nullable values."); + Done(); } } diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index 4ab2f0e6e5..b31ca2a2ad 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -1010,19 +1010,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.RawKind) { case DataKind.I1: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U1: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I2: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U2: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I4: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U4: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.I8: - return GetComparerOne(r1, r2, col, (x, y) => x.RawValue == y.RawValue); + return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.U8: return GetComparerOne(r1, r2, col, (x, y) => x == y); case DataKind.R4: @@ -1056,19 +1056,19 @@ protected Func GetColumnComparer(IRow r1, IRow r2, int col, ColumnType typ switch (type.ItemType.RawKind) { case DataKind.I1: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U1: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I2: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U2: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I4: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U4: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.I8: - return GetComparerVec(r1, r2, col, size, (x, y) => x.RawValue == y.RawValue); + return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.U8: return GetComparerVec(r1, r2, col, size, (x, y) => x == y); case DataKind.R4: diff --git a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs index 08c9e17a28..37db8a26f4 100644 --- a/test/Microsoft.ML.TestFramework/TestSparseDataView.cs +++ b/test/Microsoft.ML.TestFramework/TestSparseDataView.cs @@ -4,6 +4,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; +using System; using Xunit; using Xunit.Abstractions; @@ -34,7 +35,7 @@ private class SparseExample public void SparseDataView() { GenericSparseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f }); - GenericSparseDataView(new DvInt4[] { 1, 2, 3 }, new DvInt4[] { 1, 10, 100 }); + GenericSparseDataView(new int[] { 1, 2, 3 }, new int[] { 1, 10, 100 }); GenericSparseDataView(new DvBool[] { true, true, true }, new DvBool[] { false, false, false }); GenericSparseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 }); GenericSparseDataView(new DvText[] { new DvText("a"), new DvText("b"), new DvText("c") }, @@ -76,7 +77,7 @@ private void GenericSparseDataView(T[] v1, T[] v2) public void DenseDataView() { GenericDenseDataView(new[] { 1f, 2f, 3f }, new[] { 1f, 10f, 100f }); - GenericDenseDataView(new DvInt4[] { 1, 2, 3 }, new DvInt4[] { 1, 10, 100 }); + GenericDenseDataView(new int[] { 1, 2, 3 }, new int[] { 1, 10, 100 }); GenericDenseDataView(new DvBool[] { true, true, true }, new DvBool[] { false, false, false }); GenericDenseDataView(new double[] { 1, 2, 3 }, new double[] { 1, 10, 100 }); GenericDenseDataView(new DvText[] { new DvText("a"), new DvText("b"), new DvText("c") }, diff --git a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs index 7690368e2f..96f3d3a6c7 100644 --- a/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs +++ b/test/Microsoft.ML.Tests/CollectionDataSourceTests.cs @@ -293,23 +293,7 @@ public class ConversionSimpleClass public ulong fuLong; public float fFloat; public double fDouble; - public bool fBool; - public string fString; - } - - public class ConversionNullalbeClass - { - public int? fInt; - public uint? fuInt; - public short? fShort; - public ushort? fuShort; - public sbyte? fsByte; - public byte? fByte; - public long? fLong; - public ulong? fuLong; - public float? fFloat; - public double? fDouble; - public bool? fBool; + public DvBool fBool; public string fString; } @@ -434,56 +418,6 @@ public void RoundTripConversionWithBasicTypes() new ConversionSimpleClass() }; - var dataNullable = new List - { - new ConversionNullalbeClass() - { - fInt = int.MaxValue - 1, - fuInt = uint.MaxValue - 1, - fBool = true, - fsByte = sbyte.MaxValue - 1, - fByte = byte.MaxValue - 1, - fDouble = double.MaxValue - 1, - fFloat = float.MaxValue - 1, - fLong = long.MaxValue - 1, - fuLong = ulong.MaxValue - 1, - fShort = short.MaxValue - 1, - fuShort = ushort.MaxValue - 1, - fString = "ha" - }, - new ConversionNullalbeClass() - { - fInt = int.MaxValue, - fuInt = uint.MaxValue, - fBool = true, - fsByte = sbyte.MaxValue, - fByte = byte.MaxValue, - fDouble = double.MaxValue, - fFloat = float.MaxValue, - fLong = long.MaxValue, - fuLong = ulong.MaxValue, - fShort = short.MaxValue, - fuShort = ushort.MaxValue, - fString = "ooh" - }, - new ConversionNullalbeClass() - { - fInt = int.MinValue + 1, - fuInt = uint.MinValue, - fBool = false, - fsByte = sbyte.MinValue + 1, - fByte = byte.MinValue, - fDouble = double.MinValue + 1, - fFloat = float.MinValue + 1, - fLong = long.MinValue + 1, - fuLong = ulong.MinValue, - fShort = short.MinValue + 1, - fuShort = ushort.MinValue, - fString = "" - }, - new ConversionNullalbeClass() - }; - using (var env = new TlcEnvironment()) { var dataView = ComponentCreation.CreateDataView(env, data); @@ -494,15 +428,6 @@ public void RoundTripConversionWithBasicTypes() Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); - - dataView = ComponentCreation.CreateDataView(env, dataNullable); - var enumeratorNullable = dataView.AsEnumerable(env, false).GetEnumerator(); - var originalNullableEnumerator = dataNullable.GetEnumerator(); - while (enumeratorNullable.MoveNext() && originalNullableEnumerator.MoveNext()) - { - Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullableEnumerator.Current)); - } - Assert.True(!enumeratorNullable.MoveNext() && !originalNullableEnumerator.MoveNext()); } } @@ -542,38 +467,6 @@ public void ConversionExceptionsBehavior() } } - public class ConversionLossMinValueClass - { - public int? fInt; - public long? fLong; - public short? fShort; - public sbyte? fSByte; - } - - [Fact] - public void ConversionMinValueToNullBehavior() - { - using (var env = new TlcEnvironment()) - { - - var data = new List - { - new ConversionLossMinValueClass() { fSByte = null, fInt = null, fLong = null, fShort = null }, - new ConversionLossMinValueClass() { fSByte = sbyte.MinValue, fInt = int.MinValue, fLong = long.MinValue, fShort = short.MinValue } - }; - foreach (var field in typeof(ConversionLossMinValueClass).GetFields()) - { - var dataView = ComponentCreation.CreateDataView(env, data); - var enumerator = dataView.AsEnumerable(env, false).GetEnumerator(); - while (enumerator.MoveNext()) - { - Assert.True(enumerator.Current.fInt == null && enumerator.Current.fLong == null && - enumerator.Current.fSByte == null && enumerator.Current.fShort == null); - } - } - } - } - public class ConversionLossMinValueClassProperties { private int? _fInt; @@ -771,23 +664,7 @@ public class ClassWithArrays public ulong[] fuLong; public float[] fFloat; public double[] fDouble; - public bool[] fBool; - } - - public class ClassWithNullableArrays - { - public string[] fString; - public int?[] fInt; - public uint?[] fuInt; - public short?[] fShort; - public ushort?[] fuShort; - public sbyte?[] fsByte; - public byte?[] fByte; - public long?[] fLong; - public ulong?[] fuLong; - public float?[] fFloat; - public double?[] fDouble; - public bool?[] fBool; + public DvBool[] fBool; } [Fact] @@ -801,7 +678,7 @@ public void RoundTripConversionWithArrays() fInt = new int[3] { 0, 1, 2 }, fFloat = new float[3] { -0.99f, 0f, 0.99f }, fString = new string[2] { "hola", "lola" }, - fBool = new bool[2] { true, false }, + fBool = new DvBool[2] { true, false }, fByte = new byte[3] { 0, 124, 255 }, fDouble = new double[3] { -1, 0, 1 }, fLong = new long[] { 0, 1, 2 }, @@ -815,27 +692,6 @@ public void RoundTripConversionWithArrays() new ClassWithArrays() }; - var nullableData = new List - { - new ClassWithNullableArrays() - { - fInt = new int?[3] { null, -1, 1 }, - fFloat = new float?[3] { -0.99f, null, 0.99f }, - fString = new string[2] { null, "" }, - fBool = new bool?[3] { true, null, false }, - fByte = new byte?[4] { 0, 125, null, 255 }, - fDouble = new double?[3] { -1, null, 1 }, - fLong = new long?[] { null, -1, 1 }, - fsByte = new sbyte?[3] { -127, 127, null }, - fShort = new short?[3] { 0, null, 32767 }, - fuInt = new uint?[4] { null, 42, 0, uint.MaxValue }, - fuLong = new ulong?[3] { ulong.MaxValue, null, 0 }, - fuShort = new ushort?[3] { 0, null, ushort.MaxValue } - }, - new ClassWithNullableArrays() { fInt = new int?[3] { -2, 1, 0 }, fFloat = new float?[3] { 0.99f, 0f, -0.99f }, fString = new string[2] { "lola", "hola" } }, - new ClassWithNullableArrays() - }; - using (var env = new TlcEnvironment()) { var dataView = ComponentCreation.CreateDataView(env, data); @@ -846,15 +702,6 @@ public void RoundTripConversionWithArrays() Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); - - var nullableDataView = ComponentCreation.CreateDataView(env, nullableData); - var enumeratorNullable = nullableDataView.AsEnumerable(env, false).GetEnumerator(); - var originalNullalbleEnumerator = nullableData.GetEnumerator(); - while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext()) - { - Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); - } - Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); } } public class ClassWithArrayProperties @@ -870,7 +717,7 @@ public class ClassWithArrayProperties private ulong[] _fuLong; private float[] _fFloat; private double[] _fDouble; - private bool[] _fBool; + private DvBool[] _fBool; public string[] StringProp { get { return _fString; } set { _fString = value; } } public int[] IntProp { get { return _fInt; } set { _fInt = value; } } public uint[] UIntProp { get { return _fuInt; } set { _fuInt = value; } } @@ -882,36 +729,7 @@ public class ClassWithArrayProperties public ulong[] ULongProp { get { return _fuLong; } set { _fuLong = value; } } public float[] FloatProp { get { return _fFloat; } set { _fFloat = value; } } public double[] DobuleProp { get { return _fDouble; } set { _fDouble = value; } } - public bool[] BoolProp { get { return _fBool; } set { _fBool = value; } } - } - - public class ClassWithNullableArrayProperties - { - private string[] _fString; - private int?[] _fInt; - private uint?[] _fuInt; - private short?[] _fShort; - private ushort?[] _fuShort; - private sbyte?[] _fsByte; - private byte?[] _fByte; - private long?[] _fLong; - private ulong?[] _fuLong; - private float?[] _fFloat; - private double?[] _fDouble; - private bool?[] _fBool; - - public string[] StringProp { get { return _fString; } set { _fString = value; } } - public int?[] IntProp { get { return _fInt; } set { _fInt = value; } } - public uint?[] UIntProp { get { return _fuInt; } set { _fuInt = value; } } - public short?[] ShortProp { get { return _fShort; } set { _fShort = value; } } - public ushort?[] UShortProp { get { return _fuShort; } set { _fuShort = value; } } - public sbyte?[] SByteProp { get { return _fsByte; } set { _fsByte = value; } } - public byte?[] ByteProp { get { return _fByte; } set { _fByte = value; } } - public long?[] LongProp { get { return _fLong; } set { _fLong = value; } } - public ulong?[] ULongProp { get { return _fuLong; } set { _fuLong = value; } } - public float?[] SingleProp { get { return _fFloat; } set { _fFloat = value; } } - public double?[] DoubleProp { get { return _fDouble; } set { _fDouble = value; } } - public bool?[] BoolProp { get { return _fBool; } set { _fBool = value; } } + public DvBool[] BoolProp { get { return _fBool; } set { _fBool = value; } } } [Fact] @@ -925,7 +743,7 @@ public void RoundTripConversionWithArrayPropertiess() IntProp = new int[3] { 0, 1, 2 }, FloatProp = new float[3] { -0.99f, 0f, 0.99f }, StringProp = new string[2] { "hola", "lola" }, - BoolProp = new bool[2] { true, false }, + BoolProp = new DvBool[2] { true, false }, ByteProp = new byte[3] { 0, 124, 255 }, DobuleProp = new double[3] { -1, 0, 1 }, LongProp = new long[] { 0, 1, 2 }, @@ -939,26 +757,6 @@ public void RoundTripConversionWithArrayPropertiess() new ClassWithArrayProperties() }; - var nullableData = new List - { - new ClassWithNullableArrayProperties() - { - IntProp = new int?[3] { null, -1, 1 }, - SingleProp = new float?[3] { -0.99f, null, 0.99f }, - StringProp = new string[2] { null, "" }, - BoolProp = new bool?[3] { true, null, false }, - ByteProp = new byte?[4] { 0, 125, null, 255 }, - DoubleProp = new double?[3] { -1, null, 1 }, - LongProp = new long?[] { null, -1, 1 }, - SByteProp = new sbyte?[3] { -127, 127, null }, - ShortProp = new short?[3] { 0, null, 32767 }, - UIntProp = new uint?[4] { null, 42, 0, uint.MaxValue }, - ULongProp = new ulong?[3] { ulong.MaxValue, null, 0 }, - UShortProp = new ushort?[3] { 0, null, ushort.MaxValue } - }, - new ClassWithNullableArrayProperties() { IntProp = new int?[3] { -2, 1, 0 }, SingleProp = new float?[3] { 0.99f, 0f, -0.99f }, StringProp = new string[2] { "lola", "hola" } }, - new ClassWithNullableArrayProperties() - }; using (var env = new TlcEnvironment()) { @@ -970,15 +768,6 @@ public void RoundTripConversionWithArrayPropertiess() Assert.True(CompareThroughReflection(enumeratorSimple.Current, originalEnumerator.Current)); } Assert.True(!enumeratorSimple.MoveNext() && !originalEnumerator.MoveNext()); - - var nullableDataView = ComponentCreation.CreateDataView(env, nullableData); - var enumeratorNullable = nullableDataView.AsEnumerable(env, false).GetEnumerator(); - var originalNullalbleEnumerator = nullableData.GetEnumerator(); - while (enumeratorNullable.MoveNext() && originalNullalbleEnumerator.MoveNext()) - { - Assert.True(CompareThroughReflection(enumeratorNullable.Current, originalNullalbleEnumerator.Current)); - } - Assert.True(!enumeratorNullable.MoveNext() && !originalNullalbleEnumerator.MoveNext()); } } diff --git a/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs index 60d2dc2fb1..dc39b588f8 100644 --- a/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/CopyColumnEstimatorTests.cs @@ -171,16 +171,16 @@ private void ValidateCopyColumnTransformer(IDataView result) { using (var cursor = result.GetRowCursor(x => true)) { - DvInt4 avalue = 0; - DvInt4 bvalue = 0; - DvInt4 dvalue = 0; - DvInt4 evalue = 0; - DvInt4 fvalue = 0; - var aGetter = cursor.GetGetter(0); - var bGetter = cursor.GetGetter(1); - var dGetter = cursor.GetGetter(3); - var eGetter = cursor.GetGetter(4); - var fGetter = cursor.GetGetter(5); + int avalue = 0; + int bvalue = 0; + int dvalue = 0; + int evalue = 0; + int fvalue = 0; + var aGetter = cursor.GetGetter(0); + var bGetter = cursor.GetGetter(1); + var dGetter = cursor.GetGetter(3); + var eGetter = cursor.GetGetter(4); + var fGetter = cursor.GetGetter(5); while (cursor.MoveNext()) { aGetter(ref avalue); diff --git a/test/Microsoft.ML.Tests/LearningPipelineTests.cs b/test/Microsoft.ML.Tests/LearningPipelineTests.cs index f19e3285d7..458cc01e71 100644 --- a/test/Microsoft.ML.Tests/LearningPipelineTests.cs +++ b/test/Microsoft.ML.Tests/LearningPipelineTests.cs @@ -119,7 +119,7 @@ public class BooleanLabelData public float[] Features; [ColumnName("Label")] - public bool Label; + public DvBool Label; } [Fact] @@ -137,36 +137,6 @@ public void BooleanLabelPipeline() var model = pipeline.Train(); } - public class NullableBooleanLabelData - { - [ColumnName("Features")] - [VectorType(2)] - public float[] Features; - - [ColumnName("Label")] - public bool? Label; - } - - [Fact] - public void NullableBooleanLabelPipeline() - { - var data = new NullableBooleanLabelData[2]; - data[0] = new NullableBooleanLabelData - { - Features = new float[] { 0.0f, 1.0f }, - Label = null - }; - data[1] = new NullableBooleanLabelData - { - Features = new float[] { 1.0f, 0.0f }, - Label = false - }; - var pipeline = new LearningPipeline(); - pipeline.Add(CollectionDataSource.Create(data)); - pipeline.Add(new FastForestBinaryClassifier()); - var model = pipeline.Train(); - } - [Fact] public void AppendPipeline() { diff --git a/test/Microsoft.ML.Tests/TextLoaderTests.cs b/test/Microsoft.ML.Tests/TextLoaderTests.cs index 50a7e55975..82a2b6192d 100644 --- a/test/Microsoft.ML.Tests/TextLoaderTests.cs +++ b/test/Microsoft.ML.Tests/TextLoaderTests.cs @@ -7,13 +7,131 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.TestFramework; using System; +using System.IO; using Xunit; using Xunit.Abstractions; namespace Microsoft.ML.EntryPoints.Tests { + public class TextLoaderTestPipe : TestDataPipeBase + { + public TextLoaderTestPipe(ITestOutputHelper output) + : base(output) + { + + } + + [Fact] + public void TestTextLoaderDataTypes() + { + string pathData = DeleteOutputPath("SavePipe", "TextInput.txt"); + File.WriteAllLines(pathData, new string[] { + string.Format("{0},{1},{2},{3}", sbyte.MinValue, short.MinValue, int.MinValue, long.MinValue), + string.Format("{0},{1},{2},{3}", sbyte.MaxValue, short.MaxValue, int.MaxValue, long.MaxValue), + "\"\",\"\",\"\",\"\"" + }); + + var data = TestCore(pathData, true, + new[] { + "loader=Text{col=DvInt1:I1:0 col=DvInt2:I2:1 col=DvInt4:I4:2 col=DvInt8:I8:3 sep=comma}", + }, logCurs: true); + + using (var cursor = data.GetRowCursor((a => true))) + { + var col1 = cursor.GetGetter(0); + var col2 = cursor.GetGetter(1); + var col3 = cursor.GetGetter(2); + var col4 = cursor.GetGetter(3); + + Assert.True(cursor.MoveNext()); + + sbyte[] sByteTargets = new sbyte[] { sbyte.MinValue, sbyte.MaxValue, default}; + short[] shortTargets = new short[] { short.MinValue, short.MaxValue, default }; + int[] intTargets = new int[] { int.MinValue, int.MaxValue, default }; + long[] longTargets = new long[] { long.MinValue, long.MaxValue, default }; + + int i = 0; + for (; i < sByteTargets.Length; i++) + { + sbyte sbyteValue = -1; + col1(ref sbyteValue); + Assert.Equal(sByteTargets[i], sbyteValue); + + short shortValue = -1; + col2(ref shortValue); + Assert.Equal(shortTargets[i], shortValue); + + int intValue = -1; + col3(ref intValue); + Assert.Equal(intTargets[i], intValue); + + long longValue = -1; + col4(ref longValue); + Assert.Equal(longTargets[i], longValue); + + if (i < sByteTargets.Length - 1) + Assert.True(cursor.MoveNext()); + else + Assert.False(cursor.MoveNext()); + } + + Assert.Equal(i, sByteTargets.Length); + } + } + + [Fact] + public void TestTextLoaderInvalidLongMin() + { + string pathData = DeleteOutputPath("SavePipe", "TextInput.txt"); + File.WriteAllLines(pathData, new string[] { + "-9223372036854775809" + + }); + + try + { + var data = TestCore(pathData, true, + new[] { + "loader=Text{col=DvInt8:I8:0 sep=comma}", + }, logCurs: true); + } + catch(Exception ex) + { + Assert.Equal("Value could not be parsed from text to long.", ex.Message); + return; + } + + Assert.True(false, "Test failed."); + } + + [Fact] + public void TestTextLoaderInvalidLongMax() + { + string pathData = DeleteOutputPath("SavePipe", "TextInput.txt"); + File.WriteAllLines(pathData, new string[] { + "9223372036854775808" + }); + + try + { + var data = TestCore(pathData, true, + new[] { + "loader=Text{col=DvInt8:I8:0 sep=comma}", + }, logCurs: true); + } + catch (Exception ex) + { + Assert.Equal("Value could not be parsed from text to long.", ex.Message); + return; + } + + Assert.True(false, "Test failed."); + } + } + public class TextLoaderTests : BaseTestClass { public TextLoaderTests(ITestOutputHelper output) @@ -21,7 +139,7 @@ public TextLoaderTests(ITestOutputHelper output) { } - + [Fact] public void ConstructorDoesntThrow() {