Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Initial compatibility with stock toolchain #722

Merged
merged 3 commits into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Models/Text/WordSeg/Lattice.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import _Differentiation
import ModelSupport
import TensorFlow

Expand Down
4 changes: 4 additions & 0 deletions Models/Text/WordSeg/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import ModelSupport
import TensorFlow

#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
import Numerics
#endif

/// Types that can be optimized by an optimizer.
///
/// TODO: Consider promoting this into a public protocol in swift-apis?
Expand Down
2 changes: 2 additions & 0 deletions Models/Text/WordSeg/SemiRing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import _Differentiation

#if os(iOS) || os(macOS) || os(tvOS) || os(watchOS)
import Darwin
#elseif os(Windows)
Expand Down
10 changes: 7 additions & 3 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ let package = Package(
.package(url: "https://github.com/apple/swift-protobuf.git", from: "1.10.0"),
.package(url: "https://github.com/apple/swift-argument-parser", .branch("main")),
.package(url: "https://github.com/google/swift-benchmark", from: "0.1.0"),
.package(url: "https://github.com/apple/swift-numerics", .branch("main")),
],
targets: [
.target(
Expand All @@ -35,12 +36,15 @@ let package = Package(
.target(name: "Datasets", dependencies: ["ModelSupport"], path: "Datasets"),
.target(name: "STBImage", path: "Support/STBImage"),
.target(
name: "ModelSupport", dependencies: ["STBImage"], path: "Support",
exclude: ["STBImage"]),
name: "ModelSupport",
dependencies: ["STBImage", .product(name: "Numerics", package: "swift-numerics"),],
path: "Support", exclude: ["STBImage"]),
.target(name: "TensorBoard", dependencies: ["SwiftProtobuf", "ModelSupport", "TrainingLoop"], path: "TensorBoard"),
.target(name: "ImageClassificationModels", path: "Models/ImageClassification"),
.target(name: "VideoClassificationModels", path: "Models/Spatiotemporal"),
.target(name: "TextModels", dependencies: ["Checkpoints", "Datasets", "SwiftProtobuf"], path: "Models/Text"),
.target(name: "TextModels",
dependencies: ["Checkpoints", "Datasets", "SwiftProtobuf", .product(name: "Numerics", package: "swift-numerics")],
path: "Models/Text"),
.target(name: "RecommendationModels", path: "Models/Recommendation"),
.target(name: "TrainingLoop", dependencies: ["ModelSupport"], path: "TrainingLoop"),
.target(
Expand Down
17 changes: 17 additions & 0 deletions Support/AnyLayer.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// TODO: Re-enable this for the stock toolchain when it can be realigned with VectorProtocol.
#if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
import TensorFlow
import _Differentiation

Expand Down Expand Up @@ -253,3 +269,4 @@ extension AnyLayer: Layer {
return _callAsFunction(input)
}
}
#endif
17 changes: 17 additions & 0 deletions Support/AnyLayerTangentVector.swift
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// TODO: Re-enable this for the stock toolchain when it can be realigned with VectorProtocol.
#if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
import TensorFlow
import _Differentiation

Expand Down Expand Up @@ -824,3 +840,4 @@ extension AnyLayerTangentVector: ElementaryFunctions {
return .init(box: x.box._root(n))
}
}
#endif
8 changes: 8 additions & 0 deletions Tests/SupportTests/AnyLayerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ import XCTest
import ModelSupport
import TensorFlow

// TODO: Re-enable these once AnyLayer can be aligned with the new VectorProtocol.
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
final class AnyLayerTests: XCTestCase {
static var allTests = [(String, (XCTestCase) -> () -> Void)]()
}
#else

struct ElementaryFunctionsTests<WrapperTestType: XCTestCase, Reference: ElementaryFunctions, Test: ElementaryFunctions> {
let createReference: ([Float]) -> Reference
let referenceToTest: (Reference) -> Test
Expand Down Expand Up @@ -167,3 +174,4 @@ final class AnyLayerTests: XCTestCase {
("testTangentVectorElementaryFunctions", testTangentVectorElementaryFunctions),
]
}
#endif