Skip to content

Add metadata support with tests #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 248 additions & 11 deletions Sources/Hub/HubApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//

import Foundation
import CryptoKit
import os

public struct HubApi {
var downloadBase: URL
Expand All @@ -29,6 +31,8 @@ public struct HubApi {
}

public static let shared = HubApi()

private static let logger = Logger()
}

private extension HubApi {
Expand Down Expand Up @@ -91,18 +95,24 @@ public extension HubApi {
return (data, response)
}

/// Throws error if page does not exist or is not accessible.
/// Allows relative redirects but ignores absolute ones for LFS files.
func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) {
var request = URLRequest(url: url)
request.httpMethod = "HEAD"
if let hfToken = hfToken {
request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization")
}
request.setValue("identity", forHTTPHeaderField: "Accept-Encoding")
let (data, response) = try await URLSession.shared.data(for: request)

let redirectDelegate = RedirectDelegate()
let session = URLSession(configuration: .default, delegate: redirectDelegate, delegateQueue: nil)

let (data, response) = try await session.data(for: request)
guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError }

switch response.statusCode {
case 200..<300: break
case 200..<400: break // Allow redirects to pass through to the redirect delegate
case 400..<500: throw Hub.HubClientError.authorizationRequired
default: throw Hub.HubClientError.httpStatusCode(response.statusCode)
}
Expand Down Expand Up @@ -138,6 +148,20 @@ public extension HubApi {
}
}

/// Additional Errors
public extension HubApi {
enum EnvironmentError: LocalizedError {
case invalidMetadataError(String)

public var errorDescription: String? {
switch self {
case .invalidMetadataError(let message):
return message
}
}
}
}

/// Configuration loading helpers
public extension HubApi {
/// Assumes the file has already been downloaded.
Expand Down Expand Up @@ -184,6 +208,9 @@ public extension HubApi {
let hfToken: String?
let endpoint: String?
let backgroundSession: Bool

let sha256Pattern = "^[0-9a-f]{64}$"
let commitHashPattern = "^[0-9a-f]{40}$"

var source: URL {
// https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/tokenizer.json?download=true
Expand All @@ -201,6 +228,13 @@ public extension HubApi {
repoDestination.appending(path: relativeFilename)
}

var metadataDestination: URL {
repoDestination
.appendingPathComponent(".cache")
.appendingPathComponent("huggingface")
.appendingPathComponent("download")
}

var downloaded: Bool {
FileManager.default.fileExists(atPath: destination.path)
}
Expand All @@ -209,15 +243,158 @@ public extension HubApi {
let directoryURL = destination.deletingLastPathComponent()
try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil)
}


func prepareMetadataDestination() throws {
try FileManager.default.createDirectory(at: metadataDestination, withIntermediateDirectories: true, attributes: nil)
}

/// Reads metadata about a file in the local directory related to a download process.
///
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L263
///
/// - Parameters:
/// - localDir: The local directory where metadata files are downloaded.
/// - filePath: The path of the file for which metadata is being read.
/// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed.
/// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid.
func readDownloadMetadata(localDir: URL, filePath: String) throws -> LocalDownloadFileMetadata? {
let metadataPath = localDir.appending(path: filePath)
if FileManager.default.fileExists(atPath: metadataPath.path) {
do {
let contents = try String(contentsOf: metadataPath, encoding: .utf8)
let lines = contents.components(separatedBy: .newlines)

guard lines.count >= 3 else {
throw EnvironmentError.invalidMetadataError("Metadata file is missing required fields.")
}

let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines)
let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines)
guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines)) else {
throw EnvironmentError.invalidMetadataError("Missing or invalid timestamp.")
}
let timestampDate = Date(timeIntervalSince1970: timestamp)

// TODO: check if file hasn't been modified since the metadata was saved
// Reference: https://github.com/huggingface/huggingface_hub/blob/2fdc6f48ef5e6b22ee9bcdc1945948ac070da675/src/huggingface_hub/_local_folder.py#L303

return LocalDownloadFileMetadata(commitHash: commitHash, etag: etag, filename: filePath, timestamp: timestampDate)
} catch {
do {
logger.warning("Invalid metadata file \(metadataPath): \(error). Removing it from disk and continue.")
try FileManager.default.removeItem(at: metadataPath)
} catch {
throw EnvironmentError.invalidMetadataError("Could not remove corrupted metadata file \(metadataPath): \(error)")
}
return nil
}
}

// metadata file does not exist
return nil
}

func isValidHash(hash: String, pattern: String) -> Bool {
let regex = try? NSRegularExpression(pattern: pattern)
let range = NSRange(location: 0, length: hash.utf16.count)
return regex?.firstMatch(in: hash, options: [], range: range) != nil
}

/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391
func writeDownloadMetadata(commitHash: String, etag: String, metadataRelativePath: String) throws {
let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n"
let metadataPath = metadataDestination.appending(component: metadataRelativePath)

do {
try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true)
try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8)
} catch {
throw EnvironmentError.invalidMetadataError("Failed to write metadata file \(metadataPath)")
}
}

func computeFileHash(file url: URL) throws -> String {
// Open file for reading
guard let fileHandle = try? FileHandle(forReadingFrom: url) else {
throw Hub.HubClientError.unexpectedError
}

defer {
try? fileHandle.close()
}

var hasher = SHA256()
let chunkSize = 1024 * 1024 // 1MB chunks

while autoreleasepool(invoking: {
let nextChunk = try? fileHandle.read(upToCount: chunkSize)

guard let nextChunk,
!nextChunk.isEmpty
else {
return false
}

hasher.update(data: nextChunk)

return true
}) { }

let digest = hasher.finalize()
return digest.map { String(format: "%02x", $0) }.joined()
}


// Note we go from Combine in Downloader to callback-based progress reporting
// We'll probably need to support Combine as well to play well with Swift UI
// (See for example PipelineLoader in swift-coreml-diffusers)
@discardableResult
func download(progressHandler: @escaping (Double) -> Void) async throws -> URL {
guard !downloaded else { return destination }

let metadataRelativePath = "\(relativeFilename).metadata"

let localMetadata = try readDownloadMetadata(localDir: metadataDestination, filePath: metadataRelativePath)
let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source)

let localCommitHash = localMetadata?.commitHash ?? ""
let remoteCommitHash = remoteMetadata.commitHash ?? ""

// Local file exists + metadata exists + commit_hash matches => return file
if isValidHash(hash: remoteCommitHash, pattern: commitHashPattern) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash {
return destination
}

// From now on, etag, commit_hash, url and size are not empty
guard let remoteCommitHash = remoteMetadata.commitHash,
let remoteEtag = remoteMetadata.etag,
remoteMetadata.location != "" else {
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server")
}

// Local file exists => check if it's up-to-date
if downloaded {
// etag matches => update metadata and return file
if localMetadata?.etag == remoteEtag {
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)
return destination
}

// etag is a sha256
// => means it's an LFS file (large)
// => let's compute local hash and compare
// => if match, update metadata and return file
if isValidHash(hash: remoteEtag, pattern: sha256Pattern) {
let fileHash = try computeFileHash(file: destination)
if fileHash == remoteEtag {
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)
return destination
}
}
}

// Otherwise, let's download the file!
try prepareDestination()
try prepareMetadataDestination()

let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession)
let downloadSubscriber = downloader.downloadState.sink { state in
if case .downloading(let progress) = state {
Expand All @@ -227,6 +404,9 @@ public extension HubApi {
_ = try withExtendedLifetime(downloadSubscriber) {
try downloader.waitUntilDone()
}

try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)

return destination
}
}
Expand Down Expand Up @@ -274,20 +454,36 @@ public extension HubApi {

/// Metadata
public extension HubApi {
/// A structure representing metadata for a remote file
/// Data structure containing information about a file versioned on the Hub
struct FileMetadata {
/// The file's Git commit hash
/// The commit hash related to the file
public let commitHash: String?

/// Server-provided ETag for caching
/// Etag of the file on the server
public let etag: String?

/// Stringified URL location of the file
/// Location where to download the file. Can be a Hub url or not (CDN).
public let location: String

/// The file's size in bytes
/// Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer.
public let size: Int?
}

/// Metadata about a file in the local directory related to a download process
struct LocalDownloadFileMetadata {
/// Commit hash of the file in the repo
public let commitHash: String

/// ETag of the file in the repo. Used to check if the file has changed.
/// For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash.
public let etag: String

/// Path of the file in the repo
public let filename: String

/// The timestamp of when the metadata was saved i.e. when the metadata was accurate
public let timestamp: Date
}

private func normalizeEtag(_ etag: String?) -> String? {
guard let etag = etag else { return nil }
Expand All @@ -296,13 +492,14 @@ public extension HubApi {

func getFileMetadata(url: URL) async throws -> FileMetadata {
let (_, response) = try await httpHead(for: url)
let location = response.statusCode == 302 ? response.value(forHTTPHeaderField: "Location") : response.url?.absoluteString

return FileMetadata(
commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"),
etag: normalizeEtag(
(response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag"))
),
location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString,
location: location ?? url.absoluteString,
size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "")
)
}
Expand Down Expand Up @@ -395,3 +592,43 @@ public extension [String] {
filter { fnmatch(glob, $0, 0) == 0 }
}
}

/// Only allow relative redirects and reject others
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258
private class RedirectDelegate: NSObject, URLSessionTaskDelegate {
func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) {
// Check if it's a redirect status code (300-399)
if (300...399).contains(response.statusCode) {
// Get the Location header
if let locationString = response.value(forHTTPHeaderField: "Location"),
let locationUrl = URL(string: locationString) {

// Check if it's a relative redirect (no host component)
if locationUrl.host == nil {
// For relative redirects, construct the new URL using the original request's base
if let originalUrl = task.originalRequest?.url,
var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) {
// Update the path component with the relative path
components.path = locationUrl.path
components.query = locationUrl.query

// Create new request with the resolved URL
if let resolvedUrl = components.url {
var newRequest = URLRequest(url: resolvedUrl)
// Copy headers from original request
task.originalRequest?.allHTTPHeaderFields?.forEach { key, value in
newRequest.setValue(value, forHTTPHeaderField: key)
}
newRequest.setValue(resolvedUrl.absoluteString, forHTTPHeaderField: "Location")
completionHandler(newRequest)
return
}
}
}
}
}

// For all other cases (non-redirects or absolute redirects), prevent redirect
completionHandler(nil)
}
}
Loading