|
1 |
| -import Batcher |
| 1 | +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
2 | 15 | import Foundation
|
| 16 | +import TensorFlow |
| 17 | + |
| 18 | +public struct COCODataset<Entropy: RandomNumberGenerator> { |
| 19 | + /// Type of the collection of non-collated batches. |
| 20 | + public typealias Batches = Slices<Sampling<[ObjectDetectionExample], ArraySlice<Int>>> |
| 21 | + /// The type of the training data, represented as a sequence of epochs, which |
| 22 | + /// are collection of batches. |
| 23 | + public typealias Training = LazyMapSequence< |
| 24 | + TrainingEpochs<[ObjectDetectionExample], Entropy>, |
| 25 | + LazyMapSequence<Batches, [ObjectDetectionExample]> |
| 26 | + > |
| 27 | + /// The type of the validation data, represented as a collection of batches. |
| 28 | + public typealias Validation = LazyMapSequence<Slices<[ObjectDetectionExample]>, [ObjectDetectionExample]> |
| 29 | + /// The training epochs. |
| 30 | + public let training: Training |
| 31 | + /// The validation batches. |
| 32 | + public let validation: Validation |
| 33 | + |
| 34 | + /// Creates an instance with `batchSize` on `device` using `remoteBinaryArchiveLocation`. |
| 35 | + /// |
| 36 | + /// - Parameters: |
| 37 | + /// - training: The COCO metadata for the training data. |
| 38 | + /// - validation: The COCO metadata for the validation data. |
| 39 | + /// - includeMasks: Whether to include the segmentation masks when loading the dataset. |
| 40 | + /// - batchSize: Number of images provided per batch. |
| 41 | + /// - entropy: A source of randomness used to shuffle sample ordering. It |
| 42 | + /// will be stored in `self`, so if it is only pseudorandom and has value |
| 43 | + /// semantics, the sequence of epochs is deterministic and not dependent |
| 44 | + /// on other operations. |
| 45 | + /// - device: The Device on which resulting Tensors from this dataset will be placed, as well |
| 46 | + /// as where the latter stages of any conversion calculations will be performed. |
| 47 | + public init( |
| 48 | + training: COCO, validation: COCO, includeMasks: Bool, batchSize: Int, |
| 49 | + entropy: Entropy, device: Device, |
| 50 | + transform: @escaping (ObjectDetectionExample) -> [ObjectDetectionExample] |
| 51 | + ) { |
| 52 | + let trainingSamples = loadCOCOExamples( |
| 53 | + from: training, |
| 54 | + includeMasks: includeMasks, |
| 55 | + batchSize: batchSize) |
3 | 56 |
|
4 |
| -public struct COCODataset: ObjectDetectionDataset { |
5 |
| - public typealias SourceDataSet = [ObjectDetectionExample] |
6 |
| - public let trainingExamples: SourceDataSet |
7 |
| - public let training: Batcher<SourceDataSet> |
8 |
| - public let testExamples: SourceDataSet |
9 |
| - public let test: Batcher<SourceDataSet> |
| 57 | + self.training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy) |
| 58 | + .lazy.map { (batches: Batches) -> LazyMapSequence<Batches, [ObjectDetectionExample]> in |
| 59 | + return batches.lazy.map { |
| 60 | + makeBatch(samples: $0, device: device, transform: transform) |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + let validationSamples = loadCOCOExamples( |
| 65 | + from: validation, |
| 66 | + includeMasks: includeMasks, |
| 67 | + batchSize: batchSize) |
10 | 68 |
|
11 |
| - public init( |
12 |
| - training: COCO, test: COCO, |
13 |
| - includeMasks: Bool, batchSize: Int, numWorkers: Int |
14 |
| - ) { |
15 |
| - self.trainingExamples = |
16 |
| - loadCOCOExamples( |
17 |
| - from: training, |
18 |
| - includeMasks: includeMasks, |
19 |
| - batchSize: batchSize, |
20 |
| - numWorkers: numWorkers) |
21 |
| - self.training = |
22 |
| - Batcher( |
23 |
| - on: trainingExamples, |
24 |
| - batchSize: batchSize, |
25 |
| - numWorkers: numWorkers, |
26 |
| - shuffle: true) |
27 |
| - self.testExamples = |
28 |
| - loadCOCOExamples( |
29 |
| - from: test, |
30 |
| - includeMasks: includeMasks, |
31 |
| - batchSize: batchSize, |
32 |
| - numWorkers: numWorkers) |
33 |
| - self.test = |
34 |
| - Batcher( |
35 |
| - on: testExamples, |
36 |
| - batchSize: batchSize, |
37 |
| - numWorkers: numWorkers, |
38 |
| - shuffle: false) |
| 69 | + self.validation = validationSamples.inBatches(of: batchSize).lazy.map { |
| 70 | + makeBatch(samples: $0, device: device, transform: transform) |
39 | 71 | }
|
| 72 | + } |
| 73 | + |
| 74 | + public static func identity(_ example: ObjectDetectionExample) -> [ObjectDetectionExample] { |
| 75 | + return [example] |
| 76 | + } |
40 | 77 | }
|
41 | 78 |
|
42 |
| -func loadCOCOExamples(from coco: COCO, includeMasks: Bool, batchSize: Int, numWorkers: Int) |
| 79 | +extension COCODataset: ObjectDetectionData where Entropy == SystemRandomNumberGenerator { |
| 80 | + /// Creates an instance with `batchSize`, using the SystemRandomNumberGenerator. |
| 81 | + public init( |
| 82 | + training: COCO, validation: COCO, includeMasks: Bool, batchSize: Int, |
| 83 | + on device: Device = Device.default, |
| 84 | + transform: @escaping (ObjectDetectionExample) -> [ObjectDetectionExample] = COCODataset.identity |
| 85 | + ) { |
| 86 | + self.init( |
| 87 | + training: training, validation: validation, includeMasks: includeMasks, batchSize: batchSize, |
| 88 | + entropy: SystemRandomNumberGenerator(), device: device, transform: transform) |
| 89 | + } |
| 90 | +} |
| 91 | + |
| 92 | + |
| 93 | +func loadCOCOExamples(from coco: COCO, includeMasks: Bool, batchSize: Int) |
43 | 94 | -> [ObjectDetectionExample]
|
44 | 95 | {
|
45 | 96 | let images = coco.metadata["images"] as! [COCO.Image]
|
46 | 97 | let batchCount: Int = images.count / batchSize + 1
|
47 |
| - let n = min(numWorkers, batchCount) |
48 | 98 | let batches = Array(0..<batchCount)
|
49 |
| - let examples: [[ObjectDetectionExample]] = batches._concurrentMap(nthreads: n) { batchIdx in |
| 99 | + let examples: [[ObjectDetectionExample]] = batches.map { batchIdx in |
50 | 100 | var examples: [ObjectDetectionExample] = []
|
51 | 101 | for i in 0..<batchSize {
|
52 | 102 | let idx = batchSize * batchIdx + i
|
@@ -118,3 +168,12 @@ func loadCOCOExample(coco: COCO, image: COCO.Image, includeMasks: Bool) -> Objec
|
118 | 168 | }
|
119 | 169 | return ObjectDetectionExample(image: img, objects: objects)
|
120 | 170 | }
|
| 171 | + |
| 172 | +fileprivate func makeBatch<BatchSamples: Collection>( |
| 173 | + samples: BatchSamples, device: Device, |
| 174 | + transform: (ObjectDetectionExample) -> [ObjectDetectionExample] |
| 175 | +) -> [ObjectDetectionExample] where BatchSamples.Element == ObjectDetectionExample { |
| 176 | + return samples.reduce([]) { |
| 177 | + $0 + transform($1) |
| 178 | + } |
| 179 | +} |
0 commit comments