From 8c2891ee71d782739f8695aee7c24e6287baac01 Mon Sep 17 00:00:00 2001 From: Michelle Casbon Date: Fri, 8 May 2020 11:46:01 -0400 Subject: [PATCH 1/5] WordSeg model and tests --- Models/Text/WordSeg/DataSet.swift | 122 +++++ Models/Text/WordSeg/Lattice.swift | 225 +++++++++ Models/Text/WordSeg/Model.swift | 438 ++++++++++++++++++ Models/Text/WordSeg/README.md | 11 + Models/Text/WordSeg/SE-0259.swift | 153 ++++++ Models/Text/WordSeg/SemiRing.swift | 89 ++++ Models/Text/WordSeg/Vocabularies.swift | 179 +++++++ Tests/TextTests/CMakeLists.txt | 4 + .../WordSegmentationTests/ExampleData.swift | 429 +++++++++++++++++ .../WordSegmentationTests/ProbeLayers.swift | 209 +++++++++ .../TorchParameters.swift | 46 ++ .../WordSegmentationTests.swift | 162 +++++++ Tests/TextTests/XCTestManifests.swift | 4 + 13 files changed, 2071 insertions(+) create mode 100644 Models/Text/WordSeg/DataSet.swift create mode 100644 Models/Text/WordSeg/Lattice.swift create mode 100644 Models/Text/WordSeg/Model.swift create mode 100644 Models/Text/WordSeg/README.md create mode 100644 Models/Text/WordSeg/SE-0259.swift create mode 100644 Models/Text/WordSeg/SemiRing.swift create mode 100644 Models/Text/WordSeg/Vocabularies.swift create mode 100644 Tests/TextTests/WordSegmentationTests/ExampleData.swift create mode 100644 Tests/TextTests/WordSegmentationTests/ProbeLayers.swift create mode 100644 Tests/TextTests/WordSegmentationTests/TorchParameters.swift create mode 100644 Tests/TextTests/WordSegmentationTests/WordSegmentationTests.swift diff --git a/Models/Text/WordSeg/DataSet.swift b/Models/Text/WordSeg/DataSet.swift new file mode 100644 index 00000000000..719f06f4422 --- /dev/null +++ b/Models/Text/WordSeg/DataSet.swift @@ -0,0 +1,122 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +public struct DataSet { + public let training: [CharacterSequence] + public private(set) var testing: [CharacterSequence]? + public private(set) var validation: [CharacterSequence]? + public let alphabet: Alphabet + + private static func load(data: Data) throws -> [String] { + guard let contents: String = String(data: data, encoding: .utf8) else { + throw CharacterErrors.nonUtf8Data + } + return load(contents: contents) + } + + private static func load(contents: String) -> [String] { + var strings = [String]() + + for line in contents.components(separatedBy: .newlines) { + let stripped: String = line.components(separatedBy: .whitespaces).joined() + if stripped.isEmpty { continue } + strings.append(stripped) + } + return strings + } + + private static func makeAlphabet( + datasets training: [String], + _ otherSequences: [String]?..., + eos: String = "", + eow: String = "", + pad: String = "" + ) -> Alphabet { + var letters: Set = [] + + for dataset in otherSequences + [training] { + guard let dataset = dataset else { continue } + for sentence in dataset { + for character in sentence { + letters.insert(character) + } + } + } + + // Sort the letters to make it easier to interpret ints vs letters. + var sorted = Array(letters) + sorted.sort() + + return Alphabet(sorted, eos: eos, eow: eow, pad: pad) + } + + private static func convertDataset(_ dataset: [String], alphabet: Alphabet) throws -> [CharacterSequence] { + return try dataset.map { try CharacterSequence(alphabet: alphabet, appendingEoSTo: $0) } + } + private static func convertDataset(_ dataset: [String]?, alphabet: Alphabet) throws -> [CharacterSequence]? { + if let ds = dataset { + let tmp: [CharacterSequence] = try convertDataset(ds, alphabet: alphabet) // Use tmp to disambiguate function + return tmp + } + return nil + } + + public init?( + training trainingFile: String, + validation validationFile: String? = nil, + testing testingFile: String? = nil + ) throws { + let trainingData = try Data(contentsOf: URL(fileURLWithPath: trainingFile), + options: .alwaysMapped) + let training = try Self.load(data: trainingData) + + var validation: [String]? = nil + var testing: [String]? = nil + + if let validationFile = validationFile { + let data = try Data(contentsOf: URL(fileURLWithPath: validationFile), + options: .alwaysMapped) + validation = try Self.load(data: data) + } + + if let testingFile = testingFile { + let data: Data = try Data(contentsOf: URL(fileURLWithPath: testingFile), + options: .alwaysMapped) + testing = try Self.load(data: data) + } + self.alphabet = Self.makeAlphabet(datasets: training, validation, testing) + self.training = try Self.convertDataset(training, alphabet: self.alphabet) + self.validation = try Self.convertDataset(validation, alphabet: self.alphabet) + self.testing = try Self.convertDataset(testing, alphabet: self.alphabet) + } + + init(training trainingData: Data, validation validationData: Data?, testing testingData: Data?) throws { + let training = try Self.load(data: trainingData) + var validation: [String]? = nil + var testing: [String]? = nil + if let validationData = validationData { + validation = try Self.load(data: validationData) + } + if let testingData = testingData { + testing = try Self.load(data: testingData) + } + + self.alphabet = Self.makeAlphabet(datasets: training, validation, testing) + self.training = try Self.convertDataset(training, alphabet: self.alphabet) + self.validation = try Self.convertDataset(validation, alphabet: self.alphabet) + self.testing = try Self.convertDataset(testing, alphabet: self.alphabet) + } +} diff --git a/Models/Text/WordSeg/Lattice.swift b/Models/Text/WordSeg/Lattice.swift new file mode 100644 index 00000000000..4dd0c1ec2e9 --- /dev/null +++ b/Models/Text/WordSeg/Lattice.swift @@ -0,0 +1,225 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import TensorFlow + +#if os(iOS) || os(macOS) || os(tvOS) || os(watchOS) + import Darwin +#elseif os(Windows) + import ucrt +#else + import Glibc +#endif + +/// Lattice +/// +/// Represents the lattice used by the WordSeg algorithm. +public struct Lattice: Differentiable { + /// Edge + /// + /// Represents an Edge + public struct Edge: Differentiable { + @noDerivative public var start: Int + @noDerivative public var end: Int + @noDerivative public var string: CharacterSequence + public var logp: Float + + // expectation + public var score: SemiRing + public var totalScore: SemiRing + + @differentiable + init(start: Int, end: Int, sentence: CharacterSequence, logp: Float, + previous: SemiRing, order: Int) { + self.start = start + self.end = end + self.string = sentence + self.logp = logp + + self.score = + SemiRing(logp: logp, + // TODO(abdulras): this should really use integeral pow + logr: logp + logf(powf(Float(sentence.count), Float(order)))) + self.totalScore = self.score * previous + } + + @differentiable + public init(start: Int, end: Int, string: CharacterSequence, logp: Float, + score: SemiRing, totalScore: SemiRing) { + self.start = start + self.end = end + self.string = string + self.logp = logp + self.score = score + self.totalScore = totalScore + } + } + + /// Node + /// + /// Represents a node in the lattice + public struct Node: Differentiable { + @noDerivative public var bestEdge: Edge? + public var bestScore: Float = 0.0 + public var edges = [Edge]() + public var semiringScore: SemiRing = SemiRing.one + + init() {} + + @differentiable + public init(bestEdge: Edge?, bestScore: Float, edges: [Edge], + semiringScore: SemiRing) { + self.bestEdge = bestEdge + self.bestScore = bestScore + self.edges = edges + self.semiringScore = semiringScore + } + + @differentiable + func computeSemiringScore() -> SemiRing { + // TODO: Reduceinto and += + edges.differentiableMap { $0.totalScore }.differentiableReduce(SemiRing.zero) { $0 + $1 } + } + } + + var positions: [Node] + + @differentiable + public subscript(index: Int) -> Node { + get { return positions[index] } + set(v) { positions[index] = v } + //_modify { yield &positions[index] } + } + + // TODO: remove dummy. (workaround to make AD thing that lattice is varied) + @differentiable + init(count: Int, _ dummy: Tensor) { + positions = Array(repeating: Node(), count: count + 1) + } + + public init(positions: [Node]) { + self.positions = positions + } + + mutating func viterbi(sentence: String) -> [Edge] { + // Forwards pass + for position in 0...sentence.count { + var bestScore = -Float.infinity + var bestEdge: Edge! + for edge in self[position].edges { + let score: Float = self[edge.start].bestScore + edge.logp + if score > bestScore { + bestScore = score + bestEdge = edge + } + } + self[position].bestScore = bestScore + self[position].bestEdge = bestEdge + } + + // Backwards + var bestPath: [Edge] = [] + var nextEdge = self[sentence.count].bestEdge! + while nextEdge.start != 0 { + bestPath.append(nextEdge) + nextEdge = self[nextEdge.start].bestEdge! + } + bestPath.append(nextEdge) + + return bestPath.reversed() + } +} + +extension Lattice: CustomStringConvertible { + public var description: String { + """ + [ + \(positions.enumerated().map { " \($0.0): \($0.1)" }.joined(separator: "\n\n")) + ] + """ + } +} + +extension Lattice.Node: CustomStringConvertible { + public var description: String { + var edgesStr: String + if edges.isEmpty { + edgesStr = " " + } else { + edgesStr = edges.enumerated().map { " \($0.0) - \($0.1)" }.joined(separator: "\n") + } + return """ + best edge: \(String(describing: bestEdge)), best score: \(bestScore), score: \(semiringScore.shortDescription) + \(edgesStr) + """ + } +} + +extension Lattice.Edge: CustomStringConvertible { + public var description: String { + "[\(start)->\(end)] logp: \(logp), score: \(score.shortDescription), total score: \(totalScore.shortDescription), sentence: \(string)" + } +} + +/// SE-0259-esque equality with tolerance +extension Lattice { + public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { + guard self.positions.count == other.positions.count else { + print("positions count mismatch: \(self.positions.count) != \(other.positions.count)") + return false + } + return zip(self.positions, other.positions).enumerated() + .map { (index, position) in + let eq = position.0.isAlmostEqual(to: position.1, tolerance: tolerance) + if !eq { + print("mismatch at \(index): \(position.0) != \(position.1)") + } + return eq + } + .reduce(true) { $0 && $1 } + } +} + +extension Lattice.Node { + public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { + guard self.edges.count == other.edges.count else { return false } + + if !self.bestScore.isAlmostEqual(to: other.bestScore, tolerance: tolerance) { + return false + } + if let lhs = self.bestEdge, let rhs = other.bestEdge { + if !lhs.isAlmostEqual(to: rhs, tolerance: tolerance) { + return false + } + } + if !self.semiringScore.isAlmostEqual(to: other.semiringScore, tolerance: tolerance) { + return false + } + return zip(self.edges, other.edges) + .map { $0.isAlmostEqual(to: $1, tolerance: tolerance) } + .reduce(true) { $0 && $1 } + } +} + +extension Lattice.Edge { + public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { + return self.start == other.start && + self.end == other.end && + // TODO: figure out why the string equality is being ignored + // self.string == other.string && + self.logp.isAlmostEqual(to: other.logp, tolerance: tolerance) && + self.score.isAlmostEqual(to: other.score, tolerance: tolerance) && + self.totalScore.isAlmostEqual(to: other.totalScore, tolerance: tolerance) + } +} diff --git a/Models/Text/WordSeg/Model.swift b/Models/Text/WordSeg/Model.swift new file mode 100644 index 00000000000..f0612559978 --- /dev/null +++ b/Models/Text/WordSeg/Model.swift @@ -0,0 +1,438 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Original Paper: +// "Learning to Discover, Ground, and Use Words with Segmental Neural Language +// Models" +// Kazuya Kawakami, Chris Dyer, Phil Blunsom +// https://www.aclweb.org/anthology/P19-1645.pdf +// This implementation is not affiliated with DeepMind and has not been +// verified by the authors. + +import TensorFlow + +public struct Conf { + public var ndim: Int + public var dropoutProb: Double + public var chrVocab: Alphabet + public var strVocab: Lexicon + public var order: Int + + public init( + ndim: Int, + dropoutProb: Double, + chrVocab: Alphabet, + strVocab: Lexicon, + order: Int + ) { + self.ndim = ndim + self.dropoutProb = dropoutProb + self.chrVocab = chrVocab + self.strVocab = strVocab + self.order = order + } +} + +/// SNLM +/// +/// A representation of the Segmental Neural Language Model. +/// +/// \ref https://www.aclweb.org/anthology/P19-1645.pdf +public struct SNLM: EuclideanDifferentiable, KeyPathIterable { + @noDerivative public var conf: Conf + + // MARK: - Encoder + + public var embEnc: Embedding + public var lstmEnc: LSTM + + // MARK: - Interpolation weight + + public var mlpInterpolation: MLP + + // MARK: - Lexical memory + + public var mlpMemory: MLP + + // MARK: - Character-level decoder + + public var embDec: Embedding + public var lstmDec: LSTM + public var denseDec: Dense + + // MARK: - Other layers + + public var drop: Dropout + + // MARK: - Initializer + + public init(conf: Conf) { + self.conf = conf + + // Encoder + self.embEnc = Embedding(vocabularySize: conf.chrVocab.count, embeddingSize: conf.ndim) + self.lstmEnc = LSTM(LSTMCell(inputSize: conf.ndim, hiddenSize: conf.ndim)) + + // Interpolation weight + self.mlpInterpolation = MLP( + nIn: conf.ndim, + nHidden: conf.ndim, + nOut: 2, + dropoutProbability: conf.dropoutProb) + + // Lexical memory + self.mlpMemory = MLP( + nIn: conf.ndim, + nHidden: conf.ndim, + nOut: conf.strVocab.count, + dropoutProbability: conf.dropoutProb) + + // Character-level decoder + self.embDec = Embedding(vocabularySize: conf.chrVocab.count, embeddingSize: conf.ndim) + self.lstmDec = LSTM(LSTMCell(inputSize: conf.ndim, hiddenSize: conf.ndim)) + self.denseDec = Dense(inputSize: conf.ndim, outputSize: conf.chrVocab.count) + + // Other layers + self.drop = Dropout(probability: conf.dropoutProb) + } + + // MARK: - Encode + + /// Returns the hidden states of the encoder LSTM applied to the given sentence. + public func encode(_ x: CharacterSequence) -> [Tensor] { + let embedded = drop(embEnc(x.tensor)) + // TODO: If I inline `makeEncoderInput`, it breaks AD. + let encoderStates = lstmEnc(makeEncoderInput(embedded)) + // TODO: Need to add dropout here, but it breaks AD. + // TODO: If I inline `computeEncoderResult`, it breaks AD. + return computeEncoderResult(encoderStates) + } + + // MARK: - Decode + + /// Returns log probabilities for each of the candidates. + public func decode(_ candidates: [CharacterSequence], _ state: Tensor) -> Tensor { + // TODO: Shouldn't use a closure here. + let maxLen = { candidates.map { $0.count }.max()! + 1 }() + var xBatch: [Int32] = [] + var yBatch: [Int32] = [] + for candidate in candidates { + let padding = Array(repeating: conf.chrVocab.pad, count: maxLen - candidate.count - 1) + + // x is {sentence}{padding} + xBatch.append(conf.chrVocab.eow) + xBatch.append(contentsOf: candidate.characters) + xBatch.append(contentsOf: padding) + + // y is {sentence}{padding} + yBatch.append(contentsOf: candidate.characters) + yBatch.append(conf.chrVocab.eow) + yBatch.append(contentsOf: padding) + } + + // Shapes are [time x batch] so that we can unstack the time dimension into the array that + // the LSTM wants as input. + let x: Tensor = Tensor(shape: [candidates.count, maxLen], scalars: xBatch).transposed() + let y: Tensor = Tensor(shape: [candidates.count, maxLen], scalars: yBatch).transposed() + + // [time x batch x ndim] + let embeddedX = drop(embDec(x)) + + // [batch x ndim] + let stateBatch = state.rankLifted().tiled(multiples: Tensor([Int32(candidates.count), 1])) + + // [time] array of LSTM states whose `hidden` and `cell` fields have shape [batch x ndim] + let decoderStates = lstmDec.callAsFunction2( + embeddedX.unstacked(), + initialState: LSTMCell.State( + cell: Tensor(zeros: stateBatch.shape), + hidden: stateBatch)) + + // [time x batch x ndim] + // TODO: Need to add dropout here, but it breaks AD. + // TODO: If I inline `computeEncoderResult`, it breaks AD. + let decoderResult = computeDecoderResult(decoderStates) + + // [time x batch x chrVocab.count] + let logits = denseDec(decoderResult) + + // [time x batch] + let logp = -1 * softmaxCrossEntropy( + logits: logits.reshaped(to: [logits.shape[0] * logits.shape[1], logits.shape[2]]), + labels: y.flattened(), + reduction: identity).reshaped(to: y.shape) + + // [time x batch] + let logpExcludingPad = logp * Tensor(y .!= conf.chrVocab.pad) + + // [batch] + let candidateLogP = logpExcludingPad.transposed().sum(squeezingAxes: 1) + + return candidateLogP + } + + // MARK: - buildLattice + + //def get_logp_lex(self, logp_lex, candidate): + // if candidate in self.str_vocab: + // candidate_idx = self.str_vocab[candidate] + // return logp_lex[candidate_idx] + // else: + // return torch.log(torch_util.var_from_scaler(0.0, "FloatTensor", self.gpu)) + + func get_logp_lex(_ logp_lex: [Float], _ candidate: CharacterSequence) -> Float { + guard let index = conf.strVocab.dictionary[candidate] else { + return -Float.infinity + } + return logp_lex[Int(index)] + } + + // TODO: Triggers compiler crash. + @differentiable + public func buildLattice(_ sentence: CharacterSequence, maxLen: Int) -> Lattice { + var lattice = Lattice(count: sentence.count, embEnc.embeddings) + let states = encode(sentence) + let logg_batch = mlpInterpolation(Tensor(stacking: states)) + let logp_lex_batch = mlpMemory(Tensor(stacking: states)) + for pos in 0.."] + continue + } + } + candidates.append(candidate) + } + + //# Calculate probabilities + //current_state = states[pos] + //logg = logg_batch[pos] + //logp_lex = logp_lex_batch[pos] + //logp_chr = self.decode(candidates, current_state) + //# Update semiring score + //if pos != 0: + // lattice[pos]["semiring_score"] = semiring.add(lattice[pos]["edges"], + // self.gpu) + let current_state = states[pos] + let logg = scalarsWithADHack(logg_batch[pos]) // [2] + let logp_lex = scalarsWithADHack(logp_lex_batch[pos]) // [strVocab.chr.count] + let logp_chr = scalarsWithADHack(decode(candidates, current_state)) // [candidates.count] + if pos != 0 { + // TODO: Mutate in place when AD supports it. + let updatedNode = Lattice.Node( + bestEdge: lattice[pos].bestEdge, + bestScore: lattice[pos].bestScore, + edges: lattice[pos].edges, + semiringScore: lattice[pos].computeSemiringScore() + ) + lattice.positions = update(lattice.positions, index: pos, value: updatedNode) + } + + //for i, candidate in enumerate(candidates): + // next_pos = pos + len(candidate) + // logp_lex_i = self.get_logp_lex(logp_lex, candidate) + // logp_chr_i = logp_chr[i] + // logp_i = torch_util.logsumexp([logg[0] + logp_lex_i, + // logg[1] + logp_chr_i]) + // # Create an edge. + // edge = make_edge(pos, next_pos, candidate, logp_i, + // lattice[pos]["semiring_score"], self.order, self.gpu) + // lattice[next_pos]["edges"].append(edge) + + for (i, candidate) in candidates.enumerated() { + let next_pos = pos + candidate.count + let logp_lex_i = get_logp_lex(logp_lex, candidate) + let logp_chr_i = logp_chr[i] + let logp_i = logSumExp(logg[0] + logp_lex_i, logg[1] + logp_chr_i) + let edge = Lattice.Edge( + start: pos, + end: next_pos, + sentence: candidate, + logp: logp_i, + previous: lattice[pos].semiringScore, + order: conf.order) + + // TODO: Mutate in place when AD supports it. + let updatedNode = Lattice.Node( + bestEdge: lattice[next_pos].bestEdge, + bestScore: lattice[next_pos].bestScore, + edges: lattice[next_pos].edges + [edge], + semiringScore: lattice[next_pos].semiringScore + ) + lattice.positions = update(lattice.positions, index: next_pos, value: updatedNode) + } + } + + //lattice[sentence.count].recomputeSemiringScore() + // TODO: Mutate in place when AD supports it. + let updatedNode = Lattice.Node( + bestEdge: lattice[sentence.count].bestEdge, + bestScore: lattice[sentence.count].bestScore, + edges: lattice[sentence.count].edges, + semiringScore: lattice[sentence.count].computeSemiringScore() + ) + lattice.positions = update(lattice.positions, index: sentence.count, value: updatedNode) + + return lattice + } +} + +func update(_ arr: [T], index: Int, value: T) -> [T] { + var m = arr + m[index] = value + return m +} + +@derivative(of: update) +func vjpupdate(_ arr: [T], index: Int, value: T) -> ( + value: [T], + pullback: (Array.TangentVector) -> (Array.TangentVector, T.TangentVector) +) { + func pullback(_ tv: Array.TangentVector) -> (Array.TangentVector, T.TangentVector) { + var m = tv + m[index] = T.TangentVector.zero + return (m, tv[index]) + } + return (update(arr, index: index, value: value), pullback) +} + +public struct MLP: Layer { + public var dense1: Dense + public var dropout: Dropout + public var dense2: Dense + + public init(nIn: Int, nHidden: Int, nOut: Int, dropoutProbability: Double) { + dense1 = Dense(inputSize: nIn, outputSize: nHidden, activation: tanh) + dropout = Dropout(probability: dropoutProbability) + dense2 = Dense(inputSize: nHidden, outputSize: nOut, activation: logSoftmax) + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + return dense2(dropout(dense1(input))) + } +} + +@differentiable +func computeDecoderResult(_ states: [LSTMCell.State]) -> Tensor { + Tensor(stacking: states.differentiableMap(extractHidden)) +} + +@differentiable +func extractHidden(_ state: LSTMCell.State) -> Tensor { + return state.hidden +} + +@differentiable +func computeEncoderResult(_ states: [LSTMCell.State]) -> [Tensor] { + states.differentiableMap(extractHiddenSqueezed) +} + +@differentiable +func extractHiddenSqueezed(_ state: LSTMCell.State) -> Tensor { + return state.hidden.squeezingShape(at: 0) +} + +@differentiable +func rankLift(_ x: Tensor) -> Tensor { + return x.rankLifted() +} + +@differentiable +func makeEncoderInput(_ x: Tensor) -> [Tensor] { + return x.unstacked().differentiableMap(rankLift) +} + +// TODO: Move this derivative into tensorflow-apis +extension RecurrentLayer { + @differentiable(wrt: (self, inputs, initialState)) + public func callAsFunction2( + _ inputs: [Cell.TimeStepInput], + initialState: Cell.State + ) -> [Cell.TimeStepOutput] { + if inputs.isEmpty { return [Cell.TimeStepOutput]() } + var currentHiddenState = initialState + var timeStepOutputs: [Cell.TimeStepOutput] = [] + for timeStepInput in inputs { + let output = cell(input: timeStepInput, state: currentHiddenState) + currentHiddenState = output.state + timeStepOutputs.append(output.output) + } + return timeStepOutputs + } + + @usableFromInline + @derivative(of: callAsFunction2, wrt: (self, inputs, initialState)) + internal func _vjpCallAsFunctionWrtMore( + _ inputs: [Cell.TimeStepInput], + initialState: Cell.State + ) -> ( + value: [Cell.TimeStepOutput], + pullback: (Array.TangentVector) + -> (TangentVector, Array.TangentVector, Cell.State.TangentVector) + ) { + let timeStepCount = inputs.count + var currentHiddenState = initialState + var timeStepOutputs: [Cell.TimeStepOutput] = [] + timeStepOutputs.reserveCapacity(timeStepCount) + var backpropagators: [Cell.Backpropagator] = [] + backpropagators.reserveCapacity(timeStepCount) + for timestep in inputs { + let (output, backpropagator) = cell.appliedForBackpropagation( + to: .init(input: timestep, state: currentHiddenState)) + currentHiddenState = output.state + timeStepOutputs.append(output.output) + backpropagators.append(backpropagator) + } + return (timeStepOutputs, { 𝛁outputs in + precondition(𝛁outputs.base.count == timeStepCount, + "The number of output gradients must equal the number of time steps") + var 𝛁cell = Cell.TangentVector.zero + var 𝛁state = Cell.State.TangentVector.zero + var reversed𝛁inputs: [Cell.TimeStepInput.TangentVector] = [] + reversed𝛁inputs.reserveCapacity(timeStepCount) + for (𝛁output, backpropagator) in zip(𝛁outputs.base, backpropagators).reversed() { + let (new𝛁cell, 𝛁input) = backpropagator(.init(output: 𝛁output, state: 𝛁state)) + 𝛁cell += new𝛁cell + 𝛁state = 𝛁input.state + reversed𝛁inputs.append(𝛁input.input) + } + return (.init(cell: 𝛁cell), .init(Array(reversed𝛁inputs.reversed())), 𝛁state) + }) + } +} + +// TODO: Better way of dealing with this problem. +func scalarsWithADHack(_ t: Tensor) -> [Float] { + t.scalars +} + +@derivative(of: scalarsWithADHack) +func vjpScalarsHack(_ t: Tensor) -> (value: [Float], pullback: (Array.TangentVector) -> Tensor) { + // TODO: Capture less stuff. + func pullback(_ tv: Array.TangentVector) -> Tensor { + if tv.count == 0 { + return Tensor(zeros: t.shape) + } + return Tensor(shape: t.shape, scalars: tv.base) + } + return (t.scalars, pullback) +} diff --git a/Models/Text/WordSeg/README.md b/Models/Text/WordSeg/README.md new file mode 100644 index 00000000000..a5254400901 --- /dev/null +++ b/Models/Text/WordSeg/README.md @@ -0,0 +1,11 @@ +# WordSeg Model + +This is a Swift implementation of the paper +["Learning to Discover, Ground, and Use Words with Segmental Neural Language +Models"][paper] +by Kazuya Kawakami, Chris Dyer, and Phil Blunsom. + +This implementation is not affiliated with DeepMind and has not been verified by +the authors. + +[paper]: https://www.aclweb.org/anthology/P19-1645.pdf diff --git a/Models/Text/WordSeg/SE-0259.swift b/Models/Text/WordSeg/SE-0259.swift new file mode 100644 index 00000000000..4dbdf7185fe --- /dev/null +++ b/Models/Text/WordSeg/SE-0259.swift @@ -0,0 +1,153 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file comes from SE-0259, currently undergoing the Swift Evolution +// process. +// https://github.com/apple/swift-evolution/blob/master/proposals/0259-approximately-equal.md + +extension FloatingPoint { + /// Test approximate equality with relative tolerance. + /// + /// Do not use this function to check if a number is approximately + /// zero; no reasoned relative tolerance can do what you want for + /// that case. Use `isAlmostZero` instead for that case. + /// + /// The relation defined by this predicate is symmetric and reflexive + /// (except for NaN), but *is not* transitive. Because of this, it is + /// often unsuitable for use for key comparisons, but it can be used + /// successfully in many other contexts. + /// + /// The internet is full advice about what not to do when comparing + /// floating-point values: + /// + /// - "Never compare floats for equality." + /// - "Always use an epsilon." + /// - "Floating-point values are always inexact." + /// + /// Much of this advice is false, and most of the rest is technically + /// correct but misleading. Almost none of it provides specific and + /// correct recommendations for what you *should* do if you need to + /// compare floating-point numbers. + /// + /// There is no uniformly correct notion of "approximate equality", and + /// there is no uniformly correct tolerance that can be applied without + /// careful analysis. This function considers two values to be almost + /// equal if the relative difference between them is smaller than the + /// specified `tolerance`. + /// + /// The default value of `tolerance` is `sqrt(.ulpOfOne)`; this value + /// comes from the common numerical analysis wisdom that if you don't + /// know anything about a computation, you should assume that roughly + /// half the bits may have been lost to rounding. This is generally a + /// pretty safe choice of tolerance--if two values that agree to half + /// their bits but are not meaningfully almost equal, the computation + /// is likely ill-conditioned and should be reformulated. + /// + /// For more complete guidance on an appropriate choice of tolerance, + /// consult with a friendly numerical analyst. + /// + /// - Parameters: + /// - other: the value to compare with `self` + /// - tolerance: the relative tolerance to use for the comparison. + /// Should be in the range [.ulpOfOne, 1). + /// + /// - Returns: `true` if `self` is almost equal to `other`; otherwise + /// `false`. + @inlinable + public func isAlmostEqual( + to other: Self, + tolerance: Self = Self.ulpOfOne.squareRoot() + ) -> Bool { + // Tolerances outside of [.ulpOfOne, 1) yield well-defined but useless + // results, so this is enforced by an assert rathern than a precondition. + assert(tolerance >= .ulpOfOne && tolerance < 1, "tolerance should be in [.ulpOfOne, 1).") + // The simple computation below does not necessarily give sensible + // results if one of self or other is infinite; we need to rescale + // the computation in that case. + guard self.isFinite && other.isFinite else { + return rescaledAlmostEqual(to: other, tolerance: tolerance) + } + // This should eventually be rewritten to use a scaling facility to be + // defined on FloatingPoint suitable for hypot and scaled sums, but the + // following is good enough to be useful for now. + let scale = max(abs(self), abs(other), .leastNormalMagnitude) + return abs(self - other) < scale*tolerance + } + + /// Test if this value is nearly zero with a specified `absoluteTolerance`. + /// + /// This test uses an *absolute*, rather than *relative*, tolerance, + /// because no number should be equal to zero when a relative tolerance + /// is used. + /// + /// Some very rough guidelines for selecting a non-default tolerance for + /// your computation can be provided: + /// + /// - If this value is the result of floating-point additions or + /// subtractions, use a tolerance of `.ulpOfOne * n * scale`, where + /// `n` is the number of terms that were summed and `scale` is the + /// magnitude of the largest term in the sum. + /// + /// - If this value is the result of floating-point multiplications, + /// consider each term of the product: what is the smallest value that + /// should be meaningfully distinguished from zero? Multiply those terms + /// together to get a tolerance. + /// + /// - More generally, use half of the smallest value that should be + /// meaningfully distinct from zero for the purposes of your computation. + /// + /// For more complete guidance on an appropriate choice of tolerance, + /// consult with a friendly numerical analyst. + /// + /// - Parameter absoluteTolerance: values with magnitude smaller than + /// this value will be considered to be zero. Must be greater than + /// zero. + /// + /// - Returns: `true` if `abs(self)` is less than `absoluteTolerance`. + /// `false` otherwise. + @inlinable + public func isAlmostZero( + absoluteTolerance tolerance: Self = Self.ulpOfOne.squareRoot() + ) -> Bool { + assert(tolerance > 0) + return abs(self) < tolerance + } + + /// Rescales self and other to give meaningful results when one of them + /// is infinite. We also handle NaN here so that the fast path doesn't + /// need to worry about it. + @usableFromInline + internal func rescaledAlmostEqual(to other: Self, tolerance: Self) -> Bool { + // NaN is considered to be not approximately equal to anything, not even + // itself. + if self.isNaN || other.isNaN { return false } + if self.isInfinite { + if other.isInfinite { return self == other } + // Self is infinite and other is finite. Replace self with the binade + // of the greatestFiniteMagnitude, and reduce the exponent of other by + // one to compensate. + let scaledSelf = Self(sign: self.sign, + exponent: Self.greatestFiniteMagnitude.exponent, + significand: 1) + let scaledOther = Self(sign: .plus, + exponent: -1, + significand: other) + // Now both values are finite, so re-run the naive comparison. + return scaledSelf.isAlmostEqual(to: scaledOther, tolerance: tolerance) + } + // If self is finite and other is infinite, flip order and use scaling + // defined above, since this relation is symmetric. + return other.rescaledAlmostEqual(to: self, tolerance: tolerance) + } +} diff --git a/Models/Text/WordSeg/SemiRing.swift b/Models/Text/WordSeg/SemiRing.swift new file mode 100644 index 00000000000..39909dbb29c --- /dev/null +++ b/Models/Text/WordSeg/SemiRing.swift @@ -0,0 +1,89 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if os(iOS) || os(macOS) || os(tvOS) || os(watchOS) +import Darwin +#elseif os(Windows) +import ucrt +#else +import Glibc +#endif + +/// logSumExp(_:_:) +/// +/// Specialized logSumExp for 2 float. +@differentiable +public func logSumExp(_ lhs: Float, _ rhs: Float) -> Float { + let maxVal = max(lhs, rhs) + let sumExp = exp(lhs - maxVal) + exp(rhs - maxVal) + return maxVal + log(sumExp) +} + +@derivative(of: logSumExp) +public func vjpLogSumExp(_ lhs: Float, _ rhs: Float) -> ( + value: Float, + pullback: (Float) -> (Float, Float) +) { + func pb(v: Float) -> (Float, Float) { + let maxVal = max(lhs, rhs) + let sumExp = exp(lhs - maxVal) + exp(rhs - maxVal) + return (v * exp(lhs - maxVal) / sumExp, v * exp(rhs - maxVal) / sumExp) + } + return (logSumExp(lhs, rhs), pb) +} + +/// SemiRing +/// +/// Represents a SemiRing +public struct SemiRing: Differentiable { + public var logp: Float + public var logr: Float + + @differentiable + public init(logp: Float, logr: Float) { + self.logp = logp + self.logr = logr + } + + static var zero: SemiRing { SemiRing(logp: -Float.infinity, logr: -Float.infinity) } + static var one: SemiRing { SemiRing(logp: 0.0, logr: -Float.infinity) } +} + +@differentiable +func *(_ lhs: SemiRing, _ rhs: SemiRing) -> SemiRing { + return SemiRing(logp: lhs.logp + rhs.logp, + logr: logSumExp(lhs.logp + rhs.logr, rhs.logp + lhs.logr)) +} + +@differentiable +func +(_ lhs: SemiRing, _ rhs: SemiRing) -> SemiRing { + return SemiRing(logp: logSumExp(lhs.logp, rhs.logp), + logr: logSumExp(lhs.logr, rhs.logr)) +} + +extension SemiRing { + var shortDescription: String { + "(\(logp), \(logr))" + } +} + +/// SE-0259-esque equality with tolerance +extension SemiRing { + // TODO(abdulras) see if we can use ulp as a default tolerance + @inlinable + public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { + return self.logp.isAlmostEqual(to: other.logp, tolerance: tolerance) && + self.logr.isAlmostEqual(to: other.logr, tolerance: tolerance) + } +} diff --git a/Models/Text/WordSeg/Vocabularies.swift b/Models/Text/WordSeg/Vocabularies.swift new file mode 100644 index 00000000000..bb3e9131d83 --- /dev/null +++ b/Models/Text/WordSeg/Vocabularies.swift @@ -0,0 +1,179 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import TensorFlow +import ModelSupport + +/// Alphabet maps from characters in a string to Int32 representations. +/// +/// Note: we map from String in order to support multi-character metadata sequences such as . +/// +/// In Python implementations, this is sometimes called the character vocabulary. +public struct Alphabet { + public typealias Element = String + + var dictionary: BijectiveDictionary + + let eos: Int32 + let eow: Int32 + let pad: Int32 + + public init(_ letters: C, eos: String, eow: String, pad: String) + where C.Element == Character { + self.dictionary = .init(zip(letters.lazy.map { String($0) }, 0...)) + + self.eos = Int32(self.dictionary.count) + self.dictionary[eos] = self.eos + + self.eow = Int32(self.dictionary.count) + self.dictionary[eow] = self.eow + + self.pad = Int32(self.dictionary.count) + self.dictionary[pad] = self.pad + } + + public init(_ letters: C, eos: String, eow: String, pad: String) + where C.Element == Element { + self.dictionary = .init(zip(letters.lazy.map { String($0) }, 0...)) + + self.eos = Int32(self.dictionary.count) + self.dictionary[eos] = self.eos + + self.eow = Int32(self.dictionary.count) + self.dictionary[eow] = self.eow + + self.pad = Int32(self.dictionary.count) + self.dictionary[pad] = self.pad + } + + var count: Int { return dictionary.count } + + subscript(key: String) -> Int32? { + return dictionary[key] + } +} + +/// An Int32-based representation of a string to be used with the WordSeg model. +public struct CharacterSequence: Hashable { + let characters: [Int32] + private let eos: Int32 + + public init(_debug: Int) { + self.characters = [] + self.eos = -1 + } + + public init(alphabet: Alphabet, appendingEoSTo string: String) throws { + var characters = [Int32]() + characters.reserveCapacity(string.count + 1) + for (index, character) in string.enumerated() { + guard let value = alphabet[String(character)] else { + throw CharacterErrors.unknownCharacter(character: character, index: index, sentence: string) + } + characters.append(value) + } + characters.append(alphabet.eos) + self.init(alphabet: alphabet, characters: characters) + } + + private init(alphabet: Alphabet, characters: [Int32]) { + self.characters = characters + self.eos = alphabet.eos + } + + public init(alphabet: Alphabet, characters: ArraySlice) { + self.characters = Array(characters) + self.eos = alphabet.eos + } + + subscript(index: Int32) -> Int32 { + return characters[Int(index)] + } + + subscript(range: Range) -> ArraySlice { + return characters[range] + } + + public var count: Int { return characters.count } + var last: Int32? { return characters.last } + var tensor: Tensor { + Tensor([self.eos] + characters[0 ..< characters.count - 1]) + } +} + +extension CharacterSequence: CustomStringConvertible { + public var description: String { + "\(characters)" + } +} + +/// A mapping from characters to logical words. +/// +/// In Python implementations, this is sometimes called the String Vocabulary (which is in +/// contrast with the character vocabulary which maps the alphabet to Int32's). +public struct Lexicon { + public typealias Element = CharacterSequence + + // TODO(marcrasi): if the value is not used to construct Tensor, switch to Int + var dictionary: BijectiveDictionary + + var count: Int { return dictionary.count } + + public init(_ sequences: C) where C.Element == Element { + self.dictionary = .init(zip(sequences, 0...)) + } + + public init( + from sequences: [CharacterSequence], + alphabet: Alphabet, + maxLength: Int, + minFreq: Int + ) { + var histogram: [ArraySlice:Int] = [:] + + for sentence in sequences { + // NOTE: the use of `sentence.count - 1` is to ensure that we ignore the + // trailing `EoS` marker. + for i in 0 ..< sentence.count - 1 { + for j in 1 ... maxLength { + let e = min(i + j, sentence.count - 1) + // Store strings longer than 2. + guard e - i > 1 else { continue } + histogram[sentence[i ..< e], default: 0] += 1 + } + } + } + + let frequentWordCandidates = histogram.filter { $0.1 >= minFreq } + let vocab = frequentWordCandidates.map { CharacterSequence(alphabet: alphabet, characters: $0.0) } + + self.init(vocab) + } +} + +public enum CharacterErrors: Error { + case unknownCharacter(character: Character, index: Int, sentence: String) + case nonUtf8Data +} + +extension CharacterErrors: CustomStringConvertible { + public var description: String { + switch self { + case let .unknownCharacter(character, index, sentence): + return "Unknown character '\(character)' encountered at index \(index) while converting sentence \"\(sentence)\" to a character sequence." + case .nonUtf8Data: + return "Non-UTF8 data encountered." + } + } +} diff --git a/Tests/TextTests/CMakeLists.txt b/Tests/TextTests/CMakeLists.txt index 79a22834a26..796104ee119 100644 --- a/Tests/TextTests/CMakeLists.txt +++ b/Tests/TextTests/CMakeLists.txt @@ -1,5 +1,9 @@ add_library(TextTests Inference.swift + WordSegmentationTests/ExampleData.swift + WordSegmentationTests/ProbeLayers.swift + WordSegmentationTests/TorchParameters.swift + WordSegmentationTests/WordSegmentationTests.swift XCTestManifests.swift) set_target_properties(TextTests PROPERTIES RUNTIME_OUTPUT_DIRECTORY $ diff --git a/Tests/TextTests/WordSegmentationTests/ExampleData.swift b/Tests/TextTests/WordSegmentationTests/ExampleData.swift new file mode 100644 index 00000000000..7b0b2b2552c --- /dev/null +++ b/Tests/TextTests/WordSegmentationTests/ExampleData.swift @@ -0,0 +1,429 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import TensorFlow +import TextModels + +// Generated by `blaze run //experimental/users/marcrasi/probe_wordseg:probe` +enum Example1 { + static let parameters = TorchSNLMParameters( + emb_enc: TorchEmbeddingParameters( + weight: Tensor( + [[-0.03872304 ,-0.05338321 ], + [ 0.051871493 ,-0.024247635 ], + [-0.05366139 , 0.05159463 ], + [ 0.047521368 , 0.010823324 ], + [-0.0012689009,-0.07714497 ]] + ) + ), + lstm_enc: TorchLSTMParameters( + weight_ih_l0: Tensor( + [[ 0.029065304, 0.035071626], + [ 0.05220075 , 0.07659577 ], + [-0.03418843 , 0.054938704], + [-0.076050006, 0.011613108], + [-0.07208062 ,-0.028602105], + [ 0.07454459 ,-0.045288637], + [ 0.056929305, 0.07525672 ], + [-0.044981938,-0.06287278 ]] + ), + weight_hh_l0: Tensor( + [[ 0.054794624,-0.047684688], + [ 0.059219867, 0.035521813], + [ 0.07620007 ,-0.024476558], + [ 0.06970246 ,-0.012173824], + [-0.01848751 ,-0.040551327], + [ 0.050961852, 0.03887064 ], + [ 0.054652542,-0.015333958], + [ 0.048499897, 0.044002943]] + ), + bias_ih_l0: Tensor( + [-0.061996475, 0.07747544 , 0.04455457 ,-0.016387768,-0.07249636 , + 0.047923 ,-0.040993594,-0.011194035] + ), + bias_hh_l0: Tensor( + [ 1.9561797e-02,-4.7253761e-02,-1.0812007e-02, 7.7792764e-02, + -3.5792589e-05,-3.5584867e-02, 5.0887465e-05,-3.5637055e-02] + ) + ), + mlp_interpolation: TorchMLPParameters( + linear1: TorchLinearParameters( + weight: Tensor( + [[-0.034370955,-0.07964967 ], + [ 0.019697152,-0.059639234]] + ), + bias: Tensor( + [0.047385678,0.05388096 ] + ) + ), + linear2: TorchLinearParameters( + weight: Tensor( + [[-0.04711317 , 0.07374136 ], + [-0.058497474,-0.011900932]] + ), + bias: Tensor( + [-0.008636497, 0.052838206] + ) + ) + ), + mlp_memory: TorchMLPParameters( + linear1: TorchLinearParameters( + weight: Tensor( + [[-0.06895532 , 0.0675631 ], + [ 0.062555104,-0.006975107]] + ), + bias: Tensor( + [0.021453634 ,0.0056577697] + ) + ), + linear2: TorchLinearParameters( + weight: Tensor( + [[ 0.01337228 , 0.07287796 ], + [-0.025111437,-0.021482762], + [-0.05161675 ,-0.06811503 ], + [-0.072463006, 0.015226476]] + ), + bias: Tensor( + [ 0.0032938793,-0.043962937 , 0.043240592 , 0.0678826 ] + ) + ) + ), + emb_dec: TorchEmbeddingParameters( + weight: Tensor( + [[ 0.055306703, 0.07086456 ], + [ 0.03423354 ,-0.015132636], + [-0.04077827 , 0.016811028], + [-0.037189033,-0.07027687 ], + [ 0.054974243, 0.017300054]] + ) + ), + lstm_dec: TorchLSTMParameters( + weight_ih_l0: Tensor( + [[-0.053585358,-0.0642758 ], + [-0.07246614 , 0.025658146], + [ 0.034285776,-0.014611781], + [ 0.058412 , 0.047652483], + [ 0.065825045, 0.042562716], + [ 0.050531074, 0.047255352], + [-0.03512928 , 0.004992813], + [ 0.005484812,-0.054734543]] + ), + weight_hh_l0: Tensor( + [[ 0.07812595 , 0.0031644031], + [-0.04185462 , 0.03933753 ], + [-0.044581212 ,-0.018176649 ], + [ 0.07533194 , 0.0030083433], + [-0.045243938 ,-0.026109837 ], + [ 0.046121553 ,-0.053141937 ], + [ 0.011378422 , 0.067420706 ], + [-0.05194992 , 0.044939123 ]] + ), + bias_ih_l0: Tensor( + [ 0.049315616,-0.05961135 ,-0.047641095, 0.056274325,-0.071667776, + 0.049188778,-0.05663743 ,-0.051864214] + ), + bias_hh_l0: Tensor( + [ 0.055988565 , 0.01968135 ,-0.057932526 , 0.024752177 ,-0.029085837 , + -0.03911104 ,-0.0015038475, 0.051634952 ] + ) + ), + linear_dec: TorchLinearParameters( + weight: Tensor( + [[-0.03417959 , 0.04824567 ], + [ 0.0559683 , 0.0076355636], + [-0.03857645 , 0.015529476 ], + [ 0.057112962 , 0.036605842 ], + [ 0.023432933 ,-0.023976203 ]] + ), + bias: Tensor( + [-0.03640049 ,-0.057923757, 0.05912192 , 0.03688284 , 0.06261988 ] + ) + ) + ) + static let expectedEncoding = Tensor( + [[-0.016786836 , 0.0014875316], + [-0.024643049 , 0.003509362 ], + [-0.030508908 , 0.005803111 ], + [-0.031535156 , 0.0056067896]] + ) + static let expectedMLPInterpolationOutput = Tensor( + [[-0.72172225,-0.665366 ], + [-0.7217337 ,-0.6653552 ], + [-0.72174466,-0.66534483], + [-0.7217447 ,-0.6653447 ]] + ) + static let expectedMLPMemoryOutput = Tensor( + [[-1.400075 ,-1.4486395,-1.3622522,-1.3377004], + [-1.4000794,-1.4486222,-1.3622293,-1.3377339], + [-1.4000806,-1.4486088,-1.3622129,-1.3377609], + [-1.4000823,-1.448607 ,-1.3622097,-1.3377641]] + ) + static let expectedDecoded = Tensor( + [-6.5631247,-4.930211 ] + ) + static let lattice = Lattice( + positions: [ + Lattice.Node( + bestEdge: nil, + bestScore: 0.0, + edges: [ + ], + semiringScore: SemiRing(logp: 0.0, logr: -Float.infinity) + ), + Lattice.Node( + bestEdge: nil, + bestScore: 0.0, + edges: [ + Lattice.Edge( + start: 0, + end: 1, + string: CharacterSequence(_debug: 1), + logp: -3.9122378826141357, + score: SemiRing(logp: -3.9122378826141357, logr: -3.9122378826141357), + totalScore: SemiRing(logp: -3.9122378826141357, logr: -3.9122378826141357) + ) + ], + semiringScore: SemiRing(logp: -3.9122378826141357, logr: -3.9122378826141357) + ), + Lattice.Node( + bestEdge: nil, + bestScore: 0.0, + edges: [ + Lattice.Edge( + start: 0, + end: 2, + string: CharacterSequence(_debug: 1), + logp: -2.0545620918273926, + score: SemiRing(logp: -2.0545620918273926, logr: 1.4111738204956055), + totalScore: SemiRing(logp: -2.0545620918273926, logr: 1.4111738204956055) + ), + Lattice.Edge( + start: 1, + end: 2, + string: CharacterSequence(_debug: 1), + logp: -3.936295986175537, + score: SemiRing(logp: -3.936295986175537, logr: -3.936295986175537), + totalScore: SemiRing(logp: -7.848533630371094, logr: -7.155386447906494) + ) + ], + semiringScore: SemiRing(logp: -2.051520824432373, logr: 1.411364197731018) + ), + Lattice.Node( + bestEdge: nil, + bestScore: 0.0, + edges: [ + Lattice.Edge( + start: 0, + end: 3, + string: CharacterSequence(_debug: 1), + logp: -7.253466606140137, + score: SemiRing(logp: -7.253466606140137, logr: -1.7604050636291504), + totalScore: SemiRing(logp: -7.253466606140137, logr: -1.7604050636291504) + ), + Lattice.Edge( + start: 1, + end: 3, + string: CharacterSequence(_debug: 1), + logp: -2.0307273864746094, + score: SemiRing(logp: -2.0307273864746094, logr: 1.4350085258483887), + totalScore: SemiRing(logp: -5.942965507507324, logr: -2.446457624435425) + ), + Lattice.Edge( + start: 2, + end: 3, + string: CharacterSequence(_debug: 1), + logp: -3.912229061126709, + score: SemiRing(logp: -3.912229061126709, logr: -3.912229061126709), + totalScore: SemiRing(logp: -5.963749885559082, logr: -2.4700069427490234) + ) + ], + semiringScore: SemiRing(logp: -5.1324286460876465, logr: -1.0695605278015137) + ), + Lattice.Node( + bestEdge: nil, + bestScore: 0.0, + edges: [ + Lattice.Edge( + start: 0, + end: 4, + string: CharacterSequence(_debug: 1), + logp: -8.936973571777344, + score: SemiRing(logp: -8.936973571777344, logr: -2.0055017471313477), + totalScore: SemiRing(logp: -8.936973571777344, logr: -2.0055017471313477) + ), + Lattice.Edge( + start: 1, + end: 4, + string: CharacterSequence(_debug: 1), + logp: -7.277979850769043, + score: SemiRing(logp: -7.277979850769043, logr: -1.7849183082580566), + totalScore: SemiRing(logp: -11.190217971801758, logr: -5.69304895401001) + ), + Lattice.Edge( + start: 2, + end: 4, + string: CharacterSequence(_debug: 1), + logp: -2.0545456409454346, + score: SemiRing(logp: -2.0545456409454346, logr: 1.4111902713775635), + totalScore: SemiRing(logp: -4.106066703796387, logr: 0.051392197608947754) + ), + Lattice.Edge( + start: 3, + end: 4, + string: CharacterSequence(_debug: 1), + logp: -3.936281204223633, + score: SemiRing(logp: -3.936281204223633, logr: -3.936281204223633), + totalScore: SemiRing(logp: -9.068710327148438, logr: -4.98878812789917) + ) + ], + semiringScore: SemiRing(logp: -4.090378284454346, logr: 0.1802457720041275) + ) + ] + ) + static let gradWrtLogR = TorchSNLMParameters( + emb_enc: TorchEmbeddingParameters( + weight: Tensor( + [[-1.0885849e-05, 2.5600420e-05], + [-2.0048559e-05, 5.2774609e-05], + [-2.1089169e-05, 6.0592978e-05], + [ 0.0000000e+00, 0.0000000e+00], + [ 0.0000000e+00, 0.0000000e+00]] + ) + ), + lstm_enc: TorchLSTMParameters( + weight_ih_l0: Tensor( + [[-5.5954405e-07, 2.1259700e-07], + [-9.2493622e-08, 1.2687329e-07], + [ 4.8279912e-07,-5.5000936e-07], + [-1.1893751e-07, 9.6952142e-08], + [ 1.7744465e-05,-6.3974912e-06], + [ 2.3670214e-05,-6.8028967e-06], + [ 6.9926745e-07, 1.0588951e-07], + [-3.6788197e-07, 1.1367500e-07]] + ), + weight_hh_l0: Tensor( + [[-6.1557824e-07, 7.9759928e-08], + [ 1.8759494e-07,-2.4723642e-08], + [-3.8983197e-07, 5.1562441e-08], + [ 7.6582531e-08,-1.0336059e-08], + [ 1.6366921e-05,-2.1041824e-06], + [ 2.5587158e-05,-3.2796638e-06], + [-7.9833825e-07, 1.1357896e-07], + [ 2.2801314e-07,-3.2397072e-08]] + ), + bias_ih_l0: Tensor( + [ 5.0008435e-05,-1.0976097e-05, 1.7231478e-05,-3.3032948e-06, + -1.3631615e-03,-2.0683720e-03, 5.0045674e-05,-1.1407620e-05] + ), + bias_hh_l0: Tensor( + [ 5.0008435e-05,-1.0976097e-05, 1.7231478e-05,-3.3032948e-06, + -1.3631615e-03,-2.0683720e-03, 5.0045674e-05,-1.1407620e-05] + ) + ), + mlp_interpolation: TorchMLPParameters( + linear1: TorchLinearParameters( + weight: Tensor( + [[-2.162833e-04, 3.404662e-05], + [-1.626215e-03, 2.559960e-04]] + ), + bias: Tensor( + [0.008965784,0.06741227 ] + ) + ), + linear2: TorchLinearParameters( + weight: Tensor( + [[ 0.03779146 , 0.041938413], + [-0.03779146 ,-0.041938413]] + ), + bias: Tensor( + [ 0.7893658,-0.7893658] + ) + ) + ), + mlp_memory: TorchMLPParameters( + linear1: TorchLinearParameters( + weight: Tensor( + [[ 0.0006783868 ,-0.00010455749], + [ 0.002728255 ,-0.00042056653]] + ), + bias: Tensor( + [-0.028684907,-0.11537111 ] + ) + ), + linear2: TorchLinearParameters( + weight: Tensor( + [[-0.0098278085,-0.0017497203], + [-0.009362103 ,-0.0016668034], + [ 0.029616904 , 0.005273029 ], + [-0.010426991 ,-0.0018565056]] + ), + bias: Tensor( + [-0.4213174 ,-0.40135252, 1.2696781 ,-0.44700825] + ) + ) + ), + emb_dec: TorchEmbeddingParameters( + weight: Tensor( + [[2.8222625e-04,1.9048888e-04], + [1.8905671e-04,1.5085124e-04], + [0.0000000e+00,0.0000000e+00], + [3.5629273e-06,2.5625954e-05], + [0.0000000e+00,0.0000000e+00]] + ) + ), + lstm_dec: TorchLSTMParameters( + weight_ih_l0: Tensor( + [[-1.3549015e-05,-1.4535895e-05], + [ 3.1118125e-07,-1.6063163e-07], + [-9.2411565e-06,-8.3111609e-06], + [ 3.1314161e-07,-9.0987783e-08], + [ 2.9817267e-04, 3.1956812e-04], + [ 2.4017896e-05,-1.0910597e-04], + [-1.9980842e-05,-2.4779467e-05], + [ 1.5680398e-07,-8.3687371e-07]] + ), + weight_hh_l0: Tensor( + [[ 7.7975146e-06,-7.3286355e-07], + [-4.4332864e-07, 4.7302478e-08], + [ 7.1966651e-06,-7.2089733e-07], + [-3.4162869e-07, 4.1011177e-08], + [-1.7509436e-04, 1.6362043e-05], + [-1.0614634e-04, 1.1837798e-05], + [ 9.3842154e-06,-8.1990322e-07], + [-4.8998282e-07, 6.6971602e-08]] + ), + bias_ih_l0: Tensor( + [-1.8724482e-04, 1.3764327e-05,-1.8988468e-04, 8.8408688e-06, + 4.2812643e-03, 3.6167959e-03,-2.0367328e-04, 1.3456454e-05] + ), + bias_hh_l0: Tensor( + [-1.8724482e-04, 1.3764327e-05,-1.8988468e-04, 8.8408688e-06, + 4.2812643e-03, 3.6167959e-03,-2.0367328e-04, 1.3456454e-05] + ) + ), + linear_dec: TorchLinearParameters( + weight: Tensor( + [[-0.0041010594 , 0.00012692134], + [-0.005898418 , 0.0008138743 ], + [ 0.006042951 ,-0.0006127681 ], + [-0.0020920308 , 0.00028523407], + [ 0.0060485564 ,-0.0006132616 ]] + ), + bias: Tensor( + [ 0.14542846 , 0.14904119 ,-0.16054428 , 0.026781656,-0.16070703 ] + ) + ) + ) +} + diff --git a/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift b/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift new file mode 100644 index 00000000000..25999eb2e07 --- /dev/null +++ b/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift @@ -0,0 +1,209 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +import TensorFlow +import TextModels + +extension SNLM { + /// Sets the model parameters to the given parameters exported from the pytorch model. + mutating func setTorchParameters(_ p: TorchSNLMParameters) { + setEmbedding(&embEnc, to: p.emb_enc) + setLSTM(&lstmEnc.cell, to: p.lstm_enc) + setMLP(&mlpInterpolation, to: p.mlp_interpolation) + setMLP(&mlpMemory, to: p.mlp_memory) + setEmbedding(&embDec, to: p.emb_dec) + setLSTM(&lstmDec.cell, to: p.lstm_dec) + setDense(&denseDec, to: p.linear_dec) + } + + private func checkShapeAndSet(_ tensor: inout Tensor, to value: Tensor) { + assert(tensor.shape == value.shape, "shape mismatch while setting: \(tensor.shape) to \(value.shape)") + tensor = value + } + + // Sets the given Embedding parameters to the given pytorch embedding parameters. + private func setEmbedding(_ embedding: inout Embedding, to p: TorchEmbeddingParameters) { + checkShapeAndSet(&embedding.embeddings, to: p.weight) + } + + /// Sets the given LSTM cell's parameters to the given pytorch LSTM parameters. + private func setLSTM(_ lstm: inout LSTMCell, to p: TorchLSTMParameters) { + let fusedWeightTorch = p.weight_ih_l0.concatenated(with: p.weight_hh_l0, alongAxis: 1).transposed() + let i = fusedWeightTorch.shape[0] + let j = fusedWeightTorch.shape[1] / 4 + let fusedWeightTF = Tensor( + concatenating: [ + fusedWeightTorch.slice(lowerBounds: [0, 0], upperBounds: [i, j]), + fusedWeightTorch.slice(lowerBounds: [0, 2 * j], upperBounds: [i, 3 * j]), + fusedWeightTorch.slice(lowerBounds: [0, j], upperBounds: [i, 2 * j]), + fusedWeightTorch.slice(lowerBounds: [0, 3 * j], upperBounds: [i, 4 * j]) + ], + alongAxis: 1 + ) + + let fusedBiasTorch = (p.bias_ih_l0 + p.bias_hh_l0) + let k = fusedBiasTorch.shape[0] / 4 + let fusedBiasTF = Tensor( + concatenating: [ + fusedBiasTorch.slice(lowerBounds: [0], upperBounds: [k]), + fusedBiasTorch.slice(lowerBounds: [2 * k], upperBounds: [3 * k]), + fusedBiasTorch.slice(lowerBounds: [k], upperBounds: [2 * k]), + fusedBiasTorch.slice(lowerBounds: [3 * k], upperBounds: [4 * k]) + ] + ) + + checkShapeAndSet(&lstm.fusedWeight, to: fusedWeightTF) + checkShapeAndSet(&lstm.fusedBias, to: fusedBiasTF) + } + + /// Sets the given MLP's parameters to the given pytorch MLP parameters. + private func setMLP(_ mlp: inout MLP, to p: TorchMLPParameters) { + setDense(&mlp.dense1, to: p.linear1) + setDense(&mlp.dense2, to: p.linear2) + } + + /// Sets the given Dense's parameters to the given pytorch linear parameters. + private func setDense(_ dense: inout Dense, to p: TorchLinearParameters) { + checkShapeAndSet(&dense.weight, to: p.weight.transposed()) + checkShapeAndSet(&dense.bias, to: p.bias) + } +} + +func tangentVector(from torchGradient: TorchSNLMParameters, model: SNLM) -> SNLM.TangentVector { + var model = model + model.setTorchParameters(torchGradient) + + // `model.setTorchParameters` is for model parameters, not for gradients, so + // we need to adjust the LSTM biases, whose gradients work differently than + // `model.setTorchParameters` does. + model.lstmEnc.cell.fusedBias /= 2 + model.lstmDec.cell.fusedBias /= 2 + + return model.differentiableVectorView +} + +func almostEqual(_ lhs: SNLM.TangentVector, _ rhs: SNLM.TangentVector, relTol: Float, zeroTol: Float) -> Bool { + var success = true + for (kpIndex, kp) in lhs.recursivelyAllKeyPaths(to: Tensor.self).enumerated() { + let t1 = lhs[keyPath: kp] + let t2 = rhs[keyPath: kp] + if t1.shape != t2.shape { + print("Shape mismatch on tensor \(kpIndex)") + success = false + continue + } + for (elementIndex, (t1e, t2e)) in zip(t1.scalars, t2.scalars).enumerated() { + if !(t1e.isAlmostEqual(to: t2e, tolerance: relTol) || (t1e.isAlmostZero(absoluteTolerance: zeroTol) && t2e.isAlmostZero(absoluteTolerance: zeroTol))) { + print("Mismatch on tensor \(kpIndex) element \(elementIndex): \(t1e) & \(t2e)") + success = false + } + } + } + return success +} + +class ProbeLayerTests: XCTestCase { + func testProbeEncoder() { + // chrVocab is: + // 0 - a + // 1 - b + // 2 - + // 3 - + // 4 - + let chrVocab: Alphabet = Alphabet([ + "a", + "b", + ], eos: "", eow: "", pad: "") + + + // strVocab is: + // 0 - aaaa + // 1 - bbbb + // 2 - abab + let strVocab: Lexicon = Lexicon([ + CharacterSequence(alphabet: chrVocab, characters: [0, 0]), // "aa" + CharacterSequence(alphabet: chrVocab, characters: [1, 1]), // "bb" + CharacterSequence(alphabet: chrVocab, characters: [0, 1]), // "ab" + CharacterSequence(alphabet: chrVocab, characters: [1, 0]) // "ba" + ]) + + var model = SNLM(conf: Conf( + ndim: 2, + dropoutProb: 0, + chrVocab: chrVocab, + strVocab: strVocab, + order: 5)) + + model.setTorchParameters(Example1.parameters) + + print("Encoding") + let encoderStates = model.encode(CharacterSequence(alphabet: chrVocab, characters: [0, 1, 0, 1])) // "abab" + let encoderStatesTensor = Tensor(stacking: encoderStates) + print("Expected: \(Example1.expectedEncoding)") + print("Actual: \(encoderStatesTensor)") + XCTAssert(abs(encoderStatesTensor - Example1.expectedEncoding).max().scalarized() < 1e-6) + print("OK!\n") + + print("MLP Interpolation") + let mlpInterpolationOutput = model.mlpInterpolation(encoderStatesTensor) + print("Expected: \(Example1.expectedMLPInterpolationOutput)") + print("Actual: \(encoderStates)") + XCTAssert(abs(mlpInterpolationOutput - Example1.expectedMLPInterpolationOutput).max().scalarized() < 1e-6) + print("OK!\n") + + print("MLP Memory") + let mlpMemoryOutput = model.mlpMemory(encoderStatesTensor) + print("Expected: \(Example1.expectedMLPMemoryOutput)") + print("Actual: \(encoderStates)") + XCTAssert(abs(mlpMemoryOutput - Example1.expectedMLPMemoryOutput).max().scalarized() < 1e-6) + print("OK!\n") + + print("Decode") + let decoded = model.decode( + [ + CharacterSequence(alphabet: chrVocab, characters: [0, 0, 0]), // "aaa" + CharacterSequence(alphabet: chrVocab, characters: [0, 1]) // "ab" + ], + encoderStates[0] + ) + print("Expected: \(Example1.expectedDecoded)") + print("Actual: \(decoded)") + XCTAssert(abs(decoded - Example1.expectedDecoded).max().scalarized() < 1e-6) + print("OK!\n") + + print("Build Lattice") + let abab = CharacterSequence(alphabet: chrVocab, characters: [0, 1, 0, 1]) + let lattice = model.buildLattice(abab, maxLen: 5) + XCTAssert(lattice.isAlmostEqual(to: Example1.lattice, tolerance: 1e-5)) + + print("Gradient") + func f(_ x: SNLM) -> Float { + x.buildLattice(abab, maxLen: 5)[4].semiringScore.logr + } + let (_, grad) = valueWithGradient(at: model, in: f) + let expectedGrad = tangentVector(from: Example1.gradWrtLogR, model: model) + + if !almostEqual(grad, expectedGrad, relTol: 1e-5, zeroTol: 1e-6) { + print("\nExpected grad:\n\(expectedGrad)\n\n") + print("Actual grad:\n\(grad)") + XCTAssert(false, "Gradients wrong") + } + } + + static var allTests = [ + ("testProbeEncoder", testProbeEncoder) + ] +} diff --git a/Tests/TextTests/WordSegmentationTests/TorchParameters.swift b/Tests/TextTests/WordSegmentationTests/TorchParameters.swift new file mode 100644 index 00000000000..2e8ba122cee --- /dev/null +++ b/Tests/TextTests/WordSegmentationTests/TorchParameters.swift @@ -0,0 +1,46 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import TensorFlow + +struct TorchSNLMParameters { + var emb_enc: TorchEmbeddingParameters + var lstm_enc: TorchLSTMParameters + var mlp_interpolation: TorchMLPParameters + var mlp_memory: TorchMLPParameters + var emb_dec: TorchEmbeddingParameters + var lstm_dec: TorchLSTMParameters + var linear_dec: TorchLinearParameters +} + +struct TorchEmbeddingParameters { + var weight: Tensor +} + +struct TorchLSTMParameters { + var weight_ih_l0: Tensor + var weight_hh_l0: Tensor + var bias_ih_l0: Tensor + var bias_hh_l0: Tensor +} + +struct TorchMLPParameters { + var linear1: TorchLinearParameters + var linear2: TorchLinearParameters +} + +struct TorchLinearParameters { + var weight: Tensor + var bias: Tensor +} diff --git a/Tests/TextTests/WordSegmentationTests/WordSegmentationTests.swift b/Tests/TextTests/WordSegmentationTests/WordSegmentationTests.swift new file mode 100644 index 00000000000..3679ce1bb53 --- /dev/null +++ b/Tests/TextTests/WordSegmentationTests/WordSegmentationTests.swift @@ -0,0 +1,162 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +@testable +import TextModels + +class DataSetTests: XCTestCase { + func test_DataSetLoad() { + let buffer: [UInt8] = [ + 0x61, 0x6c, 0x70, 0x68, 0x61, 0x0a, // alpha. + ] + + var dataset: DataSet? + buffer.withUnsafeBytes { pointer in + guard let address = pointer.baseAddress else { return } + let training: Data = + Data(bytesNoCopy: UnsafeMutableRawPointer(mutating: address), + count: pointer.count, deallocator: .none) + dataset = try? DataSet(training: training, validation: nil, testing: nil) + } + + // 'a', 'h', 'l', 'p', '', '', '' + XCTAssertEqual(dataset?.alphabet.count, 7) + XCTAssertEqual(dataset?.training.count, 1) + } + + static var allTests = [ + ("test_DataSetLoad", test_DataSetLoad) + ] +} + +class SemiRingTests: XCTestCase { + func test_SemiRingAdd() { + let value: SemiRing = + SemiRing(logp: 1.0, logr: 2.0) + SemiRing(logp: 3.0, logr: 4.0) + XCTAssertEqual(value.logp, 3.126928, accuracy: 0.000001) + XCTAssertEqual(value.logr, 4.126928, accuracy: 0.000001) + } + + func test_SemiRingInit() { + let value: SemiRing = SemiRing(logp: 1.0, logr: 2.0) + XCTAssertEqual(value.logp, 1.0) + XCTAssertEqual(value.logr, 2.0) + } + + func test_SemiRingZero() { + let value: SemiRing = SemiRing.zero + XCTAssertEqual(value.logp, -Float.infinity) + XCTAssertEqual(value.logr, -Float.infinity) + } + + func test_SemiRingAdditiveIdentity() { + let value: SemiRing = SemiRing.zero + SemiRing(logp: 1.0, logr: 2.0) + XCTAssertEqual(value.logp, 1.0) + XCTAssertEqual(value.logr, 2.0) + } + + func test_SemiRingOne() { + let value: SemiRing = SemiRing.one + XCTAssertEqual(value.logp, 0.0) + XCTAssertEqual(value.logr, -Float.infinity) + } + + func test_SemiRingMultiplicativeIdentity() { + let value: SemiRing = SemiRing.one * SemiRing(logp: 1.0, logr: 2.0) + XCTAssertEqual(value.logp, 1.0) + XCTAssertEqual(value.logr, 2.0) + } + + func test_SemiRingMultiply() { + let value: SemiRing = + SemiRing(logp: 1.0, logr: 2.0) * SemiRing(logp: 3.0, logr: 4.0) + XCTAssertEqual(value.logp, 4.0) + XCTAssertEqual(value.logr, 5.693147, accuracy: 0.000001) + } + + static var allTests = [ + ("test_SemiRingAdd", test_SemiRingAdd), + ("test_SemiRingInit", test_SemiRingInit), + ("test_SemiRingZero", test_SemiRingZero), + ("test_SemiRingAdditiveIdentity", test_SemiRingAdditiveIdentity), + ("test_SemiRingOne", test_SemiRingOne), + ("test_SemiRingMultiplicativeIdentity", test_SemiRingMultiplicativeIdentity), + ("test_SemiRingMultiply", test_SemiRingMultiply), + ] +} + +class VocabularyTests: XCTestCase { + func test_AlphabetaConstruct() { + let characters: Alphabet = Alphabet([ + "a", + "b", + "c", + ], eos: "", eow: "", pad: "") + // a, b, c, EOS, EOW, PAD + XCTAssertEqual(characters.count, 6) + XCTAssertEqual(characters.eos, 3) + XCTAssertEqual(characters.eow, 4) + XCTAssertEqual(characters.pad, 5) + } + + func test_CharacterSequenceConstruct() { + let characters: Alphabet = Alphabet([ + "a", + "c", + "t", + ], eos: "", eow: "", pad: "") + let cat: CharacterSequence? = try? CharacterSequence(alphabet: characters, appendingEoSTo: "cat") + XCTAssertNotEqual(cat, nil) + // FIXME(abdulras) should the EoS be visible? + XCTAssertEqual(cat?.characters, [Int32(1), Int32(0), Int32(2), characters.eos]) + + let bat: CharacterSequence? = try? CharacterSequence(alphabet: characters, appendingEoSTo: "bat") + XCTAssertEqual(bat, nil) + } + + func test_LexiconConstruct() { + let characters: Alphabet = Alphabet([ + "a", "b", "e", "g", "h", "l", "m", "p", "t", + ], eos: "", eow: "", pad: "") + let strings: Lexicon = Lexicon([ + try! CharacterSequence(alphabet: characters, appendingEoSTo: "alpha"), + try! CharacterSequence(alphabet: characters, appendingEoSTo: "beta"), + try! CharacterSequence(alphabet: characters, appendingEoSTo: "gamma"), + ]) + + XCTAssertEqual(strings.count, 3) + } + + func test_LexiconFromSequences() { + let alphabet: Alphabet = Alphabet([ + "a", "b", "e", "g", "h", "l", "m", "p", "t", + ], eos: "", eow: "", pad: "") + let lexicon: Lexicon = Lexicon(from: [ + try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "alpha"), + try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "beta"), + try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "gamma"), + ], alphabet: alphabet, maxLength: 5, minFreq: 4) + + XCTAssertEqual(lexicon.count, 3) + } + + static var allTests = [ + ("test_AlphabetaConstruct", test_AlphabetaConstruct), + ("test_CharacterSequenceConstruct", test_CharacterSequenceConstruct), + ("test_LexiconConstruct", test_LexiconConstruct), + ("test_LexiconFromSequences", test_LexiconFromSequences), + ] +} diff --git a/Tests/TextTests/XCTestManifests.swift b/Tests/TextTests/XCTestManifests.swift index 8deb96aa4c7..24f05d8ad36 100644 --- a/Tests/TextTests/XCTestManifests.swift +++ b/Tests/TextTests/XCTestManifests.swift @@ -18,6 +18,10 @@ import XCTest public func allTests() -> [XCTestCaseEntry] { return [ testCase(TextInferenceTests.allTests), + testCase(DataSetTests.allTests), + testCase(SemiRingTests.allTests), + testCase(VocabularyTests.allTests), + testCase(ProbeLayerTests.allTests), ] } #endif From 09fea49370a751547ecb60d43cbed94c43bb082e Mon Sep 17 00:00:00 2001 From: Michelle Casbon Date: Fri, 8 May 2020 12:57:50 -0400 Subject: [PATCH 2/5] Add model files to CMakeLists --- Models/Text/CMakeLists.txt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Models/Text/CMakeLists.txt b/Models/Text/CMakeLists.txt index 510cd340e3a..0ae3c10c081 100644 --- a/Models/Text/CMakeLists.txt +++ b/Models/Text/CMakeLists.txt @@ -12,7 +12,13 @@ add_library(TextModels ScheduledParameters.swift TransformerBERT.swift Utilities.swift - WeightDecayedAdam.swift) + WeightDecayedAdam.swift + WordSeg/DataSet.swift + WordSeg/Model.swift + WordSeg/Lattice.swift + WordSeg/SE-0259.swift + WordSeg/SemiRing.swift + WordSeg/Vocabularies.swift) set_target_properties(TextModels PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY}) target_compile_options(TextModels PRIVATE From e0515b5e93f87a91ca08c7a8771a8e9e4c234b31 Mon Sep 17 00:00:00 2001 From: Michelle Casbon Date: Fri, 8 May 2020 18:39:52 -0400 Subject: [PATCH 3/5] Rename structs and remove commented code --- Tests/TextTests/CMakeLists.txt | 2 +- .../WordSegmentationTests/ExampleData.swift | 48 ++++++++-------- ...TorchParameters.swift => Parameters.swift} | 28 +++++----- .../WordSegmentationTests/ProbeLayers.swift | 56 +++++++++---------- 4 files changed, 67 insertions(+), 67 deletions(-) rename Tests/TextTests/WordSegmentationTests/{TorchParameters.swift => Parameters.swift} (63%) diff --git a/Tests/TextTests/CMakeLists.txt b/Tests/TextTests/CMakeLists.txt index 796104ee119..0a3bb0d45ba 100644 --- a/Tests/TextTests/CMakeLists.txt +++ b/Tests/TextTests/CMakeLists.txt @@ -2,7 +2,7 @@ add_library(TextTests Inference.swift WordSegmentationTests/ExampleData.swift WordSegmentationTests/ProbeLayers.swift - WordSegmentationTests/TorchParameters.swift + WordSegmentationTests/Parameters.swift WordSegmentationTests/WordSegmentationTests.swift XCTestManifests.swift) set_target_properties(TextTests PROPERTIES diff --git a/Tests/TextTests/WordSegmentationTests/ExampleData.swift b/Tests/TextTests/WordSegmentationTests/ExampleData.swift index 7b0b2b2552c..1797c1b6607 100644 --- a/Tests/TextTests/WordSegmentationTests/ExampleData.swift +++ b/Tests/TextTests/WordSegmentationTests/ExampleData.swift @@ -17,8 +17,8 @@ import TextModels // Generated by `blaze run //experimental/users/marcrasi/probe_wordseg:probe` enum Example1 { - static let parameters = TorchSNLMParameters( - emb_enc: TorchEmbeddingParameters( + static let parameters = SNLMParameters( + emb_enc: EmbeddingParameters( weight: Tensor( [[-0.03872304 ,-0.05338321 ], [ 0.051871493 ,-0.024247635 ], @@ -27,7 +27,7 @@ enum Example1 { [-0.0012689009,-0.07714497 ]] ) ), - lstm_enc: TorchLSTMParameters( + lstm_enc: LSTMParameters( weight_ih_l0: Tensor( [[ 0.029065304, 0.035071626], [ 0.05220075 , 0.07659577 ], @@ -57,8 +57,8 @@ enum Example1 { -3.5792589e-05,-3.5584867e-02, 5.0887465e-05,-3.5637055e-02] ) ), - mlp_interpolation: TorchMLPParameters( - linear1: TorchLinearParameters( + mlp_interpolation: MLPParameters( + linear1: LinearParameters( weight: Tensor( [[-0.034370955,-0.07964967 ], [ 0.019697152,-0.059639234]] @@ -67,7 +67,7 @@ enum Example1 { [0.047385678,0.05388096 ] ) ), - linear2: TorchLinearParameters( + linear2: LinearParameters( weight: Tensor( [[-0.04711317 , 0.07374136 ], [-0.058497474,-0.011900932]] @@ -77,8 +77,8 @@ enum Example1 { ) ) ), - mlp_memory: TorchMLPParameters( - linear1: TorchLinearParameters( + mlp_memory: MLPParameters( + linear1: LinearParameters( weight: Tensor( [[-0.06895532 , 0.0675631 ], [ 0.062555104,-0.006975107]] @@ -87,7 +87,7 @@ enum Example1 { [0.021453634 ,0.0056577697] ) ), - linear2: TorchLinearParameters( + linear2: LinearParameters( weight: Tensor( [[ 0.01337228 , 0.07287796 ], [-0.025111437,-0.021482762], @@ -99,7 +99,7 @@ enum Example1 { ) ) ), - emb_dec: TorchEmbeddingParameters( + emb_dec: EmbeddingParameters( weight: Tensor( [[ 0.055306703, 0.07086456 ], [ 0.03423354 ,-0.015132636], @@ -108,7 +108,7 @@ enum Example1 { [ 0.054974243, 0.017300054]] ) ), - lstm_dec: TorchLSTMParameters( + lstm_dec: LSTMParameters( weight_ih_l0: Tensor( [[-0.053585358,-0.0642758 ], [-0.07246614 , 0.025658146], @@ -138,7 +138,7 @@ enum Example1 { -0.03911104 ,-0.0015038475, 0.051634952 ] ) ), - linear_dec: TorchLinearParameters( + linear_dec: LinearParameters( weight: Tensor( [[-0.03417959 , 0.04824567 ], [ 0.0559683 , 0.0076355636], @@ -291,8 +291,8 @@ enum Example1 { ) ] ) - static let gradWrtLogR = TorchSNLMParameters( - emb_enc: TorchEmbeddingParameters( + static let gradWrtLogR = SNLMParameters( + emb_enc: EmbeddingParameters( weight: Tensor( [[-1.0885849e-05, 2.5600420e-05], [-2.0048559e-05, 5.2774609e-05], @@ -301,7 +301,7 @@ enum Example1 { [ 0.0000000e+00, 0.0000000e+00]] ) ), - lstm_enc: TorchLSTMParameters( + lstm_enc: LSTMParameters( weight_ih_l0: Tensor( [[-5.5954405e-07, 2.1259700e-07], [-9.2493622e-08, 1.2687329e-07], @@ -331,8 +331,8 @@ enum Example1 { -1.3631615e-03,-2.0683720e-03, 5.0045674e-05,-1.1407620e-05] ) ), - mlp_interpolation: TorchMLPParameters( - linear1: TorchLinearParameters( + mlp_interpolation: MLPParameters( + linear1: LinearParameters( weight: Tensor( [[-2.162833e-04, 3.404662e-05], [-1.626215e-03, 2.559960e-04]] @@ -341,7 +341,7 @@ enum Example1 { [0.008965784,0.06741227 ] ) ), - linear2: TorchLinearParameters( + linear2: LinearParameters( weight: Tensor( [[ 0.03779146 , 0.041938413], [-0.03779146 ,-0.041938413]] @@ -351,8 +351,8 @@ enum Example1 { ) ) ), - mlp_memory: TorchMLPParameters( - linear1: TorchLinearParameters( + mlp_memory: MLPParameters( + linear1: LinearParameters( weight: Tensor( [[ 0.0006783868 ,-0.00010455749], [ 0.002728255 ,-0.00042056653]] @@ -361,7 +361,7 @@ enum Example1 { [-0.028684907,-0.11537111 ] ) ), - linear2: TorchLinearParameters( + linear2: LinearParameters( weight: Tensor( [[-0.0098278085,-0.0017497203], [-0.009362103 ,-0.0016668034], @@ -373,7 +373,7 @@ enum Example1 { ) ) ), - emb_dec: TorchEmbeddingParameters( + emb_dec: EmbeddingParameters( weight: Tensor( [[2.8222625e-04,1.9048888e-04], [1.8905671e-04,1.5085124e-04], @@ -382,7 +382,7 @@ enum Example1 { [0.0000000e+00,0.0000000e+00]] ) ), - lstm_dec: TorchLSTMParameters( + lstm_dec: LSTMParameters( weight_ih_l0: Tensor( [[-1.3549015e-05,-1.4535895e-05], [ 3.1118125e-07,-1.6063163e-07], @@ -412,7 +412,7 @@ enum Example1 { 4.2812643e-03, 3.6167959e-03,-2.0367328e-04, 1.3456454e-05] ) ), - linear_dec: TorchLinearParameters( + linear_dec: LinearParameters( weight: Tensor( [[-0.0041010594 , 0.00012692134], [-0.005898418 , 0.0008138743 ], diff --git a/Tests/TextTests/WordSegmentationTests/TorchParameters.swift b/Tests/TextTests/WordSegmentationTests/Parameters.swift similarity index 63% rename from Tests/TextTests/WordSegmentationTests/TorchParameters.swift rename to Tests/TextTests/WordSegmentationTests/Parameters.swift index 2e8ba122cee..2b40588eaa9 100644 --- a/Tests/TextTests/WordSegmentationTests/TorchParameters.swift +++ b/Tests/TextTests/WordSegmentationTests/Parameters.swift @@ -14,33 +14,33 @@ import TensorFlow -struct TorchSNLMParameters { - var emb_enc: TorchEmbeddingParameters - var lstm_enc: TorchLSTMParameters - var mlp_interpolation: TorchMLPParameters - var mlp_memory: TorchMLPParameters - var emb_dec: TorchEmbeddingParameters - var lstm_dec: TorchLSTMParameters - var linear_dec: TorchLinearParameters +struct SNLMParameters { + var emb_enc: EmbeddingParameters + var lstm_enc: LSTMParameters + var mlp_interpolation: MLPParameters + var mlp_memory: MLPParameters + var emb_dec: EmbeddingParameters + var lstm_dec: LSTMParameters + var linear_dec: LinearParameters } -struct TorchEmbeddingParameters { +struct EmbeddingParameters { var weight: Tensor } -struct TorchLSTMParameters { +struct LSTMParameters { var weight_ih_l0: Tensor var weight_hh_l0: Tensor var bias_ih_l0: Tensor var bias_hh_l0: Tensor } -struct TorchMLPParameters { - var linear1: TorchLinearParameters - var linear2: TorchLinearParameters +struct MLPParameters { + var linear1: LinearParameters + var linear2: LinearParameters } -struct TorchLinearParameters { +struct LinearParameters { var weight: Tensor var bias: Tensor } diff --git a/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift b/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift index 25999eb2e07..eacbf4118e0 100644 --- a/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift +++ b/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift @@ -18,8 +18,8 @@ import TensorFlow import TextModels extension SNLM { - /// Sets the model parameters to the given parameters exported from the pytorch model. - mutating func setTorchParameters(_ p: TorchSNLMParameters) { + /// Sets the model parameters to the given parameters exported from the model. + mutating func setParameters(_ p: SNLMParameters) { setEmbedding(&embEnc, to: p.emb_enc) setLSTM(&lstmEnc.cell, to: p.lstm_enc) setMLP(&mlpInterpolation, to: p.mlp_interpolation) @@ -34,34 +34,34 @@ extension SNLM { tensor = value } - // Sets the given Embedding parameters to the given pytorch embedding parameters. - private func setEmbedding(_ embedding: inout Embedding, to p: TorchEmbeddingParameters) { + // Sets the given Embedding parameters to the given embedding parameters. + private func setEmbedding(_ embedding: inout Embedding, to p: EmbeddingParameters) { checkShapeAndSet(&embedding.embeddings, to: p.weight) } - /// Sets the given LSTM cell's parameters to the given pytorch LSTM parameters. - private func setLSTM(_ lstm: inout LSTMCell, to p: TorchLSTMParameters) { - let fusedWeightTorch = p.weight_ih_l0.concatenated(with: p.weight_hh_l0, alongAxis: 1).transposed() - let i = fusedWeightTorch.shape[0] - let j = fusedWeightTorch.shape[1] / 4 + /// Sets the given LSTM cell's parameters to the given LSTM parameters. + private func setLSTM(_ lstm: inout LSTMCell, to p: LSTMParameters) { + let fusedWeight = p.weight_ih_l0.concatenated(with: p.weight_hh_l0, alongAxis: 1).transposed() + let i = fusedWeight.shape[0] + let j = fusedWeight.shape[1] / 4 let fusedWeightTF = Tensor( concatenating: [ - fusedWeightTorch.slice(lowerBounds: [0, 0], upperBounds: [i, j]), - fusedWeightTorch.slice(lowerBounds: [0, 2 * j], upperBounds: [i, 3 * j]), - fusedWeightTorch.slice(lowerBounds: [0, j], upperBounds: [i, 2 * j]), - fusedWeightTorch.slice(lowerBounds: [0, 3 * j], upperBounds: [i, 4 * j]) + fusedWeight.slice(lowerBounds: [0, 0], upperBounds: [i, j]), + fusedWeight.slice(lowerBounds: [0, 2 * j], upperBounds: [i, 3 * j]), + fusedWeight.slice(lowerBounds: [0, j], upperBounds: [i, 2 * j]), + fusedWeight.slice(lowerBounds: [0, 3 * j], upperBounds: [i, 4 * j]) ], alongAxis: 1 ) - let fusedBiasTorch = (p.bias_ih_l0 + p.bias_hh_l0) - let k = fusedBiasTorch.shape[0] / 4 + let fusedBias = (p.bias_ih_l0 + p.bias_hh_l0) + let k = fusedBias.shape[0] / 4 let fusedBiasTF = Tensor( concatenating: [ - fusedBiasTorch.slice(lowerBounds: [0], upperBounds: [k]), - fusedBiasTorch.slice(lowerBounds: [2 * k], upperBounds: [3 * k]), - fusedBiasTorch.slice(lowerBounds: [k], upperBounds: [2 * k]), - fusedBiasTorch.slice(lowerBounds: [3 * k], upperBounds: [4 * k]) + fusedBias.slice(lowerBounds: [0], upperBounds: [k]), + fusedBias.slice(lowerBounds: [2 * k], upperBounds: [3 * k]), + fusedBias.slice(lowerBounds: [k], upperBounds: [2 * k]), + fusedBias.slice(lowerBounds: [3 * k], upperBounds: [4 * k]) ] ) @@ -69,26 +69,26 @@ extension SNLM { checkShapeAndSet(&lstm.fusedBias, to: fusedBiasTF) } - /// Sets the given MLP's parameters to the given pytorch MLP parameters. - private func setMLP(_ mlp: inout MLP, to p: TorchMLPParameters) { + /// Sets the given MLP's parameters to the given MLP parameters. + private func setMLP(_ mlp: inout MLP, to p: MLPParameters) { setDense(&mlp.dense1, to: p.linear1) setDense(&mlp.dense2, to: p.linear2) } - /// Sets the given Dense's parameters to the given pytorch linear parameters. - private func setDense(_ dense: inout Dense, to p: TorchLinearParameters) { + /// Sets the given Dense's parameters to the given linear parameters. + private func setDense(_ dense: inout Dense, to p: LinearParameters) { checkShapeAndSet(&dense.weight, to: p.weight.transposed()) checkShapeAndSet(&dense.bias, to: p.bias) } } -func tangentVector(from torchGradient: TorchSNLMParameters, model: SNLM) -> SNLM.TangentVector { +func tangentVector(from gradient: SNLMParameters, model: SNLM) -> SNLM.TangentVector { var model = model - model.setTorchParameters(torchGradient) + model.setParameters(gradient) - // `model.setTorchParameters` is for model parameters, not for gradients, so + // `model.setParameters` is for model parameters, not for gradients, so // we need to adjust the LSTM biases, whose gradients work differently than - // `model.setTorchParameters` does. + // `model.setParameters` does. model.lstmEnc.cell.fusedBias /= 2 model.lstmDec.cell.fusedBias /= 2 @@ -147,7 +147,7 @@ class ProbeLayerTests: XCTestCase { strVocab: strVocab, order: 5)) - model.setTorchParameters(Example1.parameters) + model.setParameters(Example1.parameters) print("Encoding") let encoderStates = model.encode(CharacterSequence(alphabet: chrVocab, characters: [0, 1, 0, 1])) // "abab" From 46914dfd65621a0736a71ec88165032612cd8867 Mon Sep 17 00:00:00 2001 From: Michelle Casbon Date: Fri, 8 May 2020 18:57:40 -0400 Subject: [PATCH 4/5] Rename Conf -> SNLM.Parameters, lint --- Models/Text/WordSeg/DataSet.swift | 27 +- Models/Text/WordSeg/Lattice.swift | 58 +-- Models/Text/WordSeg/Model.swift | 238 ++++++----- Models/Text/WordSeg/SE-0259.swift | 20 +- Models/Text/WordSeg/SemiRing.swift | 24 +- Models/Text/WordSeg/Vocabularies.swift | 25 +- .../WordSegmentationTests/ExampleData.swift | 382 ++++++++++-------- .../WordSegmentationTests/ProbeLayers.swift | 62 +-- .../WordSegmentationTests.swift | 67 +-- Tests/TextTests/XCTestManifests.swift | 18 +- 10 files changed, 523 insertions(+), 398 deletions(-) diff --git a/Models/Text/WordSeg/DataSet.swift b/Models/Text/WordSeg/DataSet.swift index 719f06f4422..cbef8bd6c99 100644 --- a/Models/Text/WordSeg/DataSet.swift +++ b/Models/Text/WordSeg/DataSet.swift @@ -63,10 +63,14 @@ public struct DataSet { return Alphabet(sorted, eos: eos, eow: eow, pad: pad) } - private static func convertDataset(_ dataset: [String], alphabet: Alphabet) throws -> [CharacterSequence] { + private static func convertDataset(_ dataset: [String], alphabet: Alphabet) throws + -> [CharacterSequence] + { return try dataset.map { try CharacterSequence(alphabet: alphabet, appendingEoSTo: $0) } } - private static func convertDataset(_ dataset: [String]?, alphabet: Alphabet) throws -> [CharacterSequence]? { + private static func convertDataset(_ dataset: [String]?, alphabet: Alphabet) throws + -> [CharacterSequence]? + { if let ds = dataset { let tmp: [CharacterSequence] = try convertDataset(ds, alphabet: alphabet) // Use tmp to disambiguate function return tmp @@ -79,22 +83,25 @@ public struct DataSet { validation validationFile: String? = nil, testing testingFile: String? = nil ) throws { - let trainingData = try Data(contentsOf: URL(fileURLWithPath: trainingFile), - options: .alwaysMapped) + let trainingData = try Data( + contentsOf: URL(fileURLWithPath: trainingFile), + options: .alwaysMapped) let training = try Self.load(data: trainingData) var validation: [String]? = nil var testing: [String]? = nil if let validationFile = validationFile { - let data = try Data(contentsOf: URL(fileURLWithPath: validationFile), - options: .alwaysMapped) + let data = try Data( + contentsOf: URL(fileURLWithPath: validationFile), + options: .alwaysMapped) validation = try Self.load(data: data) } if let testingFile = testingFile { - let data: Data = try Data(contentsOf: URL(fileURLWithPath: testingFile), - options: .alwaysMapped) + let data: Data = try Data( + contentsOf: URL(fileURLWithPath: testingFile), + options: .alwaysMapped) testing = try Self.load(data: data) } self.alphabet = Self.makeAlphabet(datasets: training, validation, testing) @@ -103,7 +110,9 @@ public struct DataSet { self.testing = try Self.convertDataset(testing, alphabet: self.alphabet) } - init(training trainingData: Data, validation validationData: Data?, testing testingData: Data?) throws { + init(training trainingData: Data, validation validationData: Data?, testing testingData: Data?) + throws + { let training = try Self.load(data: trainingData) var validation: [String]? = nil var testing: [String]? = nil diff --git a/Models/Text/WordSeg/Lattice.swift b/Models/Text/WordSeg/Lattice.swift index 4dd0c1ec2e9..55c6fce79ea 100644 --- a/Models/Text/WordSeg/Lattice.swift +++ b/Models/Text/WordSeg/Lattice.swift @@ -40,23 +40,28 @@ public struct Lattice: Differentiable { public var totalScore: SemiRing @differentiable - init(start: Int, end: Int, sentence: CharacterSequence, logp: Float, - previous: SemiRing, order: Int) { + init( + start: Int, end: Int, sentence: CharacterSequence, logp: Float, + previous: SemiRing, order: Int + ) { self.start = start self.end = end self.string = sentence self.logp = logp self.score = - SemiRing(logp: logp, - // TODO(abdulras): this should really use integeral pow - logr: logp + logf(powf(Float(sentence.count), Float(order)))) + SemiRing( + logp: logp, + // TODO(abdulras): this should really use integeral pow + logr: logp + logf(powf(Float(sentence.count), Float(order)))) self.totalScore = self.score * previous } @differentiable - public init(start: Int, end: Int, string: CharacterSequence, logp: Float, - score: SemiRing, totalScore: SemiRing) { + public init( + start: Int, end: Int, string: CharacterSequence, logp: Float, + score: SemiRing, totalScore: SemiRing + ) { self.start = start self.end = end self.string = string @@ -78,8 +83,10 @@ public struct Lattice: Differentiable { init() {} @differentiable - public init(bestEdge: Edge?, bestScore: Float, edges: [Edge], - semiringScore: SemiRing) { + public init( + bestEdge: Edge?, bestScore: Float, edges: [Edge], + semiringScore: SemiRing + ) { self.bestEdge = bestEdge self.bestScore = bestScore self.edges = edges @@ -160,9 +167,9 @@ extension Lattice.Node: CustomStringConvertible { edgesStr = edges.enumerated().map { " \($0.0) - \($0.1)" }.joined(separator: "\n") } return """ - best edge: \(String(describing: bestEdge)), best score: \(bestScore), score: \(semiringScore.shortDescription) - \(edgesStr) - """ + best edge: \(String(describing: bestEdge)), best score: \(bestScore), score: \(semiringScore.shortDescription) + \(edgesStr) + """ } } @@ -180,14 +187,14 @@ extension Lattice { return false } return zip(self.positions, other.positions).enumerated() - .map { (index, position) in - let eq = position.0.isAlmostEqual(to: position.1, tolerance: tolerance) - if !eq { - print("mismatch at \(index): \(position.0) != \(position.1)") - } - return eq + .map { (index, position) in + let eq = position.0.isAlmostEqual(to: position.1, tolerance: tolerance) + if !eq { + print("mismatch at \(index): \(position.0) != \(position.1)") } - .reduce(true) { $0 && $1 } + return eq + } + .reduce(true) { $0 && $1 } } } @@ -207,19 +214,18 @@ extension Lattice.Node { return false } return zip(self.edges, other.edges) - .map { $0.isAlmostEqual(to: $1, tolerance: tolerance) } - .reduce(true) { $0 && $1 } + .map { $0.isAlmostEqual(to: $1, tolerance: tolerance) } + .reduce(true) { $0 && $1 } } } extension Lattice.Edge { public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { - return self.start == other.start && - self.end == other.end && + return self.start == other.start && self.end == other.end // TODO: figure out why the string equality is being ignored // self.string == other.string && - self.logp.isAlmostEqual(to: other.logp, tolerance: tolerance) && - self.score.isAlmostEqual(to: other.score, tolerance: tolerance) && - self.totalScore.isAlmostEqual(to: other.totalScore, tolerance: tolerance) + && self.logp.isAlmostEqual(to: other.logp, tolerance: tolerance) + && self.score.isAlmostEqual(to: other.score, tolerance: tolerance) + && self.totalScore.isAlmostEqual(to: other.totalScore, tolerance: tolerance) } } diff --git a/Models/Text/WordSeg/Model.swift b/Models/Text/WordSeg/Model.swift index f0612559978..9107c687799 100644 --- a/Models/Text/WordSeg/Model.swift +++ b/Models/Text/WordSeg/Model.swift @@ -22,35 +22,35 @@ import TensorFlow -public struct Conf { - public var ndim: Int - public var dropoutProb: Double - public var chrVocab: Alphabet - public var strVocab: Lexicon - public var order: Int - - public init( - ndim: Int, - dropoutProb: Double, - chrVocab: Alphabet, - strVocab: Lexicon, - order: Int - ) { - self.ndim = ndim - self.dropoutProb = dropoutProb - self.chrVocab = chrVocab - self.strVocab = strVocab - self.order = order - } -} - /// SNLM /// /// A representation of the Segmental Neural Language Model. /// /// \ref https://www.aclweb.org/anthology/P19-1645.pdf public struct SNLM: EuclideanDifferentiable, KeyPathIterable { - @noDerivative public var conf: Conf + public struct Parameters { + public var ndim: Int + public var dropoutProb: Double + public var chrVocab: Alphabet + public var strVocab: Lexicon + public var order: Int + + public init( + ndim: Int, + dropoutProb: Double, + chrVocab: Alphabet, + strVocab: Lexicon, + order: Int + ) { + self.ndim = ndim + self.dropoutProb = dropoutProb + self.chrVocab = chrVocab + self.strVocab = strVocab + self.order = order + } + } + + @noDerivative public var parameters: Parameters // MARK: - Encoder @@ -77,34 +77,46 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { // MARK: - Initializer - public init(conf: Conf) { - self.conf = conf + public init(parameters: Parameters) { + self.parameters = parameters // Encoder - self.embEnc = Embedding(vocabularySize: conf.chrVocab.count, embeddingSize: conf.ndim) - self.lstmEnc = LSTM(LSTMCell(inputSize: conf.ndim, hiddenSize: conf.ndim)) + self.embEnc = Embedding( + vocabularySize: parameters.chrVocab.count, + embeddingSize: parameters.ndim) + self.lstmEnc = LSTM( + LSTMCell( + inputSize: parameters.ndim, + hiddenSize: + parameters.ndim)) // Interpolation weight self.mlpInterpolation = MLP( - nIn: conf.ndim, - nHidden: conf.ndim, + nIn: parameters.ndim, + nHidden: parameters.ndim, nOut: 2, - dropoutProbability: conf.dropoutProb) + dropoutProbability: parameters.dropoutProb) // Lexical memory self.mlpMemory = MLP( - nIn: conf.ndim, - nHidden: conf.ndim, - nOut: conf.strVocab.count, - dropoutProbability: conf.dropoutProb) + nIn: parameters.ndim, + nHidden: parameters.ndim, + nOut: parameters.strVocab.count, + dropoutProbability: parameters.dropoutProb) // Character-level decoder - self.embDec = Embedding(vocabularySize: conf.chrVocab.count, embeddingSize: conf.ndim) - self.lstmDec = LSTM(LSTMCell(inputSize: conf.ndim, hiddenSize: conf.ndim)) - self.denseDec = Dense(inputSize: conf.ndim, outputSize: conf.chrVocab.count) + self.embDec = Embedding( + vocabularySize: parameters.chrVocab.count, + embeddingSize: parameters.ndim) + self.lstmDec = LSTM( + LSTMCell( + inputSize: parameters.ndim, + hiddenSize: + parameters.ndim)) + self.denseDec = Dense(inputSize: parameters.ndim, outputSize: parameters.chrVocab.count) // Other layers - self.drop = Dropout(probability: conf.dropoutProb) + self.drop = Dropout(probability: parameters.dropoutProb) } // MARK: - Encode @@ -128,16 +140,16 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { var xBatch: [Int32] = [] var yBatch: [Int32] = [] for candidate in candidates { - let padding = Array(repeating: conf.chrVocab.pad, count: maxLen - candidate.count - 1) + let padding = Array(repeating: parameters.chrVocab.pad, count: maxLen - candidate.count - 1) // x is {sentence}{padding} - xBatch.append(conf.chrVocab.eow) + xBatch.append(parameters.chrVocab.eow) xBatch.append(contentsOf: candidate.characters) xBatch.append(contentsOf: padding) // y is {sentence}{padding} yBatch.append(contentsOf: candidate.characters) - yBatch.append(conf.chrVocab.eow) + yBatch.append(parameters.chrVocab.eow) yBatch.append(contentsOf: padding) } @@ -168,13 +180,16 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { let logits = denseDec(decoderResult) // [time x batch] - let logp = -1 * softmaxCrossEntropy( - logits: logits.reshaped(to: [logits.shape[0] * logits.shape[1], logits.shape[2]]), - labels: y.flattened(), - reduction: identity).reshaped(to: y.shape) + let logp = + -1 + * softmaxCrossEntropy( + logits: logits.reshaped(to: [logits.shape[0] * logits.shape[1], logits.shape[2]]), + labels: y.flattened(), + reduction: identity + ).reshaped(to: y.shape) // [time x batch] - let logpExcludingPad = logp * Tensor(y .!= conf.chrVocab.pad) + let logpExcludingPad = logp * Tensor(y .!= parameters.chrVocab.pad) // [batch] let candidateLogP = logpExcludingPad.transposed().sum(squeezingAxes: 1) @@ -192,7 +207,7 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { // return torch.log(torch_util.var_from_scaler(0.0, "FloatTensor", self.gpu)) func get_logp_lex(_ logp_lex: [Float], _ candidate: CharacterSequence) -> Float { - guard let index = conf.strVocab.dictionary[candidate] else { + guard let index = parameters.strVocab.dictionary[candidate] else { return -Float.infinity } return logp_lex[Int(index)] @@ -207,14 +222,15 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { let logp_lex_batch = mlpMemory(Tensor(stacking: states)) for pos in 0.."] continue } @@ -232,9 +248,9 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { // lattice[pos]["semiring_score"] = semiring.add(lattice[pos]["edges"], // self.gpu) let current_state = states[pos] - let logg = scalarsWithADHack(logg_batch[pos]) // [2] - let logp_lex = scalarsWithADHack(logp_lex_batch[pos]) // [strVocab.chr.count] - let logp_chr = scalarsWithADHack(decode(candidates, current_state)) // [candidates.count] + let logg = scalarsWithADHack(logg_batch[pos]) // [2] + let logp_lex = scalarsWithADHack(logp_lex_batch[pos]) // [strVocab.chr.count] + let logp_chr = scalarsWithADHack(decode(candidates, current_state)) // [candidates.count] if pos != 0 { // TODO: Mutate in place when AD supports it. let updatedNode = Lattice.Node( @@ -268,7 +284,7 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { sentence: candidate, logp: logp_i, previous: lattice[pos].semiringScore, - order: conf.order) + order: parameters.order) // TODO: Mutate in place when AD supports it. let updatedNode = Lattice.Node( @@ -363,61 +379,65 @@ func makeEncoderInput(_ x: Tensor) -> [Tensor] { // TODO: Move this derivative into tensorflow-apis extension RecurrentLayer { - @differentiable(wrt: (self, inputs, initialState)) - public func callAsFunction2( - _ inputs: [Cell.TimeStepInput], - initialState: Cell.State - ) -> [Cell.TimeStepOutput] { - if inputs.isEmpty { return [Cell.TimeStepOutput]() } - var currentHiddenState = initialState - var timeStepOutputs: [Cell.TimeStepOutput] = [] - for timeStepInput in inputs { - let output = cell(input: timeStepInput, state: currentHiddenState) - currentHiddenState = output.state - timeStepOutputs.append(output.output) - } - return timeStepOutputs + @differentiable(wrt: (self, inputs, initialState)) + public func callAsFunction2( + _ inputs: [Cell.TimeStepInput], + initialState: Cell.State + ) -> [Cell.TimeStepOutput] { + if inputs.isEmpty { return [Cell.TimeStepOutput]() } + var currentHiddenState = initialState + var timeStepOutputs: [Cell.TimeStepOutput] = [] + for timeStepInput in inputs { + let output = cell(input: timeStepInput, state: currentHiddenState) + currentHiddenState = output.state + timeStepOutputs.append(output.output) } + return timeStepOutputs + } - @usableFromInline - @derivative(of: callAsFunction2, wrt: (self, inputs, initialState)) - internal func _vjpCallAsFunctionWrtMore( - _ inputs: [Cell.TimeStepInput], - initialState: Cell.State - ) -> ( - value: [Cell.TimeStepOutput], - pullback: (Array.TangentVector) - -> (TangentVector, Array.TangentVector, Cell.State.TangentVector) - ) { - let timeStepCount = inputs.count - var currentHiddenState = initialState - var timeStepOutputs: [Cell.TimeStepOutput] = [] - timeStepOutputs.reserveCapacity(timeStepCount) - var backpropagators: [Cell.Backpropagator] = [] - backpropagators.reserveCapacity(timeStepCount) - for timestep in inputs { - let (output, backpropagator) = cell.appliedForBackpropagation( - to: .init(input: timestep, state: currentHiddenState)) - currentHiddenState = output.state - timeStepOutputs.append(output.output) - backpropagators.append(backpropagator) - } - return (timeStepOutputs, { 𝛁outputs in - precondition(𝛁outputs.base.count == timeStepCount, - "The number of output gradients must equal the number of time steps") - var 𝛁cell = Cell.TangentVector.zero - var 𝛁state = Cell.State.TangentVector.zero - var reversed𝛁inputs: [Cell.TimeStepInput.TangentVector] = [] - reversed𝛁inputs.reserveCapacity(timeStepCount) - for (𝛁output, backpropagator) in zip(𝛁outputs.base, backpropagators).reversed() { - let (new𝛁cell, 𝛁input) = backpropagator(.init(output: 𝛁output, state: 𝛁state)) - 𝛁cell += new𝛁cell - 𝛁state = 𝛁input.state - reversed𝛁inputs.append(𝛁input.input) - } - return (.init(cell: 𝛁cell), .init(Array(reversed𝛁inputs.reversed())), 𝛁state) - }) + @usableFromInline + @derivative(of: callAsFunction2, wrt: (self, inputs, initialState)) + internal func _vjpCallAsFunctionWrtMore( + _ inputs: [Cell.TimeStepInput], + initialState: Cell.State + ) -> ( + value: [Cell.TimeStepOutput], + pullback: (Array.TangentVector) + -> (TangentVector, Array.TangentVector, Cell.State.TangentVector) + ) { + let timeStepCount = inputs.count + var currentHiddenState = initialState + var timeStepOutputs: [Cell.TimeStepOutput] = [] + timeStepOutputs.reserveCapacity(timeStepCount) + var backpropagators: [Cell.Backpropagator] = [] + backpropagators.reserveCapacity(timeStepCount) + for timestep in inputs { + let (output, backpropagator) = cell.appliedForBackpropagation( + to: .init(input: timestep, state: currentHiddenState)) + currentHiddenState = output.state + timeStepOutputs.append(output.output) + backpropagators.append(backpropagator) } + return ( + timeStepOutputs, + { 𝛁outputs in + precondition( + 𝛁outputs.base.count == timeStepCount, + "The number of output gradients must equal the number of time steps") + var 𝛁cell = Cell.TangentVector.zero + var 𝛁state = Cell.State.TangentVector.zero + var reversed𝛁inputs: [Cell.TimeStepInput.TangentVector] = [] + reversed𝛁inputs.reserveCapacity(timeStepCount) + for (𝛁output, backpropagator) in zip(𝛁outputs.base, backpropagators).reversed() { + let (new𝛁cell, 𝛁input) = backpropagator(.init(output: 𝛁output, state: 𝛁state)) + 𝛁cell += new𝛁cell + 𝛁state = 𝛁input.state + reversed𝛁inputs.append(𝛁input.input) + } + return (.init(cell: 𝛁cell), .init(Array(reversed𝛁inputs.reversed())), 𝛁state) + } + ) + } } // TODO: Better way of dealing with this problem. @@ -426,7 +446,9 @@ func scalarsWithADHack(_ t: Tensor) -> [Float] { } @derivative(of: scalarsWithADHack) -func vjpScalarsHack(_ t: Tensor) -> (value: [Float], pullback: (Array.TangentVector) -> Tensor) { +func vjpScalarsHack(_ t: Tensor) -> ( + value: [Float], pullback: (Array.TangentVector) -> Tensor +) { // TODO: Capture less stuff. func pullback(_ tv: Array.TangentVector) -> Tensor { if tv.count == 0 { diff --git a/Models/Text/WordSeg/SE-0259.swift b/Models/Text/WordSeg/SE-0259.swift index 4dbdf7185fe..bc9e43b0792 100644 --- a/Models/Text/WordSeg/SE-0259.swift +++ b/Models/Text/WordSeg/SE-0259.swift @@ -68,7 +68,7 @@ extension FloatingPoint { public func isAlmostEqual( to other: Self, tolerance: Self = Self.ulpOfOne.squareRoot() - ) -> Bool { + ) -> Bool { // Tolerances outside of [.ulpOfOne, 1) yield well-defined but useless // results, so this is enforced by an assert rathern than a precondition. assert(tolerance >= .ulpOfOne && tolerance < 1, "tolerance should be in [.ulpOfOne, 1).") @@ -82,7 +82,7 @@ extension FloatingPoint { // defined on FloatingPoint suitable for hypot and scaled sums, but the // following is good enough to be useful for now. let scale = max(abs(self), abs(other), .leastNormalMagnitude) - return abs(self - other) < scale*tolerance + return abs(self - other) < scale * tolerance } /// Test if this value is nearly zero with a specified `absoluteTolerance`. @@ -119,7 +119,7 @@ extension FloatingPoint { @inlinable public func isAlmostZero( absoluteTolerance tolerance: Self = Self.ulpOfOne.squareRoot() - ) -> Bool { + ) -> Bool { assert(tolerance > 0) return abs(self) < tolerance } @@ -137,12 +137,14 @@ extension FloatingPoint { // Self is infinite and other is finite. Replace self with the binade // of the greatestFiniteMagnitude, and reduce the exponent of other by // one to compensate. - let scaledSelf = Self(sign: self.sign, - exponent: Self.greatestFiniteMagnitude.exponent, - significand: 1) - let scaledOther = Self(sign: .plus, - exponent: -1, - significand: other) + let scaledSelf = Self( + sign: self.sign, + exponent: Self.greatestFiniteMagnitude.exponent, + significand: 1) + let scaledOther = Self( + sign: .plus, + exponent: -1, + significand: other) // Now both values are finite, so re-run the naive comparison. return scaledSelf.isAlmostEqual(to: scaledOther, tolerance: tolerance) } diff --git a/Models/Text/WordSeg/SemiRing.swift b/Models/Text/WordSeg/SemiRing.swift index 39909dbb29c..284a028b039 100644 --- a/Models/Text/WordSeg/SemiRing.swift +++ b/Models/Text/WordSeg/SemiRing.swift @@ -13,11 +13,11 @@ // limitations under the License. #if os(iOS) || os(macOS) || os(tvOS) || os(watchOS) -import Darwin + import Darwin #elseif os(Windows) -import ucrt + import ucrt #else -import Glibc + import Glibc #endif /// logSumExp(_:_:) @@ -61,15 +61,17 @@ public struct SemiRing: Differentiable { } @differentiable -func *(_ lhs: SemiRing, _ rhs: SemiRing) -> SemiRing { - return SemiRing(logp: lhs.logp + rhs.logp, - logr: logSumExp(lhs.logp + rhs.logr, rhs.logp + lhs.logr)) +func * (_ lhs: SemiRing, _ rhs: SemiRing) -> SemiRing { + return SemiRing( + logp: lhs.logp + rhs.logp, + logr: logSumExp(lhs.logp + rhs.logr, rhs.logp + lhs.logr)) } @differentiable -func +(_ lhs: SemiRing, _ rhs: SemiRing) -> SemiRing { - return SemiRing(logp: logSumExp(lhs.logp, rhs.logp), - logr: logSumExp(lhs.logr, rhs.logr)) +func + (_ lhs: SemiRing, _ rhs: SemiRing) -> SemiRing { + return SemiRing( + logp: logSumExp(lhs.logp, rhs.logp), + logr: logSumExp(lhs.logr, rhs.logr)) } extension SemiRing { @@ -83,7 +85,7 @@ extension SemiRing { // TODO(abdulras) see if we can use ulp as a default tolerance @inlinable public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { - return self.logp.isAlmostEqual(to: other.logp, tolerance: tolerance) && - self.logr.isAlmostEqual(to: other.logr, tolerance: tolerance) + return self.logp.isAlmostEqual(to: other.logp, tolerance: tolerance) + && self.logr.isAlmostEqual(to: other.logr, tolerance: tolerance) } } diff --git a/Models/Text/WordSeg/Vocabularies.swift b/Models/Text/WordSeg/Vocabularies.swift index bb3e9131d83..bcfd3abaaea 100644 --- a/Models/Text/WordSeg/Vocabularies.swift +++ b/Models/Text/WordSeg/Vocabularies.swift @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -import TensorFlow import ModelSupport +import TensorFlow /// Alphabet maps from characters in a string to Int32 representations. /// @@ -30,7 +30,7 @@ public struct Alphabet { let pad: Int32 public init(_ letters: C, eos: String, eow: String, pad: String) - where C.Element == Character { + where C.Element == Character { self.dictionary = .init(zip(letters.lazy.map { String($0) }, 0...)) self.eos = Int32(self.dictionary.count) @@ -44,7 +44,7 @@ public struct Alphabet { } public init(_ letters: C, eos: String, eow: String, pad: String) - where C.Element == Element { + where C.Element == Element { self.dictionary = .init(zip(letters.lazy.map { String($0) }, 0...)) self.eos = Int32(self.dictionary.count) @@ -93,7 +93,7 @@ public struct CharacterSequence: Hashable { } public init(alphabet: Alphabet, characters: ArraySlice) { - self.characters = Array(characters) + self.characters = [Int32](characters) self.eos = alphabet.eos } @@ -108,7 +108,7 @@ public struct CharacterSequence: Hashable { public var count: Int { return characters.count } var last: Int32? { return characters.last } var tensor: Tensor { - Tensor([self.eos] + characters[0 ..< characters.count - 1]) + Tensor([self.eos] + characters[0..:Int] = [:] + var histogram: [ArraySlice: Int] = [:] for sentence in sequences { // NOTE: the use of `sentence.count - 1` is to ensure that we ignore the // trailing `EoS` marker. - for i in 0 ..< sentence.count - 1 { - for j in 1 ... maxLength { + for i in 0.. 1 else { continue } - histogram[sentence[i ..< e], default: 0] += 1 + histogram[sentence[i..= minFreq } - let vocab = frequentWordCandidates.map { CharacterSequence(alphabet: alphabet, characters: $0.0) } + let vocab = frequentWordCandidates.map { + CharacterSequence(alphabet: alphabet, characters: $0.0) + } self.init(vocab) } @@ -171,7 +173,8 @@ extension CharacterErrors: CustomStringConvertible { public var description: String { switch self { case let .unknownCharacter(character, index, sentence): - return "Unknown character '\(character)' encountered at index \(index) while converting sentence \"\(sentence)\" to a character sequence." + return + "Unknown character '\(character)' encountered at index \(index) while converting sentence \"\(sentence)\" to a character sequence." case .nonUtf8Data: return "Non-UTF8 data encountered." } diff --git a/Tests/TextTests/WordSegmentationTests/ExampleData.swift b/Tests/TextTests/WordSegmentationTests/ExampleData.swift index 1797c1b6607..ac390fd4fa0 100644 --- a/Tests/TextTests/WordSegmentationTests/ExampleData.swift +++ b/Tests/TextTests/WordSegmentationTests/ExampleData.swift @@ -20,57 +20,71 @@ enum Example1 { static let parameters = SNLMParameters( emb_enc: EmbeddingParameters( weight: Tensor( - [[-0.03872304 ,-0.05338321 ], - [ 0.051871493 ,-0.024247635 ], - [-0.05366139 , 0.05159463 ], - [ 0.047521368 , 0.010823324 ], - [-0.0012689009,-0.07714497 ]] + [ + [-0.03872304, -0.05338321], + [0.051871493, -0.024247635], + [-0.05366139, 0.05159463], + [0.047521368, 0.010823324], + [-0.0012689009, -0.07714497], + ] ) ), lstm_enc: LSTMParameters( weight_ih_l0: Tensor( - [[ 0.029065304, 0.035071626], - [ 0.05220075 , 0.07659577 ], - [-0.03418843 , 0.054938704], - [-0.076050006, 0.011613108], - [-0.07208062 ,-0.028602105], - [ 0.07454459 ,-0.045288637], - [ 0.056929305, 0.07525672 ], - [-0.044981938,-0.06287278 ]] + [ + [0.029065304, 0.035071626], + [0.05220075, 0.07659577], + [-0.03418843, 0.054938704], + [-0.076050006, 0.011613108], + [-0.07208062, -0.028602105], + [0.07454459, -0.045288637], + [0.056929305, 0.07525672], + [-0.044981938, -0.06287278], + ] ), weight_hh_l0: Tensor( - [[ 0.054794624,-0.047684688], - [ 0.059219867, 0.035521813], - [ 0.07620007 ,-0.024476558], - [ 0.06970246 ,-0.012173824], - [-0.01848751 ,-0.040551327], - [ 0.050961852, 0.03887064 ], - [ 0.054652542,-0.015333958], - [ 0.048499897, 0.044002943]] + [ + [0.054794624, -0.047684688], + [0.059219867, 0.035521813], + [0.07620007, -0.024476558], + [0.06970246, -0.012173824], + [-0.01848751, -0.040551327], + [0.050961852, 0.03887064], + [0.054652542, -0.015333958], + [0.048499897, 0.044002943], + ] ), bias_ih_l0: Tensor( - [-0.061996475, 0.07747544 , 0.04455457 ,-0.016387768,-0.07249636 , - 0.047923 ,-0.040993594,-0.011194035] + [ + -0.061996475, 0.07747544, 0.04455457, -0.016387768, -0.07249636, + 0.047923, -0.040993594, -0.011194035, + ] ), bias_hh_l0: Tensor( - [ 1.9561797e-02,-4.7253761e-02,-1.0812007e-02, 7.7792764e-02, - -3.5792589e-05,-3.5584867e-02, 5.0887465e-05,-3.5637055e-02] + [ + 1.9561797e-02, -4.7253761e-02, -1.0812007e-02, 7.7792764e-02, + -3.5792589e-05, -3.5584867e-02, 5.0887465e-05, -3.5637055e-02, + ] ) ), mlp_interpolation: MLPParameters( linear1: LinearParameters( weight: Tensor( - [[-0.034370955,-0.07964967 ], - [ 0.019697152,-0.059639234]] + [ + [-0.034370955, -0.07964967], + [0.019697152, -0.059639234], + ] ), bias: Tensor( - [0.047385678,0.05388096 ] + [0.047385678, 0.05388096] ) ), linear2: LinearParameters( weight: Tensor( - [[-0.04711317 , 0.07374136 ], - [-0.058497474,-0.011900932]] + [ + [-0.04711317, 0.07374136], + [-0.058497474, -0.011900932], + ] ), bias: Tensor( [-0.008636497, 0.052838206] @@ -80,105 +94,126 @@ enum Example1 { mlp_memory: MLPParameters( linear1: LinearParameters( weight: Tensor( - [[-0.06895532 , 0.0675631 ], - [ 0.062555104,-0.006975107]] + [ + [-0.06895532, 0.0675631], + [0.062555104, -0.006975107], + ] ), bias: Tensor( - [0.021453634 ,0.0056577697] + [0.021453634, 0.0056577697] ) ), linear2: LinearParameters( weight: Tensor( - [[ 0.01337228 , 0.07287796 ], - [-0.025111437,-0.021482762], - [-0.05161675 ,-0.06811503 ], - [-0.072463006, 0.015226476]] + [ + [0.01337228, 0.07287796], + [-0.025111437, -0.021482762], + [-0.05161675, -0.06811503], + [-0.072463006, 0.015226476], + ] ), bias: Tensor( - [ 0.0032938793,-0.043962937 , 0.043240592 , 0.0678826 ] + [0.0032938793, -0.043962937, 0.043240592, 0.0678826] ) ) ), emb_dec: EmbeddingParameters( weight: Tensor( - [[ 0.055306703, 0.07086456 ], - [ 0.03423354 ,-0.015132636], - [-0.04077827 , 0.016811028], - [-0.037189033,-0.07027687 ], - [ 0.054974243, 0.017300054]] + [ + [0.055306703, 0.07086456], + [0.03423354, -0.015132636], + [-0.04077827, 0.016811028], + [-0.037189033, -0.07027687], + [0.054974243, 0.017300054], + ] ) ), lstm_dec: LSTMParameters( weight_ih_l0: Tensor( - [[-0.053585358,-0.0642758 ], - [-0.07246614 , 0.025658146], - [ 0.034285776,-0.014611781], - [ 0.058412 , 0.047652483], - [ 0.065825045, 0.042562716], - [ 0.050531074, 0.047255352], - [-0.03512928 , 0.004992813], - [ 0.005484812,-0.054734543]] + [ + [-0.053585358, -0.0642758], + [-0.07246614, 0.025658146], + [0.034285776, -0.014611781], + [0.058412, 0.047652483], + [0.065825045, 0.042562716], + [0.050531074, 0.047255352], + [-0.03512928, 0.004992813], + [0.005484812, -0.054734543], + ] ), weight_hh_l0: Tensor( - [[ 0.07812595 , 0.0031644031], - [-0.04185462 , 0.03933753 ], - [-0.044581212 ,-0.018176649 ], - [ 0.07533194 , 0.0030083433], - [-0.045243938 ,-0.026109837 ], - [ 0.046121553 ,-0.053141937 ], - [ 0.011378422 , 0.067420706 ], - [-0.05194992 , 0.044939123 ]] + [ + [0.07812595, 0.0031644031], + [-0.04185462, 0.03933753], + [-0.044581212, -0.018176649], + [0.07533194, 0.0030083433], + [-0.045243938, -0.026109837], + [0.046121553, -0.053141937], + [0.011378422, 0.067420706], + [-0.05194992, 0.044939123], + ] ), bias_ih_l0: Tensor( - [ 0.049315616,-0.05961135 ,-0.047641095, 0.056274325,-0.071667776, - 0.049188778,-0.05663743 ,-0.051864214] + [ + 0.049315616, -0.05961135, -0.047641095, 0.056274325, -0.071667776, + 0.049188778, -0.05663743, -0.051864214, + ] ), bias_hh_l0: Tensor( - [ 0.055988565 , 0.01968135 ,-0.057932526 , 0.024752177 ,-0.029085837 , - -0.03911104 ,-0.0015038475, 0.051634952 ] + [ + 0.055988565, 0.01968135, -0.057932526, 0.024752177, -0.029085837, + -0.03911104, -0.0015038475, 0.051634952, + ] ) ), linear_dec: LinearParameters( weight: Tensor( - [[-0.03417959 , 0.04824567 ], - [ 0.0559683 , 0.0076355636], - [-0.03857645 , 0.015529476 ], - [ 0.057112962 , 0.036605842 ], - [ 0.023432933 ,-0.023976203 ]] + [ + [-0.03417959, 0.04824567], + [0.0559683, 0.0076355636], + [-0.03857645, 0.015529476], + [0.057112962, 0.036605842], + [0.023432933, -0.023976203], + ] ), bias: Tensor( - [-0.03640049 ,-0.057923757, 0.05912192 , 0.03688284 , 0.06261988 ] + [-0.03640049, -0.057923757, 0.05912192, 0.03688284, 0.06261988] ) ) ) static let expectedEncoding = Tensor( - [[-0.016786836 , 0.0014875316], - [-0.024643049 , 0.003509362 ], - [-0.030508908 , 0.005803111 ], - [-0.031535156 , 0.0056067896]] + [ + [-0.016786836, 0.0014875316], + [-0.024643049, 0.003509362], + [-0.030508908, 0.005803111], + [-0.031535156, 0.0056067896], + ] ) static let expectedMLPInterpolationOutput = Tensor( - [[-0.72172225,-0.665366 ], - [-0.7217337 ,-0.6653552 ], - [-0.72174466,-0.66534483], - [-0.7217447 ,-0.6653447 ]] + [ + [-0.72172225, -0.665366], + [-0.7217337, -0.6653552], + [-0.72174466, -0.66534483], + [-0.7217447, -0.6653447], + ] ) static let expectedMLPMemoryOutput = Tensor( - [[-1.400075 ,-1.4486395,-1.3622522,-1.3377004], - [-1.4000794,-1.4486222,-1.3622293,-1.3377339], - [-1.4000806,-1.4486088,-1.3622129,-1.3377609], - [-1.4000823,-1.448607 ,-1.3622097,-1.3377641]] + [ + [-1.400075, -1.4486395, -1.3622522, -1.3377004], + [-1.4000794, -1.4486222, -1.3622293, -1.3377339], + [-1.4000806, -1.4486088, -1.3622129, -1.3377609], + [-1.4000823, -1.448607, -1.3622097, -1.3377641], + ] ) static let expectedDecoded = Tensor( - [-6.5631247,-4.930211 ] + [-6.5631247, -4.930211] ) static let lattice = Lattice( positions: [ Lattice.Node( bestEdge: nil, bestScore: 0.0, - edges: [ - ], + edges: [], semiringScore: SemiRing(logp: 0.0, logr: -Float.infinity) ), Lattice.Node( @@ -215,7 +250,7 @@ enum Example1 { logp: -3.936295986175537, score: SemiRing(logp: -3.936295986175537, logr: -3.936295986175537), totalScore: SemiRing(logp: -7.848533630371094, logr: -7.155386447906494) - ) + ), ], semiringScore: SemiRing(logp: -2.051520824432373, logr: 1.411364197731018) ), @@ -246,7 +281,7 @@ enum Example1 { logp: -3.912229061126709, score: SemiRing(logp: -3.912229061126709, logr: -3.912229061126709), totalScore: SemiRing(logp: -5.963749885559082, logr: -2.4700069427490234) - ) + ), ], semiringScore: SemiRing(logp: -5.1324286460876465, logr: -1.0695605278015137) ), @@ -285,145 +320,174 @@ enum Example1 { logp: -3.936281204223633, score: SemiRing(logp: -3.936281204223633, logr: -3.936281204223633), totalScore: SemiRing(logp: -9.068710327148438, logr: -4.98878812789917) - ) + ), ], semiringScore: SemiRing(logp: -4.090378284454346, logr: 0.1802457720041275) - ) + ), ] ) static let gradWrtLogR = SNLMParameters( emb_enc: EmbeddingParameters( weight: Tensor( - [[-1.0885849e-05, 2.5600420e-05], - [-2.0048559e-05, 5.2774609e-05], - [-2.1089169e-05, 6.0592978e-05], - [ 0.0000000e+00, 0.0000000e+00], - [ 0.0000000e+00, 0.0000000e+00]] + [ + [-1.0885849e-05, 2.5600420e-05], + [-2.0048559e-05, 5.2774609e-05], + [-2.1089169e-05, 6.0592978e-05], + [0.0000000e+00, 0.0000000e+00], + [0.0000000e+00, 0.0000000e+00], + ] ) ), lstm_enc: LSTMParameters( weight_ih_l0: Tensor( - [[-5.5954405e-07, 2.1259700e-07], - [-9.2493622e-08, 1.2687329e-07], - [ 4.8279912e-07,-5.5000936e-07], - [-1.1893751e-07, 9.6952142e-08], - [ 1.7744465e-05,-6.3974912e-06], - [ 2.3670214e-05,-6.8028967e-06], - [ 6.9926745e-07, 1.0588951e-07], - [-3.6788197e-07, 1.1367500e-07]] + [ + [-5.5954405e-07, 2.1259700e-07], + [-9.2493622e-08, 1.2687329e-07], + [4.8279912e-07, -5.5000936e-07], + [-1.1893751e-07, 9.6952142e-08], + [1.7744465e-05, -6.3974912e-06], + [2.3670214e-05, -6.8028967e-06], + [6.9926745e-07, 1.0588951e-07], + [-3.6788197e-07, 1.1367500e-07], + ] ), weight_hh_l0: Tensor( - [[-6.1557824e-07, 7.9759928e-08], - [ 1.8759494e-07,-2.4723642e-08], - [-3.8983197e-07, 5.1562441e-08], - [ 7.6582531e-08,-1.0336059e-08], - [ 1.6366921e-05,-2.1041824e-06], - [ 2.5587158e-05,-3.2796638e-06], - [-7.9833825e-07, 1.1357896e-07], - [ 2.2801314e-07,-3.2397072e-08]] + [ + [-6.1557824e-07, 7.9759928e-08], + [1.8759494e-07, -2.4723642e-08], + [-3.8983197e-07, 5.1562441e-08], + [7.6582531e-08, -1.0336059e-08], + [1.6366921e-05, -2.1041824e-06], + [2.5587158e-05, -3.2796638e-06], + [-7.9833825e-07, 1.1357896e-07], + [2.2801314e-07, -3.2397072e-08], + ] ), bias_ih_l0: Tensor( - [ 5.0008435e-05,-1.0976097e-05, 1.7231478e-05,-3.3032948e-06, - -1.3631615e-03,-2.0683720e-03, 5.0045674e-05,-1.1407620e-05] + [ + 5.0008435e-05, -1.0976097e-05, 1.7231478e-05, -3.3032948e-06, + -1.3631615e-03, -2.0683720e-03, 5.0045674e-05, -1.1407620e-05, + ] ), bias_hh_l0: Tensor( - [ 5.0008435e-05,-1.0976097e-05, 1.7231478e-05,-3.3032948e-06, - -1.3631615e-03,-2.0683720e-03, 5.0045674e-05,-1.1407620e-05] + [ + 5.0008435e-05, -1.0976097e-05, 1.7231478e-05, -3.3032948e-06, + -1.3631615e-03, -2.0683720e-03, 5.0045674e-05, -1.1407620e-05, + ] ) ), mlp_interpolation: MLPParameters( linear1: LinearParameters( weight: Tensor( - [[-2.162833e-04, 3.404662e-05], - [-1.626215e-03, 2.559960e-04]] + [ + [-2.162833e-04, 3.404662e-05], + [-1.626215e-03, 2.559960e-04], + ] ), bias: Tensor( - [0.008965784,0.06741227 ] + [0.008965784, 0.06741227] ) ), linear2: LinearParameters( weight: Tensor( - [[ 0.03779146 , 0.041938413], - [-0.03779146 ,-0.041938413]] + [ + [0.03779146, 0.041938413], + [-0.03779146, -0.041938413], + ] ), bias: Tensor( - [ 0.7893658,-0.7893658] + [0.7893658, -0.7893658] ) ) ), mlp_memory: MLPParameters( linear1: LinearParameters( weight: Tensor( - [[ 0.0006783868 ,-0.00010455749], - [ 0.002728255 ,-0.00042056653]] + [ + [0.0006783868, -0.00010455749], + [0.002728255, -0.00042056653], + ] ), bias: Tensor( - [-0.028684907,-0.11537111 ] + [-0.028684907, -0.11537111] ) ), linear2: LinearParameters( weight: Tensor( - [[-0.0098278085,-0.0017497203], - [-0.009362103 ,-0.0016668034], - [ 0.029616904 , 0.005273029 ], - [-0.010426991 ,-0.0018565056]] + [ + [-0.0098278085, -0.0017497203], + [-0.009362103, -0.0016668034], + [0.029616904, 0.005273029], + [-0.010426991, -0.0018565056], + ] ), bias: Tensor( - [-0.4213174 ,-0.40135252, 1.2696781 ,-0.44700825] + [-0.4213174, -0.40135252, 1.2696781, -0.44700825] ) ) ), emb_dec: EmbeddingParameters( weight: Tensor( - [[2.8222625e-04,1.9048888e-04], - [1.8905671e-04,1.5085124e-04], - [0.0000000e+00,0.0000000e+00], - [3.5629273e-06,2.5625954e-05], - [0.0000000e+00,0.0000000e+00]] + [ + [2.8222625e-04, 1.9048888e-04], + [1.8905671e-04, 1.5085124e-04], + [0.0000000e+00, 0.0000000e+00], + [3.5629273e-06, 2.5625954e-05], + [0.0000000e+00, 0.0000000e+00], + ] ) ), lstm_dec: LSTMParameters( weight_ih_l0: Tensor( - [[-1.3549015e-05,-1.4535895e-05], - [ 3.1118125e-07,-1.6063163e-07], - [-9.2411565e-06,-8.3111609e-06], - [ 3.1314161e-07,-9.0987783e-08], - [ 2.9817267e-04, 3.1956812e-04], - [ 2.4017896e-05,-1.0910597e-04], - [-1.9980842e-05,-2.4779467e-05], - [ 1.5680398e-07,-8.3687371e-07]] + [ + [-1.3549015e-05, -1.4535895e-05], + [3.1118125e-07, -1.6063163e-07], + [-9.2411565e-06, -8.3111609e-06], + [3.1314161e-07, -9.0987783e-08], + [2.9817267e-04, 3.1956812e-04], + [2.4017896e-05, -1.0910597e-04], + [-1.9980842e-05, -2.4779467e-05], + [1.5680398e-07, -8.3687371e-07], + ] ), weight_hh_l0: Tensor( - [[ 7.7975146e-06,-7.3286355e-07], - [-4.4332864e-07, 4.7302478e-08], - [ 7.1966651e-06,-7.2089733e-07], - [-3.4162869e-07, 4.1011177e-08], - [-1.7509436e-04, 1.6362043e-05], - [-1.0614634e-04, 1.1837798e-05], - [ 9.3842154e-06,-8.1990322e-07], - [-4.8998282e-07, 6.6971602e-08]] + [ + [7.7975146e-06, -7.3286355e-07], + [-4.4332864e-07, 4.7302478e-08], + [7.1966651e-06, -7.2089733e-07], + [-3.4162869e-07, 4.1011177e-08], + [-1.7509436e-04, 1.6362043e-05], + [-1.0614634e-04, 1.1837798e-05], + [9.3842154e-06, -8.1990322e-07], + [-4.8998282e-07, 6.6971602e-08], + ] ), bias_ih_l0: Tensor( - [-1.8724482e-04, 1.3764327e-05,-1.8988468e-04, 8.8408688e-06, - 4.2812643e-03, 3.6167959e-03,-2.0367328e-04, 1.3456454e-05] + [ + -1.8724482e-04, 1.3764327e-05, -1.8988468e-04, 8.8408688e-06, + 4.2812643e-03, 3.6167959e-03, -2.0367328e-04, 1.3456454e-05, + ] ), bias_hh_l0: Tensor( - [-1.8724482e-04, 1.3764327e-05,-1.8988468e-04, 8.8408688e-06, - 4.2812643e-03, 3.6167959e-03,-2.0367328e-04, 1.3456454e-05] + [ + -1.8724482e-04, 1.3764327e-05, -1.8988468e-04, 8.8408688e-06, + 4.2812643e-03, 3.6167959e-03, -2.0367328e-04, 1.3456454e-05, + ] ) ), linear_dec: LinearParameters( weight: Tensor( - [[-0.0041010594 , 0.00012692134], - [-0.005898418 , 0.0008138743 ], - [ 0.006042951 ,-0.0006127681 ], - [-0.0020920308 , 0.00028523407], - [ 0.0060485564 ,-0.0006132616 ]] + [ + [-0.0041010594, 0.00012692134], + [-0.005898418, 0.0008138743], + [0.006042951, -0.0006127681], + [-0.0020920308, 0.00028523407], + [0.0060485564, -0.0006132616], + ] ), bias: Tensor( - [ 0.14542846 , 0.14904119 ,-0.16054428 , 0.026781656,-0.16070703 ] + [0.14542846, 0.14904119, -0.16054428, 0.026781656, -0.16070703] ) ) ) } - diff --git a/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift b/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift index eacbf4118e0..8e2f66718ee 100644 --- a/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift +++ b/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -import XCTest - import TensorFlow import TextModels +import XCTest extension SNLM { /// Sets the model parameters to the given parameters exported from the model. @@ -30,7 +29,9 @@ extension SNLM { } private func checkShapeAndSet(_ tensor: inout Tensor, to value: Tensor) { - assert(tensor.shape == value.shape, "shape mismatch while setting: \(tensor.shape) to \(value.shape)") + assert( + tensor.shape == value.shape, "shape mismatch while setting: \(tensor.shape) to \(value.shape)" + ) tensor = value } @@ -49,7 +50,7 @@ extension SNLM { fusedWeight.slice(lowerBounds: [0, 0], upperBounds: [i, j]), fusedWeight.slice(lowerBounds: [0, 2 * j], upperBounds: [i, 3 * j]), fusedWeight.slice(lowerBounds: [0, j], upperBounds: [i, 2 * j]), - fusedWeight.slice(lowerBounds: [0, 3 * j], upperBounds: [i, 4 * j]) + fusedWeight.slice(lowerBounds: [0, 3 * j], upperBounds: [i, 4 * j]), ], alongAxis: 1 ) @@ -61,7 +62,7 @@ extension SNLM { fusedBias.slice(lowerBounds: [0], upperBounds: [k]), fusedBias.slice(lowerBounds: [2 * k], upperBounds: [3 * k]), fusedBias.slice(lowerBounds: [k], upperBounds: [2 * k]), - fusedBias.slice(lowerBounds: [3 * k], upperBounds: [4 * k]) + fusedBias.slice(lowerBounds: [3 * k], upperBounds: [4 * k]), ] ) @@ -95,7 +96,9 @@ func tangentVector(from gradient: SNLMParameters, model: SNLM) -> SNLM.TangentVe return model.differentiableVectorView } -func almostEqual(_ lhs: SNLM.TangentVector, _ rhs: SNLM.TangentVector, relTol: Float, zeroTol: Float) -> Bool { +func almostEqual( + _ lhs: SNLM.TangentVector, _ rhs: SNLM.TangentVector, relTol: Float, zeroTol: Float +) -> Bool { var success = true for (kpIndex, kp) in lhs.recursivelyAllKeyPaths(to: Tensor.self).enumerated() { let t1 = lhs[keyPath: kp] @@ -106,7 +109,10 @@ func almostEqual(_ lhs: SNLM.TangentVector, _ rhs: SNLM.TangentVector, relTol: F continue } for (elementIndex, (t1e, t2e)) in zip(t1.scalars, t2.scalars).enumerated() { - if !(t1e.isAlmostEqual(to: t2e, tolerance: relTol) || (t1e.isAlmostZero(absoluteTolerance: zeroTol) && t2e.isAlmostZero(absoluteTolerance: zeroTol))) { + if !(t1e.isAlmostEqual(to: t2e, tolerance: relTol) + || (t1e.isAlmostZero(absoluteTolerance: zeroTol) + && t2e.isAlmostZero(absoluteTolerance: zeroTol))) + { print("Mismatch on tensor \(kpIndex) element \(elementIndex): \(t1e) & \(t2e)") success = false } @@ -123,34 +129,36 @@ class ProbeLayerTests: XCTestCase { // 2 - // 3 - // 4 - - let chrVocab: Alphabet = Alphabet([ - "a", - "b", - ], eos: "", eow: "", pad: "") - + let chrVocab: Alphabet = Alphabet( + [ + "a", + "b", + ], eos: "", eow: "", pad: "") // strVocab is: // 0 - aaaa // 1 - bbbb // 2 - abab let strVocab: Lexicon = Lexicon([ - CharacterSequence(alphabet: chrVocab, characters: [0, 0]), // "aa" - CharacterSequence(alphabet: chrVocab, characters: [1, 1]), // "bb" - CharacterSequence(alphabet: chrVocab, characters: [0, 1]), // "ab" - CharacterSequence(alphabet: chrVocab, characters: [1, 0]) // "ba" + CharacterSequence(alphabet: chrVocab, characters: [0, 0]), // "aa" + CharacterSequence(alphabet: chrVocab, characters: [1, 1]), // "bb" + CharacterSequence(alphabet: chrVocab, characters: [0, 1]), // "ab" + CharacterSequence(alphabet: chrVocab, characters: [1, 0]), // "ba" ]) - var model = SNLM(conf: Conf( - ndim: 2, - dropoutProb: 0, - chrVocab: chrVocab, - strVocab: strVocab, - order: 5)) + var model = SNLM( + parameters: SNLM.Parameters( + ndim: 2, + dropoutProb: 0, + chrVocab: chrVocab, + strVocab: strVocab, + order: 5)) model.setParameters(Example1.parameters) print("Encoding") - let encoderStates = model.encode(CharacterSequence(alphabet: chrVocab, characters: [0, 1, 0, 1])) // "abab" + let encoderStates = model.encode( + CharacterSequence(alphabet: chrVocab, characters: [0, 1, 0, 1])) // "abab" let encoderStatesTensor = Tensor(stacking: encoderStates) print("Expected: \(Example1.expectedEncoding)") print("Actual: \(encoderStatesTensor)") @@ -161,7 +169,9 @@ class ProbeLayerTests: XCTestCase { let mlpInterpolationOutput = model.mlpInterpolation(encoderStatesTensor) print("Expected: \(Example1.expectedMLPInterpolationOutput)") print("Actual: \(encoderStates)") - XCTAssert(abs(mlpInterpolationOutput - Example1.expectedMLPInterpolationOutput).max().scalarized() < 1e-6) + XCTAssert( + abs(mlpInterpolationOutput - Example1.expectedMLPInterpolationOutput).max().scalarized() + < 1e-6) print("OK!\n") print("MLP Memory") @@ -174,8 +184,8 @@ class ProbeLayerTests: XCTestCase { print("Decode") let decoded = model.decode( [ - CharacterSequence(alphabet: chrVocab, characters: [0, 0, 0]), // "aaa" - CharacterSequence(alphabet: chrVocab, characters: [0, 1]) // "ab" + CharacterSequence(alphabet: chrVocab, characters: [0, 0, 0]), // "aaa" + CharacterSequence(alphabet: chrVocab, characters: [0, 1]), // "ab" ], encoderStates[0] ) diff --git a/Tests/TextTests/WordSegmentationTests/WordSegmentationTests.swift b/Tests/TextTests/WordSegmentationTests/WordSegmentationTests.swift index 3679ce1bb53..8303b98e551 100644 --- a/Tests/TextTests/WordSegmentationTests/WordSegmentationTests.swift +++ b/Tests/TextTests/WordSegmentationTests/WordSegmentationTests.swift @@ -14,21 +14,21 @@ import XCTest -@testable -import TextModels +@testable import TextModels class DataSetTests: XCTestCase { func test_DataSetLoad() { let buffer: [UInt8] = [ - 0x61, 0x6c, 0x70, 0x68, 0x61, 0x0a, // alpha. + 0x61, 0x6c, 0x70, 0x68, 0x61, 0x0a, // alpha. ] var dataset: DataSet? buffer.withUnsafeBytes { pointer in guard let address = pointer.baseAddress else { return } let training: Data = - Data(bytesNoCopy: UnsafeMutableRawPointer(mutating: address), - count: pointer.count, deallocator: .none) + Data( + bytesNoCopy: UnsafeMutableRawPointer(mutating: address), + count: pointer.count, deallocator: .none) dataset = try? DataSet(training: training, validation: nil, testing: nil) } @@ -45,7 +45,7 @@ class DataSetTests: XCTestCase { class SemiRingTests: XCTestCase { func test_SemiRingAdd() { let value: SemiRing = - SemiRing(logp: 1.0, logr: 2.0) + SemiRing(logp: 3.0, logr: 4.0) + SemiRing(logp: 1.0, logr: 2.0) + SemiRing(logp: 3.0, logr: 4.0) XCTAssertEqual(value.logp, 3.126928, accuracy: 0.000001) XCTAssertEqual(value.logr, 4.126928, accuracy: 0.000001) } @@ -82,7 +82,7 @@ class SemiRingTests: XCTestCase { func test_SemiRingMultiply() { let value: SemiRing = - SemiRing(logp: 1.0, logr: 2.0) * SemiRing(logp: 3.0, logr: 4.0) + SemiRing(logp: 1.0, logr: 2.0) * SemiRing(logp: 3.0, logr: 4.0) XCTAssertEqual(value.logp, 4.0) XCTAssertEqual(value.logr, 5.693147, accuracy: 0.000001) } @@ -100,11 +100,12 @@ class SemiRingTests: XCTestCase { class VocabularyTests: XCTestCase { func test_AlphabetaConstruct() { - let characters: Alphabet = Alphabet([ - "a", - "b", - "c", - ], eos: "", eow: "", pad: "") + let characters: Alphabet = Alphabet( + [ + "a", + "b", + "c", + ], eos: "", eow: "", pad: "") // a, b, c, EOS, EOW, PAD XCTAssertEqual(characters.count, 6) XCTAssertEqual(characters.eos, 3) @@ -113,24 +114,28 @@ class VocabularyTests: XCTestCase { } func test_CharacterSequenceConstruct() { - let characters: Alphabet = Alphabet([ - "a", - "c", - "t", - ], eos: "", eow: "", pad: "") - let cat: CharacterSequence? = try? CharacterSequence(alphabet: characters, appendingEoSTo: "cat") + let characters: Alphabet = Alphabet( + [ + "a", + "c", + "t", + ], eos: "", eow: "", pad: "") + let cat: CharacterSequence? = try? CharacterSequence( + alphabet: characters, appendingEoSTo: "cat") XCTAssertNotEqual(cat, nil) // FIXME(abdulras) should the EoS be visible? XCTAssertEqual(cat?.characters, [Int32(1), Int32(0), Int32(2), characters.eos]) - let bat: CharacterSequence? = try? CharacterSequence(alphabet: characters, appendingEoSTo: "bat") + let bat: CharacterSequence? = try? CharacterSequence( + alphabet: characters, appendingEoSTo: "bat") XCTAssertEqual(bat, nil) } func test_LexiconConstruct() { - let characters: Alphabet = Alphabet([ - "a", "b", "e", "g", "h", "l", "m", "p", "t", - ], eos: "", eow: "", pad: "") + let characters: Alphabet = Alphabet( + [ + "a", "b", "e", "g", "h", "l", "m", "p", "t", + ], eos: "", eow: "", pad: "") let strings: Lexicon = Lexicon([ try! CharacterSequence(alphabet: characters, appendingEoSTo: "alpha"), try! CharacterSequence(alphabet: characters, appendingEoSTo: "beta"), @@ -141,14 +146,16 @@ class VocabularyTests: XCTestCase { } func test_LexiconFromSequences() { - let alphabet: Alphabet = Alphabet([ - "a", "b", "e", "g", "h", "l", "m", "p", "t", - ], eos: "", eow: "", pad: "") - let lexicon: Lexicon = Lexicon(from: [ - try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "alpha"), - try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "beta"), - try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "gamma"), - ], alphabet: alphabet, maxLength: 5, minFreq: 4) + let alphabet: Alphabet = Alphabet( + [ + "a", "b", "e", "g", "h", "l", "m", "p", "t", + ], eos: "", eow: "", pad: "") + let lexicon: Lexicon = Lexicon( + from: [ + try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "alpha"), + try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "beta"), + try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "gamma"), + ], alphabet: alphabet, maxLength: 5, minFreq: 4) XCTAssertEqual(lexicon.count, 3) } diff --git a/Tests/TextTests/XCTestManifests.swift b/Tests/TextTests/XCTestManifests.swift index 24f05d8ad36..95d2ceb858b 100644 --- a/Tests/TextTests/XCTestManifests.swift +++ b/Tests/TextTests/XCTestManifests.swift @@ -15,13 +15,13 @@ import XCTest #if !os(macOS) - public func allTests() -> [XCTestCaseEntry] { - return [ - testCase(TextInferenceTests.allTests), - testCase(DataSetTests.allTests), - testCase(SemiRingTests.allTests), - testCase(VocabularyTests.allTests), - testCase(ProbeLayerTests.allTests), - ] - } + public func allTests() -> [XCTestCaseEntry] { + return [ + testCase(TextInferenceTests.allTests), + testCase(DataSetTests.allTests), + testCase(SemiRingTests.allTests), + testCase(VocabularyTests.allTests), + testCase(ProbeLayerTests.allTests), + ] + } #endif From 33a059913e82401f2f551fad30068f3f7c3fcd46 Mon Sep 17 00:00:00 2001 From: Michelle Casbon Date: Fri, 8 May 2020 19:04:39 -0400 Subject: [PATCH 5/5] Make DataSet internal --- Models/Text/WordSeg/DataSet.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Models/Text/WordSeg/DataSet.swift b/Models/Text/WordSeg/DataSet.swift index cbef8bd6c99..2e8b532c103 100644 --- a/Models/Text/WordSeg/DataSet.swift +++ b/Models/Text/WordSeg/DataSet.swift @@ -14,7 +14,7 @@ import Foundation -public struct DataSet { +internal struct DataSet { public let training: [CharacterSequence] public private(set) var testing: [CharacterSequence]? public private(set) var validation: [CharacterSequence]?