Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 0f11d8d

Browse files
authored
Migrating the COCO dataset to Epochs (#606)
* Initial conversion of COCO to Epochs. * Fixing tests, adding convenience initializers. * Forgot an assertion. * Remove Batcher from Dataset CMakeLists. * Adding a settable transform function for custom data manipulation in the COCO pipeline. * Allow more than one example to be generated during preprocessing from a source example.
1 parent 20fa285 commit 0f11d8d

9 files changed

+201
-63
lines changed

Datasets/CMakeLists.txt

-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ add_library(Datasets
2525
ImageSegmentationDataset.swift
2626
OxfordIIITPets/OxfordIIITPets.swift)
2727
target_link_libraries(Datasets PUBLIC
28-
Batcher
2928
ModelSupport)
3029
set_target_properties(Datasets PROPERTIES
3130
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})

Datasets/COCO/COCO.swift

+15
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
116
import Foundation
217

318
// Code below is ported from https://github.com/cocometadata/cocoapi

Datasets/COCO/COCODataset.swift

+97-38
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,102 @@
1-
import Batcher
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
215
import Foundation
16+
import TensorFlow
17+
18+
public struct COCODataset<Entropy: RandomNumberGenerator> {
19+
/// Type of the collection of non-collated batches.
20+
public typealias Batches = Slices<Sampling<[ObjectDetectionExample], ArraySlice<Int>>>
21+
/// The type of the training data, represented as a sequence of epochs, which
22+
/// are collection of batches.
23+
public typealias Training = LazyMapSequence<
24+
TrainingEpochs<[ObjectDetectionExample], Entropy>,
25+
LazyMapSequence<Batches, [ObjectDetectionExample]>
26+
>
27+
/// The type of the validation data, represented as a collection of batches.
28+
public typealias Validation = LazyMapSequence<Slices<[ObjectDetectionExample]>, [ObjectDetectionExample]>
29+
/// The training epochs.
30+
public let training: Training
31+
/// The validation batches.
32+
public let validation: Validation
33+
34+
/// Creates an instance with `batchSize` on `device` using `remoteBinaryArchiveLocation`.
35+
///
36+
/// - Parameters:
37+
/// - training: The COCO metadata for the training data.
38+
/// - validation: The COCO metadata for the validation data.
39+
/// - includeMasks: Whether to include the segmentation masks when loading the dataset.
40+
/// - batchSize: Number of images provided per batch.
41+
/// - entropy: A source of randomness used to shuffle sample ordering. It
42+
/// will be stored in `self`, so if it is only pseudorandom and has value
43+
/// semantics, the sequence of epochs is deterministic and not dependent
44+
/// on other operations.
45+
/// - device: The Device on which resulting Tensors from this dataset will be placed, as well
46+
/// as where the latter stages of any conversion calculations will be performed.
47+
public init(
48+
training: COCO, validation: COCO, includeMasks: Bool, batchSize: Int,
49+
entropy: Entropy, device: Device,
50+
transform: @escaping (ObjectDetectionExample) -> [ObjectDetectionExample]
51+
) {
52+
let trainingSamples = loadCOCOExamples(
53+
from: training,
54+
includeMasks: includeMasks,
55+
batchSize: batchSize)
356

4-
public struct COCODataset: ObjectDetectionDataset {
5-
public typealias SourceDataSet = [ObjectDetectionExample]
6-
public let trainingExamples: SourceDataSet
7-
public let training: Batcher<SourceDataSet>
8-
public let testExamples: SourceDataSet
9-
public let test: Batcher<SourceDataSet>
57+
self.training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
58+
.lazy.map { (batches: Batches) -> LazyMapSequence<Batches, [ObjectDetectionExample]> in
59+
return batches.lazy.map {
60+
makeBatch(samples: $0, device: device, transform: transform)
61+
}
62+
}
63+
64+
let validationSamples = loadCOCOExamples(
65+
from: validation,
66+
includeMasks: includeMasks,
67+
batchSize: batchSize)
1068

11-
public init(
12-
training: COCO, test: COCO,
13-
includeMasks: Bool, batchSize: Int, numWorkers: Int
14-
) {
15-
self.trainingExamples =
16-
loadCOCOExamples(
17-
from: training,
18-
includeMasks: includeMasks,
19-
batchSize: batchSize,
20-
numWorkers: numWorkers)
21-
self.training =
22-
Batcher(
23-
on: trainingExamples,
24-
batchSize: batchSize,
25-
numWorkers: numWorkers,
26-
shuffle: true)
27-
self.testExamples =
28-
loadCOCOExamples(
29-
from: test,
30-
includeMasks: includeMasks,
31-
batchSize: batchSize,
32-
numWorkers: numWorkers)
33-
self.test =
34-
Batcher(
35-
on: testExamples,
36-
batchSize: batchSize,
37-
numWorkers: numWorkers,
38-
shuffle: false)
69+
self.validation = validationSamples.inBatches(of: batchSize).lazy.map {
70+
makeBatch(samples: $0, device: device, transform: transform)
3971
}
72+
}
73+
74+
public static func identity(_ example: ObjectDetectionExample) -> [ObjectDetectionExample] {
75+
return [example]
76+
}
4077
}
4178

42-
func loadCOCOExamples(from coco: COCO, includeMasks: Bool, batchSize: Int, numWorkers: Int)
79+
extension COCODataset: ObjectDetectionData where Entropy == SystemRandomNumberGenerator {
80+
/// Creates an instance with `batchSize`, using the SystemRandomNumberGenerator.
81+
public init(
82+
training: COCO, validation: COCO, includeMasks: Bool, batchSize: Int,
83+
on device: Device = Device.default,
84+
transform: @escaping (ObjectDetectionExample) -> [ObjectDetectionExample] = COCODataset.identity
85+
) {
86+
self.init(
87+
training: training, validation: validation, includeMasks: includeMasks, batchSize: batchSize,
88+
entropy: SystemRandomNumberGenerator(), device: device, transform: transform)
89+
}
90+
}
91+
92+
93+
func loadCOCOExamples(from coco: COCO, includeMasks: Bool, batchSize: Int)
4394
-> [ObjectDetectionExample]
4495
{
4596
let images = coco.metadata["images"] as! [COCO.Image]
4697
let batchCount: Int = images.count / batchSize + 1
47-
let n = min(numWorkers, batchCount)
4898
let batches = Array(0..<batchCount)
49-
let examples: [[ObjectDetectionExample]] = batches._concurrentMap(nthreads: n) { batchIdx in
99+
let examples: [[ObjectDetectionExample]] = batches.map { batchIdx in
50100
var examples: [ObjectDetectionExample] = []
51101
for i in 0..<batchSize {
52102
let idx = batchSize * batchIdx + i
@@ -118,3 +168,12 @@ func loadCOCOExample(coco: COCO, image: COCO.Image, includeMasks: Bool) -> Objec
118168
}
119169
return ObjectDetectionExample(image: img, objects: objects)
120170
}
171+
172+
fileprivate func makeBatch<BatchSamples: Collection>(
173+
samples: BatchSamples, device: Device,
174+
transform: (ObjectDetectionExample) -> [ObjectDetectionExample]
175+
) -> [ObjectDetectionExample] where BatchSamples.Element == ObjectDetectionExample {
176+
return samples.reduce([]) {
177+
$0 + transform($1)
178+
}
179+
}

Datasets/COCO/COCOVariant.swift

+14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
import Foundation
216
import ModelSupport
317

Datasets/LanguageModelDataset.swift

+14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
import TensorFlow
216

317
/// A dataset suitable for language modeling.

Datasets/ObjectDetectionDataset.swift

+38-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
1-
import Batcher
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
215
import Foundation
316
import ModelSupport
417
import TensorFlow
@@ -52,7 +65,7 @@ public struct LabeledObject {
5265
}
5366
}
5467

55-
public struct ObjectDetectionExample: _Collatable, KeyPathIterable {
68+
public struct ObjectDetectionExample: KeyPathIterable {
5669
public let image: LazyImage
5770
public let objects: [LabeledObject]
5871

@@ -62,10 +75,28 @@ public struct ObjectDetectionExample: _Collatable, KeyPathIterable {
6275
}
6376
}
6477

65-
public protocol ObjectDetectionDataset {
66-
associatedtype SourceDataSet: Collection
67-
where SourceDataSet.Element == ObjectDetectionExample, SourceDataSet.Index == Int
78+
/// Types whose elements represent an object detection dataset (with both
79+
/// training and validation data).
80+
public protocol ObjectDetectionData {
81+
/// The type of the training data, represented as a sequence of epochs, which
82+
/// are collection of batches.
83+
associatedtype Training: Sequence
84+
where Training.Element: Collection, Training.Element.Element == [ObjectDetectionExample]
85+
/// The type of the validation data, represented as a collection of batches.
86+
associatedtype Validation: Collection where Validation.Element == [ObjectDetectionExample]
87+
/// Creates an instance from a given `batchSize`.
88+
init(
89+
training: COCO, validation: COCO, includeMasks: Bool, batchSize: Int, on device: Device,
90+
transform: @escaping (ObjectDetectionExample) -> [ObjectDetectionExample])
91+
/// The `training` epochs.
92+
var training: Training { get }
93+
/// The `validation` batches.
94+
var validation: Validation { get }
6895

69-
var training: Batcher<SourceDataSet> { get }
70-
var test: Batcher<SourceDataSet> { get }
96+
// The following is probably going to be necessary since we can't extract that
97+
// information from `Epochs` or `Batches`.
98+
/// The number of samples in the `training` set.
99+
//var trainingSampleCount: Int {get}
100+
/// The number of samples in the `validation` set.
101+
//var validationSampleCount: Int {get}
71102
}

Datasets/TensorPair.swift

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
// limitations under the License.
1414

1515
import TensorFlow
16-
import Batcher
1716

1817
/// A generic tuple of two tensors `Tensor`.
1918
///
2019
/// - Note: `TensorPair` has a generic name and provides little semantic information, to conform to
2120
/// `Collatable`. You can use it for most basic datasets with one tensor of inputs and one tensor of
2221
/// labels but you should write your own struct for more complex tasks (or if you want more descriptive
2322
/// names).
24-
public struct TensorPair<S1: TensorFlowScalar, S2: TensorFlowScalar>: _Collatable, KeyPathIterable {
23+
public struct TensorPair<S1: TensorFlowScalar, S2: TensorFlowScalar>: KeyPathIterable {
2524
public var first: Tensor<S1>
2625
public var second: Tensor<S2>
2726

@@ -30,4 +29,4 @@ public struct TensorPair<S1: TensorFlowScalar, S2: TensorFlowScalar>: _Collatabl
3029
self.first = first
3130
self.second = second
3231
}
33-
}
32+
}

Package.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ let package = Package(
2525
],
2626
targets: [
2727
.target(name: "Batcher", path: "Batcher"),
28-
.target(name: "Datasets", dependencies: ["ModelSupport", "Batcher"], path: "Datasets"),
28+
.target(name: "Datasets", dependencies: ["ModelSupport"], path: "Datasets"),
2929
.target(name: "STBImage", path: "Support/STBImage"),
3030
.target(
3131
name: "ModelSupport", dependencies: ["SwiftProtobuf", "STBImage"], path: "Support",
@@ -117,7 +117,7 @@ let package = Package(
117117
),
118118
.target(
119119
name: "pix2pix",
120-
dependencies: ["Batcher", "ArgumentParser", "ModelSupport", "Datasets"],
120+
dependencies: ["ArgumentParser", "ModelSupport", "Datasets"],
121121
path: "pix2pix"
122122
),
123123
.target(

Tests/DatasetsTests/COCO/COCODatasetTests.swift

+19-12
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,33 @@ final class COCODatasetTests: XCTestCase {
99
// to avoid fetching the full training data during CI runs.
1010
let dataset = COCODataset(
1111
training: COCOVariant.loadVal(),
12-
test: COCOVariant.loadTest(),
13-
includeMasks: false, batchSize: 32, numWorkers: 8)
14-
verify(dataset.trainingExamples)
15-
verify(dataset.testExamples)
12+
validation: COCOVariant.loadTest(),
13+
includeMasks: false, batchSize: 32)
14+
15+
for epochBatches in dataset.training.prefix(1) {
16+
let batch = epochBatches.first!
17+
XCTAssertTrue(batch[0].image.width != 0)
18+
}
19+
20+
let validationBatch = dataset.validation.first!
21+
XCTAssertTrue(validationBatch[0].image.width != 0)
1622
}
1723

1824
func testExamplesIncludingMasks() {
1925
// We use val/test variants here, instead of train/val,
2026
// to avoid fetching the full training data during CI runs.
2127
let dataset = COCODataset(
2228
training: COCOVariant.loadVal(),
23-
test: COCOVariant.loadTest(),
24-
includeMasks: true, batchSize: 32, numWorkers: 8)
25-
verify(dataset.trainingExamples)
26-
verify(dataset.testExamples)
27-
}
29+
validation: COCOVariant.loadTest(),
30+
includeMasks: true, batchSize: 32)
31+
32+
for epochBatches in dataset.training.prefix(1) {
33+
let batch = epochBatches.first!
34+
XCTAssertTrue(batch[0].image.width != 0)
35+
}
2836

29-
func verify(_ examples: [ObjectDetectionExample]) {
30-
XCTAssertTrue(examples.count > 0)
31-
XCTAssertTrue(examples[0].image.width != 0)
37+
let validationBatch = dataset.validation.first!
38+
XCTAssertTrue(validationBatch[0].image.width != 0)
3239
}
3340

3441
static var allTests = [

0 commit comments

Comments
 (0)