4
4
5
5
using System ;
6
6
using System . IO ;
7
+ using System . IO . Compression ;
7
8
using System . Linq ;
8
9
using Microsoft . ML . Runtime ;
9
10
using Microsoft . ML . Runtime . CommandLine ;
@@ -41,6 +42,8 @@ internal sealed class TensorFlowMapper : IRowMapper
41
42
private readonly bool [ ] _isVectorInput ;
42
43
private readonly TFShape [ ] _tfInputShapes ;
43
44
private readonly TFDataType [ ] _tfInputTypes ;
45
+ private readonly bool _isFrozen ;
46
+ private readonly string _exportDir ;
44
47
45
48
private readonly string _outputColName ;
46
49
private readonly ColumnType _outputColType ;
@@ -66,7 +69,7 @@ public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, byte[] modelB
66
69
_host . CheckNonEmpty ( modelBytes , nameof ( modelBytes ) ) ;
67
70
_host . CheckNonEmpty ( inputColNames , nameof ( inputColNames ) ) ;
68
71
_host . CheckNonEmpty ( outputColName , nameof ( outputColName ) ) ;
69
-
72
+ _isFrozen = true ;
70
73
_session = LoadTFSession ( modelBytes , null ) ;
71
74
_host . CheckValue ( _session . Graph [ outputColName ] , nameof ( outputColName ) , "Output does not exist in the model" ) ;
72
75
_host . Check ( inputColNames . All ( name => _session . Graph [ name ] != null ) , "One of the input does not exist in the model" ) ;
@@ -83,6 +86,8 @@ public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, string export
83
86
_host . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
84
87
_host . CheckNonEmpty ( inputColNames , nameof ( inputColNames ) ) ;
85
88
_host . CheckNonEmpty ( outputColName , nameof ( outputColName ) ) ;
89
+ _isFrozen = false ;
90
+ _exportDir = exportDir ;
86
91
87
92
_session = LoadTFSession ( exportDir ) ;
88
93
_host . CheckValue ( _session . Graph [ outputColName ] , nameof ( outputColName ) , "Output does not exist in the model" ) ;
@@ -99,41 +104,92 @@ public static TensorFlowMapper Create(IHostEnvironment env, ModelLoadContext ctx
99
104
env . CheckValue ( ctx , nameof ( ctx ) ) ;
100
105
ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
101
106
102
- var numInputs = ctx . Reader . ReadInt32 ( ) ;
103
- Contracts . CheckDecode ( numInputs > 0 ) ;
107
+ var isFrozen = ctx . Reader . ReadInt32 ( ) ;
108
+ if ( isFrozen == 1 )
109
+ {
110
+ var numInputs = ctx . Reader . ReadInt32 ( ) ;
111
+ Contracts . CheckDecode ( numInputs > 0 ) ;
112
+
113
+ string [ ] source = new string [ numInputs ] ;
114
+ for ( int j = 0 ; j < source . Length ; j ++ )
115
+ source [ j ] = ctx . LoadNonEmptyString ( ) ;
116
+
117
+ byte [ ] data = null ;
118
+ if ( ! ctx . TryLoadBinaryStream ( "TFModel" , r => data = r . ReadByteArray ( ) ) )
119
+ throw env . ExceptDecode ( ) ;
120
+
121
+ var outputColName = ctx . LoadNonEmptyString ( ) ;
122
+
123
+ return new TensorFlowMapper ( env , schema , data , source , outputColName ) ;
124
+ }
125
+ else
126
+ {
127
+ var numInputs = ctx . Reader . ReadInt32 ( ) ;
128
+ Contracts . CheckDecode ( numInputs > 0 ) ;
129
+
130
+ string [ ] source = new string [ numInputs ] ;
131
+ for ( int j = 0 ; j < source . Length ; j ++ )
132
+ source [ j ] = ctx . LoadNonEmptyString ( ) ;
133
+
134
+ // Load model binary
135
+ byte [ ] tfFilesBin = null ;
136
+ var load = ctx . TryLoadBinaryStream ( "TFSavedModel" , br => tfFilesBin = br . ReadByteArray ( ) ) ;
104
137
105
- string [ ] source = new string [ numInputs ] ;
106
- for ( int j = 0 ; j < source . Length ; j ++ )
107
- source [ j ] = ctx . LoadNonEmptyString ( ) ;
138
+ var tempDirName = Path . GetFullPath ( Path . Combine ( Path . GetTempPath ( ) , "_MLNET_TFTransform_" + Guid . NewGuid ( ) ) ) ;
139
+ var tempDir = Directory . CreateDirectory ( tempDirName ) ;
140
+ var tfZipFilePath = Path . Combine ( tempDir . FullName , "tf_savedmodel.zip" ) ;
108
141
109
- byte [ ] data = null ;
110
- if ( ! ctx . TryLoadBinaryStream ( "TFModel" , r => data = r . ReadByteArray ( ) ) )
111
- throw env . ExceptDecode ( ) ;
142
+ File . WriteAllBytes ( tfZipFilePath , tfFilesBin ) ;
143
+ ZipFile . ExtractToDirectory ( tfZipFilePath , Path . Combine ( tempDir . FullName , "tf_savedmodel" ) ) ;
112
144
113
- var outputColName = ctx . LoadNonEmptyString ( ) ;
145
+ var outputColName = ctx . LoadNonEmptyString ( ) ;
114
146
115
- return new TensorFlowMapper ( env , schema , data , source , outputColName ) ;
147
+ return new TensorFlowMapper ( env , schema , Path . Combine ( tempDir . FullName , "tf_savedmodel" ) , source , outputColName ) ;
148
+ }
116
149
}
117
150
118
151
public void Save ( ModelSaveContext ctx )
119
152
{
120
153
_host . AssertValue ( ctx ) ;
121
154
ctx . CheckAtModel ( ) ;
122
155
ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
156
+ ctx . Writer . Write ( _isFrozen ? 1 : 0 ) ;
157
+ if ( _isFrozen )
158
+ {
159
+ var buffer = new TFBuffer ( ) ;
160
+ _session . Graph . ToGraphDef ( buffer ) ;
123
161
124
- var buffer = new TFBuffer ( ) ;
125
- _session . Graph . ToGraphDef ( buffer ) ;
126
-
127
- ctx . SaveBinaryStream ( "TFModel" , w =>
162
+ ctx . SaveBinaryStream ( "TFModel" , w =>
163
+ {
164
+ w . WriteByteArray ( buffer . ToArray ( ) ) ;
165
+ } ) ;
166
+ Contracts . AssertNonEmpty ( _inputColNames ) ;
167
+ ctx . Writer . Write ( _inputColNames . Length ) ;
168
+ foreach ( var colName in _inputColNames )
169
+ ctx . SaveNonEmptyString ( colName ) ;
170
+
171
+ ctx . SaveNonEmptyString ( _outputColName ) ;
172
+ }
173
+ else
128
174
{
129
- w . WriteByteArray ( buffer . ToArray ( ) ) ;
130
- } ) ;
131
- Contracts . AssertNonEmpty ( _inputColNames ) ;
132
- ctx . Writer . Write ( _inputColNames . Length ) ;
133
- foreach ( var colName in _inputColNames )
134
- ctx . SaveNonEmptyString ( colName ) ;
135
-
136
- ctx . SaveNonEmptyString ( _outputColName ) ;
175
+ var tempDirName = Path . GetFullPath ( Path . Combine ( Path . GetTempPath ( ) , "_MLNET_TFTransform_" + Guid . NewGuid ( ) ) ) ;
176
+ var tempDir = Directory . CreateDirectory ( tempDirName ) ;
177
+ var tfZipFilePath = Path . Combine ( tempDir . FullName , "tf_savedmodel.zip" ) ;
178
+
179
+ ZipFile . CreateFromDirectory ( _exportDir , tfZipFilePath , CompressionLevel . Fastest , false ) ;
180
+ byte [ ] byteArray = File . ReadAllBytes ( tfZipFilePath ) ;
181
+ ctx . SaveBinaryStream ( "TFSavedModel" , w =>
182
+ {
183
+ w . WriteByteArray ( byteArray ) ;
184
+ } ) ;
185
+
186
+ Contracts . AssertNonEmpty ( _inputColNames ) ;
187
+ ctx . Writer . Write ( _inputColNames . Length ) ;
188
+ foreach ( var colName in _inputColNames )
189
+ ctx . SaveNonEmptyString ( colName ) ;
190
+
191
+ ctx . SaveNonEmptyString ( _outputColName ) ;
192
+ }
137
193
}
138
194
139
195
private TFSession LoadTFSession ( byte [ ] modelBytes , string modelArg )
0 commit comments