|
2 | 2 | // The .NET Foundation licenses this file to you under the MIT license.
|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
| 5 | +using System; |
| 6 | +using System.Collections.Generic; |
5 | 7 | using Microsoft.ML.Data;
|
6 | 8 | using Microsoft.ML.Internal.Utilities;
|
7 | 9 | using Microsoft.ML.Model;
|
@@ -91,6 +93,40 @@ public void SetLabelsDistribution(double[] labelsDistribution, double[] weights)
|
91 | 93 | _instanceWeights = weights;
|
92 | 94 | }
|
93 | 95 |
|
| 96 | + internal IReadOnlyList<IReadOnlyList<double>> ExtractLeafSamples() |
| 97 | + { |
| 98 | + var sampledLabelsBuffer = new List<List<double>>(); |
| 99 | + var sampleCountPerLeaf = _labelsDistribution.Length / NumLeaves; |
| 100 | + for (int i = 0; i < NumLeaves; ++i) |
| 101 | + { |
| 102 | + var samplesPerLeaf = new List<double>(); |
| 103 | + var weightsPerLeaf = new List<double>(); |
| 104 | + for (int j = 0; j < sampleCountPerLeaf; ++j) |
| 105 | + samplesPerLeaf.Add(_labelsDistribution[i * sampleCountPerLeaf + j]); |
| 106 | + sampledLabelsBuffer.Add(samplesPerLeaf); |
| 107 | + } |
| 108 | + return sampledLabelsBuffer; |
| 109 | + } |
| 110 | + |
| 111 | + internal IReadOnlyList<IReadOnlyList<double>> ExtractLeafSampleWeights() |
| 112 | + { |
| 113 | + var labelWeightsBuffer = new List<List<double>>(); |
| 114 | + var sampleCountPerLeaf = _labelsDistribution.Length / NumLeaves; |
| 115 | + for (int i = 0; i < NumLeaves; ++i) |
| 116 | + { |
| 117 | + var weightsPerLeaf = new List<double>(); |
| 118 | + for (int j = 0; j < sampleCountPerLeaf; ++j) |
| 119 | + { |
| 120 | + if (_instanceWeights != null) |
| 121 | + weightsPerLeaf.Add(_instanceWeights[i * sampleCountPerLeaf + j]); |
| 122 | + else |
| 123 | + weightsPerLeaf.Add(1.0); |
| 124 | + } |
| 125 | + labelWeightsBuffer.Add(weightsPerLeaf); |
| 126 | + } |
| 127 | + return labelWeightsBuffer; |
| 128 | + } |
| 129 | + |
94 | 130 | public override int SizeInBytes()
|
95 | 131 | {
|
96 | 132 | return base.SizeInBytes() + _labelsDistribution.SizeInBytes() + (_instanceWeights != null ? _instanceWeights.SizeInBytes() : 0);
|
|
0 commit comments