Skip to content
This repository was archived by the owner on Mar 30, 2022. It is now read-only.

Commit 7bbc9ed

Browse files
authored
Adding initial guide to Tensors, along with a general docs Readme (#604)
* Added new Tensor guide and an overview Readme for the docs. * Extracted accelerator backends to their own guide, added X10 debugging guide, added model summary guide. * Moved Tensor guides into the main section and added a brief description of _Raw operators.
1 parent 3ef4fc7 commit 7bbc9ed

File tree

6 files changed

+411
-0
lines changed

6 files changed

+411
-0
lines changed

docs/README.md

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Swift for TensorFlow documentation
2+
3+
This is the primary location of Swift for TensorFlow's documentation and tutorials.
4+
5+
Formatted versions of the current guides, tutorials, and automatically generated API documentation
6+
can be found at [tensorflow.org/swift](https://www.tensorflow.org/swift). The original versions
7+
of the guides can be found [here](site/guide/).
8+
9+
## Tutorials ![](https://www.tensorflow.org/images/colab_logo_32px.png)
10+
11+
Tutorial | Last Updated |
12+
-------- | ------------ |
13+
[A Swift Tour](https://colab.research.google.com/github/tensorflow/swift/blob/main/docs/site/tutorials/a_swift_tour.ipynb) | March 2019
14+
[Protocol-Oriented Programming & Generics](https://colab.research.google.com/github/tensorflow/swift/blob/main/docs/site/tutorials/protocol_oriented_generics.ipynb) | August 2019
15+
[Python Interoperability](https://colab.research.google.com/github/tensorflow/swift/blob/main/docs/site/tutorials/python_interoperability.ipynb) | March 2019
16+
[Custom Differentiation](https://colab.research.google.com/github/tensorflow/swift/blob/main/docs/site/tutorials/custom_differentiation.ipynb) | March 2019
17+
[Sharp Edges in Differentiability](https://colab.research.google.com/github/tensorflow/swift/blob/main/docs/site/tutorials/Swift_autodiff_sharp_edges.ipynb) | November 2020
18+
[Model Training Walkthrough](https://colab.research.google.com/github/tensorflow/swift/blob/main/docs/site/tutorials/model_training_walkthrough.ipynb) | March 2019
19+
[Raw TensorFlow Operators](https://colab.research.google.com/github/tensorflow/swift/blob/main/docs/site/tutorials/raw_tensorflow_operators.ipynb) | December 2019
20+
[Introducing X10, an XLA-Based Backend](https://colab.research.google.com/github/tensorflow/swift/blob/main/docs/site/tutorials/introducing_x10.ipynb) | May 2020
21+
22+
## Technology reference
23+
24+
Many different technological directions have been explored over the lifetime of the project. An
25+
archive of reference guides, some now obsolete, can be found here:
26+
27+
Document | Last Updated | Status |
28+
-------- | ------------ | ------ |
29+
[Swift Differentiable Programming Manifesto](https://github.com/apple/swift/blob/main/docs/DifferentiableProgramming.md) | January 2020 | Current
30+
[Swift Differentiable Programming Implementation Overview](https://docs.google.com/document/d/1_BirmTqdotglwNTOcYAW-ib6mx_jl-gH9Dbg4WmHZh0) | August 2019 | Current
31+
[Swift Differentiable Programming Design Overview](https://docs.google.com/document/d/1bPepWLfRQa6CtXqKA8CDQ87uZHixNav-TFjLSisuKag/edit?usp=sharing) | June 2019 | Outdated
32+
[Differentiable Types](DifferentiableTypes.md) | March 2019 | Outdated
33+
[Differentiable Functions and Differentiation APIs](DifferentiableFunctions.md) | March 2019 | Outdated
34+
[Dynamic Property Iteration using Key Paths](DynamicPropertyIteration.md) | March 2019 | Current
35+
[Hierarchical Parameter Iteration and Optimization](ParameterOptimization.md) | March 2019 | Current
36+
[First-Class Automatic Differentiation in Swift: A Manifesto](https://gist.github.com/rxwei/30ba75ce092ab3b0dce4bde1fc2c9f1d) | October 2018 | Outdated
37+
[Automatic Differentiation Whitepaper](AutomaticDifferentiation.md) | April 2018 | Outdated
38+
[Python Interoperability](PythonInteroperability.md) | April 2018 | Current
39+
[Graph Program Extraction](GraphProgramExtraction.md) | April 2018 | Outdated

docs/site/_book.yaml

+8
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@ upper_tabs:
2121
- title: Swift differentiable programming manifesto
2222
path: https://github.com/apple/swift/blob/main/docs/DifferentiableProgramming.md
2323
status: external
24+
- title: Tensors
25+
path: /swift/guide/tensors
26+
- title: Accelerator backends
27+
path: /swift/guide/backends
28+
- title: Debugging X10 issues
29+
path: /swift/guide/debugging_x10
2430
- heading: "Machine learning models"
31+
- title: Model summaries
32+
path: /swift/guide/model_summary
2533
- title: Swift for TensorFlow model garden
2634
path: https://github.com/tensorflow/swift-models
2735
status: external

docs/site/guide/backends.md

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Accelerator backends
2+
3+
It's pretty straightforward to describe a `Tensor` calculation, but when and how that calculation
4+
is performed will depend on which backend is used for the `Tensor`s and when the results
5+
are needed on the host CPU.
6+
7+
Behind the scenes, operations on `Tensor`s are dispatched to accelerators like GPUs or
8+
[TPUs](https://cloud.google.com/tpu), or run on the CPU when no accelerator is available. This
9+
happens automatically for you, and makes it easy to perform complex parallel calculations using
10+
a high-level interface. However, it can be useful to understand how this dispatch occurs and be
11+
able to customize it for optimal performance.
12+
13+
Swift for TensorFlow has two backends for performing accelerated computation: TensorFlow eager mode
14+
and X10. The default backend is TensorFlow eager mode, but that can be overridden. An
15+
[interactive tutorial](https://colab.research.google.com/github/tensorflow/swift/blob/main/docs/site/tutorials/introducing_x10.ipynb)
16+
is available that walks you through the use of these different backends.
17+
18+
## TensorFlow eager mode
19+
20+
The TensorFlow eager mode backend leverages
21+
[the TensorFlow C API](https://www.tensorflow.org/install/lang_c) to send each `Tensor` operation
22+
to a GPU or CPU as it is encountered. The result of that operation is then retrieved and passed on
23+
to the next operation.
24+
25+
This operation-by-operation dispatch is straightforward to understand and requires no explicit
26+
configuration within your code. However, in many cases it does not result in optimal performance
27+
due to the overhead from sending off many small operations, combined with the lack of operation
28+
fusion and optimization that can occur when graphs of operations are present. Finally, TensorFlow eager mode is incompatible with TPUs, and can only be used with CPUs and GPUs.
29+
30+
## X10 (XLA-based tracing)
31+
32+
X10 is the name of the Swift for TensorFlow backend that uses lazy tensor tracing and [the XLA
33+
optimizing compiler](https://www.tensorflow.org/xla) to in many cases significantly improve
34+
performance over operation-by-operation dispatch. Additionally, it adds compatibility for
35+
[TPUs](https://cloud.google.com/tpu), accelerators specifically optimized for the kinds of
36+
calculations found within machine learning models.
37+
38+
The use of X10 for `Tensor` calculations is not the default, so you need to opt in to this backend.
39+
That is done by specifying that a `Tensor` is placed on an XLA device:
40+
41+
```swift
42+
let tensor1 = Tensor<Float>([0.0, 1.0, 2.0], on: Device.defaultXLA)
43+
let tensor2 = Tensor<Float>([1.5, 2.5, 3.5], on: Device.defaultXLA)
44+
```
45+
46+
After that point, describing a calculation is exactly the same as for TensorFlow eager mode:
47+
48+
```swift
49+
let tensor3 = tensor1 + tensor2
50+
```
51+
52+
Further detail can be provided when creating a `Tensor`, such as what kind of accelerator to use
53+
and even which one, if several are available. For example, a `Tensor` can be created on the second
54+
TPU device (assuming it is visible to the host the program is running on) using the following:
55+
56+
```swift
57+
let tpuTensor = Tensor<Float>([0.0, 1.0, 2.0], on: Device(kind: .TPU, ordinal: 1, backend: .XLA))
58+
```
59+
60+
No implicit movement of `Tensor`s between devices is performed, so if two `Tensor`s on different
61+
devices are used in an operation together, a runtime error will occur. To manually copy the
62+
contents of a `Tensor` to a new device, you can use the `Tensor(copying:to:)` initializer. Some
63+
larger-scale structures that contain `Tensor`s within them, like models and optimizers, have helper
64+
functions for moving all of their interior `Tensor`s to a new device in one step.
65+
66+
Unlike TensorFlow eager mode, operations using the X10 backend are not individually dispatched as
67+
they are encountered. Instead, dispatching to an accelerator is only triggered by either reading
68+
calculated values back to the host or by placing an explicit barrier. The way this works is that
69+
the runtime starts from the value being read to the host (or the last calculation before a manual
70+
barrier) and traces the graph of calculations that result in that value.
71+
72+
This traced graph is then converted to the XLA HLO intermediate representation and passed to the
73+
XLA compiler to be optimized and compiled for execution on the accelerator. From there, the entire
74+
calculation is sent to the accelerator and the end result obtained.
75+
76+
Calculation is a time-consuming process, so X10 is best used with massively parallel calculations
77+
that are expressed via a graph and that are performed many times. Hash values and caching are used so that identical graphs are only compiled once for every unique configuration.
78+
79+
For machine learning models, the training process often involves a loop where the model is
80+
subjected to the same series of calculations over and over. You'll want each of these passes to be
81+
seen as a repetition of the same trace, rather than one long graph with repeated units inside it.
82+
This is enabled by the manual insertion of a call to `LazyTensorBarrier()` function at the
83+
locations in your code where you wish for a trace to end.
84+
85+
### Mixed-precision support in X10
86+
87+
Training with mixed precision via X10 is supported and both low-level and
88+
high-level API are provided to control it. The
89+
[low-level API](https://github.com/tensorflow/swift-apis/blob/main/Sources/TensorFlow/Core/MixedPrecision.swift)
90+
offers two computed properties: `toReducedPrecision` and `toFullPrecision` which
91+
convert between full and reduced precision, along with `isReducedPrecision`
92+
to query the precision. Besides `Tensor`s, models and optimizers can be converted
93+
between full and reduced precision using this API.
94+
95+
Note that conversion to reduced precision doesn't change the logical type of a
96+
`Tensor`. If `t` is a `Tensor<Float>`, `t.toReducedPrecision` is also a
97+
`Tensor<Float>` with a reduced-precision underlying representation.
98+
99+
As with devices, operations between tensors of different precisions are not
100+
allowed. This avoids silent and unwanted promotion to 32-bit floats, which would be hard
101+
to detect by the user.

docs/site/guide/debugging_x10.md

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Debugging X10 issues
2+
3+
The X10 accelerator backend can provide significantly higher throughput for graph-based parallel
4+
computation, but its deferred tracing and just-in-time compilation can lead to non-obvious behavior
5+
sometimes. This might include frequent recompilation of traces due to graph or tensor shape changes,
6+
or huge graphs that lead to memory issues during compilation.
7+
8+
One way to diagnose issues is to use the execution metrics and counters provided by
9+
X10. The first thing to check when a model is slow is to generate a metrics
10+
report.
11+
12+
# Metrics
13+
14+
To print a metrics report, add a `PrintX10Metrics()` call to your program:
15+
16+
```swift
17+
import TensorFlow
18+
19+
...
20+
PrintX10Metrics()
21+
...
22+
```
23+
24+
This will log various metrics and counters at the `INFO` level.
25+
26+
## Understanding the metrics report
27+
28+
The report includes things like:
29+
30+
- How many times we trigger XLA compilations and the total time spent on
31+
compilation.
32+
- How many times we launch an XLA computation and the total time spent on
33+
execution.
34+
- How many device data handles we create / destroy, etc.
35+
36+
This information is reported in terms of percentiles of the samples. An example
37+
is:
38+
39+
```
40+
Metric: CompileTime
41+
TotalSamples: 202
42+
Counter: 06m09s401ms746.001us
43+
ValueRate: 778ms572.062us / second
44+
Rate: 0.425201 / second
45+
Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us
46+
```
47+
48+
We also provide counters, which are named integer variables which track internal
49+
software status. For example:
50+
51+
```
52+
Counter: CachedSyncTensors
53+
Value: 395
54+
```
55+
56+
## Known caveats
57+
58+
`Tensor`s backed by X10 behave semantically like default eager mode`Tensor`s. However, there are
59+
some performance and completeness caveats:
60+
61+
1. Degraded performance because of too many recompilations.
62+
63+
XLA compilation is expensive. X10 automatically recompiles the graph every
64+
time new shapes are encountered, with no user intervention. Models need to
65+
see stabilized shapes within a few training steps and from that point no
66+
recompilation is needed. Additionally, the execution paths must stabilize
67+
quickly for the same reason: X10 recompiles when a new execution path is
68+
encountered. To sum up, in order to avoid recompilations:
69+
70+
* Avoid highly variable dynamic shapes. However, a low number of different
71+
shapes could be fine. Pad tensors to fixed sizes when possible.
72+
* Avoid loops with different number of iterations between training steps.
73+
X10 currently unrolls loops, therefore different number of loop
74+
iterations translate into different (unrolled) execution paths.
75+
76+
2. A small number of operations aren't supported by X10 yet.
77+
78+
We currently have a handful of operations which aren't supported, either
79+
because there isn't a good way to express them via XLA and static shapes
80+
(currently just `nonZeroIndices`) or lack of known use cases (several linear
81+
algebra operations and multinomial initialization). While the second
82+
category is easy to address as needed, the first category can only be
83+
addressed through interoperability with the CPU, non-XLA implementation.
84+
Using interoperability too often has significant performance implications
85+
because of host round-trips and fragmenting a fully fused model in multiple
86+
traces. Users are therefore advised to avoid using such operations in their
87+
models.
88+
89+
On Linux, use `XLA_SAVE_TENSORS_FILE` (documented in the next section) to
90+
get the Swift stack trace which called the unsupported operation. Function
91+
names can be manually demangled using `swift-demangle`.
92+
93+
94+
# Obtaining and graphing traces
95+
96+
If you suspect there are problems with the way graphs are being traced, or want to understand the
97+
tracing process, tools are provided to log out and visualize traces. You can have X10 log out the
98+
traces it finds by setting the `XLA_SAVE_TENSORS_FILE` environment variable:
99+
100+
```sh
101+
export XLA_SAVE_TENSORS_FILE=/home/person/TraceLog.txt
102+
```
103+
104+
These trace logs come in three formats: `text`, `hlo`, and `dot`, with the format settable through
105+
the environment variable XLA_SAVE_TENSORS_FMT:
106+
107+
```sh
108+
export XLA_SAVE_TENSORS_FMT=text
109+
```
110+
111+
When you run your application, the `text` representation that is logged out will show each
112+
individual trace in a high-level text notation used by X10. The `hlo` representation shows the
113+
intermediate representation that is passed to the XLA compiler. You may want to restrict the number
114+
of iterations within your training or calculation loops to prevent these logs from becoming too large. Also, each run of your application will append to this file, so you may wish to delete it
115+
between runs.
116+
117+
Setting the variable `XLA_LOG_GRAPH_CHANGES` to 1 will also indicate within the trace log where
118+
changes in the graph have occurred. This is extremely helpful in finding places where recompilation
119+
will result.
120+
121+
For a visual representation of a trace, the `dot` option will log out Graphviz-compatible graphs. If
122+
you extract the portion of a trace that looks like
123+
124+
```
125+
digraph G {
126+
...
127+
}
128+
```
129+
130+
into its own file, Graphviz (assuming it is installed) can generate a visual diagram via
131+
132+
```sh
133+
dot -Tpng trace.dot -o trace.png
134+
```
135+
136+
Note that setting the `XLA_SAVE_TENSORS_FILE` environment variable, especially when used in
137+
combination with `XLA_LOG_GRAPH_CHANGES` will have a substantial negative impact on performance.
138+
Only use these when debugging, and not for regular operation.
139+
140+
# Additional environment variables
141+
142+
Additional environment variables for debugging include:
143+
144+
* `XLA_USE_BF16`: If set to 1, transforms all the `Float` values to BF16.
145+
Should only be used for debugging since we offer automatic mixed precision.
146+
147+
* `XLA_USE_32BIT_LONG`: If set to 1, maps S4TF `Long` type to the XLA 32 bit
148+
integer type. On TPU, 64 bit integer computations are expensive, so setting
149+
this flag might help. Of course, the user needs to be certain that the
150+
values still fit in a 32 bit integer.

docs/site/guide/model_summary.md

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Model Summaries
2+
3+
A summary provides details about the architecture of a model, such as layer
4+
types and shapes.
5+
6+
The design proposal can be found [here][design]. This
7+
implementation is a WIP, so please file an [Issue][new_issue] with
8+
enhancements you would like to see or problems you run into.
9+
10+
**Note:** Model summaries are currently supported on the X10 backend only.
11+
12+
## Viewing a model summary
13+
14+
Create an X10 device and model.
15+
16+
```
17+
import TensorFlow
18+
19+
public struct MyModel: Layer {
20+
public var dense1 = Dense<Float>(inputSize: 1, outputSize: 1)
21+
public var dense2 = Dense<Float>(inputSize: 4, outputSize: 4)
22+
public var dense3 = Dense<Float>(inputSize: 4, outputSize: 4)
23+
public var flatten = Flatten<Float>()
24+
25+
@differentiable
26+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
27+
let layer1 = dense1(input)
28+
let layer2 = layer1.reshaped(to: [1, 4])
29+
let layer3 = dense2(layer2)
30+
let layer4 = dense3(layer3)
31+
return flatten(layer4)
32+
}
33+
}
34+
35+
let device = Device.defaultXLA
36+
let model0 = MyModel()
37+
let model = MyModel(copying: model0, to: device)
38+
```
39+
40+
Create an input tensor.
41+
42+
```
43+
let input = Tensor<Float>(repeating: 1, shape: [1, 4, 1, 1], on: device)
44+
```
45+
46+
Generate a summary of your model.
47+
48+
```
49+
let summary = model.summary(input: input)
50+
print(summary)
51+
```
52+
53+
```
54+
Layer Output Shape Attributes
55+
=============================== ==================== ======================
56+
Dense<Float> [1, 4, 1, 1]
57+
Dense<Float> [1, 4]
58+
Dense<Float> [1, 4]
59+
Flatten<Float> [1, 4]
60+
```
61+
62+
**Note:** the `summary()` function executes the model in order to obtain
63+
details about its architecture.
64+
65+
66+
[design]: https://docs.google.com/document/d/1hEhMiwLtuzsN3RvIC3FAh6NvtTimU8o_qdzMkGvntVg/view
67+
[new_issue]: https://github.com/tensorflow/swift-apis/issues/new

0 commit comments

Comments
 (0)