-
Notifications
You must be signed in to change notification settings - Fork 127
Add metadata support with tests #155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
02d2571
f590932
22b6892
bedfc7a
af26e60
b4e1c49
26707b8
5839d33
9d39cf1
97b6163
fe2f32b
30adb75
2869c8f
ff89442
0e30b28
7290768
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -6,6 +6,8 @@ | |||||
// | ||||||
|
||||||
import Foundation | ||||||
import CryptoKit | ||||||
import os | ||||||
|
||||||
public struct HubApi { | ||||||
var downloadBase: URL | ||||||
|
@@ -29,6 +31,8 @@ public struct HubApi { | |||||
} | ||||||
|
||||||
public static let shared = HubApi() | ||||||
|
||||||
private static let logger = Logger() | ||||||
} | ||||||
|
||||||
private extension HubApi { | ||||||
|
@@ -91,18 +95,24 @@ public extension HubApi { | |||||
return (data, response) | ||||||
} | ||||||
|
||||||
/// Throws error if page does not exist or is not accessible. | ||||||
/// Allows relative redirects but ignores absolute ones for LFS files. | ||||||
func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { | ||||||
var request = URLRequest(url: url) | ||||||
request.httpMethod = "HEAD" | ||||||
if let hfToken = hfToken { | ||||||
request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") | ||||||
} | ||||||
request.setValue("identity", forHTTPHeaderField: "Accept-Encoding") | ||||||
let (data, response) = try await URLSession.shared.data(for: request) | ||||||
|
||||||
let redirectDelegate = RedirectDelegate() | ||||||
let session = URLSession(configuration: .default, delegate: redirectDelegate, delegateQueue: nil) | ||||||
|
||||||
let (data, response) = try await session.data(for: request) | ||||||
guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError } | ||||||
|
||||||
switch response.statusCode { | ||||||
case 200..<300: break | ||||||
case 200..<400: break // Allow redirects to pass through to the redirect delegate | ||||||
case 400..<500: throw Hub.HubClientError.authorizationRequired | ||||||
default: throw Hub.HubClientError.httpStatusCode(response.statusCode) | ||||||
} | ||||||
|
@@ -138,6 +148,20 @@ public extension HubApi { | |||||
} | ||||||
} | ||||||
|
||||||
/// Additional Errors | ||||||
public extension HubApi { | ||||||
enum EnvironmentError: LocalizedError { | ||||||
case invalidMetadataError(String) | ||||||
|
||||||
public var errorDescription: String? { | ||||||
switch self { | ||||||
case .invalidMetadataError(let message): | ||||||
return message | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
/// Configuration loading helpers | ||||||
public extension HubApi { | ||||||
/// Assumes the file has already been downloaded. | ||||||
|
@@ -201,6 +225,13 @@ public extension HubApi { | |||||
repoDestination.appending(path: relativeFilename) | ||||||
} | ||||||
|
||||||
var metadataDestination: URL { | ||||||
repoDestination | ||||||
.appendingPathComponent(".cache") | ||||||
.appendingPathComponent("huggingface") | ||||||
.appendingPathComponent("download") | ||||||
} | ||||||
|
||||||
var downloaded: Bool { | ||||||
FileManager.default.fileExists(atPath: destination.path) | ||||||
} | ||||||
|
@@ -209,15 +240,158 @@ public extension HubApi { | |||||
let directoryURL = destination.deletingLastPathComponent() | ||||||
try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) | ||||||
} | ||||||
|
||||||
|
||||||
func prepareMetadataDestination() throws { | ||||||
try FileManager.default.createDirectory(at: metadataDestination, withIntermediateDirectories: true, attributes: nil) | ||||||
} | ||||||
|
||||||
/// Reads metadata about a file in the local directory related to a download process. | ||||||
/// | ||||||
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L263 | ||||||
/// | ||||||
/// - Parameters: | ||||||
/// - localDir: The local directory where metadata files are downloaded. | ||||||
/// - filePath: The path of the file for which metadata is being read. | ||||||
/// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed. | ||||||
/// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid. | ||||||
func readDownloadMetadata(localDir: URL, filePath: String) throws -> LocalDownloadFileMetadata? { | ||||||
let metadataPath = localDir.appending(path: filePath) | ||||||
if FileManager.default.fileExists(atPath: metadataPath.path) { | ||||||
do { | ||||||
let contents = try String(contentsOf: metadataPath, encoding: .utf8) | ||||||
let lines = contents.components(separatedBy: .newlines) | ||||||
|
||||||
guard lines.count >= 3 else { | ||||||
throw EnvironmentError.invalidMetadataError("Metadata file is missing required fields.") | ||||||
} | ||||||
|
||||||
let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines) | ||||||
let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines) | ||||||
guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines)) else { | ||||||
throw EnvironmentError.invalidMetadataError("Missing or invalid timestamp.") | ||||||
} | ||||||
let timestampDate = Date(timeIntervalSince1970: timestamp) | ||||||
|
||||||
// TODO: check if file hasn't been modified since the metadata was saved | ||||||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
return LocalDownloadFileMetadata(commitHash: commitHash, etag: etag, filename: filePath, timestamp: timestampDate) | ||||||
} catch { | ||||||
do { | ||||||
logger.warning("Invalid metadata file \(metadataPath): \(error). Removing it from disk and continue.") | ||||||
try FileManager.default.removeItem(at: metadataPath) | ||||||
} catch { | ||||||
throw EnvironmentError.invalidMetadataError("Could not remove corrupted metadata file \(metadataPath): \(error)") | ||||||
} | ||||||
return nil | ||||||
} | ||||||
} | ||||||
|
||||||
// metadata file does not exist | ||||||
return nil | ||||||
} | ||||||
|
||||||
func isValidSHA256(_ hash: String) -> Bool { | ||||||
let sha256Pattern = "^[0-9a-f]{64}$" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, you are right, but that's for SHA256 etags that represent LFS files. Commit hashes will never match this, we have to use this pattern instead. |
||||||
let regex = try? NSRegularExpression(pattern: sha256Pattern) | ||||||
let range = NSRange(location: 0, length: hash.utf16.count) | ||||||
return regex?.firstMatch(in: hash, options: [], range: range) != nil | ||||||
} | ||||||
|
||||||
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391 | ||||||
func writeDownloadMetadata(commitHash: String, etag: String, metadataRelativePath: String) throws { | ||||||
let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n" | ||||||
let metadataPath = metadataDestination.appending(component: metadataRelativePath) | ||||||
|
||||||
do { | ||||||
try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true) | ||||||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8) | ||||||
} catch { | ||||||
throw EnvironmentError.invalidMetadataError("Failed to write metadata file \(metadataPath)") | ||||||
} | ||||||
} | ||||||
|
||||||
func computeFileHash(file url: URL) throws -> String { | ||||||
// Open file for reading | ||||||
guard let fileHandle = try? FileHandle(forReadingFrom: url) else { | ||||||
throw Hub.HubClientError.unexpectedError | ||||||
} | ||||||
|
||||||
defer { | ||||||
try? fileHandle.close() | ||||||
} | ||||||
|
||||||
var hasher = SHA256() | ||||||
let chunkSize = 1024 * 1024 // 1MB chunks | ||||||
|
||||||
while autoreleasepool(invoking: { | ||||||
let nextChunk = try? fileHandle.read(upToCount: chunkSize) | ||||||
|
||||||
guard let nextChunk, | ||||||
!nextChunk.isEmpty | ||||||
else { | ||||||
return false | ||||||
} | ||||||
|
||||||
hasher.update(data: nextChunk) | ||||||
|
||||||
return true | ||||||
}) { } | ||||||
|
||||||
let digest = hasher.finalize() | ||||||
return digest.map { String(format: "%02x", $0) }.joined() | ||||||
} | ||||||
|
||||||
|
||||||
// Note we go from Combine in Downloader to callback-based progress reporting | ||||||
// We'll probably need to support Combine as well to play well with Swift UI | ||||||
// (See for example PipelineLoader in swift-coreml-diffusers) | ||||||
@discardableResult | ||||||
func download(progressHandler: @escaping (Double) -> Void) async throws -> URL { | ||||||
guard !downloaded else { return destination } | ||||||
|
||||||
let metadataRelativePath = "\(relativeFilename).metadata" | ||||||
|
||||||
let localMetadata = try readDownloadMetadata(localDir: metadataDestination, filePath: metadataRelativePath) | ||||||
let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source) | ||||||
|
||||||
let localCommitHash = localMetadata?.commitHash ?? "" | ||||||
let remoteCommitHash = remoteMetadata.commitHash ?? "" | ||||||
|
||||||
// Local file exists + metadata exists + commit_hash matches => return file | ||||||
if isValidSHA256(remoteCommitHash) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||
return destination | ||||||
} | ||||||
|
||||||
// From now on, etag, commit_hash, url and size are not empty | ||||||
guard let remoteCommitHash = remoteMetadata.commitHash, | ||||||
let remoteEtag = remoteMetadata.etag, | ||||||
remoteMetadata.location != "" else { | ||||||
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server") | ||||||
} | ||||||
|
||||||
// Local file exists => check if it's up-to-date | ||||||
if downloaded { | ||||||
// etag matches => update metadata and return file | ||||||
if localMetadata?.etag == remoteEtag { | ||||||
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) | ||||||
return destination | ||||||
} | ||||||
|
||||||
// metadata is outdated + etag is a sha256 | ||||||
// => means it's an LFS file (large) | ||||||
// => let's compute local hash and compare | ||||||
// => if match, update metadata and return file | ||||||
if localMetadata != nil && isValidSHA256(remoteEtag) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This |
||||||
let fileHash = try computeFileHash(file: destination) | ||||||
if fileHash == remoteEtag { | ||||||
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) | ||||||
return destination | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
// Otherwise, let's download the file! | ||||||
try prepareDestination() | ||||||
try prepareMetadataDestination() | ||||||
|
||||||
let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession) | ||||||
let downloadSubscriber = downloader.downloadState.sink { state in | ||||||
if case .downloading(let progress) = state { | ||||||
|
@@ -227,6 +401,9 @@ public extension HubApi { | |||||
_ = try withExtendedLifetime(downloadSubscriber) { | ||||||
try downloader.waitUntilDone() | ||||||
} | ||||||
|
||||||
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) | ||||||
|
||||||
return destination | ||||||
} | ||||||
} | ||||||
|
@@ -274,20 +451,36 @@ public extension HubApi { | |||||
|
||||||
/// Metadata | ||||||
public extension HubApi { | ||||||
/// A structure representing metadata for a remote file | ||||||
/// Data structure containing information about a file versioned on the Hub | ||||||
struct FileMetadata { | ||||||
/// The file's Git commit hash | ||||||
/// The commit hash related to the file | ||||||
public let commitHash: String? | ||||||
|
||||||
/// Server-provided ETag for caching | ||||||
/// Etag of the file on the server | ||||||
public let etag: String? | ||||||
|
||||||
/// Stringified URL location of the file | ||||||
/// Location where to download the file. Can be a Hub url or not (CDN). | ||||||
public let location: String | ||||||
|
||||||
/// The file's size in bytes | ||||||
/// Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer. | ||||||
public let size: Int? | ||||||
} | ||||||
|
||||||
/// Metadata about a file in the local directory related to a download process | ||||||
struct LocalDownloadFileMetadata { | ||||||
/// Commit hash of the file in the repo | ||||||
public let commitHash: String | ||||||
|
||||||
/// ETag of the file in the repo. Used to check if the file has changed. | ||||||
/// For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash. | ||||||
public let etag: String | ||||||
|
||||||
/// Path of the file in the repo | ||||||
public let filename: String | ||||||
|
||||||
/// The timestamp of when the metadata was saved i.e. when the metadata was accurate | ||||||
public let timestamp: Date | ||||||
} | ||||||
|
||||||
private func normalizeEtag(_ etag: String?) -> String? { | ||||||
guard let etag = etag else { return nil } | ||||||
|
@@ -296,13 +489,14 @@ public extension HubApi { | |||||
|
||||||
func getFileMetadata(url: URL) async throws -> FileMetadata { | ||||||
let (_, response) = try await httpHead(for: url) | ||||||
let location = response.statusCode == 302 ? response.value(forHTTPHeaderField: "Location") : response.url?.absoluteString | ||||||
|
||||||
return FileMetadata( | ||||||
commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"), | ||||||
etag: normalizeEtag( | ||||||
(response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag")) | ||||||
), | ||||||
location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString, | ||||||
location: location ?? url.absoluteString, | ||||||
size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "") | ||||||
) | ||||||
} | ||||||
|
@@ -395,3 +589,43 @@ public extension [String] { | |||||
filter { fnmatch(glob, $0, 0) == 0 } | ||||||
} | ||||||
} | ||||||
|
||||||
/// Only allow relative redirects and reject others | ||||||
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258 | ||||||
private class RedirectDelegate: NSObject, URLSessionTaskDelegate { | ||||||
func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) { | ||||||
// Check if it's a redirect status code (300-399) | ||||||
if (300...399).contains(response.statusCode) { | ||||||
// Get the Location header | ||||||
if let locationString = response.value(forHTTPHeaderField: "Location"), | ||||||
let locationUrl = URL(string: locationString) { | ||||||
|
||||||
// Check if it's a relative redirect (no host component) | ||||||
if locationUrl.host == nil { | ||||||
// For relative redirects, construct the new URL using the original request's base | ||||||
if let originalUrl = task.originalRequest?.url, | ||||||
var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) { | ||||||
// Update the path component with the relative path | ||||||
components.path = locationUrl.path | ||||||
components.query = locationUrl.query | ||||||
|
||||||
// Create new request with the resolved URL | ||||||
if let resolvedUrl = components.url { | ||||||
var newRequest = URLRequest(url: resolvedUrl) | ||||||
// Copy headers from original request | ||||||
task.originalRequest?.allHTTPHeaderFields?.forEach { key, value in | ||||||
newRequest.setValue(value, forHTTPHeaderField: key) | ||||||
} | ||||||
newRequest.setValue(resolvedUrl.absoluteString, forHTTPHeaderField: "Location") | ||||||
completionHandler(newRequest) | ||||||
return | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
// For all other cases (non-redirects or absolute redirects), prevent redirect | ||||||
completionHandler(nil) | ||||||
} | ||||||
} |
Uh oh!
There was an error while loading. Please reload this page.