Skip to content

Internalize IDataTransform #2509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/Microsoft.ML.Data/DataView/RowToRowMapperTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ internal interface IRowMapper : ICanSaveModel
/// </summary>
ITransformer GetTransformer();
}

public delegate void SignatureLoadRowMapper(ModelLoadContext ctx, Schema schema);
[BestFriend]
internal delegate void SignatureLoadRowMapper(ModelLoadContext ctx, Schema schema);

/// <summary>
/// This class is a transform that can add any number of output columns, that depend on any number of input columns.
/// It does so with the help of an <see cref="IRowMapper"/>, that is given a schema in its constructor, and has methods
/// to get the dependencies on input columns and the getters for the output columns, given an active set of output columns.
/// </summary>
public sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper,
[BestFriend]
internal sealed class RowToRowMapperTransform : RowToRowTransformBase, IRowToRowMapper,
ITransformCanSaveOnnx, ITransformCanSavePfa, ITransformTemplate
{
private readonly IRowMapper _mapper;
Expand Down Expand Up @@ -92,8 +93,7 @@ private static VersionInfo GetVersionInfo()

bool ICanSavePfa.CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false;

[BestFriend]
internal RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func<Schema, IRowMapper> mapperFactory)
public RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapper mapper, Func<Schema, IRowMapper> mapperFactory)
: base(env, RegistrationName, input)
{
Contracts.CheckValue(mapper, nameof(mapper));
Expand All @@ -103,8 +103,7 @@ internal RowToRowMapperTransform(IHostEnvironment env, IDataView input, IRowMapp
_bindings = new ColumnBindings(input.Schema, mapper.GetOutputColumns());
}

[BestFriend]
internal static Schema GetOutputSchema(Schema inputSchema, IRowMapper mapper)
public static Schema GetOutputSchema(Schema inputSchema, IRowMapper mapper)
{
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Contracts.CheckValue(mapper, nameof(mapper));
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

namespace Microsoft.ML.Transforms
{
public sealed class LabelConvertTransform : OneToOneTransformBase
[BestFriend]
internal sealed class LabelConvertTransform : OneToOneTransformBase
{
public sealed class Column : OneToOneColumn
{
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ public Schema GetOutputSchema(Schema inputSchema)

public IDataView Transform(IDataView input) => MakeDataTransform(input);

protected RowToRowMapperTransform MakeDataTransform(IDataView input)
[BestFriend]
private protected RowToRowMapperTransform MakeDataTransform(IDataView input)
{
Host.CheckValue(input, nameof(input));
return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema), MakeRowMapper);
Expand Down
12 changes: 8 additions & 4 deletions src/Microsoft.ML.Data/Transforms/TransformBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ namespace Microsoft.ML.Data
/// <summary>
/// Base class for transforms.
/// </summary>
public abstract class TransformBase : IDataTransform
[BestFriend]
internal abstract class TransformBase : IDataTransform
{
protected readonly IHost Host;

Expand Down Expand Up @@ -104,7 +105,8 @@ public RowCursor GetRowCursor(IEnumerable<Schema.Column> columnsNeeded, Random r
/// <summary>
/// Base class for transforms that map single input row to single output row.
/// </summary>
public abstract class RowToRowTransformBase : TransformBase
[BestFriend]
internal abstract class RowToRowTransformBase : TransformBase
{
protected RowToRowTransformBase(IHostEnvironment env, string name, IDataView input)
: base(env, name, input)
Expand Down Expand Up @@ -150,7 +152,8 @@ void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
}
}

public abstract class RowToRowMapperTransformBase : RowToRowTransformBase, IRowToRowMapper
[BestFriend]
internal abstract class RowToRowMapperTransformBase : RowToRowTransformBase, IRowToRowMapper
{
protected RowToRowMapperTransformBase(IHostEnvironment env, string name, IDataView input)
: base(env, name, input)
Expand Down Expand Up @@ -244,7 +247,8 @@ public override bool IsColumnActive(int col)
/// on multiple input columns.
/// This class provides the implementation of ISchema and IRowCursor.
/// </summary>
public abstract class OneToOneTransformBase : RowToRowMapperTransformBase, ITransposeDataView, ITransformCanSavePfa,
[BestFriend]
internal abstract class OneToOneTransformBase : RowToRowMapperTransformBase, ITransposeDataView, ITransformCanSavePfa,
ITransformCanSaveOnnx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few [BestFriend] in this class, if you are doing another iteration, maybe we could make those public and remove the attribute. Not a requirement though.

{
/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.ImageAnalytics/VectorToImageTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace Microsoft.ML.ImageAnalytics
/// <summary>
/// Transform which takes one or many columns with vectors in them and transform them to <see cref="ImageType"/> representation.
/// </summary>
public sealed class VectorToImageTransform : OneToOneTransformBase
internal sealed class VectorToImageTransform : OneToOneTransformBase
{
public class Column : OneToOneColumn
{
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Transforms/HashJoiningTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace Microsoft.ML.Transforms.Conversions
/// column there is an option to specify which slots should be hashed together into one output slot.
/// This transform can be applied either to single valued columns or to known length vector columns.
/// </summary>
public sealed class HashJoiningTransform : OneToOneTransformBase
[BestFriend]
internal sealed class HashJoiningTransform : OneToOneTransformBase
{
public const int NumBitsMin = 1;
public const int NumBitsLim = 32;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace Microsoft.ML.Transforms
{
public sealed class MissingValueIndicatorTransform : OneToOneTransformBase
internal sealed class MissingValueIndicatorTransform : OneToOneTransformBase
{
public sealed class Column : OneToOneColumn
{
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Transforms/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Core.Tests" + PublicKey.TestValue)]

[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.EntryPoints" + PublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Runtime.NeuralNetworks" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.RServerScoring.NeuralNetworks" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Runtime.TextAnalytics" + InternalPublicKey.Value)]
Expand Down