1
- # Copyright (c) 2020 , NVIDIA CORPORATION. All rights reserved.
1
+ # Copyright (c) 2021 , NVIDIA CORPORATION. All rights reserved.
2
2
#
3
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
4
# you may not use this file except in compliance with the License.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ """ Preprocess dataset and prepare it for training
16
+
17
+ Example usage:
18
+ $ python preprocess_data.py --input_dir ./src --output_dir ./dst
19
+ --vol_per_file 2
20
+
21
+ All arguments are listed under `python preprocess_data.py -h`.
22
+
23
+ """
15
24
import os
16
25
import argparse
17
26
from random import shuffle
20
29
import nibabel as nib
21
30
import tensorflow as tf
22
31
23
-
24
32
PARSER = argparse .ArgumentParser ()
25
33
26
34
PARSER .add_argument ('--input_dir' , '-i' ,
38
46
39
47
40
48
def load_features (path ):
49
+ """ Load features from Nifti
50
+
51
+ :param path: Path to dataset
52
+ :return: Loaded data
53
+ """
41
54
data = np .zeros ((240 , 240 , 155 , 4 ), dtype = np .uint8 )
42
55
name = os .path .basename (path )
43
56
for i , modality in enumerate (["_t1.nii.gz" , "_t1ce.nii.gz" , "_t2.nii.gz" , "_flair.nii.gz" ]):
44
- vol = load_single_nifti (os .path .join (path , name + modality )).astype (np .float32 )
57
+ vol = load_single_nifti (os .path .join (path , name + modality )).astype (np .float32 )
45
58
vol [vol > 0.85 * vol .max ()] = 0.85 * vol .max ()
46
59
vol = 255 * vol / vol .max ()
47
60
data [..., i ] = vol .astype (np .uint8 )
@@ -50,16 +63,37 @@ def load_features(path):
50
63
51
64
52
65
def load_segmentation (path ):
66
+ """ Load segmentations from Nifti
67
+
68
+ :param path: Path to dataset
69
+ :return: Loaded data
70
+ """
53
71
path = os .path .join (path , os .path .basename (path )) + "_seg.nii.gz"
54
72
return load_single_nifti (path ).astype (np .uint8 )
55
73
56
74
57
75
def load_single_nifti (path ):
76
+ """ Load Nifti file as numpy
77
+
78
+ :param path: Path to file
79
+ :return: Loaded data
80
+ """
58
81
data = nib .load (path ).get_fdata ().astype (np .int16 )
59
82
return np .transpose (data , (1 , 0 , 2 ))
60
83
61
84
62
- def write_to_file (features_list , labels_list , foreground_mean_list , foreground_std_list , output_dir , count ):
85
+ def write_to_file (features_list , labels_list , foreground_mean_list , foreground_std_list , output_dir , # pylint: disable=R0913
86
+ count ):
87
+ """ Dump numpy array to tfrecord
88
+
89
+ :param features_list: List of features
90
+ :param labels_list: List of labels
91
+ :param foreground_mean_list: List of means for each volume
92
+ :param foreground_std_list: List of std for each volume
93
+ :param output_dir: Directory where to write
94
+ :param count: Index of the record
95
+ :return:
96
+ """
63
97
output_filename = os .path .join (output_dir , "volume-{}.tfrecord" .format (count ))
64
98
filelist = list (zip (np .array (features_list ),
65
99
np .array (labels_list ),
@@ -69,17 +103,22 @@ def write_to_file(features_list, labels_list, foreground_mean_list, foreground_s
69
103
70
104
71
105
def np_to_tfrecords (filelist , output_filename ):
106
+ """ Convert numpy array to tfrecord
107
+
108
+ :param filelist: List of files
109
+ :param output_filename: Destination directory
110
+ """
72
111
writer = tf .io .TFRecordWriter (output_filename )
73
112
74
- for idx in range ( len ( filelist )) :
75
- X = filelist [ idx ] [0 ].flatten ().tostring ()
76
- Y = filelist [ idx ] [1 ].flatten ().tostring ()
77
- mean = filelist [ idx ] [2 ].astype (np .float32 ).flatten ()
78
- stdev = filelist [ idx ] [3 ].astype (np .float32 ).flatten ()
113
+ for file_item in filelist :
114
+ sample = file_item [0 ].flatten ().tostring ()
115
+ label = file_item [1 ].flatten ().tostring ()
116
+ mean = file_item [2 ].astype (np .float32 ).flatten ()
117
+ stdev = file_item [3 ].astype (np .float32 ).flatten ()
79
118
80
119
d_feature = {}
81
- d_feature ['X' ] = tf .train .Feature (bytes_list = tf .train .BytesList (value = [X ]))
82
- d_feature ['Y' ] = tf .train .Feature (bytes_list = tf .train .BytesList (value = [Y ]))
120
+ d_feature ['X' ] = tf .train .Feature (bytes_list = tf .train .BytesList (value = [sample ]))
121
+ d_feature ['Y' ] = tf .train .Feature (bytes_list = tf .train .BytesList (value = [label ]))
83
122
d_feature ['mean' ] = tf .train .Feature (float_list = tf .train .FloatList (value = mean ))
84
123
d_feature ['stdev' ] = tf .train .Feature (float_list = tf .train .FloatList (value = stdev ))
85
124
@@ -90,8 +129,9 @@ def np_to_tfrecords(filelist, output_filename):
90
129
writer .close ()
91
130
92
131
93
- def main ():
94
- # parse arguments
132
+ def main (): # pylint: disable=R0914
133
+ """ Starting point of the application"""
134
+
95
135
params = PARSER .parse_args ()
96
136
input_dir = params .input_dir
97
137
output_dir = params .output_dir
@@ -101,7 +141,7 @@ def main():
101
141
if params .single_data_dir :
102
142
patient_list .extend ([os .path .join (input_dir , folder ) for folder in os .listdir (input_dir )])
103
143
else :
104
- assert "HGG" in os .listdir (input_dir ) and "LGG" in os .listdir (input_dir ),\
144
+ assert "HGG" in os .listdir (input_dir ) and "LGG" in os .listdir (input_dir ), \
105
145
"Data directory has to contain folders named HGG and LGG. " \
106
146
"If you have a single folder with patient's data please set --single_data_dir flag"
107
147
path_hgg = os .path .join (input_dir , "HGG" )
@@ -135,7 +175,7 @@ def main():
135
175
foreground_mean_list .append (fg_mean )
136
176
foreground_std_list .append (fg_std )
137
177
138
- if (i + 1 ) % params .vol_per_file == 0 :
178
+ if (i + 1 ) % params .vol_per_file == 0 :
139
179
write_to_file (features_list , labels_list , foreground_mean_list , foreground_std_list , output_dir , count )
140
180
141
181
# Clear lists
@@ -158,4 +198,3 @@ def main():
158
198
159
199
if __name__ == '__main__' :
160
200
main ()
161
-
0 commit comments