Skip to content

Commit 1baf295

Browse files
committed
add download buttons & expose llamaState.loadModel
1 parent ff87313 commit 1baf295

File tree

3 files changed

+158
-3
lines changed

3 files changed

+158
-3
lines changed

examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@ import Foundation
33
@MainActor
44
class LlamaState: ObservableObject {
55
@Published var messageLog = ""
6+
@Published var cacheCleared = false
67

78
private var llamaContext: LlamaContext?
8-
private var modelUrl: URL? {
9+
private var defaultModelUrl: URL? {
910
Bundle.main.url(forResource: "ggml-model", withExtension: "gguf", subdirectory: "models")
1011
// Bundle.main.url(forResource: "llama-2-7b-chat", withExtension: "Q2_K.gguf", subdirectory: "models")
1112
}
13+
1214
init() {
1315
do {
14-
try loadModel()
16+
try loadModel(modelUrl: defaultModelUrl)
1517
} catch {
1618
messageLog += "Error!\n"
1719
}
1820
}
1921

20-
private func loadModel() throws {
22+
func loadModel(modelUrl: URL?) throws {
2123
messageLog += "Loading model...\n"
2224
if let modelUrl {
2325
llamaContext = try LlamaContext.create_context(path: modelUrl.path())

examples/llama.swiftui/llama.swiftui/UI/ContentView.swift

+35
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,22 @@ struct ContentView: View {
55

66
@State private var multiLineText = ""
77

8+
private static func cleanupModelCaches() {
9+
// Delete all models (*.gguf)
10+
let fileManager = FileManager.default
11+
let documentsUrl = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0]
12+
do {
13+
let fileURLs = try fileManager.contentsOfDirectory(at: documentsUrl, includingPropertiesForKeys: nil)
14+
for fileURL in fileURLs {
15+
if fileURL.pathExtension == "gguf" {
16+
try fileManager.removeItem(at: fileURL)
17+
}
18+
}
19+
} catch {
20+
print("Error while enumerating files \(documentsUrl.path): \(error.localizedDescription)")
21+
}
22+
}
23+
824
var body: some View {
925
VStack {
1026
// automatically scroll to bottom of text view
@@ -35,6 +51,25 @@ struct ContentView: View {
3551
.foregroundColor(.white)
3652
.cornerRadius(8)
3753
}
54+
55+
VStack {
56+
DownloadButton(
57+
llamaState: llamaState,
58+
modelName: "TheBloke / TinyLlama-1.1B-1T-OpenOrca-GGUF (Q4_0)",
59+
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
60+
filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf"
61+
)
62+
DownloadButton(
63+
llamaState: llamaState,
64+
modelName: "TheBloke / TinyLlama-1.1B-1T-OpenOrca-GGUF (Q8_0)",
65+
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true",
66+
filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf"
67+
)
68+
Button("Clear downloaded models") {
69+
ContentView.cleanupModelCaches()
70+
llamaState.cacheCleared = true
71+
}
72+
}
3873
}
3974
.padding()
4075
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import SwiftUI
2+
3+
struct DownloadButton: View {
4+
@ObservedObject private var llamaState: LlamaState
5+
private var modelName: String
6+
private var modelUrl: String
7+
private var filename: String
8+
9+
@State private var status: String
10+
11+
@State private var downloadTask: URLSessionDataTask?
12+
@State private var progress = 0.0
13+
@State private var observation: NSKeyValueObservation?
14+
15+
private static func getFileURL(filename: String) -> URL {
16+
FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0].appendingPathComponent(filename)
17+
}
18+
19+
private func checkFileExistenceAndUpdateStatus() {
20+
}
21+
22+
init(llamaState: LlamaState, modelName: String, modelUrl: String, filename: String) {
23+
self.llamaState = llamaState
24+
self.modelName = modelName
25+
self.modelUrl = modelUrl
26+
self.filename = filename
27+
28+
let fileURL = DownloadButton.getFileURL(filename: filename)
29+
status = FileManager.default.fileExists(atPath: fileURL.path) ? "downloaded" : "download"
30+
}
31+
32+
private func download() {
33+
status = "downloading"
34+
downloadTask = URLSession.shared.dataTask(with: URL(string: modelUrl)!) { data, response, error in
35+
if let error = error {
36+
print("Error: \(error.localizedDescription)")
37+
return
38+
}
39+
40+
guard let response = response as? HTTPURLResponse, (200...299).contains(response.statusCode) else {
41+
print("Server error!")
42+
return
43+
}
44+
45+
if let data = data {
46+
do {
47+
let fileURL = DownloadButton.getFileURL(filename: filename)
48+
try data.write(to: fileURL)
49+
50+
llamaState.cacheCleared = false
51+
52+
print("Downloaded model \(modelName) to \(fileURL)")
53+
status = "downloaded"
54+
try llamaState.loadModel(modelUrl: fileURL)
55+
} catch let err {
56+
print("Error: \(err.localizedDescription)")
57+
}
58+
}
59+
}
60+
observation = downloadTask?.progress.observe(\.fractionCompleted) { progress, _ in
61+
self.progress = progress.fractionCompleted
62+
}
63+
downloadTask?.resume()
64+
}
65+
66+
var body: some View {
67+
VStack {
68+
if status == "download" {
69+
Button(action: download) {
70+
Text("Download " + modelName)
71+
}
72+
} else if status == "downloading" {
73+
Button(action: {
74+
downloadTask?.cancel()
75+
status = "download"
76+
}) {
77+
Text("\(modelName) (Downloading \(Int(progress * 100))%)")
78+
}
79+
} else if status == "downloaded" {
80+
Button(action: {
81+
let fileURL = DownloadButton.getFileURL(filename: filename)
82+
if !FileManager.default.fileExists(atPath: fileURL.path) {
83+
download()
84+
return
85+
}
86+
do {
87+
try llamaState.loadModel(modelUrl: fileURL)
88+
} catch let err {
89+
print("Error: \(err.localizedDescription)")
90+
}
91+
}) {
92+
Text("\(modelName) (Downloaded)")
93+
}
94+
} else {
95+
Text("Unknown status")
96+
}
97+
}
98+
.onDisappear() {
99+
downloadTask?.cancel()
100+
}
101+
.onChange(of: llamaState.cacheCleared) { newValue in
102+
if newValue {
103+
downloadTask?.cancel()
104+
let fileURL = DownloadButton.getFileURL(filename: filename)
105+
status = FileManager.default.fileExists(atPath: fileURL.path) ? "downloaded" : "download"
106+
}
107+
}
108+
}
109+
}
110+
111+
#Preview {
112+
DownloadButton(
113+
llamaState: LlamaState(),
114+
modelName: "TheBloke / TinyLlama-1.1B-1T-OpenOrca-GGUF (Q4_0)",
115+
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
116+
filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf",
117+
)
118+
}

0 commit comments

Comments
 (0)