This repository was archived by the owner on Apr 23, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 149
WordSeg model and tests #498
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
8c2891e
WordSeg model and tests
texasmichelle 09fea49
Add model files to CMakeLists
texasmichelle e0515b5
Rename structs and remove commented code
texasmichelle 46914df
Rename Conf -> SNLM.Parameters, lint
texasmichelle 33a0599
Make DataSet internal
texasmichelle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
// Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import Foundation | ||
|
||
internal struct DataSet { | ||
public let training: [CharacterSequence] | ||
public private(set) var testing: [CharacterSequence]? | ||
public private(set) var validation: [CharacterSequence]? | ||
public let alphabet: Alphabet | ||
|
||
private static func load(data: Data) throws -> [String] { | ||
guard let contents: String = String(data: data, encoding: .utf8) else { | ||
throw CharacterErrors.nonUtf8Data | ||
} | ||
return load(contents: contents) | ||
} | ||
|
||
private static func load(contents: String) -> [String] { | ||
var strings = [String]() | ||
|
||
for line in contents.components(separatedBy: .newlines) { | ||
let stripped: String = line.components(separatedBy: .whitespaces).joined() | ||
if stripped.isEmpty { continue } | ||
strings.append(stripped) | ||
} | ||
return strings | ||
} | ||
|
||
private static func makeAlphabet( | ||
datasets training: [String], | ||
_ otherSequences: [String]?..., | ||
eos: String = "</s>", | ||
eow: String = "</w>", | ||
pad: String = "</pad>" | ||
) -> Alphabet { | ||
var letters: Set<Character> = [] | ||
|
||
for dataset in otherSequences + [training] { | ||
guard let dataset = dataset else { continue } | ||
for sentence in dataset { | ||
for character in sentence { | ||
letters.insert(character) | ||
} | ||
} | ||
} | ||
|
||
// Sort the letters to make it easier to interpret ints vs letters. | ||
var sorted = Array(letters) | ||
sorted.sort() | ||
|
||
return Alphabet(sorted, eos: eos, eow: eow, pad: pad) | ||
} | ||
|
||
private static func convertDataset(_ dataset: [String], alphabet: Alphabet) throws | ||
-> [CharacterSequence] | ||
{ | ||
return try dataset.map { try CharacterSequence(alphabet: alphabet, appendingEoSTo: $0) } | ||
} | ||
private static func convertDataset(_ dataset: [String]?, alphabet: Alphabet) throws | ||
-> [CharacterSequence]? | ||
{ | ||
if let ds = dataset { | ||
let tmp: [CharacterSequence] = try convertDataset(ds, alphabet: alphabet) // Use tmp to disambiguate function | ||
return tmp | ||
} | ||
return nil | ||
} | ||
|
||
public init?( | ||
training trainingFile: String, | ||
validation validationFile: String? = nil, | ||
testing testingFile: String? = nil | ||
) throws { | ||
let trainingData = try Data( | ||
contentsOf: URL(fileURLWithPath: trainingFile), | ||
options: .alwaysMapped) | ||
let training = try Self.load(data: trainingData) | ||
|
||
var validation: [String]? = nil | ||
var testing: [String]? = nil | ||
|
||
if let validationFile = validationFile { | ||
let data = try Data( | ||
contentsOf: URL(fileURLWithPath: validationFile), | ||
options: .alwaysMapped) | ||
validation = try Self.load(data: data) | ||
} | ||
|
||
if let testingFile = testingFile { | ||
let data: Data = try Data( | ||
contentsOf: URL(fileURLWithPath: testingFile), | ||
options: .alwaysMapped) | ||
testing = try Self.load(data: data) | ||
} | ||
self.alphabet = Self.makeAlphabet(datasets: training, validation, testing) | ||
self.training = try Self.convertDataset(training, alphabet: self.alphabet) | ||
self.validation = try Self.convertDataset(validation, alphabet: self.alphabet) | ||
self.testing = try Self.convertDataset(testing, alphabet: self.alphabet) | ||
} | ||
|
||
init(training trainingData: Data, validation validationData: Data?, testing testingData: Data?) | ||
throws | ||
{ | ||
let training = try Self.load(data: trainingData) | ||
var validation: [String]? = nil | ||
var testing: [String]? = nil | ||
if let validationData = validationData { | ||
validation = try Self.load(data: validationData) | ||
} | ||
if let testingData = testingData { | ||
testing = try Self.load(data: testingData) | ||
} | ||
|
||
self.alphabet = Self.makeAlphabet(datasets: training, validation, testing) | ||
self.training = try Self.convertDataset(training, alphabet: self.alphabet) | ||
self.validation = try Self.convertDataset(validation, alphabet: self.alphabet) | ||
self.testing = try Self.convertDataset(testing, alphabet: self.alphabet) | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
// Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import TensorFlow | ||
|
||
#if os(iOS) || os(macOS) || os(tvOS) || os(watchOS) | ||
import Darwin | ||
#elseif os(Windows) | ||
import ucrt | ||
#else | ||
import Glibc | ||
#endif | ||
|
||
/// Lattice | ||
/// | ||
/// Represents the lattice used by the WordSeg algorithm. | ||
public struct Lattice: Differentiable { | ||
/// Edge | ||
/// | ||
/// Represents an Edge | ||
public struct Edge: Differentiable { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar |
||
@noDerivative public var start: Int | ||
@noDerivative public var end: Int | ||
@noDerivative public var string: CharacterSequence | ||
public var logp: Float | ||
|
||
// expectation | ||
public var score: SemiRing | ||
public var totalScore: SemiRing | ||
|
||
@differentiable | ||
init( | ||
start: Int, end: Int, sentence: CharacterSequence, logp: Float, | ||
previous: SemiRing, order: Int | ||
) { | ||
self.start = start | ||
self.end = end | ||
self.string = sentence | ||
self.logp = logp | ||
|
||
self.score = | ||
SemiRing( | ||
logp: logp, | ||
// TODO(abdulras): this should really use integeral pow | ||
logr: logp + logf(powf(Float(sentence.count), Float(order)))) | ||
self.totalScore = self.score * previous | ||
} | ||
|
||
@differentiable | ||
public init( | ||
start: Int, end: Int, string: CharacterSequence, logp: Float, | ||
score: SemiRing, totalScore: SemiRing | ||
) { | ||
self.start = start | ||
self.end = end | ||
self.string = string | ||
self.logp = logp | ||
self.score = score | ||
self.totalScore = totalScore | ||
} | ||
} | ||
|
||
/// Node | ||
/// | ||
/// Represents a node in the lattice | ||
public struct Node: Differentiable { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar |
||
@noDerivative public var bestEdge: Edge? | ||
public var bestScore: Float = 0.0 | ||
public var edges = [Edge]() | ||
public var semiringScore: SemiRing = SemiRing.one | ||
|
||
init() {} | ||
|
||
@differentiable | ||
public init( | ||
bestEdge: Edge?, bestScore: Float, edges: [Edge], | ||
semiringScore: SemiRing | ||
) { | ||
self.bestEdge = bestEdge | ||
self.bestScore = bestScore | ||
self.edges = edges | ||
self.semiringScore = semiringScore | ||
} | ||
|
||
@differentiable | ||
func computeSemiringScore() -> SemiRing { | ||
// TODO: Reduceinto and += | ||
edges.differentiableMap { $0.totalScore }.differentiableReduce(SemiRing.zero) { $0 + $1 } | ||
} | ||
} | ||
|
||
var positions: [Node] | ||
|
||
@differentiable | ||
public subscript(index: Int) -> Node { | ||
get { return positions[index] } | ||
set(v) { positions[index] = v } | ||
//_modify { yield &positions[index] } | ||
} | ||
|
||
// TODO: remove dummy. (workaround to make AD thing that lattice is varied) | ||
@differentiable | ||
init(count: Int, _ dummy: Tensor<Float>) { | ||
positions = Array(repeating: Node(), count: count + 1) | ||
} | ||
|
||
public init(positions: [Node]) { | ||
self.positions = positions | ||
} | ||
|
||
mutating func viterbi(sentence: String) -> [Edge] { | ||
// Forwards pass | ||
for position in 0...sentence.count { | ||
var bestScore = -Float.infinity | ||
var bestEdge: Edge! | ||
for edge in self[position].edges { | ||
let score: Float = self[edge.start].bestScore + edge.logp | ||
if score > bestScore { | ||
bestScore = score | ||
bestEdge = edge | ||
} | ||
} | ||
self[position].bestScore = bestScore | ||
self[position].bestEdge = bestEdge | ||
} | ||
|
||
// Backwards | ||
var bestPath: [Edge] = [] | ||
var nextEdge = self[sentence.count].bestEdge! | ||
while nextEdge.start != 0 { | ||
bestPath.append(nextEdge) | ||
nextEdge = self[nextEdge.start].bestEdge! | ||
} | ||
bestPath.append(nextEdge) | ||
|
||
return bestPath.reversed() | ||
} | ||
} | ||
|
||
extension Lattice: CustomStringConvertible { | ||
public var description: String { | ||
""" | ||
[ | ||
\(positions.enumerated().map { " \($0.0): \($0.1)" }.joined(separator: "\n\n")) | ||
] | ||
""" | ||
} | ||
} | ||
|
||
extension Lattice.Node: CustomStringConvertible { | ||
public var description: String { | ||
var edgesStr: String | ||
if edges.isEmpty { | ||
edgesStr = " <no edges>" | ||
} else { | ||
edgesStr = edges.enumerated().map { " \($0.0) - \($0.1)" }.joined(separator: "\n") | ||
} | ||
return """ | ||
best edge: \(String(describing: bestEdge)), best score: \(bestScore), score: \(semiringScore.shortDescription) | ||
\(edgesStr) | ||
""" | ||
} | ||
} | ||
|
||
extension Lattice.Edge: CustomStringConvertible { | ||
public var description: String { | ||
"[\(start)->\(end)] logp: \(logp), score: \(score.shortDescription), total score: \(totalScore.shortDescription), sentence: \(string)" | ||
} | ||
} | ||
|
||
/// SE-0259-esque equality with tolerance | ||
extension Lattice { | ||
public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { | ||
guard self.positions.count == other.positions.count else { | ||
print("positions count mismatch: \(self.positions.count) != \(other.positions.count)") | ||
return false | ||
} | ||
return zip(self.positions, other.positions).enumerated() | ||
.map { (index, position) in | ||
let eq = position.0.isAlmostEqual(to: position.1, tolerance: tolerance) | ||
if !eq { | ||
print("mismatch at \(index): \(position.0) != \(position.1)") | ||
} | ||
return eq | ||
} | ||
.reduce(true) { $0 && $1 } | ||
} | ||
} | ||
|
||
extension Lattice.Node { | ||
public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { | ||
guard self.edges.count == other.edges.count else { return false } | ||
|
||
if !self.bestScore.isAlmostEqual(to: other.bestScore, tolerance: tolerance) { | ||
return false | ||
} | ||
if let lhs = self.bestEdge, let rhs = other.bestEdge { | ||
if !lhs.isAlmostEqual(to: rhs, tolerance: tolerance) { | ||
return false | ||
} | ||
} | ||
if !self.semiringScore.isAlmostEqual(to: other.semiringScore, tolerance: tolerance) { | ||
return false | ||
} | ||
return zip(self.edges, other.edges) | ||
.map { $0.isAlmostEqual(to: $1, tolerance: tolerance) } | ||
.reduce(true) { $0 && $1 } | ||
} | ||
} | ||
|
||
extension Lattice.Edge { | ||
public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { | ||
return self.start == other.start && self.end == other.end | ||
// TODO: figure out why the string equality is being ignored | ||
// self.string == other.string && | ||
&& self.logp.isAlmostEqual(to: other.logp, tolerance: tolerance) | ||
&& self.score.isAlmostEqual(to: other.score, tolerance: tolerance) | ||
&& self.totalScore.isAlmostEqual(to: other.totalScore, tolerance: tolerance) | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need
Lattice
to bepublic
? I think that that this should probably be madeinternal
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing this scope looks to be a bit more involved. Punting to #499