Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #96 from saddam213/ControlNet
Browse files Browse the repository at this point in the history
LCM ControlNet, SDXL ControlNet, InstaFlow ControlNet
  • Loading branch information
saddam213 authored Jan 8, 2024
2 parents eaec794 + 30d2cc9 commit 72473e8
Show file tree
Hide file tree
Showing 9 changed files with 1,093 additions and 29 deletions.
208 changes: 208 additions & 0 deletions OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime.Tensors;
using OnnxStack.Core;
using OnnxStack.Core.Config;
using OnnxStack.Core.Model;
using OnnxStack.Core.Services;
using OnnxStack.StableDiffusion.Common;
using OnnxStack.StableDiffusion.Config;
using OnnxStack.StableDiffusion.Enums;
using OnnxStack.StableDiffusion.Helpers;
using OnnxStack.StableDiffusion.Models;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;

namespace OnnxStack.StableDiffusion.Diffusers.InstaFlow
{
public class ControlNetDiffuser : InstaFlowDiffuser
{
private readonly IControlNetImageService _controlNetImageService;

/// <summary>
/// Initializes a new instance of the <see cref="ControlNetDiffuser"/> class.
/// </summary>
/// <param name="configuration">The configuration.</param>
/// <param name="onnxModelService">The onnx model service.</param>
public ControlNetDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, IControlNetImageService controlNetImageService, ILogger<ControlNetDiffuser> logger)
: base(onnxModelService, promptService, logger)
{
_controlNetImageService = controlNetImageService;
}

/// <summary>
/// Gets the type of the diffuser.
/// </summary>
public override DiffuserType DiffuserType => DiffuserType.ControlNet;


/// <summary>
/// Called on each Scheduler step.
/// </summary>
/// <param name="modelOptions">The model options.</param>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <param name="promptEmbeddings">The prompt embeddings.</param>
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
/// <param name="progressCallback">The progress callback.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
protected override async Task<DenseTensor<float>> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
{
// Get Scheduler
using (var scheduler = GetScheduler(schedulerOptions))
{
// Get timesteps
var timesteps = GetTimesteps(schedulerOptions, scheduler);

// Create latent sample
var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);

// Get Model metadata
var metadata = _onnxModelService.GetModelMetadata(modelOptions.BaseModel, OnnxModelType.Unet);

// Get Model metadata
var controlNetMetadata = _onnxModelService.GetModelMetadata(modelOptions.ControlNetModel, OnnxModelType.ControlNet);

// Control Image
var controlImage = await PrepareControlImage(modelOptions, promptOptions, schedulerOptions);

// Get the distilled Timestep
var distilledTimestep = 1.0f / timesteps.Count;

// Loop though the timesteps
var step = 0;
foreach (var timestep in timesteps)
{
step++;
var stepTime = Stopwatch.GetTimestamp();
cancellationToken.ThrowIfCancellationRequested();

// Create input tensor.
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
var timestepTensor = CreateTimestepTensor(timestep);
var controlImageTensor = performGuidance ? controlImage.Repeat(2) : controlImage;
var conditioningScale = CreateConditioningScaleTensor(schedulerOptions.ConditioningScale);

var outputChannels = performGuidance ? 2 : 1;
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
{
inferenceParameters.AddInputTensor(inputTensor);
inferenceParameters.AddInputTensor(timestepTensor);
inferenceParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);

// ControlNet
using (var controlNetParameters = new OnnxInferenceParameters(controlNetMetadata))
{
controlNetParameters.AddInputTensor(inputTensor);
controlNetParameters.AddInputTensor(timestepTensor);
controlNetParameters.AddInputTensor(promptEmbeddings.PromptEmbeds);
controlNetParameters.AddInputTensor(controlImage);
if (controlNetMetadata.Inputs.Count == 5)
controlNetParameters.AddInputTensor(conditioningScale);

// Optimization: Pre-allocate device buffers for inputs
foreach (var item in controlNetMetadata.Outputs)
controlNetParameters.AddOutputBuffer();

// ControlNet inference
var controlNetResults = _onnxModelService.RunInference(modelOptions.ControlNetModel, OnnxModelType.ControlNet, controlNetParameters);

// Add ControlNet outputs to Unet input
foreach (var item in controlNetResults)
inferenceParameters.AddInput(item);

// Add output buffer
inferenceParameters.AddOutputBuffer(outputDimension);

// Unet inference
var results = await _onnxModelService.RunInferenceAsync(modelOptions.BaseModel, OnnxModelType.Unet, inferenceParameters);
using (var result = results.First())
{
var noisePred = result.ToDenseTensor();

// Perform guidance
if (performGuidance)
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);

// Scheduler Step
latents = scheduler.Step(noisePred, timestep, latents).Result;

latents = noisePred
.MultiplyTensorByFloat(distilledTimestep)
.AddTensors(latents);
}
}
}

ReportProgress(progressCallback, step, timesteps.Count, latents);
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
}

// Decode Latents
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
}
}


/// <summary>
/// Gets the timesteps.
/// </summary>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <returns></returns>
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
{
return scheduler.Timesteps;
}


/// <summary>
/// Prepares the input latents.
/// </summary>
/// <param name="model">The model.</param>
/// <param name="prompt">The prompt.</param>
/// <param name="options">The options.</param>
/// <param name="scheduler">The scheduler.</param>
/// <param name="timesteps">The timesteps.</param>
/// <returns></returns>
protected override Task<DenseTensor<float>> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
{
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
}


/// <summary>
/// Creates the Conditioning Scale tensor.
/// </summary>
/// <param name="conditioningScale">The conditioningScale.</param>
/// <returns></returns>
protected static DenseTensor<double> CreateConditioningScaleTensor(float conditioningScale)
{
return TensorHelper.CreateTensor(new double[] { conditioningScale }, new int[] { 1 });
}


/// <summary>
/// Prepares the control image.
/// </summary>
/// <param name="promptOptions">The prompt options.</param>
/// <param name="schedulerOptions">The scheduler options.</param>
/// <returns></returns>
protected async Task<DenseTensor<float>> PrepareControlImage(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
{
var controlImage = promptOptions.InputContolImage;
if (schedulerOptions.IsControlImageProcessingEnabled)
{
controlImage = await _controlNetImageService.PrepareInputImage(modelOptions.ControlNetModel, promptOptions.InputContolImage, schedulerOptions.Height, schedulerOptions.Width);
}
return controlImage.ToDenseTensor(new[] { 1, 3, schedulerOptions.Height, schedulerOptions.Width }, false);
}
}
}
Loading

0 comments on commit 72473e8

Please sign in to comment.