Skip to content

Add incomplete download support #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
138 changes: 88 additions & 50 deletions Sources/Hub/Downloader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,39 @@ class Downloader: NSObject, ObservableObject {
enum DownloadError: Error {
case invalidDownloadLocation
case unexpectedError
case tempFileNotFound
}

private(set) lazy var downloadState: CurrentValueSubject<DownloadState, Never> = CurrentValueSubject(.notStarted)
private var stateSubscriber: Cancellable?

private(set) var tempFilePath: URL
private(set) var expectedSize: Int?
private(set) var downloadedSize: Int = 0

private var urlSession: URLSession? = nil

var session: URLSession? = nil
var downloadTask: Task<Void, Error>? = nil

init(
from url: URL,
to destination: URL,
incompleteDestination: URL,
using authToken: String? = nil,
inBackground: Bool = false,
resumeSize: Int = 0,
headers: [String: String]? = nil,
expectedSize: Int? = nil,
timeout: TimeInterval = 10,
numRetries: Int = 5
) {
self.destination = destination
self.expectedSize = expectedSize

// Create incomplete file path based on destination
tempFilePath = incompleteDestination

// If resume size wasn't specified, check for an existing incomplete file
let resumeSize = Self.incompleteFileSize(at: incompleteDestination)

super.init()
let sessionIdentifier = "swift-transformers.hub.downloader"

Expand All @@ -53,9 +67,22 @@ class Downloader: NSObject, ObservableObject {
config.sessionSendsLaunchEvents = true
}

urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)
session = URLSession(configuration: config, delegate: self, delegateQueue: nil)

setupDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries)
setUpDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries)
}

/// Check if an incomplete file exists for the destination and returns its size
/// - Parameter destination: The destination URL for the download
/// - Returns: Size of the incomplete file if it exists, otherwise 0
static func incompleteFileSize(at incompletePath: URL) -> Int {
if FileManager.default.fileExists(atPath: incompletePath.path) {
if let attributes = try? FileManager.default.attributesOfItem(atPath: incompletePath.path), let fileSize = attributes[.size] as? Int {
return fileSize
}
}

return 0
}

/// Sets up and initiates a file download operation
Expand All @@ -68,7 +95,7 @@ class Downloader: NSObject, ObservableObject {
/// - expectedSize: Expected file size in bytes for validation
/// - timeout: Time interval before the request times out
/// - numRetries: Number of retry attempts for failed downloads
private func setupDownload(
private func setUpDownload(
from url: URL,
with authToken: String?,
resumeSize: Int,
Expand All @@ -77,59 +104,67 @@ class Downloader: NSObject, ObservableObject {
timeout: TimeInterval,
numRetries: Int
) {
downloadState.value = .downloading(0)
urlSession?.getAllTasks { tasks in
session?.getAllTasks { tasks in
// If there's an existing pending background task with the same URL, let it proceed.
if let existing = tasks.filter({ $0.originalRequest?.url == url }).first {
switch existing.state {
case .running:
// print("Already downloading \(url)")
return
case .suspended:
// print("Resuming suspended download task for \(url)")
existing.resume()
return
case .canceling:
// print("Starting new download task for \(url), previous was canceling")
break
case .completed:
// print("Starting new download task for \(url), previous is complete but the file is no longer present (I think it's cached)")
break
case .canceling, .completed:
existing.cancel()
@unknown default:
// print("Unknown state for running task; cancelling and creating a new one")
existing.cancel()
}
}
var request = URLRequest(url: url)

// Use headers from argument else create an empty header dictionary
var requestHeaders = headers ?? [:]

// Populate header auth and range fields
if let authToken {
requestHeaders["Authorization"] = "Bearer \(authToken)"
}
if resumeSize > 0 {
requestHeaders["Range"] = "bytes=\(resumeSize)-"
}

request.timeoutInterval = timeout
request.allHTTPHeaderFields = requestHeaders

Task {
self.downloadTask = Task {
do {
// Create a temp file to write
let tempURL = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
FileManager.default.createFile(atPath: tempURL.path, contents: nil)
let tempFile = try FileHandle(forWritingTo: tempURL)
// Set up the request with appropriate headers
var request = URLRequest(url: url)
var requestHeaders = headers ?? [:]

if let authToken {
requestHeaders["Authorization"] = "Bearer \(authToken)"
}

self.downloadedSize = resumeSize

// Set Range header if we're resuming
if resumeSize > 0 {
requestHeaders["Range"] = "bytes=\(resumeSize)-"

// Calculate and show initial progress
if let expectedSize, expectedSize > 0 {
let initialProgress = Double(resumeSize) / Double(expectedSize)
self.downloadState.value = .downloading(initialProgress)
} else {
self.downloadState.value = .downloading(0)
}
} else {
self.downloadState.value = .downloading(0)
}

defer { tempFile.closeFile() }
try await self.httpGet(request: request, tempFile: tempFile, resumeSize: resumeSize, numRetries: numRetries, expectedSize: expectedSize)
request.timeoutInterval = timeout
request.allHTTPHeaderFields = requestHeaders

// Open the incomplete file for writing
let tempFile = try FileHandle(forWritingTo: self.tempFilePath)

// If resuming, seek to end of file
if resumeSize > 0 {
try tempFile.seekToEnd()
}

try await self.httpGet(request: request, tempFile: tempFile, resumeSize: self.downloadedSize, numRetries: numRetries, expectedSize: expectedSize)

// Clean up and move the completed download to its final destination
tempFile.closeFile()
try FileManager.default.moveDownloadedFile(from: tempURL, to: self.destination)

try Task.checkCancellation()
try FileManager.default.moveDownloadedFile(from: self.tempFilePath, to: self.destination)
self.downloadState.value = .completed(self.destination)
} catch {
self.downloadState.value = .failed(error)
Expand All @@ -156,7 +191,7 @@ class Downloader: NSObject, ObservableObject {
numRetries: Int,
expectedSize: Int?
) async throws {
guard let session = urlSession else {
guard let session else {
throw DownloadError.unexpectedError
}

Expand All @@ -169,16 +204,13 @@ class Downloader: NSObject, ObservableObject {
// Start the download and get the byte stream
let (asyncBytes, response) = try await session.bytes(for: newRequest)

guard let response = response as? HTTPURLResponse else {
guard let httpResponse = response as? HTTPURLResponse else {
throw DownloadError.unexpectedError
}

guard (200..<300).contains(response.statusCode) else {
guard (200..<300).contains(httpResponse.statusCode) else {
throw DownloadError.unexpectedError
}

var downloadedSize = resumeSize

// Create a buffer to collect bytes before writing to disk
var buffer = Data(capacity: chunkSize)

Expand Down Expand Up @@ -213,12 +245,12 @@ class Downloader: NSObject, ObservableObject {
try await Task.sleep(nanoseconds: 1_000_000_000)

let config = URLSessionConfiguration.default
self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)
self.session = URLSession(configuration: config, delegate: self, delegateQueue: nil)

try await httpGet(
request: request,
tempFile: tempFile,
resumeSize: downloadedSize,
resumeSize: self.downloadedSize,
numRetries: newNumRetries - 1,
expectedSize: expectedSize
)
Expand Down Expand Up @@ -252,7 +284,9 @@ class Downloader: NSObject, ObservableObject {
}

func cancel() {
urlSession?.invalidateAndCancel()
session?.invalidateAndCancel()
downloadTask?.cancel()
downloadState.value = .failed(URLError(.cancelled))
}
}

Expand Down Expand Up @@ -284,9 +318,13 @@ extension Downloader: URLSessionDownloadDelegate {

extension FileManager {
func moveDownloadedFile(from srcURL: URL, to dstURL: URL) throws {
if fileExists(atPath: dstURL.path) {
if fileExists(atPath: dstURL.path()) {
try removeItem(at: dstURL)
}

let directoryURL = dstURL.deletingLastPathComponent()
try createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil)

try moveItem(at: srcURL, to: dstURL)
}
}
6 changes: 3 additions & 3 deletions Sources/Hub/Hub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ public extension Hub {
}
}

enum RepoType: String {
enum RepoType: String, Codable {
case models
case datasets
case spaces
}

struct Repo {
struct Repo: Codable {
public let id: String
public let type: RepoType

Expand Down
61 changes: 40 additions & 21 deletions Sources/Hub/HubApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,13 @@ public extension HubApi {
FileManager.default.fileExists(atPath: destination.path)
}

func prepareDestination() throws {
let directoryURL = destination.deletingLastPathComponent()
try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil)
}

func prepareMetadataDestination() throws {
let directoryURL = metadataDestination.deletingLastPathComponent()
/// We're using incomplete destination to prepare cache destination because incomplete files include lfs + non-lfs files (vs only lfs for metadata files)
func prepareCacheDestination(_ incompleteDestination: URL) throws {
let directoryURL = incompleteDestination.deletingLastPathComponent()
try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil)
if !FileManager.default.fileExists(atPath: incompleteDestination.path) {
try "".write(to: incompleteDestination, atomically: true, encoding: .utf8)
}
}

/// Note we go from Combine in Downloader to callback-based progress reporting
Expand Down Expand Up @@ -423,22 +422,42 @@ public extension HubApi {
}

// Otherwise, let's download the file!
try prepareDestination()
try prepareMetadataDestination()

let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession, expectedSize: remoteSize)
let downloadSubscriber = downloader.downloadState.sink { state in
if case let .downloading(progress) = state {
progressHandler(progress)
let incompleteDestination = repoMetadataDestination.appending(path: relativeFilename + ".\(remoteEtag).incomplete")
try prepareCacheDestination(incompleteDestination)

let downloader = Downloader(
from: source,
to: destination,
incompleteDestination: incompleteDestination,
using: hfToken,
inBackground: backgroundSession,
expectedSize: remoteSize
)

return try await withTaskCancellationHandler {
let downloadSubscriber = downloader.downloadState.sink { state in
switch state {
case let .downloading(progress):
progressHandler(progress)
case .completed, .failed, .notStarted:
break
}
}
do {
_ = try withExtendedLifetime(downloadSubscriber) {
try downloader.waitUntilDone()
}

try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)

return destination
} catch {
// If download fails, leave the incomplete file in place for future resume
throw error
}
} onCancel: {
downloader.cancel()
}
_ = try withExtendedLifetime(downloadSubscriber) {
try downloader.waitUntilDone()
}

try hub.writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataPath: metadataDestination)

return destination
}
}

Expand Down
Loading
Loading