diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index fdf1256..c55823e 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -6,6 +6,8 @@ // import Foundation +import CryptoKit +import os public struct HubApi { var downloadBase: URL @@ -29,6 +31,8 @@ public struct HubApi { } public static let shared = HubApi() + + private static let logger = Logger() } private extension HubApi { @@ -91,6 +95,8 @@ public extension HubApi { return (data, response) } + /// Throws error if page does not exist or is not accessible. + /// Allows relative redirects but ignores absolute ones for LFS files. func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) request.httpMethod = "HEAD" @@ -98,11 +104,15 @@ public extension HubApi { request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") } request.setValue("identity", forHTTPHeaderField: "Accept-Encoding") - let (data, response) = try await URLSession.shared.data(for: request) + + let redirectDelegate = RedirectDelegate() + let session = URLSession(configuration: .default, delegate: redirectDelegate, delegateQueue: nil) + + let (data, response) = try await session.data(for: request) guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError } switch response.statusCode { - case 200..<300: break + case 200..<400: break // Allow redirects to pass through to the redirect delegate case 400..<500: throw Hub.HubClientError.authorizationRequired default: throw Hub.HubClientError.httpStatusCode(response.statusCode) } @@ -138,6 +148,20 @@ public extension HubApi { } } +/// Additional Errors +public extension HubApi { + enum EnvironmentError: LocalizedError { + case invalidMetadataError(String) + + public var errorDescription: String? { + switch self { + case .invalidMetadataError(let message): + return message + } + } + } +} + /// Configuration loading helpers public extension HubApi { /// Assumes the file has already been downloaded. @@ -184,6 +208,9 @@ public extension HubApi { let hfToken: String? let endpoint: String? let backgroundSession: Bool + + let sha256Pattern = "^[0-9a-f]{64}$" + let commitHashPattern = "^[0-9a-f]{40}$" var source: URL { // https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/tokenizer.json?download=true @@ -201,6 +228,13 @@ public extension HubApi { repoDestination.appending(path: relativeFilename) } + var metadataDestination: URL { + repoDestination + .appendingPathComponent(".cache") + .appendingPathComponent("huggingface") + .appendingPathComponent("download") + } + var downloaded: Bool { FileManager.default.fileExists(atPath: destination.path) } @@ -209,15 +243,158 @@ public extension HubApi { let directoryURL = destination.deletingLastPathComponent() try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) } - + + func prepareMetadataDestination() throws { + try FileManager.default.createDirectory(at: metadataDestination, withIntermediateDirectories: true, attributes: nil) + } + + /// Reads metadata about a file in the local directory related to a download process. + /// + /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L263 + /// + /// - Parameters: + /// - localDir: The local directory where metadata files are downloaded. + /// - filePath: The path of the file for which metadata is being read. + /// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed. + /// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid. + func readDownloadMetadata(localDir: URL, filePath: String) throws -> LocalDownloadFileMetadata? { + let metadataPath = localDir.appending(path: filePath) + if FileManager.default.fileExists(atPath: metadataPath.path) { + do { + let contents = try String(contentsOf: metadataPath, encoding: .utf8) + let lines = contents.components(separatedBy: .newlines) + + guard lines.count >= 3 else { + throw EnvironmentError.invalidMetadataError("Metadata file is missing required fields.") + } + + let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines) + let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines) + guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines)) else { + throw EnvironmentError.invalidMetadataError("Missing or invalid timestamp.") + } + let timestampDate = Date(timeIntervalSince1970: timestamp) + + // TODO: check if file hasn't been modified since the metadata was saved + // Reference: https://github.com/huggingface/huggingface_hub/blob/2fdc6f48ef5e6b22ee9bcdc1945948ac070da675/src/huggingface_hub/_local_folder.py#L303 + + return LocalDownloadFileMetadata(commitHash: commitHash, etag: etag, filename: filePath, timestamp: timestampDate) + } catch { + do { + logger.warning("Invalid metadata file \(metadataPath): \(error). Removing it from disk and continue.") + try FileManager.default.removeItem(at: metadataPath) + } catch { + throw EnvironmentError.invalidMetadataError("Could not remove corrupted metadata file \(metadataPath): \(error)") + } + return nil + } + } + + // metadata file does not exist + return nil + } + + func isValidHash(hash: String, pattern: String) -> Bool { + let regex = try? NSRegularExpression(pattern: pattern) + let range = NSRange(location: 0, length: hash.utf16.count) + return regex?.firstMatch(in: hash, options: [], range: range) != nil + } + + /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391 + func writeDownloadMetadata(commitHash: String, etag: String, metadataRelativePath: String) throws { + let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n" + let metadataPath = metadataDestination.appending(component: metadataRelativePath) + + do { + try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true) + try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8) + } catch { + throw EnvironmentError.invalidMetadataError("Failed to write metadata file \(metadataPath)") + } + } + + func computeFileHash(file url: URL) throws -> String { + // Open file for reading + guard let fileHandle = try? FileHandle(forReadingFrom: url) else { + throw Hub.HubClientError.unexpectedError + } + + defer { + try? fileHandle.close() + } + + var hasher = SHA256() + let chunkSize = 1024 * 1024 // 1MB chunks + + while autoreleasepool(invoking: { + let nextChunk = try? fileHandle.read(upToCount: chunkSize) + + guard let nextChunk, + !nextChunk.isEmpty + else { + return false + } + + hasher.update(data: nextChunk) + + return true + }) { } + + let digest = hasher.finalize() + return digest.map { String(format: "%02x", $0) }.joined() + } + + // Note we go from Combine in Downloader to callback-based progress reporting // We'll probably need to support Combine as well to play well with Swift UI // (See for example PipelineLoader in swift-coreml-diffusers) @discardableResult func download(progressHandler: @escaping (Double) -> Void) async throws -> URL { - guard !downloaded else { return destination } - + let metadataRelativePath = "\(relativeFilename).metadata" + + let localMetadata = try readDownloadMetadata(localDir: metadataDestination, filePath: metadataRelativePath) + let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source) + + let localCommitHash = localMetadata?.commitHash ?? "" + let remoteCommitHash = remoteMetadata.commitHash ?? "" + + // Local file exists + metadata exists + commit_hash matches => return file + if isValidHash(hash: remoteCommitHash, pattern: commitHashPattern) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash { + return destination + } + + // From now on, etag, commit_hash, url and size are not empty + guard let remoteCommitHash = remoteMetadata.commitHash, + let remoteEtag = remoteMetadata.etag, + remoteMetadata.location != "" else { + throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server") + } + + // Local file exists => check if it's up-to-date + if downloaded { + // etag matches => update metadata and return file + if localMetadata?.etag == remoteEtag { + try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) + return destination + } + + // etag is a sha256 + // => means it's an LFS file (large) + // => let's compute local hash and compare + // => if match, update metadata and return file + if isValidHash(hash: remoteEtag, pattern: sha256Pattern) { + let fileHash = try computeFileHash(file: destination) + if fileHash == remoteEtag { + try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) + return destination + } + } + } + + // Otherwise, let's download the file! try prepareDestination() + try prepareMetadataDestination() + let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession) let downloadSubscriber = downloader.downloadState.sink { state in if case .downloading(let progress) = state { @@ -227,6 +404,9 @@ public extension HubApi { _ = try withExtendedLifetime(downloadSubscriber) { try downloader.waitUntilDone() } + + try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) + return destination } } @@ -274,20 +454,36 @@ public extension HubApi { /// Metadata public extension HubApi { - /// A structure representing metadata for a remote file + /// Data structure containing information about a file versioned on the Hub struct FileMetadata { - /// The file's Git commit hash + /// The commit hash related to the file public let commitHash: String? - /// Server-provided ETag for caching + /// Etag of the file on the server public let etag: String? - /// Stringified URL location of the file + /// Location where to download the file. Can be a Hub url or not (CDN). public let location: String - /// The file's size in bytes + /// Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer. public let size: Int? } + + /// Metadata about a file in the local directory related to a download process + struct LocalDownloadFileMetadata { + /// Commit hash of the file in the repo + public let commitHash: String + + /// ETag of the file in the repo. Used to check if the file has changed. + /// For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash. + public let etag: String + + /// Path of the file in the repo + public let filename: String + + /// The timestamp of when the metadata was saved i.e. when the metadata was accurate + public let timestamp: Date + } private func normalizeEtag(_ etag: String?) -> String? { guard let etag = etag else { return nil } @@ -296,13 +492,14 @@ public extension HubApi { func getFileMetadata(url: URL) async throws -> FileMetadata { let (_, response) = try await httpHead(for: url) + let location = response.statusCode == 302 ? response.value(forHTTPHeaderField: "Location") : response.url?.absoluteString return FileMetadata( commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"), etag: normalizeEtag( (response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag")) ), - location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString, + location: location ?? url.absoluteString, size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "") ) } @@ -395,3 +592,43 @@ public extension [String] { filter { fnmatch(glob, $0, 0) == 0 } } } + +/// Only allow relative redirects and reject others +/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258 +private class RedirectDelegate: NSObject, URLSessionTaskDelegate { + func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) { + // Check if it's a redirect status code (300-399) + if (300...399).contains(response.statusCode) { + // Get the Location header + if let locationString = response.value(forHTTPHeaderField: "Location"), + let locationUrl = URL(string: locationString) { + + // Check if it's a relative redirect (no host component) + if locationUrl.host == nil { + // For relative redirects, construct the new URL using the original request's base + if let originalUrl = task.originalRequest?.url, + var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) { + // Update the path component with the relative path + components.path = locationUrl.path + components.query = locationUrl.query + + // Create new request with the resolved URL + if let resolvedUrl = components.url { + var newRequest = URLRequest(url: resolvedUrl) + // Copy headers from original request + task.originalRequest?.allHTTPHeaderFields?.forEach { key, value in + newRequest.setValue(value, forHTTPHeaderField: key) + } + newRequest.setValue(resolvedUrl.absoluteString, forHTTPHeaderField: "Location") + completionHandler(newRequest) + return + } + } + } + } + } + + // For all other cases (non-redirects or absolute redirects), prevent redirect + completionHandler(nil) + } +} diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 9871ba6..7248c11 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -144,10 +144,31 @@ class HubApiTests: XCTestCase { XCTAssertGreaterThan(metadata.size!, 0) } } + + /// Verify with `curl -I https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel` + func testGetLargeFileMetadata() async throws { + do { + let revision = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb" + let etag = "fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107" + let location = "https://cdn-lfs.hf.co/repos/4a/4e/4a4e587f66a2979dcd75e1d7324df8ee9ef74be3582a05bea31c2c26d0d467d0/fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.mlmodel%3B+filename%3D%22model.mlmodel" + let size = 504766 + + let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel") + let metadata = try await Hub.getFileMetadata(fileURL: url!) + + XCTAssertEqual(metadata.commitHash, revision) + XCTAssertNotNil(metadata.etag, etag) + XCTAssertTrue(metadata.location.contains(location)) + XCTAssertEqual(metadata.size, size) + } catch { + XCTFail("\(error)") + } + } } class SnapshotDownloadTests: XCTestCase { let repo = "coreml-projects/Llama-2-7b-chat-coreml" + let lfsRepo = "pcuenq/smol-lfs" let downloadDestination: URL = { let base = FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first! return base.appending(component: "huggingface-tests") @@ -163,7 +184,7 @@ class SnapshotDownloadTests: XCTestCase { } } - func getRelativeFiles(url: URL) -> [String] { + func getRelativeFiles(url: URL, repo: String) -> [String] { var filenames: [String] = [] let prefix = downloadDestination.appending(path: "models/\(repo)").path.appending("/") @@ -185,16 +206,27 @@ class SnapshotDownloadTests: XCTestCase { func testDownload() async throws { let hubApi = HubApi(downloadBase: downloadDestination) var lastProgress: Progress? = nil + + // Add debug prints + print("Download destination before: \(downloadDestination.path)") + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in print("Total Progress: \(progress.fractionCompleted)") print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") lastProgress = progress } + + // Add more debug prints + print("Downloaded to: \(downloadedTo.path)") + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) + print("Downloaded filenames: \(downloadedFilenames)") + print("Prefix used in getRelativeFiles: \(downloadDestination.appending(path: "models/\(repo)").path)") + XCTAssertEqual(lastProgress?.fractionCompleted, 1) XCTAssertEqual(lastProgress?.completedUnitCount, 6) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - let downloadedFilenames = getRelativeFiles(url: downloadDestination) XCTAssertEqual( Set(downloadedFilenames), Set([ @@ -220,7 +252,7 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.completedUnitCount, 1) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - let downloadedFilenames = getRelativeFiles(url: downloadDestination) + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) XCTAssertEqual( Set(downloadedFilenames), Set([ @@ -241,7 +273,31 @@ class SnapshotDownloadTests: XCTestCase { XCTAssertEqual(lastProgress?.completedUnitCount, 6) XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) - let downloadedFilenames = getRelativeFiles(url: downloadDestination) + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) + XCTAssertEqual( + Set(downloadedFilenames), + Set([ + "config.json", "tokenizer.json", "tokenizer_config.json", + "llama-2-7b-chat.mlpackage/Manifest.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", + ]) + ) + } + + func testDownloadFileMetadata() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 6) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) XCTAssertEqual( Set(downloadedFilenames), Set([ @@ -251,5 +307,521 @@ class SnapshotDownloadTests: XCTestCase { "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", ]) ) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: repo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([ + ".cache/huggingface/download/config.json.metadata", + ".cache/huggingface/download/tokenizer.json.metadata", + ".cache/huggingface/download/tokenizer_config.json.metadata", + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Manifest.json.metadata", + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json.metadata", + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json.metadata", + ]) + ) + } + + func testDownloadFileMetadataExists() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 6) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) + XCTAssertEqual( + Set(downloadedFilenames), + Set([ + "config.json", "tokenizer.json", "tokenizer_config.json", + "llama-2-7b-chat.mlpackage/Manifest.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", + ]) + ) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let configPath = downloadedTo.appending(path: "config.json") + var attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: repo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([ + ".cache/huggingface/download/config.json.metadata", + ".cache/huggingface/download/tokenizer.json.metadata", + ".cache/huggingface/download/tokenizer_config.json.metadata", + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Manifest.json.metadata", + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json.metadata", + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json.metadata", + ]) + ) + + let _ = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) + let secondDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will not be downloaded again thus last modified date will remain unchanged + XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) + } + + func testDownloadFileMetadataSame() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "tokenizer.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) + XCTAssertEqual(Set(downloadedFilenames), Set(["tokenizer.json"])) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let metadataPath = metadataDestination.appending(path: "tokenizer.json.metadata") + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: repo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([ + ".cache/huggingface/download/tokenizer.json.metadata", + ]) + ) + + let originalMetadata = try String(contentsOf: metadataPath, encoding: .utf8) + + let _ = try await hubApi.snapshot(from: repo, matching: "tokenizer.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + let secondDownloadMetadata = try String(contentsOf: metadataPath, encoding: .utf8) + + // File hasn't changed so commit hash and etag will be identical + let originalArr = originalMetadata.components(separatedBy: .newlines) + let secondDownloadArr = secondDownloadMetadata.components(separatedBy: .newlines) + + XCTAssertTrue(originalArr[0] == secondDownloadArr[0]) + XCTAssertTrue(originalArr[1] == secondDownloadArr[1]) + } + + func testDownloadFileMetadataCorrupted() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 6) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) + XCTAssertEqual( + Set(downloadedFilenames), + Set([ + "config.json", "tokenizer.json", "tokenizer_config.json", + "llama-2-7b-chat.mlpackage/Manifest.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", + ]) + ) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let configPath = downloadedTo.appending(path: "config.json") + var attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: repo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([ + ".cache/huggingface/download/config.json.metadata", + ".cache/huggingface/download/tokenizer.json.metadata", + ".cache/huggingface/download/tokenizer_config.json.metadata", + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Manifest.json.metadata", + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json.metadata", + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json.metadata", + ]) + ) + + // Corrupt config.json.metadata + print("Testing corrupted file.") + try "a".write(to: metadataDestination.appendingPathComponent("config.json.metadata"), atomically: true, encoding: .utf8) + + let _ = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) + let secondDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will be downloaded again thus last modified date will change + XCTAssertTrue(originalTimestamp != secondDownloadTimestamp) + + // Corrupt config.metadata again + print("Testing corrupted timestamp.") + try "a\nb\nc\n".write(to: metadataDestination.appendingPathComponent("config.json.metadata"), atomically: true, encoding: .utf8) + + let _ = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) + let thirdDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will be downloaded again thus last modified date will change + XCTAssertTrue(originalTimestamp != thirdDownloadTimestamp) + } + + func testDownloadLargeFileMetadataCorrupted() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.mlmodel") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) + XCTAssertEqual( + Set(downloadedFilenames), + Set(["llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel"]) + ) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let modelPath = downloadedTo.appending(path: "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel") + var attributes = try FileManager.default.attributesOfItem(atPath: modelPath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: repo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([ + ".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata", + ]) + ) + + // Corrupt model.metadata etag + print("Testing corrupted etag.") + let corruptedMetadataString = "a\nfc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020108\n0\n" + let metadataFile = metadataDestination.appendingPathComponent("llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata") + try corruptedMetadataString.write(to: metadataFile, atomically: true, encoding: .utf8) + + let _ = try await hubApi.snapshot(from: repo, matching: "*.mlmodel") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: modelPath.path) + let thirdDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will not be downloaded again because this is an LFS file. + // While downloading LFS files, we first check if local file ETag is the same as remote ETag. + // If that's the case we just update the metadata and keep the local file. + XCTAssertEqual(originalTimestamp, thirdDownloadTimestamp) + + let metadataString = try String(contentsOfFile: metadataFile.path) + + // Updated metadata file needs to have the correct commit hash, etag and timestamp. + // This is being updated because the local etag (SHA256 checksum) matches the remote etag + XCTAssertNotEqual(metadataString, corruptedMetadataString) + } + + func testDownloadLargeFile() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.mlmodel") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) + XCTAssertEqual(Set(downloadedFilenames), Set(["llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel"])) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: repo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([".cache/huggingface/download/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata"]) + ) + + let metadataFile = metadataDestination.appendingPathComponent("llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel.metadata") + let metadataString = try String(contentsOfFile: metadataFile.path) + + let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nfc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107" + XCTAssertTrue(metadataString.contains(expected)) + } + + func testDownloadSmolLargeFile() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) + XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: lfsRepo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([".cache/huggingface/download/x.bin.metadata"]) + ) + + let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") + let metadataString = try String(contentsOfFile: metadataFile.path) + + let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" + XCTAssertTrue(metadataString.contains(expected)) + } + + func testRegexValidation() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) + XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: lfsRepo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([".cache/huggingface/download/x.bin.metadata"]) + ) + + let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") + let metadataString = try String(contentsOfFile: metadataFile.path) + let metadataArr = metadataString.components(separatedBy: .newlines) + + let commitHash = metadataArr[0] + let etag = metadataArr[1] + + // Not needed for the downloads, just to test validation function + let downloader = HubApi.HubFileDownloader( + repo: Hub.Repo(id: lfsRepo), + repoDestination: downloadedTo, + relativeFilename: "x.bin", + hfToken: nil, + endpoint: nil, + backgroundSession: false + ) + + XCTAssertTrue(downloader.isValidHash(hash: commitHash, pattern: downloader.commitHashPattern)) + XCTAssertTrue(downloader.isValidHash(hash: etag, pattern: downloader.sha256Pattern)) + + XCTAssertFalse(downloader.isValidHash(hash: "\(commitHash)a", pattern: downloader.commitHashPattern)) + XCTAssertFalse(downloader.isValidHash(hash: "\(etag)a", pattern: downloader.sha256Pattern)) + } + + func testLFSFileNoMetadata() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) + XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let filePath = downloadedTo.appending(path: "x.bin") + var attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: lfsRepo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([".cache/huggingface/download/x.bin.metadata"]) + ) + + let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") + try FileManager.default.removeItem(atPath: metadataFile.path) + + let _ = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let secondDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will not be downloaded again thus last modified date will remain unchanged + XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) + XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) + + let metadataString = try String(contentsOfFile: metadataFile.path) + let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" + + XCTAssertTrue(metadataString.contains(expected)) + } + + func testLFSFileCorruptedMetadata() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(lfsRepo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: lfsRepo) + XCTAssertEqual(Set(downloadedFilenames), Set(["x.bin"])) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let filePath = downloadedTo.appending(path: "x.bin") + var attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: lfsRepo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([".cache/huggingface/download/x.bin.metadata"]) + ) + + let metadataFile = metadataDestination.appendingPathComponent("x.bin.metadata") + try "a".write(to: metadataFile, atomically: true, encoding: .utf8) + + let _ = try await hubApi.snapshot(from: lfsRepo, matching: "x.bin") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let secondDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will not be downloaded again thus last modified date will remain unchanged + XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) + XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) + + let metadataString = try String(contentsOfFile: metadataFile.path) + let expected = "77b984598d90af6143d73d5a2d6214b23eba7e27\n98ea6e4f216f2fb4b69fff9b3a44842c38686ca685f3f55dc48c5d3fb1107be4" + + XCTAssertTrue(metadataString.contains(expected)) + } + + func testNonLFSFileRedownload() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "config.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination, repo: repo) + XCTAssertEqual(Set(downloadedFilenames), Set(["config.json"])) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface/download") + + let filePath = downloadedTo.appending(path: "config.json") + var attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination, repo: repo) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([".cache/huggingface/download/config.json.metadata"]) + ) + + let metadataFile = metadataDestination.appendingPathComponent("config.json.metadata") + try FileManager.default.removeItem(atPath: metadataFile.path) + + let _ = try await hubApi.snapshot(from: repo, matching: "config.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: filePath.path) + let secondDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will be downloaded again thus last modified date will change + XCTAssertTrue(originalTimestamp != secondDownloadTimestamp) + XCTAssertTrue(FileManager.default.fileExists(atPath: metadataDestination.path)) + + let metadataString = try String(contentsOfFile: metadataFile.path) + let expected = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb\nd6ceb92ce9e3c83ab146dc8e92a93517ac1cc66f" + + XCTAssertTrue(metadataString.contains(expected)) } }