Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

WordSeg model and tests #498

Merged
merged 5 commits into from
May 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Models/Text/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ add_library(TextModels
ScheduledParameters.swift
TransformerBERT.swift
Utilities.swift
WeightDecayedAdam.swift)
WeightDecayedAdam.swift
WordSeg/DataSet.swift
WordSeg/Model.swift
WordSeg/Lattice.swift
WordSeg/SE-0259.swift
WordSeg/SemiRing.swift
WordSeg/Vocabularies.swift)
set_target_properties(TextModels PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
target_compile_options(TextModels PRIVATE
Expand Down
131 changes: 131 additions & 0 deletions Models/Text/WordSeg/DataSet.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

internal struct DataSet {
public let training: [CharacterSequence]
public private(set) var testing: [CharacterSequence]?
public private(set) var validation: [CharacterSequence]?
public let alphabet: Alphabet

private static func load(data: Data) throws -> [String] {
guard let contents: String = String(data: data, encoding: .utf8) else {
throw CharacterErrors.nonUtf8Data
}
return load(contents: contents)
}

private static func load(contents: String) -> [String] {
var strings = [String]()

for line in contents.components(separatedBy: .newlines) {
let stripped: String = line.components(separatedBy: .whitespaces).joined()
if stripped.isEmpty { continue }
strings.append(stripped)
}
return strings
}

private static func makeAlphabet(
datasets training: [String],
_ otherSequences: [String]?...,
eos: String = "</s>",
eow: String = "</w>",
pad: String = "</pad>"
) -> Alphabet {
var letters: Set<Character> = []

for dataset in otherSequences + [training] {
guard let dataset = dataset else { continue }
for sentence in dataset {
for character in sentence {
letters.insert(character)
}
}
}

// Sort the letters to make it easier to interpret ints vs letters.
var sorted = Array(letters)
sorted.sort()

return Alphabet(sorted, eos: eos, eow: eow, pad: pad)
}

private static func convertDataset(_ dataset: [String], alphabet: Alphabet) throws
-> [CharacterSequence]
{
return try dataset.map { try CharacterSequence(alphabet: alphabet, appendingEoSTo: $0) }
}
private static func convertDataset(_ dataset: [String]?, alphabet: Alphabet) throws
-> [CharacterSequence]?
{
if let ds = dataset {
let tmp: [CharacterSequence] = try convertDataset(ds, alphabet: alphabet) // Use tmp to disambiguate function
return tmp
}
return nil
}

public init?(
training trainingFile: String,
validation validationFile: String? = nil,
testing testingFile: String? = nil
) throws {
let trainingData = try Data(
contentsOf: URL(fileURLWithPath: trainingFile),
options: .alwaysMapped)
let training = try Self.load(data: trainingData)

var validation: [String]? = nil
var testing: [String]? = nil

if let validationFile = validationFile {
let data = try Data(
contentsOf: URL(fileURLWithPath: validationFile),
options: .alwaysMapped)
validation = try Self.load(data: data)
}

if let testingFile = testingFile {
let data: Data = try Data(
contentsOf: URL(fileURLWithPath: testingFile),
options: .alwaysMapped)
testing = try Self.load(data: data)
}
self.alphabet = Self.makeAlphabet(datasets: training, validation, testing)
self.training = try Self.convertDataset(training, alphabet: self.alphabet)
self.validation = try Self.convertDataset(validation, alphabet: self.alphabet)
self.testing = try Self.convertDataset(testing, alphabet: self.alphabet)
}

init(training trainingData: Data, validation validationData: Data?, testing testingData: Data?)
throws
{
let training = try Self.load(data: trainingData)
var validation: [String]? = nil
var testing: [String]? = nil
if let validationData = validationData {
validation = try Self.load(data: validationData)
}
if let testingData = testingData {
testing = try Self.load(data: testingData)
}

self.alphabet = Self.makeAlphabet(datasets: training, validation, testing)
self.training = try Self.convertDataset(training, alphabet: self.alphabet)
self.validation = try Self.convertDataset(validation, alphabet: self.alphabet)
self.testing = try Self.convertDataset(testing, alphabet: self.alphabet)
}
}
231 changes: 231 additions & 0 deletions Models/Text/WordSeg/Lattice.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import TensorFlow

#if os(iOS) || os(macOS) || os(tvOS) || os(watchOS)
import Darwin
#elseif os(Windows)
import ucrt
#else
import Glibc
#endif

/// Lattice
///
/// Represents the lattice used by the WordSeg algorithm.
public struct Lattice: Differentiable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need Lattice to be public? I think that that this should probably be made internal.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing this scope looks to be a bit more involved. Punting to #499

/// Edge
///
/// Represents an Edge
public struct Edge: Differentiable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar

@noDerivative public var start: Int
@noDerivative public var end: Int
@noDerivative public var string: CharacterSequence
public var logp: Float

// expectation
public var score: SemiRing
public var totalScore: SemiRing

@differentiable
init(
start: Int, end: Int, sentence: CharacterSequence, logp: Float,
previous: SemiRing, order: Int
) {
self.start = start
self.end = end
self.string = sentence
self.logp = logp

self.score =
SemiRing(
logp: logp,
// TODO(abdulras): this should really use integeral pow
logr: logp + logf(powf(Float(sentence.count), Float(order))))
self.totalScore = self.score * previous
}

@differentiable
public init(
start: Int, end: Int, string: CharacterSequence, logp: Float,
score: SemiRing, totalScore: SemiRing
) {
self.start = start
self.end = end
self.string = string
self.logp = logp
self.score = score
self.totalScore = totalScore
}
}

/// Node
///
/// Represents a node in the lattice
public struct Node: Differentiable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar

@noDerivative public var bestEdge: Edge?
public var bestScore: Float = 0.0
public var edges = [Edge]()
public var semiringScore: SemiRing = SemiRing.one

init() {}

@differentiable
public init(
bestEdge: Edge?, bestScore: Float, edges: [Edge],
semiringScore: SemiRing
) {
self.bestEdge = bestEdge
self.bestScore = bestScore
self.edges = edges
self.semiringScore = semiringScore
}

@differentiable
func computeSemiringScore() -> SemiRing {
// TODO: Reduceinto and +=
edges.differentiableMap { $0.totalScore }.differentiableReduce(SemiRing.zero) { $0 + $1 }
}
}

var positions: [Node]

@differentiable
public subscript(index: Int) -> Node {
get { return positions[index] }
set(v) { positions[index] = v }
//_modify { yield &positions[index] }
}

// TODO: remove dummy. (workaround to make AD thing that lattice is varied)
@differentiable
init(count: Int, _ dummy: Tensor<Float>) {
positions = Array(repeating: Node(), count: count + 1)
}

public init(positions: [Node]) {
self.positions = positions
}

mutating func viterbi(sentence: String) -> [Edge] {
// Forwards pass
for position in 0...sentence.count {
var bestScore = -Float.infinity
var bestEdge: Edge!
for edge in self[position].edges {
let score: Float = self[edge.start].bestScore + edge.logp
if score > bestScore {
bestScore = score
bestEdge = edge
}
}
self[position].bestScore = bestScore
self[position].bestEdge = bestEdge
}

// Backwards
var bestPath: [Edge] = []
var nextEdge = self[sentence.count].bestEdge!
while nextEdge.start != 0 {
bestPath.append(nextEdge)
nextEdge = self[nextEdge.start].bestEdge!
}
bestPath.append(nextEdge)

return bestPath.reversed()
}
}

extension Lattice: CustomStringConvertible {
public var description: String {
"""
[
\(positions.enumerated().map { " \($0.0): \($0.1)" }.joined(separator: "\n\n"))
]
"""
}
}

extension Lattice.Node: CustomStringConvertible {
public var description: String {
var edgesStr: String
if edges.isEmpty {
edgesStr = " <no edges>"
} else {
edgesStr = edges.enumerated().map { " \($0.0) - \($0.1)" }.joined(separator: "\n")
}
return """
best edge: \(String(describing: bestEdge)), best score: \(bestScore), score: \(semiringScore.shortDescription)
\(edgesStr)
"""
}
}

extension Lattice.Edge: CustomStringConvertible {
public var description: String {
"[\(start)->\(end)] logp: \(logp), score: \(score.shortDescription), total score: \(totalScore.shortDescription), sentence: \(string)"
}
}

/// SE-0259-esque equality with tolerance
extension Lattice {
public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool {
guard self.positions.count == other.positions.count else {
print("positions count mismatch: \(self.positions.count) != \(other.positions.count)")
return false
}
return zip(self.positions, other.positions).enumerated()
.map { (index, position) in
let eq = position.0.isAlmostEqual(to: position.1, tolerance: tolerance)
if !eq {
print("mismatch at \(index): \(position.0) != \(position.1)")
}
return eq
}
.reduce(true) { $0 && $1 }
}
}

extension Lattice.Node {
public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool {
guard self.edges.count == other.edges.count else { return false }

if !self.bestScore.isAlmostEqual(to: other.bestScore, tolerance: tolerance) {
return false
}
if let lhs = self.bestEdge, let rhs = other.bestEdge {
if !lhs.isAlmostEqual(to: rhs, tolerance: tolerance) {
return false
}
}
if !self.semiringScore.isAlmostEqual(to: other.semiringScore, tolerance: tolerance) {
return false
}
return zip(self.edges, other.edges)
.map { $0.isAlmostEqual(to: $1, tolerance: tolerance) }
.reduce(true) { $0 && $1 }
}
}

extension Lattice.Edge {
public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool {
return self.start == other.start && self.end == other.end
// TODO: figure out why the string equality is being ignored
// self.string == other.string &&
&& self.logp.isAlmostEqual(to: other.logp, tolerance: tolerance)
&& self.score.isAlmostEqual(to: other.score, tolerance: tolerance)
&& self.totalScore.isAlmostEqual(to: other.totalScore, tolerance: tolerance)
}
}
Loading