Skip to content

Commit abf5b16

Browse files
Add metadata support with tests (#155)
* add metadata and resumable download support with tests * fix download progress handling * fix httpHead by disabling redirection for lfs * remove unnecessary test case * fix metadata location for lfs files and relative redirects * add test case for large file metadata * preserve original repo file structure for metadata files * add test case for lfs and more comments * Minor cleanup * Remove resumable downloads for future PR * Cleanup unused variable * fix metadata download path and add tests * add and test separate validation for commit hashes * only redownload lfs files when missing or checksum does not match --------- Co-authored-by: ZachNagengast <[email protected]>
1 parent 55710dd commit abf5b16

File tree

2 files changed

+824
-15
lines changed

2 files changed

+824
-15
lines changed

Sources/Hub/HubApi.swift

Lines changed: 248 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
//
77

88
import Foundation
9+
import CryptoKit
10+
import os
911

1012
public struct HubApi {
1113
var downloadBase: URL
@@ -29,6 +31,8 @@ public struct HubApi {
2931
}
3032

3133
public static let shared = HubApi()
34+
35+
private static let logger = Logger()
3236
}
3337

3438
private extension HubApi {
@@ -92,18 +96,24 @@ public extension HubApi {
9296
return (data, response)
9397
}
9498

99+
/// Throws error if page does not exist or is not accessible.
100+
/// Allows relative redirects but ignores absolute ones for LFS files.
95101
func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) {
96102
var request = URLRequest(url: url)
97103
request.httpMethod = "HEAD"
98104
if let hfToken = hfToken {
99105
request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization")
100106
}
101107
request.setValue("identity", forHTTPHeaderField: "Accept-Encoding")
102-
let (data, response) = try await URLSession.shared.data(for: request)
108+
109+
let redirectDelegate = RedirectDelegate()
110+
let session = URLSession(configuration: .default, delegate: redirectDelegate, delegateQueue: nil)
111+
112+
let (data, response) = try await session.data(for: request)
103113
guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError }
104114

105115
switch response.statusCode {
106-
case 200..<300: break
116+
case 200..<400: break // Allow redirects to pass through to the redirect delegate
107117
case 400..<500: throw Hub.HubClientError.authorizationRequired
108118
default: throw Hub.HubClientError.httpStatusCode(response.statusCode)
109119
}
@@ -139,6 +149,20 @@ public extension HubApi {
139149
}
140150
}
141151

152+
/// Additional Errors
153+
public extension HubApi {
154+
enum EnvironmentError: LocalizedError {
155+
case invalidMetadataError(String)
156+
157+
public var errorDescription: String? {
158+
switch self {
159+
case .invalidMetadataError(let message):
160+
return message
161+
}
162+
}
163+
}
164+
}
165+
142166
/// Configuration loading helpers
143167
public extension HubApi {
144168
/// Assumes the file has already been downloaded.
@@ -185,6 +209,9 @@ public extension HubApi {
185209
let hfToken: String?
186210
let endpoint: String?
187211
let backgroundSession: Bool
212+
213+
let sha256Pattern = "^[0-9a-f]{64}$"
214+
let commitHashPattern = "^[0-9a-f]{40}$"
188215

189216
var source: URL {
190217
// https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/tokenizer.json?download=true
@@ -202,6 +229,13 @@ public extension HubApi {
202229
repoDestination.appending(path: relativeFilename)
203230
}
204231

232+
var metadataDestination: URL {
233+
repoDestination
234+
.appendingPathComponent(".cache")
235+
.appendingPathComponent("huggingface")
236+
.appendingPathComponent("download")
237+
}
238+
205239
var downloaded: Bool {
206240
FileManager.default.fileExists(atPath: destination.path)
207241
}
@@ -210,15 +244,158 @@ public extension HubApi {
210244
let directoryURL = destination.deletingLastPathComponent()
211245
try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil)
212246
}
213-
247+
248+
func prepareMetadataDestination() throws {
249+
try FileManager.default.createDirectory(at: metadataDestination, withIntermediateDirectories: true, attributes: nil)
250+
}
251+
252+
/// Reads metadata about a file in the local directory related to a download process.
253+
///
254+
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L263
255+
///
256+
/// - Parameters:
257+
/// - localDir: The local directory where metadata files are downloaded.
258+
/// - filePath: The path of the file for which metadata is being read.
259+
/// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed.
260+
/// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid.
261+
func readDownloadMetadata(localDir: URL, filePath: String) throws -> LocalDownloadFileMetadata? {
262+
let metadataPath = localDir.appending(path: filePath)
263+
if FileManager.default.fileExists(atPath: metadataPath.path) {
264+
do {
265+
let contents = try String(contentsOf: metadataPath, encoding: .utf8)
266+
let lines = contents.components(separatedBy: .newlines)
267+
268+
guard lines.count >= 3 else {
269+
throw EnvironmentError.invalidMetadataError("Metadata file is missing required fields.")
270+
}
271+
272+
let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines)
273+
let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines)
274+
guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines)) else {
275+
throw EnvironmentError.invalidMetadataError("Missing or invalid timestamp.")
276+
}
277+
let timestampDate = Date(timeIntervalSince1970: timestamp)
278+
279+
// TODO: check if file hasn't been modified since the metadata was saved
280+
// Reference: https://github.com/huggingface/huggingface_hub/blob/2fdc6f48ef5e6b22ee9bcdc1945948ac070da675/src/huggingface_hub/_local_folder.py#L303
281+
282+
return LocalDownloadFileMetadata(commitHash: commitHash, etag: etag, filename: filePath, timestamp: timestampDate)
283+
} catch {
284+
do {
285+
logger.warning("Invalid metadata file \(metadataPath): \(error). Removing it from disk and continue.")
286+
try FileManager.default.removeItem(at: metadataPath)
287+
} catch {
288+
throw EnvironmentError.invalidMetadataError("Could not remove corrupted metadata file \(metadataPath): \(error)")
289+
}
290+
return nil
291+
}
292+
}
293+
294+
// metadata file does not exist
295+
return nil
296+
}
297+
298+
func isValidHash(hash: String, pattern: String) -> Bool {
299+
let regex = try? NSRegularExpression(pattern: pattern)
300+
let range = NSRange(location: 0, length: hash.utf16.count)
301+
return regex?.firstMatch(in: hash, options: [], range: range) != nil
302+
}
303+
304+
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391
305+
func writeDownloadMetadata(commitHash: String, etag: String, metadataRelativePath: String) throws {
306+
let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n"
307+
let metadataPath = metadataDestination.appending(component: metadataRelativePath)
308+
309+
do {
310+
try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true)
311+
try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8)
312+
} catch {
313+
throw EnvironmentError.invalidMetadataError("Failed to write metadata file \(metadataPath)")
314+
}
315+
}
316+
317+
func computeFileHash(file url: URL) throws -> String {
318+
// Open file for reading
319+
guard let fileHandle = try? FileHandle(forReadingFrom: url) else {
320+
throw Hub.HubClientError.unexpectedError
321+
}
322+
323+
defer {
324+
try? fileHandle.close()
325+
}
326+
327+
var hasher = SHA256()
328+
let chunkSize = 1024 * 1024 // 1MB chunks
329+
330+
while autoreleasepool(invoking: {
331+
let nextChunk = try? fileHandle.read(upToCount: chunkSize)
332+
333+
guard let nextChunk,
334+
!nextChunk.isEmpty
335+
else {
336+
return false
337+
}
338+
339+
hasher.update(data: nextChunk)
340+
341+
return true
342+
}) { }
343+
344+
let digest = hasher.finalize()
345+
return digest.map { String(format: "%02x", $0) }.joined()
346+
}
347+
348+
214349
// Note we go from Combine in Downloader to callback-based progress reporting
215350
// We'll probably need to support Combine as well to play well with Swift UI
216351
// (See for example PipelineLoader in swift-coreml-diffusers)
217352
@discardableResult
218353
func download(progressHandler: @escaping (Double) -> Void) async throws -> URL {
219-
guard !downloaded else { return destination }
220-
354+
let metadataRelativePath = "\(relativeFilename).metadata"
355+
356+
let localMetadata = try readDownloadMetadata(localDir: metadataDestination, filePath: metadataRelativePath)
357+
let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source)
358+
359+
let localCommitHash = localMetadata?.commitHash ?? ""
360+
let remoteCommitHash = remoteMetadata.commitHash ?? ""
361+
362+
// Local file exists + metadata exists + commit_hash matches => return file
363+
if isValidHash(hash: remoteCommitHash, pattern: commitHashPattern) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash {
364+
return destination
365+
}
366+
367+
// From now on, etag, commit_hash, url and size are not empty
368+
guard let remoteCommitHash = remoteMetadata.commitHash,
369+
let remoteEtag = remoteMetadata.etag,
370+
remoteMetadata.location != "" else {
371+
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server")
372+
}
373+
374+
// Local file exists => check if it's up-to-date
375+
if downloaded {
376+
// etag matches => update metadata and return file
377+
if localMetadata?.etag == remoteEtag {
378+
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)
379+
return destination
380+
}
381+
382+
// etag is a sha256
383+
// => means it's an LFS file (large)
384+
// => let's compute local hash and compare
385+
// => if match, update metadata and return file
386+
if isValidHash(hash: remoteEtag, pattern: sha256Pattern) {
387+
let fileHash = try computeFileHash(file: destination)
388+
if fileHash == remoteEtag {
389+
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)
390+
return destination
391+
}
392+
}
393+
}
394+
395+
// Otherwise, let's download the file!
221396
try prepareDestination()
397+
try prepareMetadataDestination()
398+
222399
let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession)
223400
let downloadSubscriber = downloader.downloadState.sink { state in
224401
if case .downloading(let progress) = state {
@@ -228,6 +405,9 @@ public extension HubApi {
228405
_ = try withExtendedLifetime(downloadSubscriber) {
229406
try downloader.waitUntilDone()
230407
}
408+
409+
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)
410+
231411
return destination
232412
}
233413
}
@@ -275,20 +455,36 @@ public extension HubApi {
275455

276456
/// Metadata
277457
public extension HubApi {
278-
/// A structure representing metadata for a remote file
458+
/// Data structure containing information about a file versioned on the Hub
279459
struct FileMetadata {
280-
/// The file's Git commit hash
460+
/// The commit hash related to the file
281461
public let commitHash: String?
282462

283-
/// Server-provided ETag for caching
463+
/// Etag of the file on the server
284464
public let etag: String?
285465

286-
/// Stringified URL location of the file
466+
/// Location where to download the file. Can be a Hub url or not (CDN).
287467
public let location: String
288468

289-
/// The file's size in bytes
469+
/// Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer.
290470
public let size: Int?
291471
}
472+
473+
/// Metadata about a file in the local directory related to a download process
474+
struct LocalDownloadFileMetadata {
475+
/// Commit hash of the file in the repo
476+
public let commitHash: String
477+
478+
/// ETag of the file in the repo. Used to check if the file has changed.
479+
/// For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash.
480+
public let etag: String
481+
482+
/// Path of the file in the repo
483+
public let filename: String
484+
485+
/// The timestamp of when the metadata was saved i.e. when the metadata was accurate
486+
public let timestamp: Date
487+
}
292488

293489
private func normalizeEtag(_ etag: String?) -> String? {
294490
guard let etag = etag else { return nil }
@@ -297,13 +493,14 @@ public extension HubApi {
297493

298494
func getFileMetadata(url: URL) async throws -> FileMetadata {
299495
let (_, response) = try await httpHead(for: url)
496+
let location = response.statusCode == 302 ? response.value(forHTTPHeaderField: "Location") : response.url?.absoluteString
300497

301498
return FileMetadata(
302499
commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"),
303500
etag: normalizeEtag(
304501
(response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag"))
305502
),
306-
location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString,
503+
location: location ?? url.absoluteString,
307504
size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "")
308505
)
309506
}
@@ -396,3 +593,43 @@ public extension [String] {
396593
filter { fnmatch(glob, $0, 0) == 0 }
397594
}
398595
}
596+
597+
/// Only allow relative redirects and reject others
598+
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258
599+
private class RedirectDelegate: NSObject, URLSessionTaskDelegate {
600+
func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) {
601+
// Check if it's a redirect status code (300-399)
602+
if (300...399).contains(response.statusCode) {
603+
// Get the Location header
604+
if let locationString = response.value(forHTTPHeaderField: "Location"),
605+
let locationUrl = URL(string: locationString) {
606+
607+
// Check if it's a relative redirect (no host component)
608+
if locationUrl.host == nil {
609+
// For relative redirects, construct the new URL using the original request's base
610+
if let originalUrl = task.originalRequest?.url,
611+
var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) {
612+
// Update the path component with the relative path
613+
components.path = locationUrl.path
614+
components.query = locationUrl.query
615+
616+
// Create new request with the resolved URL
617+
if let resolvedUrl = components.url {
618+
var newRequest = URLRequest(url: resolvedUrl)
619+
// Copy headers from original request
620+
task.originalRequest?.allHTTPHeaderFields?.forEach { key, value in
621+
newRequest.setValue(value, forHTTPHeaderField: key)
622+
}
623+
newRequest.setValue(resolvedUrl.absoluteString, forHTTPHeaderField: "Location")
624+
completionHandler(newRequest)
625+
return
626+
}
627+
}
628+
}
629+
}
630+
}
631+
632+
// For all other cases (non-redirects or absolute redirects), prevent redirect
633+
completionHandler(nil)
634+
}
635+
}

0 commit comments

Comments
 (0)