Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 62f932d

Browse files
WordSeg model and tests (#498)
* WordSeg model and tests * Add model files to CMakeLists * Rename structs and remove commented code * Rename Conf -> SNLM.Parameters, lint * Make DataSet internal
1 parent fa0ca07 commit 62f932d

14 files changed

+2208
-6
lines changed

Models/Text/CMakeLists.txt

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@ add_library(TextModels
1212
ScheduledParameters.swift
1313
TransformerBERT.swift
1414
Utilities.swift
15-
WeightDecayedAdam.swift)
15+
WeightDecayedAdam.swift
16+
WordSeg/DataSet.swift
17+
WordSeg/Model.swift
18+
WordSeg/Lattice.swift
19+
WordSeg/SE-0259.swift
20+
WordSeg/SemiRing.swift
21+
WordSeg/Vocabularies.swift)
1622
set_target_properties(TextModels PROPERTIES
1723
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
1824
target_compile_options(TextModels PRIVATE

Models/Text/WordSeg/DataSet.swift

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import Foundation
16+
17+
internal struct DataSet {
18+
public let training: [CharacterSequence]
19+
public private(set) var testing: [CharacterSequence]?
20+
public private(set) var validation: [CharacterSequence]?
21+
public let alphabet: Alphabet
22+
23+
private static func load(data: Data) throws -> [String] {
24+
guard let contents: String = String(data: data, encoding: .utf8) else {
25+
throw CharacterErrors.nonUtf8Data
26+
}
27+
return load(contents: contents)
28+
}
29+
30+
private static func load(contents: String) -> [String] {
31+
var strings = [String]()
32+
33+
for line in contents.components(separatedBy: .newlines) {
34+
let stripped: String = line.components(separatedBy: .whitespaces).joined()
35+
if stripped.isEmpty { continue }
36+
strings.append(stripped)
37+
}
38+
return strings
39+
}
40+
41+
private static func makeAlphabet(
42+
datasets training: [String],
43+
_ otherSequences: [String]?...,
44+
eos: String = "</s>",
45+
eow: String = "</w>",
46+
pad: String = "</pad>"
47+
) -> Alphabet {
48+
var letters: Set<Character> = []
49+
50+
for dataset in otherSequences + [training] {
51+
guard let dataset = dataset else { continue }
52+
for sentence in dataset {
53+
for character in sentence {
54+
letters.insert(character)
55+
}
56+
}
57+
}
58+
59+
// Sort the letters to make it easier to interpret ints vs letters.
60+
var sorted = Array(letters)
61+
sorted.sort()
62+
63+
return Alphabet(sorted, eos: eos, eow: eow, pad: pad)
64+
}
65+
66+
private static func convertDataset(_ dataset: [String], alphabet: Alphabet) throws
67+
-> [CharacterSequence]
68+
{
69+
return try dataset.map { try CharacterSequence(alphabet: alphabet, appendingEoSTo: $0) }
70+
}
71+
private static func convertDataset(_ dataset: [String]?, alphabet: Alphabet) throws
72+
-> [CharacterSequence]?
73+
{
74+
if let ds = dataset {
75+
let tmp: [CharacterSequence] = try convertDataset(ds, alphabet: alphabet) // Use tmp to disambiguate function
76+
return tmp
77+
}
78+
return nil
79+
}
80+
81+
public init?(
82+
training trainingFile: String,
83+
validation validationFile: String? = nil,
84+
testing testingFile: String? = nil
85+
) throws {
86+
let trainingData = try Data(
87+
contentsOf: URL(fileURLWithPath: trainingFile),
88+
options: .alwaysMapped)
89+
let training = try Self.load(data: trainingData)
90+
91+
var validation: [String]? = nil
92+
var testing: [String]? = nil
93+
94+
if let validationFile = validationFile {
95+
let data = try Data(
96+
contentsOf: URL(fileURLWithPath: validationFile),
97+
options: .alwaysMapped)
98+
validation = try Self.load(data: data)
99+
}
100+
101+
if let testingFile = testingFile {
102+
let data: Data = try Data(
103+
contentsOf: URL(fileURLWithPath: testingFile),
104+
options: .alwaysMapped)
105+
testing = try Self.load(data: data)
106+
}
107+
self.alphabet = Self.makeAlphabet(datasets: training, validation, testing)
108+
self.training = try Self.convertDataset(training, alphabet: self.alphabet)
109+
self.validation = try Self.convertDataset(validation, alphabet: self.alphabet)
110+
self.testing = try Self.convertDataset(testing, alphabet: self.alphabet)
111+
}
112+
113+
init(training trainingData: Data, validation validationData: Data?, testing testingData: Data?)
114+
throws
115+
{
116+
let training = try Self.load(data: trainingData)
117+
var validation: [String]? = nil
118+
var testing: [String]? = nil
119+
if let validationData = validationData {
120+
validation = try Self.load(data: validationData)
121+
}
122+
if let testingData = testingData {
123+
testing = try Self.load(data: testingData)
124+
}
125+
126+
self.alphabet = Self.makeAlphabet(datasets: training, validation, testing)
127+
self.training = try Self.convertDataset(training, alphabet: self.alphabet)
128+
self.validation = try Self.convertDataset(validation, alphabet: self.alphabet)
129+
self.testing = try Self.convertDataset(testing, alphabet: self.alphabet)
130+
}
131+
}

Models/Text/WordSeg/Lattice.swift

+231
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
#if os(iOS) || os(macOS) || os(tvOS) || os(watchOS)
18+
import Darwin
19+
#elseif os(Windows)
20+
import ucrt
21+
#else
22+
import Glibc
23+
#endif
24+
25+
/// Lattice
26+
///
27+
/// Represents the lattice used by the WordSeg algorithm.
28+
public struct Lattice: Differentiable {
29+
/// Edge
30+
///
31+
/// Represents an Edge
32+
public struct Edge: Differentiable {
33+
@noDerivative public var start: Int
34+
@noDerivative public var end: Int
35+
@noDerivative public var string: CharacterSequence
36+
public var logp: Float
37+
38+
// expectation
39+
public var score: SemiRing
40+
public var totalScore: SemiRing
41+
42+
@differentiable
43+
init(
44+
start: Int, end: Int, sentence: CharacterSequence, logp: Float,
45+
previous: SemiRing, order: Int
46+
) {
47+
self.start = start
48+
self.end = end
49+
self.string = sentence
50+
self.logp = logp
51+
52+
self.score =
53+
SemiRing(
54+
logp: logp,
55+
// TODO(abdulras): this should really use integeral pow
56+
logr: logp + logf(powf(Float(sentence.count), Float(order))))
57+
self.totalScore = self.score * previous
58+
}
59+
60+
@differentiable
61+
public init(
62+
start: Int, end: Int, string: CharacterSequence, logp: Float,
63+
score: SemiRing, totalScore: SemiRing
64+
) {
65+
self.start = start
66+
self.end = end
67+
self.string = string
68+
self.logp = logp
69+
self.score = score
70+
self.totalScore = totalScore
71+
}
72+
}
73+
74+
/// Node
75+
///
76+
/// Represents a node in the lattice
77+
public struct Node: Differentiable {
78+
@noDerivative public var bestEdge: Edge?
79+
public var bestScore: Float = 0.0
80+
public var edges = [Edge]()
81+
public var semiringScore: SemiRing = SemiRing.one
82+
83+
init() {}
84+
85+
@differentiable
86+
public init(
87+
bestEdge: Edge?, bestScore: Float, edges: [Edge],
88+
semiringScore: SemiRing
89+
) {
90+
self.bestEdge = bestEdge
91+
self.bestScore = bestScore
92+
self.edges = edges
93+
self.semiringScore = semiringScore
94+
}
95+
96+
@differentiable
97+
func computeSemiringScore() -> SemiRing {
98+
// TODO: Reduceinto and +=
99+
edges.differentiableMap { $0.totalScore }.differentiableReduce(SemiRing.zero) { $0 + $1 }
100+
}
101+
}
102+
103+
var positions: [Node]
104+
105+
@differentiable
106+
public subscript(index: Int) -> Node {
107+
get { return positions[index] }
108+
set(v) { positions[index] = v }
109+
//_modify { yield &positions[index] }
110+
}
111+
112+
// TODO: remove dummy. (workaround to make AD thing that lattice is varied)
113+
@differentiable
114+
init(count: Int, _ dummy: Tensor<Float>) {
115+
positions = Array(repeating: Node(), count: count + 1)
116+
}
117+
118+
public init(positions: [Node]) {
119+
self.positions = positions
120+
}
121+
122+
mutating func viterbi(sentence: String) -> [Edge] {
123+
// Forwards pass
124+
for position in 0...sentence.count {
125+
var bestScore = -Float.infinity
126+
var bestEdge: Edge!
127+
for edge in self[position].edges {
128+
let score: Float = self[edge.start].bestScore + edge.logp
129+
if score > bestScore {
130+
bestScore = score
131+
bestEdge = edge
132+
}
133+
}
134+
self[position].bestScore = bestScore
135+
self[position].bestEdge = bestEdge
136+
}
137+
138+
// Backwards
139+
var bestPath: [Edge] = []
140+
var nextEdge = self[sentence.count].bestEdge!
141+
while nextEdge.start != 0 {
142+
bestPath.append(nextEdge)
143+
nextEdge = self[nextEdge.start].bestEdge!
144+
}
145+
bestPath.append(nextEdge)
146+
147+
return bestPath.reversed()
148+
}
149+
}
150+
151+
extension Lattice: CustomStringConvertible {
152+
public var description: String {
153+
"""
154+
[
155+
\(positions.enumerated().map { " \($0.0): \($0.1)" }.joined(separator: "\n\n"))
156+
]
157+
"""
158+
}
159+
}
160+
161+
extension Lattice.Node: CustomStringConvertible {
162+
public var description: String {
163+
var edgesStr: String
164+
if edges.isEmpty {
165+
edgesStr = " <no edges>"
166+
} else {
167+
edgesStr = edges.enumerated().map { " \($0.0) - \($0.1)" }.joined(separator: "\n")
168+
}
169+
return """
170+
best edge: \(String(describing: bestEdge)), best score: \(bestScore), score: \(semiringScore.shortDescription)
171+
\(edgesStr)
172+
"""
173+
}
174+
}
175+
176+
extension Lattice.Edge: CustomStringConvertible {
177+
public var description: String {
178+
"[\(start)->\(end)] logp: \(logp), score: \(score.shortDescription), total score: \(totalScore.shortDescription), sentence: \(string)"
179+
}
180+
}
181+
182+
/// SE-0259-esque equality with tolerance
183+
extension Lattice {
184+
public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool {
185+
guard self.positions.count == other.positions.count else {
186+
print("positions count mismatch: \(self.positions.count) != \(other.positions.count)")
187+
return false
188+
}
189+
return zip(self.positions, other.positions).enumerated()
190+
.map { (index, position) in
191+
let eq = position.0.isAlmostEqual(to: position.1, tolerance: tolerance)
192+
if !eq {
193+
print("mismatch at \(index): \(position.0) != \(position.1)")
194+
}
195+
return eq
196+
}
197+
.reduce(true) { $0 && $1 }
198+
}
199+
}
200+
201+
extension Lattice.Node {
202+
public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool {
203+
guard self.edges.count == other.edges.count else { return false }
204+
205+
if !self.bestScore.isAlmostEqual(to: other.bestScore, tolerance: tolerance) {
206+
return false
207+
}
208+
if let lhs = self.bestEdge, let rhs = other.bestEdge {
209+
if !lhs.isAlmostEqual(to: rhs, tolerance: tolerance) {
210+
return false
211+
}
212+
}
213+
if !self.semiringScore.isAlmostEqual(to: other.semiringScore, tolerance: tolerance) {
214+
return false
215+
}
216+
return zip(self.edges, other.edges)
217+
.map { $0.isAlmostEqual(to: $1, tolerance: tolerance) }
218+
.reduce(true) { $0 && $1 }
219+
}
220+
}
221+
222+
extension Lattice.Edge {
223+
public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool {
224+
return self.start == other.start && self.end == other.end
225+
// TODO: figure out why the string equality is being ignored
226+
// self.string == other.string &&
227+
&& self.logp.isAlmostEqual(to: other.logp, tolerance: tolerance)
228+
&& self.score.isAlmostEqual(to: other.score, tolerance: tolerance)
229+
&& self.totalScore.isAlmostEqual(to: other.totalScore, tolerance: tolerance)
230+
}
231+
}

0 commit comments

Comments
 (0)