6
6
//
7
7
8
8
import Foundation
9
+ import CryptoKit
10
+ import os
9
11
10
12
public struct HubApi {
11
13
var downloadBase : URL
@@ -29,6 +31,8 @@ public struct HubApi {
29
31
}
30
32
31
33
public static let shared = HubApi ( )
34
+
35
+ private static let logger = Logger ( )
32
36
}
33
37
34
38
private extension HubApi {
@@ -92,18 +96,24 @@ public extension HubApi {
92
96
return ( data, response)
93
97
}
94
98
99
+ /// Throws error if page does not exist or is not accessible.
100
+ /// Allows relative redirects but ignores absolute ones for LFS files.
95
101
func httpHead( for url: URL ) async throws -> ( Data , HTTPURLResponse ) {
96
102
var request = URLRequest ( url: url)
97
103
request. httpMethod = " HEAD "
98
104
if let hfToken = hfToken {
99
105
request. setValue ( " Bearer \( hfToken) " , forHTTPHeaderField: " Authorization " )
100
106
}
101
107
request. setValue ( " identity " , forHTTPHeaderField: " Accept-Encoding " )
102
- let ( data, response) = try await URLSession . shared. data ( for: request)
108
+
109
+ let redirectDelegate = RedirectDelegate ( )
110
+ let session = URLSession ( configuration: . default, delegate: redirectDelegate, delegateQueue: nil )
111
+
112
+ let ( data, response) = try await session. data ( for: request)
103
113
guard let response = response as? HTTPURLResponse else { throw Hub . HubClientError. unexpectedError }
104
114
105
115
switch response. statusCode {
106
- case 200 ..< 300 : break
116
+ case 200 ..< 400 : break // Allow redirects to pass through to the redirect delegate
107
117
case 400 ..< 500 : throw Hub . HubClientError. authorizationRequired
108
118
default : throw Hub . HubClientError. httpStatusCode ( response. statusCode)
109
119
}
@@ -139,6 +149,20 @@ public extension HubApi {
139
149
}
140
150
}
141
151
152
+ /// Additional Errors
153
+ public extension HubApi {
154
+ enum EnvironmentError : LocalizedError {
155
+ case invalidMetadataError( String )
156
+
157
+ public var errorDescription : String ? {
158
+ switch self {
159
+ case . invalidMetadataError( let message) :
160
+ return message
161
+ }
162
+ }
163
+ }
164
+ }
165
+
142
166
/// Configuration loading helpers
143
167
public extension HubApi {
144
168
/// Assumes the file has already been downloaded.
@@ -185,6 +209,9 @@ public extension HubApi {
185
209
let hfToken : String ?
186
210
let endpoint : String ?
187
211
let backgroundSession : Bool
212
+
213
+ let sha256Pattern = " ^[0-9a-f]{64}$ "
214
+ let commitHashPattern = " ^[0-9a-f]{40}$ "
188
215
189
216
var source : URL {
190
217
// https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/tokenizer.json?download=true
@@ -202,6 +229,13 @@ public extension HubApi {
202
229
repoDestination. appending ( path: relativeFilename)
203
230
}
204
231
232
+ var metadataDestination : URL {
233
+ repoDestination
234
+ . appendingPathComponent ( " .cache " )
235
+ . appendingPathComponent ( " huggingface " )
236
+ . appendingPathComponent ( " download " )
237
+ }
238
+
205
239
var downloaded : Bool {
206
240
FileManager . default. fileExists ( atPath: destination. path)
207
241
}
@@ -210,15 +244,158 @@ public extension HubApi {
210
244
let directoryURL = destination. deletingLastPathComponent ( )
211
245
try FileManager . default. createDirectory ( at: directoryURL, withIntermediateDirectories: true , attributes: nil )
212
246
}
213
-
247
+
248
+ func prepareMetadataDestination( ) throws {
249
+ try FileManager . default. createDirectory ( at: metadataDestination, withIntermediateDirectories: true , attributes: nil )
250
+ }
251
+
252
+ /// Reads metadata about a file in the local directory related to a download process.
253
+ ///
254
+ /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L263
255
+ ///
256
+ /// - Parameters:
257
+ /// - localDir: The local directory where metadata files are downloaded.
258
+ /// - filePath: The path of the file for which metadata is being read.
259
+ /// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed.
260
+ /// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid.
261
+ func readDownloadMetadata( localDir: URL , filePath: String ) throws -> LocalDownloadFileMetadata ? {
262
+ let metadataPath = localDir. appending ( path: filePath)
263
+ if FileManager . default. fileExists ( atPath: metadataPath. path) {
264
+ do {
265
+ let contents = try String ( contentsOf: metadataPath, encoding: . utf8)
266
+ let lines = contents. components ( separatedBy: . newlines)
267
+
268
+ guard lines. count >= 3 else {
269
+ throw EnvironmentError . invalidMetadataError ( " Metadata file is missing required fields. " )
270
+ }
271
+
272
+ let commitHash = lines [ 0 ] . trimmingCharacters ( in: . whitespacesAndNewlines)
273
+ let etag = lines [ 1 ] . trimmingCharacters ( in: . whitespacesAndNewlines)
274
+ guard let timestamp = Double ( lines [ 2 ] . trimmingCharacters ( in: . whitespacesAndNewlines) ) else {
275
+ throw EnvironmentError . invalidMetadataError ( " Missing or invalid timestamp. " )
276
+ }
277
+ let timestampDate = Date ( timeIntervalSince1970: timestamp)
278
+
279
+ // TODO: check if file hasn't been modified since the metadata was saved
280
+ // Reference: https://github.com/huggingface/huggingface_hub/blob/2fdc6f48ef5e6b22ee9bcdc1945948ac070da675/src/huggingface_hub/_local_folder.py#L303
281
+
282
+ return LocalDownloadFileMetadata ( commitHash: commitHash, etag: etag, filename: filePath, timestamp: timestampDate)
283
+ } catch {
284
+ do {
285
+ logger. warning ( " Invalid metadata file \( metadataPath) : \( error) . Removing it from disk and continue. " )
286
+ try FileManager . default. removeItem ( at: metadataPath)
287
+ } catch {
288
+ throw EnvironmentError . invalidMetadataError ( " Could not remove corrupted metadata file \( metadataPath) : \( error) " )
289
+ }
290
+ return nil
291
+ }
292
+ }
293
+
294
+ // metadata file does not exist
295
+ return nil
296
+ }
297
+
298
+ func isValidHash( hash: String , pattern: String ) -> Bool {
299
+ let regex = try ? NSRegularExpression ( pattern: pattern)
300
+ let range = NSRange ( location: 0 , length: hash. utf16. count)
301
+ return regex? . firstMatch ( in: hash, options: [ ] , range: range) != nil
302
+ }
303
+
304
+ /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391
305
+ func writeDownloadMetadata( commitHash: String , etag: String , metadataRelativePath: String ) throws {
306
+ let metadataContent = " \( commitHash) \n \( etag) \n \( Date ( ) . timeIntervalSince1970) \n "
307
+ let metadataPath = metadataDestination. appending ( component: metadataRelativePath)
308
+
309
+ do {
310
+ try FileManager . default. createDirectory ( at: metadataPath. deletingLastPathComponent ( ) , withIntermediateDirectories: true )
311
+ try metadataContent. write ( to: metadataPath, atomically: true , encoding: . utf8)
312
+ } catch {
313
+ throw EnvironmentError . invalidMetadataError ( " Failed to write metadata file \( metadataPath) " )
314
+ }
315
+ }
316
+
317
+ func computeFileHash( file url: URL ) throws -> String {
318
+ // Open file for reading
319
+ guard let fileHandle = try ? FileHandle ( forReadingFrom: url) else {
320
+ throw Hub . HubClientError. unexpectedError
321
+ }
322
+
323
+ defer {
324
+ try ? fileHandle. close ( )
325
+ }
326
+
327
+ var hasher = SHA256 ( )
328
+ let chunkSize = 1024 * 1024 // 1MB chunks
329
+
330
+ while autoreleasepool ( invoking: {
331
+ let nextChunk = try ? fileHandle. read ( upToCount: chunkSize)
332
+
333
+ guard let nextChunk,
334
+ !nextChunk. isEmpty
335
+ else {
336
+ return false
337
+ }
338
+
339
+ hasher. update ( data: nextChunk)
340
+
341
+ return true
342
+ } ) { }
343
+
344
+ let digest = hasher. finalize ( )
345
+ return digest. map { String ( format: " %02x " , $0) } . joined ( )
346
+ }
347
+
348
+
214
349
// Note we go from Combine in Downloader to callback-based progress reporting
215
350
// We'll probably need to support Combine as well to play well with Swift UI
216
351
// (See for example PipelineLoader in swift-coreml-diffusers)
217
352
@discardableResult
218
353
func download( progressHandler: @escaping ( Double ) -> Void ) async throws -> URL {
219
- guard !downloaded else { return destination }
220
-
354
+ let metadataRelativePath = " \( relativeFilename) .metadata "
355
+
356
+ let localMetadata = try readDownloadMetadata ( localDir: metadataDestination, filePath: metadataRelativePath)
357
+ let remoteMetadata = try await HubApi . shared. getFileMetadata ( url: source)
358
+
359
+ let localCommitHash = localMetadata? . commitHash ?? " "
360
+ let remoteCommitHash = remoteMetadata. commitHash ?? " "
361
+
362
+ // Local file exists + metadata exists + commit_hash matches => return file
363
+ if isValidHash ( hash: remoteCommitHash, pattern: commitHashPattern) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash {
364
+ return destination
365
+ }
366
+
367
+ // From now on, etag, commit_hash, url and size are not empty
368
+ guard let remoteCommitHash = remoteMetadata. commitHash,
369
+ let remoteEtag = remoteMetadata. etag,
370
+ remoteMetadata. location != " " else {
371
+ throw EnvironmentError . invalidMetadataError ( " File metadata must have been retrieved from server " )
372
+ }
373
+
374
+ // Local file exists => check if it's up-to-date
375
+ if downloaded {
376
+ // etag matches => update metadata and return file
377
+ if localMetadata? . etag == remoteEtag {
378
+ try writeDownloadMetadata ( commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)
379
+ return destination
380
+ }
381
+
382
+ // etag is a sha256
383
+ // => means it's an LFS file (large)
384
+ // => let's compute local hash and compare
385
+ // => if match, update metadata and return file
386
+ if isValidHash ( hash: remoteEtag, pattern: sha256Pattern) {
387
+ let fileHash = try computeFileHash ( file: destination)
388
+ if fileHash == remoteEtag {
389
+ try writeDownloadMetadata ( commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)
390
+ return destination
391
+ }
392
+ }
393
+ }
394
+
395
+ // Otherwise, let's download the file!
221
396
try prepareDestination ( )
397
+ try prepareMetadataDestination ( )
398
+
222
399
let downloader = Downloader ( from: source, to: destination, using: hfToken, inBackground: backgroundSession)
223
400
let downloadSubscriber = downloader. downloadState. sink { state in
224
401
if case . downloading( let progress) = state {
@@ -228,6 +405,9 @@ public extension HubApi {
228
405
_ = try withExtendedLifetime ( downloadSubscriber) {
229
406
try downloader. waitUntilDone ( )
230
407
}
408
+
409
+ try writeDownloadMetadata ( commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)
410
+
231
411
return destination
232
412
}
233
413
}
@@ -275,20 +455,36 @@ public extension HubApi {
275
455
276
456
/// Metadata
277
457
public extension HubApi {
278
- /// A structure representing metadata for a remote file
458
+ /// Data structure containing information about a file versioned on the Hub
279
459
struct FileMetadata {
280
- /// The file's Git commit hash
460
+ /// The commit hash related to the file
281
461
public let commitHash : String ?
282
462
283
- /// Server-provided ETag for caching
463
+ /// Etag of the file on the server
284
464
public let etag : String ?
285
465
286
- /// Stringified URL location of the file
466
+ /// Location where to download the file. Can be a Hub url or not (CDN).
287
467
public let location : String
288
468
289
- /// The file's size in bytes
469
+ /// Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer.
290
470
public let size : Int ?
291
471
}
472
+
473
+ /// Metadata about a file in the local directory related to a download process
474
+ struct LocalDownloadFileMetadata {
475
+ /// Commit hash of the file in the repo
476
+ public let commitHash : String
477
+
478
+ /// ETag of the file in the repo. Used to check if the file has changed.
479
+ /// For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash.
480
+ public let etag : String
481
+
482
+ /// Path of the file in the repo
483
+ public let filename : String
484
+
485
+ /// The timestamp of when the metadata was saved i.e. when the metadata was accurate
486
+ public let timestamp : Date
487
+ }
292
488
293
489
private func normalizeEtag( _ etag: String ? ) -> String ? {
294
490
guard let etag = etag else { return nil }
@@ -297,13 +493,14 @@ public extension HubApi {
297
493
298
494
func getFileMetadata( url: URL ) async throws -> FileMetadata {
299
495
let ( _, response) = try await httpHead ( for: url)
496
+ let location = response. statusCode == 302 ? response. value ( forHTTPHeaderField: " Location " ) : response. url? . absoluteString
300
497
301
498
return FileMetadata (
302
499
commitHash: response. value ( forHTTPHeaderField: " X-Repo-Commit " ) ,
303
500
etag: normalizeEtag (
304
501
( response. value ( forHTTPHeaderField: " X-Linked-Etag " ) ) ?? ( response. value ( forHTTPHeaderField: " Etag " ) )
305
502
) ,
306
- location: ( response . value ( forHTTPHeaderField : " Location " ) ) ?? url. absoluteString,
503
+ location: location ?? url. absoluteString,
307
504
size: Int ( response. value ( forHTTPHeaderField: " X-Linked-Size " ) ?? response. value ( forHTTPHeaderField: " Content-Length " ) ?? " " )
308
505
)
309
506
}
@@ -396,3 +593,43 @@ public extension [String] {
396
593
filter { fnmatch ( glob, $0, 0 ) == 0 }
397
594
}
398
595
}
596
+
597
+ /// Only allow relative redirects and reject others
598
+ /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258
599
+ private class RedirectDelegate : NSObject , URLSessionTaskDelegate {
600
+ func urlSession( _ session: URLSession , task: URLSessionTask , willPerformHTTPRedirection response: HTTPURLResponse , newRequest request: URLRequest , completionHandler: @escaping ( URLRequest ? ) -> Void ) {
601
+ // Check if it's a redirect status code (300-399)
602
+ if ( 300 ... 399 ) . contains ( response. statusCode) {
603
+ // Get the Location header
604
+ if let locationString = response. value ( forHTTPHeaderField: " Location " ) ,
605
+ let locationUrl = URL ( string: locationString) {
606
+
607
+ // Check if it's a relative redirect (no host component)
608
+ if locationUrl. host == nil {
609
+ // For relative redirects, construct the new URL using the original request's base
610
+ if let originalUrl = task. originalRequest? . url,
611
+ var components = URLComponents ( url: originalUrl, resolvingAgainstBaseURL: true ) {
612
+ // Update the path component with the relative path
613
+ components. path = locationUrl. path
614
+ components. query = locationUrl. query
615
+
616
+ // Create new request with the resolved URL
617
+ if let resolvedUrl = components. url {
618
+ var newRequest = URLRequest ( url: resolvedUrl)
619
+ // Copy headers from original request
620
+ task. originalRequest? . allHTTPHeaderFields? . forEach { key, value in
621
+ newRequest. setValue ( value, forHTTPHeaderField: key)
622
+ }
623
+ newRequest. setValue ( resolvedUrl. absoluteString, forHTTPHeaderField: " Location " )
624
+ completionHandler ( newRequest)
625
+ return
626
+ }
627
+ }
628
+ }
629
+ }
630
+ }
631
+
632
+ // For all other cases (non-redirects or absolute redirects), prevent redirect
633
+ completionHandler ( nil )
634
+ }
635
+ }
0 commit comments