Skip to content

Commit

Permalink
Merge pull request LykosAI#787 from ionite34/fix-invalid-filenames
Browse files Browse the repository at this point in the history
consolidate model import & remove invalid filename characters
  • Loading branch information
mohnjiles authored Aug 19, 2024
2 parents bd2a420 + 311a176 commit b47b8f6
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 248 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning 2.0](https://semver.org/spec/v2
- Fixed SwarmUI settings being overwritten on launch
- Fixed Forge output folder links pointing to the incorrect folder
- LORAs are now sorted by model name properly in the Extra Networks dropdown
- Fixed errors when downloading models with invalid characters in the file name

## v2.12.0-pre.1
### Added
Expand Down
1 change: 1 addition & 0 deletions StabilityMatrix.Avalonia/Services/IModelImportService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Task DoImport(
CivitFile? selectedFile = null,
IProgress<ProgressReport>? progress = null,
Func<Task>? onImportComplete = null,
Func<Task>? onImportCanceled = null,
Func<Task>? onImportFailed = null
);
}
91 changes: 9 additions & 82 deletions StabilityMatrix.Avalonia/Services/ModelDownloadLinkHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ public class ModelDownloadLinkHandler(
ICivitApi civitApi,
INotificationService notificationService,
ISettingsManager settingsManager,
IDownloadService downloadService,
ITrackedDownloadService trackedDownloadService
IModelImportService modelImportService
) : IAsyncDisposable, IModelDownloadLinkHandler
{
private IAsyncDisposable? uriHandlerSubscription;
Expand Down Expand Up @@ -161,93 +160,21 @@ private void UriReceivedHandler(Uri receivedUri)

var rootModelsDirectory = new DirectoryPath(settingsManager.ModelsDirectory);
var downloadDirectory = rootModelsDirectory.JoinDir(
selectedFile.Type == CivitFileType.VAE
selectedFile?.Type == CivitFileType.VAE
? SharedFolderType.VAE.GetStringValue()
: model.Type.ConvertTo<SharedFolderType>().GetStringValue()
);

downloadDirectory.Create();
var downloadPath = downloadDirectory.JoinFile(selectedFile.Name);

// Create tracked download
var download = trackedDownloadService.NewDownload(selectedFile.DownloadUrl, downloadPath);

// Download model info and preview first
var saveCmInfoTask = SaveCmInfo(model, modelVersion, selectedFile, downloadDirectory);
var savePreviewImageTask = SavePreviewImage(modelVersion, downloadPath);

Task.WaitAll([saveCmInfoTask, savePreviewImageTask]);

var cmInfoPath = saveCmInfoTask.Result;
var previewImagePath = savePreviewImageTask.Result;

// Add hash info
download.ExpectedHashSha256 = selectedFile.Hashes.SHA256;

// Add files to cleanup list
download.ExtraCleanupFileNames.Add(cmInfoPath);
if (previewImagePath is not null)
{
download.ExtraCleanupFileNames.Add(previewImagePath);
}

// Add hash context action
download.ContextAction = CivitPostDownloadContextAction.FromCivitFile(selectedFile);

download.Start();
var importTask = modelImportService.DoImport(
model,
downloadDirectory,
selectedVersion: modelVersion,
selectedFile: selectedFile
);
importTask.Wait();

Dispatcher.UIThread.Post(
() => notificationService.Show("Download Started", $"Downloading {selectedFile.Name}")
);
}

private static async Task<FilePath> SaveCmInfo(
CivitModel model,
CivitModelVersion modelVersion,
CivitFile modelFile,
DirectoryPath downloadDirectory
)
{
var modelFileName = Path.GetFileNameWithoutExtension(modelFile.Name);
var modelInfo = new ConnectedModelInfo(model, modelVersion, modelFile, DateTime.UtcNow);

await modelInfo.SaveJsonToDirectory(downloadDirectory, modelFileName);

var jsonName = $"{modelFileName}.cm-info.json";
return downloadDirectory.JoinFile(jsonName);
}

/// <summary>
/// Saves the preview image to the same directory as the model file
/// </summary>
/// <param name="modelVersion"></param>
/// <param name="modelFilePath"></param>
/// <returns>The file path of the saved preview image</returns>
private async Task<FilePath?> SavePreviewImage(CivitModelVersion modelVersion, FilePath modelFilePath)
{
// Skip if model has no images
if (modelVersion.Images == null || modelVersion.Images.Count == 0)
{
return null;
}

var image = modelVersion.Images.FirstOrDefault(x => x.Type == "image");
if (image is null)
return null;

var imageExtension = Path.GetExtension(image.Url).TrimStart('.');
if (imageExtension is "jpg" or "jpeg" or "png")
{
var imageDownloadPath = modelFilePath.Directory!.JoinFile(
$"{modelFilePath.NameWithoutExtension}.preview.{imageExtension}"
);

var imageTask = downloadService.DownloadToFileAsync(image.Url, imageDownloadPath);
await notificationService.TryAsync(imageTask, "Could not download preview image");

return imageDownloadPath;
}

return null;
}
}
7 changes: 6 additions & 1 deletion StabilityMatrix.Avalonia/Services/ModelImportService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public async Task DoImport(
CivitFile? selectedFile = null,
IProgress<ProgressReport>? progress = null,
Func<Task>? onImportComplete = null,
Func<Task>? onImportCanceled = null,
Func<Task>? onImportFailed = null
)
{
Expand Down Expand Up @@ -112,6 +113,10 @@ public async Task DoImport(
// Folders might be missing if user didn't install any packages yet
downloadFolder.Create();

// Fix invalid chars in FileName
modelFile.Name = Path.GetInvalidFileNameChars()
.Aggregate(modelFile.Name, (current, c) => current.Replace(c, '_'));

var downloadPath = downloadFolder.JoinFile(modelFile.Name);

// Download model info and preview first
Expand Down Expand Up @@ -145,7 +150,7 @@ public async Task DoImport(
}
else if (e == ProgressState.Cancelled)
{
// todo?
onImportCanceled?.Invoke().SafeFireAndForget();
}
else if (e == ProgressState.Failed)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public partial class CheckpointBrowserCardViewModel : ProgressViewModel
private readonly ServiceManager<ViewModelBase> dialogFactory;
private readonly INotificationService notificationService;
private readonly IModelIndexService modelIndexService;
private readonly IModelImportService modelImportService;

public Action<CheckpointBrowserCardViewModel>? OnDownloadStart { get; set; }

Expand Down Expand Up @@ -87,7 +88,8 @@ public CheckpointBrowserCardViewModel(
ISettingsManager settingsManager,
ServiceManager<ViewModelBase> dialogFactory,
INotificationService notificationService,
IModelIndexService modelIndexService
IModelIndexService modelIndexService,
IModelImportService modelImportService
)
{
this.downloadService = downloadService;
Expand All @@ -96,6 +98,7 @@ IModelIndexService modelIndexService
this.dialogFactory = dialogFactory;
this.notificationService = notificationService;
this.modelIndexService = modelIndexService;
this.modelImportService = modelImportService;

// Update image when nsfw setting changes
AddDisposable(
Expand Down Expand Up @@ -284,56 +287,6 @@ private async Task ShowVersionDialog(CivitModel model)
DelayedClearProgress(TimeSpan.FromMilliseconds(1000));
}

private static async Task<FilePath> SaveCmInfo(
CivitModel model,
CivitModelVersion modelVersion,
CivitFile modelFile,
DirectoryPath downloadDirectory
)
{
var modelFileName = Path.GetFileNameWithoutExtension(modelFile.Name);
var modelInfo = new ConnectedModelInfo(model, modelVersion, modelFile, DateTime.UtcNow);

await modelInfo.SaveJsonToDirectory(downloadDirectory, modelFileName);

var jsonName = $"{modelFileName}.cm-info.json";
return downloadDirectory.JoinFile(jsonName);
}

/// <summary>
/// Saves the preview image to the same directory as the model file
/// </summary>
/// <param name="modelVersion"></param>
/// <param name="modelFilePath"></param>
/// <returns>The file path of the saved preview image</returns>
private async Task<FilePath?> SavePreviewImage(CivitModelVersion modelVersion, FilePath modelFilePath)
{
// Skip if model has no images
if (modelVersion.Images == null || modelVersion.Images.Count == 0)
{
return null;
}

var image = modelVersion.Images.FirstOrDefault(x => x.Type == "image");
if (image is null)
return null;

var imageExtension = Path.GetExtension(image.Url).TrimStart('.');
if (imageExtension is "jpg" or "jpeg" or "png")
{
var imageDownloadPath = modelFilePath.Directory!.JoinFile(
$"{modelFilePath.NameWithoutExtension}.preview.{imageExtension}"
);

var imageTask = downloadService.DownloadToFileAsync(image.Url, imageDownloadPath);
await notificationService.TryAsync(imageTask, "Could not download preview image");

return imageDownloadPath;
}

return null;
}

private async Task DoImport(
CivitModel model,
DirectoryPath downloadFolder,
Expand Down Expand Up @@ -378,56 +331,37 @@ private async Task DoImport(
return;
}

// Folders might be missing if user didn't install any packages yet
downloadFolder.Create();

var downloadPath = downloadFolder.JoinFile(modelFile.Name);

// Download model info and preview first
var cmInfoPath = await SaveCmInfo(model, modelVersion, modelFile, downloadFolder);
var previewImagePath = await SavePreviewImage(modelVersion, downloadPath);

// Create tracked download
var download = trackedDownloadService.NewDownload(modelFile.DownloadUrl, downloadPath);

// Add hash info
download.ExpectedHashSha256 = modelFile.Hashes.SHA256;

// Add files to cleanup list
download.ExtraCleanupFileNames.Add(cmInfoPath);
if (previewImagePath is not null)
{
download.ExtraCleanupFileNames.Add(previewImagePath);
}

// Attach for progress updates
download.ProgressStateChanged += (s, e) =>
{
if (e == ProgressState.Success)
await modelImportService.DoImport(
model,
downloadFolder,
modelVersion,
modelFile,
onImportComplete: () =>
{
Text = "Import Complete";
IsIndeterminate = false;
Value = 100;
CheckIfInstalled();
DelayedClearProgress(TimeSpan.FromMilliseconds(800));
}
else if (e == ProgressState.Cancelled)
return Task.CompletedTask;
},
onImportCanceled: () =>
{
Text = "Cancelled";
DelayedClearProgress(TimeSpan.FromMilliseconds(500));
}
else if (e == ProgressState.Failed)
return Task.CompletedTask;
},
onImportFailed: () =>
{
Text = "Download Failed";
DelayedClearProgress(TimeSpan.FromMilliseconds(800));
}
};

// Add hash context action
download.ContextAction = CivitPostDownloadContextAction.FromCivitFile(modelFile);
download.Start();
return Task.CompletedTask;
}
);
}

private void DelayedClearProgress(TimeSpan delay)
Expand Down
Loading

0 comments on commit b47b8f6

Please sign in to comment.