diff --git a/Package.swift b/Package.swift index de151f4f..3288400e 100644 --- a/Package.swift +++ b/Package.swift @@ -6,8 +6,8 @@ import PackageDescription let package = Package( name: "stable-diffusion", platforms: [ - .macOS(.v11), - .iOS(.v14), + .macOS(.v13), + .iOS(.v16), ], products: [ .library( @@ -18,12 +18,15 @@ let package = Package( targets: ["StableDiffusionCLI"]) ], dependencies: [ - .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.3") + .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.3"), + .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"), ], targets: [ .target( name: "StableDiffusion", - dependencies: [], + dependencies: [ + .product(name: "Transformers", package: "swift-transformers"), + ], path: "swift/StableDiffusion"), .executableTarget( name: "StableDiffusionCLI", diff --git a/README.md b/README.md index 45414e9d..0cfe9275 100644 --- a/README.md +++ b/README.md @@ -246,6 +246,52 @@ An example `` would be `"recipe_4.50_bit_mixedpalett + +## Using Stable Diffusion 3 + +
+ Details (Click to expand) + +### Model Conversion + +Stable Diffusion 3 uses some new and some old models to run. For the text encoders, the conversion can be done using a similar command as before with the `--sd3-version` flag. + +```bash +python -m python_coreml_stable_diffusion.torch2coreml --model-version stabilityai/stable-diffusion-3-medium --bundle-resources-for-swift-cli --convert-text-encoder --sd3-version -o +``` + +For the new models (MMDiT, a new VAE with 16 channels, and the T5 text encoder), there are a number of new CLI flags that utilize the [DiffusionKit](https://www.github.com/argmaxinc/DiffusionKit) repo: + +- `--sd3-version`: Indicates to the converter to treat this as a Stable Diffusion 3 model +- `--convert-mmdit`: Convert the MMDiT model +- `--convert-vae-decoder`: Convert the new VAE model (this will use the 16 channel version if --sd3-version is set) +- `--include-t5`: Downloads and includes a pre-converted T5 text encoder in the conversion + +e.g.: +```bash +python -m python_coreml_stable_diffusion.torch2coreml --model-version stabilityai/stable-diffusion-3-medium --bundle-resources-for-swift-cli --convert-vae-decoder --convert-mmdit --include-t5 --sd3-version -o +``` + +To convert the full pipeline with at 1024x1024 resolution, the following command may be used: + +```bash +python -m python_coreml_stable_diffusion.torch2coreml --model-version stabilityai/stable-diffusion-3-medium --bundle-resources-for-swift-cli --convert-text-encoder --convert-vae-decoder --convert-mmdit --include-t5 --sd3-version --latent-h 128 --latent-w 128 -o +``` + +Keep in mind that the MMDiT model is quite large and will require increasingly more memory and time to convert as the latent resolution increases. + +Also note that currently the MMDiT model requires fp32 and therefore only supports `CPU_AND_GPU` compute units and `ORIGINAL` attention implementation (the default for this pipeline). + +### Swift Inference + +Swift inference for Stable Diffusion 3 is similar to the previous versions. The only difference is that the `--sd3` flag should be used to indicate that the model is a Stable Diffusion 3 model. + +```bash +swift run StableDiffusionSample --resource-path --output-path --compute-units cpuAndGPU --sd3 +``` + +
+ ## Using Stable Diffusion XL
@@ -356,6 +402,7 @@ Resources: - [`stabilityai/stable-diffusion-2-1-base`](https://huggingface.co/apple/coreml-stable-diffusion-2-1-base) - [`stabilityai/stable-diffusion-xl-base-1.0`](https://huggingface.co/apple/coreml-stable-diffusion-xl-base) - [`stabilityai/stable-diffusion-xl-{base+refiner}-1.0`](https://huggingface.co/apple/coreml-stable-diffusion-xl-base-with-refiner) + - [`stabilityai/stable-diffusion-3-medium`](https://huggingface.co/stabilityai/stable-diffusion-3-medium) If you want to use any of those models you may download the weights and proceed to [generate images with Python](#image-generation-with-python) or [Swift](#image-generation-with-swift). diff --git a/python_coreml_stable_diffusion/torch2coreml.py b/python_coreml_stable_diffusion/torch2coreml.py index 602f78e7..a85c2cd0 100644 --- a/python_coreml_stable_diffusion/torch2coreml.py +++ b/python_coreml_stable_diffusion/torch2coreml.py @@ -16,7 +16,12 @@ DiffusionPipeline, ControlNetModel ) +from diffusionkit.tests.torch2coreml import ( + convert_mmdit_to_mlpackage, + convert_vae_to_mlpackage +) import gc +from huggingface_hub import snapshot_download import logging @@ -207,6 +212,26 @@ def _compile_coreml_model(source_model_path, output_dir, final_name): return target_path +def _download_t5_model(args, t5_save_path): + t5_url = args.text_encoder_t5_url + match = re.match(r'https://huggingface.co/(.+)/resolve/main/(.+)', t5_url) + if not match: + raise ValueError(f"Invalid Hugging Face URL: {t5_url}") + repo_id, model_subpath = match.groups() + + download_path = snapshot_download( + repo_id=repo_id, + revision="main", + allow_patterns=[f"{model_subpath}/*"] + ) + logger.info(f"Downloaded T5 model to {download_path}") + + # Move the downloaded model to the top level of the Resources directory + logger.info(f"Copying T5 model from {download_path} to {t5_save_path}") + cache_path = os.path.join(download_path, model_subpath) + shutil.copytree(cache_path, t5_save_path) + + def bundle_resources_for_swift_cli(args): """ - Compiles Core ML models from mlpackage into mlmodelc format @@ -228,6 +253,7 @@ def bundle_resources_for_swift_cli(args): ("refiner", "UnetRefiner"), ("refiner_chunk1", "UnetRefinerChunk1"), ("refiner_chunk2", "UnetRefinerChunk2"), + ("mmdit", "MultiModalDiffusionTransformer"), ("control-unet", "ControlledUnet"), ("control-unet_chunk1", "ControlledUnetChunk1"), ("control-unet_chunk2", "ControlledUnetChunk2"), @@ -241,7 +267,7 @@ def bundle_resources_for_swift_cli(args): logger.warning( f"{source_path} not found, skipping compilation to {target_name}.mlmodelc" ) - + if args.convert_controlnet: for controlnet_model_version in args.convert_controlnet: controlnet_model_name = controlnet_model_version.replace("/", "_") @@ -271,6 +297,25 @@ def bundle_resources_for_swift_cli(args): f.write(requests.get(args.text_encoder_merges_url).content) logger.info("Done") + # Fetch and save pre-converted T5 text encoder model + t5_model_name = "TextEncoderT5.mlmodelc" + t5_save_path = os.path.join(resources_dir, t5_model_name) + if args.include_t5: + if not os.path.exists(t5_save_path): + logger.info("Downloading pre-converted T5 encoder model TextEncoderT5.mlmodelc") + _download_t5_model(args, t5_save_path) + logger.info("Done") + else: + logger.info(f"Skipping T5 download as {t5_save_path} already exists") + + # Fetch and save T5 text tokenizer JSON files + logger.info("Downloading and saving T5 tokenizer files tokenizer_config.json and tokenizer.json") + with open(os.path.join(resources_dir, "tokenizer_config.json"), "wb") as f: + f.write(requests.get(args.text_encoder_t5_config_url).content) + with open(os.path.join(resources_dir, "tokenizer.json"), "wb") as f: + f.write(requests.get(args.text_encoder_t5_data_url).content) + logger.info("Done") + return resources_dir @@ -557,6 +602,61 @@ def forward(self, z): del traced_vae_decoder, pipe.vae.decoder, coreml_vae_decoder gc.collect() +def convert_vae_decoder_sd3(args): + """ Converts the VAE component of Stable Diffusion 3 + """ + out_path = _get_out_path(args, "vae_decoder") + if os.path.exists(out_path): + logger.info( + f"`vae_decoder` already exists at {out_path}, skipping conversion." + ) + return + + # Convert the VAE Decoder model via DiffusionKit + converted_vae_path = convert_vae_to_mlpackage( + model_version=args.model_version, + latent_h=args.latent_h, + latent_w=args.latent_w, + output_dir=args.o, + ) + + # Load converted model + coreml_vae_decoder = ct.models.MLModel(converted_vae_path) + + # Set model metadata + coreml_vae_decoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}" + coreml_vae_decoder.license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)" + coreml_vae_decoder.version = args.model_version + coreml_vae_decodershort_description = \ + "Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \ + "Please refer to https://arxiv.org/pdf/2403.03206 for details." + + # Set the input descriptions + coreml_vae_decoder.input_description["z"] = \ + "The denoised latent embeddings from the unet model after the last step of reverse diffusion" + + # Set the output descriptions + coreml_vae_decoder.output_description[ + "image"] = "Generated image normalized to range [-1, 1]" + + # Set package version metadata + from python_coreml_stable_diffusion._version import __version__ + coreml_vae_decoder.user_defined_metadata["com.github.apple.ml-stable-diffusion.version"] = __version__ + from diffusionkit.version import __version__ + coreml_vae_decoder.user_defined_metadata["com.github.argmax.diffusionkit.version"] = __version__ + + # Save the updated model + coreml_vae_decoder.save(out_path) + + logger.info(f"Saved vae_decoder into {out_path}") + + # Delete the original file + if os.path.exists(converted_vae_path): + shutil.rmtree(converted_vae_path) + + del coreml_vae_decoder + gc.collect() + def convert_vae_encoder(pipe, args): """ Converts the VAE Encoder component of Stable Diffusion @@ -909,6 +1009,72 @@ def convert_unet(pipe, args, model_name = None): chunk_mlprogram.main(args) +def convert_mmdit(args): + """ Converts the MMDiT component of Stable Diffusion 3 + """ + out_path = _get_out_path(args, "mmdit") + if os.path.exists(out_path): + logger.info( + f"`mmdit` already exists at {out_path}, skipping conversion." + ) + return + + # Convert the MMDiT model via DiffusionKit + converted_mmdit_path = convert_mmdit_to_mlpackage( + model_version=args.model_version, + latent_h=args.latent_h, + latent_w=args.latent_w, + output_dir=args.o, + # FIXME: Hardcoding to CPU_AND_GPU since ANE doesn't support FLOAT32 + compute_precision=ct.precision.FLOAT32, + compute_unit=ct.ComputeUnit.CPU_AND_GPU, + ) + + # Load converted model + coreml_mmdit = ct.models.MLModel(converted_mmdit_path) + + # Set model metadata + coreml_mmdit.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}" + coreml_mmdit.license = "Stability AI Community License (https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md)" + coreml_mmdit.version = args.model_version + coreml_mmdit.short_description = \ + "Stable Diffusion 3 generates images conditioned on text or other images as input through the diffusion process. " \ + "Please refer to https://arxiv.org/pdf/2403.03206 for details." + + # Set the input descriptions + coreml_mmdit.input_description["latent_image_embeddings"] = \ + "The low resolution latent feature maps being denoised through reverse diffusion" + coreml_mmdit.input_description["token_level_text_embeddings"] = \ + "Output embeddings from the associated text_encoder model to condition to generated image on text. " \ + "A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. " + coreml_mmdit.input_description["pooled_text_embeddings"] = \ + "Additional embeddings that if specified are added to the embeddings that are passed along to the MMDiT model." + coreml_mmdit.input_description["timestep"] = \ + "A value emitted by the associated scheduler object to condition the model on a given noise schedule" + + # Set the output descriptions + coreml_mmdit.output_description["denoiser_output"] = \ + "Same shape and dtype as the `latent_image_embeddings` input. " \ + "The predicted noise to facilitate the reverse diffusion (denoising) process" + + # Set package version metadata + from python_coreml_stable_diffusion._version import __version__ + coreml_mmdit.user_defined_metadata["com.github.apple.ml-stable-diffusion.version"] = __version__ + from diffusionkit.version import __version__ + coreml_mmdit.user_defined_metadata["com.github.argmax.diffusionkit.version"] = __version__ + + # Save the updated model + coreml_mmdit.save(out_path) + + logger.info(f"Saved vae_decoder into {out_path}") + + # Delete the original file + if os.path.exists(converted_mmdit_path): + shutil.rmtree(converted_mmdit_path) + + del coreml_mmdit + gc.collect() + def convert_safety_checker(pipe, args): """ Converts the Safety Checker component of Stable Diffusion """ @@ -1288,6 +1454,16 @@ def get_pipeline(args): use_safetensors=True, vae=vae, use_auth_token=True) + elif args.sd3_version: + # SD3 uses standard SDXL diffusers pipeline besides the vae, denoiser, and T5 text encoder + sdxl_base_version = "stabilityai/stable-diffusion-xl-base-1.0" + args.xl_version = True + logger.info(f"SD3 version specified, initializing DiffusionPipeline with {sdxl_base_version} for non-SD3 components..") + pipe = DiffusionPipeline.from_pretrained(sdxl_base_version, + torch_dtype=torch.float16, + variant="fp16", + use_safetensors=True, + use_auth_token=True) else: pipe = DiffusionPipeline.from_pretrained(model_version, torch_dtype=torch.float16, @@ -1316,7 +1492,10 @@ def main(args): # Convert models if args.convert_vae_decoder: logger.info("Converting vae_decoder") - convert_vae_decoder(pipe, args) + if args.sd3_version: + convert_vae_decoder_sd3(args) + else: + convert_vae_decoder(pipe, args) logger.info("Converted vae_decoder") if args.convert_vae_encoder: @@ -1363,6 +1542,11 @@ def main(args): del pipe gc.collect() logger.info(f"Converted refiner") + + if args.convert_mmdit: + logger.info("Converting mmdit") + convert_mmdit(args) + logger.info("Converted mmdit") if args.quantize_nbits is not None: logger.info(f"Quantizing weights to {args.quantize_nbits}-bit precision") @@ -1383,6 +1567,7 @@ def parser_spec(): parser.add_argument("--convert-vae-decoder", action="store_true") parser.add_argument("--convert-vae-encoder", action="store_true") parser.add_argument("--convert-unet", action="store_true") + parser.add_argument("--convert-mmdit", action="store_true") parser.add_argument("--convert-safety-checker", action="store_true") parser.add_argument( "--convert-controlnet", @@ -1489,6 +1674,7 @@ def parser_spec(): "If specified, enable unet to receive additional inputs from controlnet. " "Each input added to corresponding resnet output." ) + parser.add_argument("--include-t5", action="store_true") # Swift CLI Resource Bundling parser.add_argument( @@ -1508,11 +1694,30 @@ def parser_spec(): default= "https://huggingface.co/openai/clip-vit-base-patch32/resolve/main/merges.txt", help="The URL to the merged pairs used in by the text tokenizer.") + parser.add_argument( + "--text-encoder-t5-url", + default= + "https://huggingface.co/argmaxinc/coreml-stable-diffusion-3-medium/resolve/main/TextEncoderT5.mlmodelc", + help="The URL to the pre-converted T5 encoder model.") + parser.add_argument( + "--text-encoder-t5-config-url", + default= + "https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer_config.json", + help="The URL to the merged pairs used in by the text tokenizer.") + parser.add_argument( + "--text-encoder-t5-data-url", + default= + "https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer.json", + help="The URL to the merged pairs used in by the text tokenizer.") parser.add_argument( "--xl-version", action="store_true", help=("If specified, the pre-trained model will be treated as an instantiation of " "`diffusers.pipelines.StableDiffusionXLPipeline` instead of `diffusers.pipelines.StableDiffusionPipeline`")) + parser.add_argument( + "--sd3-version", + action="store_true", + help=("If specified, the pre-trained model will be treated as an SD3 model.")) return parser diff --git a/requirements.txt b/requirements.txt index 8fa7b07a..bd8bb117 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ coremltools>=7.0 diffusers[torch] +diffusionkit torch transformers==4.29.2 scipy diff --git a/swift/StableDiffusion/pipeline/Decoder.swift b/swift/StableDiffusion/pipeline/Decoder.swift index cc8b9b9d..1d39aa32 100644 --- a/swift/StableDiffusion/pipeline/Decoder.swift +++ b/swift/StableDiffusion/pipeline/Decoder.swift @@ -1,5 +1,5 @@ // For licensing see accompanying LICENSE.md file. -// Copyright (C) 2022 Apple Inc. All Rights Reserved. +// Copyright (C) 2024 Apple Inc. All Rights Reserved. import Foundation import CoreML @@ -28,7 +28,7 @@ public struct Decoder: ResourceManaging { /// Unload the underlying model to free up memory public func unloadResources() { - model.unloadResources() + model.unloadResources() } /// Batch decode latent samples into images @@ -39,14 +39,15 @@ public struct Decoder: ResourceManaging { /// - Returns: decoded images public func decode( _ latents: [MLShapedArray], - scaleFactor: Float32 + scaleFactor: Float32, + shiftFactor: Float32 = 0.0 ) throws -> [CGImage] { // Form batch inputs for model let inputs: [MLFeatureProvider] = try latents.map { sample in // Reference pipeline scales the latent samples before decoding let sampleScaled = MLShapedArray( - scalars: sample.scalars.map { $0 / scaleFactor }, + scalars: sample.scalars.map { $0 / scaleFactor + shiftFactor }, shape: sample.shape) let dict = [inputName: MLMultiArray(sampleScaled)] diff --git a/swift/StableDiffusion/pipeline/DiscreteFlowScheduler.swift b/swift/StableDiffusion/pipeline/DiscreteFlowScheduler.swift new file mode 100644 index 00000000..59e3ea4a --- /dev/null +++ b/swift/StableDiffusion/pipeline/DiscreteFlowScheduler.swift @@ -0,0 +1,123 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright (C) 2024 Apple Inc. All Rights Reserved. + +import CoreML + +/// A scheduler used to compute a de-noised image +@available(iOS 16.2, macOS 13.1, *) +public final class DiscreteFlowScheduler: Scheduler { + public let trainStepCount: Int + public let inferenceStepCount: Int + public var timeSteps = [Int]() + public var betas = [Float]() + public var alphas = [Float]() + public var alphasCumProd = [Float]() + + public private(set) var modelOutputs: [MLShapedArray] = [] + + var trainSteps: Float + var shift: Float + var counter: Int + var sigmas = [Float]() + + /// Create a scheduler that uses a second order DPM-Solver++ algorithm. + /// + /// - Parameters: + /// - stepCount: Number of inference steps to schedule + /// - trainStepCount: Number of training diffusion steps + /// - timeStepShift: Amount to shift the timestep schedule + /// - Returns: A scheduler ready for its first step + public init( + stepCount: Int = 50, + trainStepCount: Int = 1000, + timeStepShift: Float = 3.0 + ) { + self.trainStepCount = trainStepCount + self.inferenceStepCount = stepCount + self.trainSteps = Float(trainStepCount) + self.shift = timeStepShift + self.counter = 0 + + let sigmaDistribution = linspace(1, trainSteps, Int(trainSteps)).map { sigmaFromTimestep($0) } + let timeStepDistribution = linspace(sigmaDistribution.first!, sigmaDistribution.last!, stepCount).reversed() + self.timeSteps = timeStepDistribution.map { Int($0 * trainSteps) } + self.sigmas = timeStepDistribution.map { sigmaFromTimestep($0 * trainSteps) } + } + + func sigmaFromTimestep(_ timestep: Float) -> Float { + if shift == 1.0 { + return timestep / trainSteps + } else { + // shift * timestep / (1 + (shift - 1) * timestep) + let t = timestep / trainSteps + return shift * t / (1 + (shift - 1) * t) + } + } + + func timestepsFromSigmas() -> [Float] { + return sigmas.map { $0 * trainSteps } + } + + /// Convert the model output to the corresponding type the algorithm needs. + func convertModelOutput(modelOutput: MLShapedArray, timestep: Int, sample: MLShapedArray) -> MLShapedArray { + assert(modelOutput.scalarCount == sample.scalarCount) + let stepIndex = timeSteps.firstIndex(of: timestep) ?? counter + let sigma = sigmas[stepIndex] + + return MLShapedArray(unsafeUninitializedShape: modelOutput.shape) { result, _ in + modelOutput.withUnsafeShapedBufferPointer { noiseScalars, _, _ in + sample.withUnsafeShapedBufferPointer { latentScalars, _, _ in + for i in 0.. [Float] { + guard let strength else { return timestepsFromSigmas() } + let startStep = max(inferenceStepCount - Int(Float(inferenceStepCount) * strength), 0) + let actualTimesteps = Array(timestepsFromSigmas()[startStep...]) + return actualTimesteps + } + + public func step(output: MLShapedArray, timeStep t: Int, sample: MLShapedArray) -> MLShapedArray { + let stepIndex = timeSteps.firstIndex(of: t) ?? counter // TODO: allow float timesteps in scheduler step protocol + let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample) + modelOutputs.append(modelOutput) + + let sigma = sigmas[stepIndex] + var dt = sigma + var prevSigma: Float = 0 + if stepIndex < sigmas.count - 1 { + prevSigma = sigmas[stepIndex + 1] + dt = prevSigma - sigma + } + + let prevSample: MLShapedArray = MLShapedArray(unsafeUninitializedShape: modelOutput.shape) { result, _ in + modelOutput.withUnsafeShapedBufferPointer { noiseScalars, _, _ in + sample.withUnsafeShapedBufferPointer { latentScalars, _, _ in + for i in 0.. MLBatchProvider { + var results = try self.first!.perform { model in + try model.predictions(fromBatch: batch) + } + + if self.count == 1 { + return results + } + + // Manual pipeline batch prediction + let inputs = batch.arrayOfFeatureValueDictionaries + for stage in self.dropFirst() { + // Combine the original inputs with the outputs of the last stage + let next = try results.arrayOfFeatureValueDictionaries + .enumerated().map { index, dict in + let nextDict = dict.merging(inputs[index]) { out, _ in out } + return try MLDictionaryFeatureProvider(dictionary: nextDict) + } + let nextBatch = MLArrayBatchProvider(array: next) + + // Predict + results = try stage.perform { model in + try model.predictions(fromBatch: nextBatch) + } + } + return results + } +} +extension MLFeatureProvider { + var featureValueDictionary: [String : MLFeatureValue] { + self.featureNames.reduce(into: [String : MLFeatureValue]()) { result, name in + result[name] = self.featureValue(for: name) + } + } +} + +extension MLBatchProvider { + var arrayOfFeatureValueDictionaries: [[String : MLFeatureValue]] { + (0..