forked from apple/ml-stable-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTextEncoder.swift
96 lines (76 loc) · 3.17 KB
/
TextEncoder.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// For licensing see accompanying LICENSE.md file.
// Copyright (C) 2022 Apple Inc. All Rights Reserved.
import Foundation
import CoreML
/// A model for encoding text
@available(iOS 16.2, macOS 13.1, *)
public struct TextEncoder: ResourceManaging {
/// Text tokenizer
var tokenizer: BPETokenizer
/// Embedding model
var model: ManagedMLModel
/// Creates text encoder which embeds a tokenized string
///
/// - Parameters:
/// - tokenizer: Tokenizer for input text
/// - url: Location of compiled text encoding Core ML model
/// - configuration: configuration to be used when the model is loaded
/// - Returns: A text encoder that will lazily load its required resources when needed or requested
public init(tokenizer: BPETokenizer,
modelAt url: URL,
configuration: MLModelConfiguration) {
self.tokenizer = tokenizer
self.model = ManagedMLModel(modelAt: url, configuration: configuration)
}
/// Ensure the model has been loaded into memory
public func loadResources() throws {
try model.loadResources()
}
/// Unload the underlying model to free up memory
public func unloadResources() {
model.unloadResources()
}
/// Encode input text/string
///
/// - Parameters:
/// - text: Input text to be tokenized and then embedded
/// - Returns: Embedding representing the input text
public func encode(_ text: String) throws -> MLShapedArray<Float32> {
// Get models expected input length
let inputLength = inputShape.last!
// Tokenize, padding to the expected length
var (tokens, ids) = tokenizer.tokenize(input: text, minCount: inputLength)
// Truncate if necessary
if ids.count > inputLength {
tokens = tokens.dropLast(tokens.count - inputLength)
ids = ids.dropLast(ids.count - inputLength)
let truncated = tokenizer.decode(tokens: tokens)
print("Needed to truncate input '\(text)' to '\(truncated)'")
}
// Use the model to generate the embedding
return try encode(ids: ids)
}
/// Prediction queue
let queue = DispatchQueue(label: "textencoder.predict")
func encode(ids: [Int]) throws -> MLShapedArray<Float32> {
let inputName = inputDescription.name
let inputShape = inputShape
let floatIds = ids.map { Float32($0) }
let inputArray = MLShapedArray<Float32>(scalars: floatIds, shape: inputShape)
let inputFeatures = try! MLDictionaryFeatureProvider(
dictionary: [inputName: MLMultiArray(inputArray)])
let result = try model.perform { model in
try model.prediction(from: inputFeatures)
}
let embeddingFeature = result.featureValue(for: "last_hidden_state")
return MLShapedArray<Float32>(converting: embeddingFeature!.multiArrayValue!)
}
var inputDescription: MLFeatureDescription {
try! model.perform { model in
model.modelDescription.inputDescriptionsByName.first!.value
}
}
var inputShape: [Int] {
inputDescription.multiArrayConstraint!.shape.map { $0.intValue }
}
}