Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 89abc7e

Browse files
vballoliBradLarson
andauthored
Add 1D VariationalAutoencoder (#401)
* Add 1D VariationalAutoencoder * Add Autoencoder CMake * Updating for Batcher. * Reformat main.swift and update the target name in the README. * Rework test dataset usage. * Fix CMakeLists. Co-authored-by: Brad Larson <[email protected]>
1 parent 27ef87d commit 89abc7e

12 files changed

+202
-0
lines changed
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
add_executable(Autoencoder1D
2+
main.swift)
3+
target_link_libraries(Autoencoder1D PRIVATE
4+
Datasets
5+
ModelSupport)
6+
7+
8+
install(TARGETS Autoencoder1D
9+
DESTINATION bin)
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
add_executable(Autoencoder2D
2+
main.swift)
3+
target_link_libraries(Autoencoder2D PRIVATE
4+
Datasets
5+
ModelSupport)
6+
7+
8+
install(TARGETS Autoencoder2D
9+
DESTINATION bin)

Autoencoder/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
add_subdirectory(Autoencoder1D)
2+
add_subdirectory(Autoencoder2D)
3+
add_subdirectory(VAE1D)

Autoencoder/VAE1D/CMakeLists.txt

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
add_executable(VariationalAutoencoder1D
2+
main.swift)
3+
target_link_libraries(VariationalAutoencoder1D PRIVATE
4+
Datasets
5+
ModelSupport)
6+
7+
8+
install(TARGETS VariationalAutoencoder1D
9+
DESTINATION bin)

Autoencoder/VAE1D/README.md

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# 1D Variational Autoencoder
2+
3+
This is an example of a simple 1-dimensional Variational Autoencoder model, using MNIST as a training dataset. Variational Autoencoder is based on the paper: "Auto-Encoding Variational Bayes", [Kingma et. al](https://arxiv.org/abs/1312.6114). It should produce output similar to the following:
4+
5+
### Epoch 1
6+
<p align="center">
7+
<img src="images/epoch-1-input.jpg" height="270" width="360">
8+
<img src="images/epoch-1-output.jpg" height="270" width="360">
9+
</p>
10+
11+
### Epoch 10
12+
<p align="center">
13+
<img src="images/epoch-10-input.jpg" height="270" width="360">
14+
<img src="images/epoch-10-output.jpg" height="270" width="360">
15+
</p>
16+
17+
18+
## Setup
19+
20+
To begin, you'll need the [latest version of Swift for
21+
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
22+
installed. Make sure you've added the correct version of `swift` to your path.
23+
24+
To train the model, run:
25+
26+
```
27+
swift run -c release VariationalAutoencoder1D
28+
```
29+
30+
## Key implementations
31+
32+
1. Reparamterization trick is internally implemented in the VAE model
33+
2. VAE model returns an `Array` of `Tensor<Float>` tensors - which is inherently a `Differentiable` extension. (Reference: [S4TF API Docs](https://www.tensorflow.org/swift/api_docs/Extensions/Array))
34+
3. Loss Function combines `sigmoidCrossEntropy` of the output and KL Divergence between the intermediate representations.
759 Bytes
Loading
1.14 KB
Loading
780 Bytes
Loading
841 Bytes
Loading

Autoencoder/VAE1D/main.swift

+134
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// Based on the paper: "Auto-Encoding Variational Bayes"
16+
// by Diederik P Kingma and Max Welling
17+
// Reference implementation: https://github.com/pytorch/examples/blob/master/vae/main.py
18+
19+
import Datasets
20+
import Foundation
21+
import ModelSupport
22+
import TensorFlow
23+
24+
let epochCount = 10
25+
let imageHeight = 28
26+
let imageWidth = 28
27+
28+
let outputFolder = "./output/"
29+
let dataset = MNIST(batchSize: 128, flattening: true)
30+
31+
let inputDim = 784 // 28*28 for any MNIST
32+
let hiddenDim = 400
33+
let latentDim = 20
34+
35+
// Variational Autoencoder
36+
public struct VAE: Layer {
37+
// Encoder
38+
public var encoderDense1: Dense<Float>
39+
public var encoderDense2_1: Dense<Float>
40+
public var encoderDense2_2: Dense<Float>
41+
// Decoder
42+
public var decoderDense1: Dense<Float>
43+
public var decoderDense2: Dense<Float>
44+
45+
public init() {
46+
self.encoderDense1 = Dense<Float>(
47+
inputSize: inputDim, outputSize: hiddenDim, activation: relu)
48+
self.encoderDense2_1 = Dense<Float>(inputSize: hiddenDim, outputSize: latentDim)
49+
self.encoderDense2_2 = Dense<Float>(inputSize: hiddenDim, outputSize: latentDim)
50+
51+
self.decoderDense1 = Dense<Float>(
52+
inputSize: latentDim, outputSize: hiddenDim, activation: relu)
53+
self.decoderDense2 = Dense<Float>(inputSize: hiddenDim, outputSize: inputDim)
54+
}
55+
56+
@differentiable
57+
public func callAsFunction(_ input: Tensor<Float>) -> [Tensor<Float>] {
58+
// Encode
59+
let intermediateInput = encoderDense1(input)
60+
let mu = encoderDense2_1(intermediateInput)
61+
let logVar = encoderDense2_2(intermediateInput)
62+
63+
// Re-parameterization trick
64+
let std = exp(0.5 * logVar)
65+
let epsilon = Tensor<Float>(randomNormal: std.shape)
66+
let z = mu + epsilon * std
67+
68+
// Decode
69+
let output = z.sequenced(through: decoderDense1, decoderDense2)
70+
return [output, mu, logVar]
71+
}
72+
}
73+
74+
var vae = VAE()
75+
let optimizer = Adam(for: vae, learningRate: 1e-3)
76+
77+
// Loss function: sum of the KL divergence of the embeddings and the cross entropy loss between the input and it's reconstruction.
78+
func vaeLossFunction(
79+
input: Tensor<Float>, output: Tensor<Float>, mu: Tensor<Float>, logVar: Tensor<Float>
80+
) -> Tensor<Float> {
81+
let crossEntropy = sigmoidCrossEntropy(logits: output, labels: input, reduction: _sum)
82+
let klDivergence = -0.5 * (1 + logVar - pow(mu, 2) - exp(logVar)).sum()
83+
return crossEntropy + klDivergence
84+
}
85+
86+
// TODO: Find a cleaner way of extracting individual images that doesn't require a second dataset.
87+
let singleImageDataset = MNIST(batchSize: 1, flattening: true)
88+
let individualTestImages = singleImageDataset.test
89+
var testImageIterator = individualTestImages.sequenced()
90+
91+
// Training loop
92+
for epoch in 1...epochCount {
93+
// Test for each epoch
94+
if let nextIndividualImage = testImageIterator.next() {
95+
let sampleTensor = nextIndividualImage.first
96+
let sampleImage = Tensor(
97+
shape: [1, imageHeight * imageWidth], scalars: sampleTensor.scalars)
98+
99+
let testOutputs = vae(sampleImage)
100+
let testImage = testOutputs[0]
101+
let testMu = testOutputs[1]
102+
let testLogVar = testOutputs[2]
103+
if epoch == 1 || epoch % 10 == 0 {
104+
do {
105+
try saveImage(
106+
sampleImage, shape: (imageWidth, imageHeight), format: .grayscale,
107+
directory: outputFolder, name: "epoch-\(epoch)-input")
108+
try saveImage(
109+
testImage, shape: (imageWidth, imageHeight), format: .grayscale,
110+
directory: outputFolder, name: "epoch-\(epoch)-output")
111+
} catch {
112+
print("Could not save image with error: \(error)")
113+
}
114+
}
115+
116+
let sampleLoss = vaeLossFunction(
117+
input: sampleImage, output: testImage, mu: testMu, logVar: testLogVar)
118+
print("[Epoch: \(epoch)] Loss: \(sampleLoss)")
119+
}
120+
121+
for batch in dataset.training.sequenced() {
122+
let x = batch.first
123+
124+
let 𝛁model = TensorFlow.gradient(at: vae) { vae -> Tensor<Float> in
125+
let outputs = vae(x)
126+
let output = outputs[0]
127+
let mu = outputs[1]
128+
let logVar = outputs[2]
129+
return vaeLossFunction(input: x, output: output, mu: mu, logVar: logVar)
130+
}
131+
132+
optimizer.update(&vae, along: 𝛁model)
133+
}
134+
}

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ set_target_properties(SwiftProtobuf PROPERTIES
8686
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
8787
add_dependencies(SwiftProtobuf swift-protobuf-install)
8888

89+
add_subdirectory(Autoencoder)
8990
add_subdirectory(Support)
9091
add_subdirectory(Batcher)
9192
add_subdirectory(Datasets)

Package.swift

+3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ let package = Package(
5656
.target(
5757
name: "Autoencoder2D", dependencies: ["Datasets", "ModelSupport"],
5858
path: "Autoencoder/Autoencoder2D"),
59+
.target(
60+
name: "VariationalAutoencoder1D", dependencies: ["Datasets", "ModelSupport"],
61+
path: "Autoencoder/VAE1D"),
5962
.target(name: "Catch", path: "Catch"),
6063
.target(name: "Gym-FrozenLake", path: "Gym/FrozenLake"),
6164
.target(name: "Gym-CartPole", path: "Gym/CartPole"),

0 commit comments

Comments
 (0)