diff --git a/src/ReactiveUI/Suspension/SuspensionHostExtensions.cs b/src/ReactiveUI/Suspension/SuspensionHostExtensions.cs index 7e6098cfc2..7bc1270b36 100644 --- a/src/ReactiveUI/Suspension/SuspensionHostExtensions.cs +++ b/src/ReactiveUI/Suspension/SuspensionHostExtensions.cs @@ -4,8 +4,10 @@ // See the LICENSE file in the project root for full license information. using System; +using System.Reactive; using System.Reactive.Disposables; using System.Reactive.Linq; +using System.Threading; using Splat; namespace ReactiveUI; @@ -16,38 +18,50 @@ namespace ReactiveUI; public static class SuspensionHostExtensions { /// - /// Observe changes to the AppState of a class derived from ISuspensionHost. + /// Func used to load app state exactly once. /// - /// The observable type. + private static Func>? ensureLoadAppStateFunc; + + /// + /// Supsension driver reference field to prevent introducing breaking change. + /// + private static ISuspensionDriver? suspensionDriver; + + /// + /// Get the current App State of a class derived from ISuspensionHost. + /// + /// The app state type. /// The suspension host. - /// An observable of the app state. - public static IObservable ObserveAppState(this ISuspensionHost item) - where T : class + /// The app state. + public static T GetAppState(this ISuspensionHost item) { if (item is null) { throw new ArgumentNullException(nameof(item)); } - return item.WhenAny(suspensionHost => suspensionHost.AppState, observedChange => observedChange.Value) - .WhereNotNull() - .Cast(); + Interlocked.Exchange(ref ensureLoadAppStateFunc, null)?.Invoke(); + + return (T)item.AppState!; } /// - /// Get the current App State of a class derived from ISuspensionHost. + /// Observe changes to the AppState of a class derived from ISuspensionHost. /// - /// The app state type. + /// The observable type. /// The suspension host. - /// The app state. - public static T GetAppState(this ISuspensionHost item) + /// An observable of the app state. + public static IObservable ObserveAppState(this ISuspensionHost item) + where T : class { if (item is null) { throw new ArgumentNullException(nameof(item)); } - return (T)item.AppState!; + return item.WhenAny(suspensionHost => suspensionHost.AppState, observedChange => observedChange.Value) + .WhereNotNull() + .Cast(); } /// @@ -65,32 +79,64 @@ public static IDisposable SetupDefaultSuspendResume(this ISuspensionHost item, I } var ret = new CompositeDisposable(); - driver ??= Locator.Current.GetService(); + suspensionDriver ??= driver ?? Locator.Current.GetService(); - if (driver is null) + if (suspensionDriver is null) { item.Log().Error("Could not find a valid driver and therefore cannot setup Suspend/Resume."); return Disposable.Empty; } + ensureLoadAppStateFunc = () => EnsureLoadAppState(item, suspensionDriver); + ret.Add(item.ShouldInvalidateState - .SelectMany(_ => driver.InvalidateState()) + .SelectMany(_ => suspensionDriver.InvalidateState()) .LoggedCatch(item, Observables.Unit, "Tried to invalidate app state") .Subscribe(_ => item.Log().Info("Invalidated app state"))); ret.Add(item.ShouldPersistState - .SelectMany(x => driver.SaveState(item.AppState!).Finally(x.Dispose)) + .SelectMany(x => suspensionDriver.SaveState(item.AppState!).Finally(x.Dispose)) .LoggedCatch(item, Observables.Unit, "Tried to persist app state") .Subscribe(_ => item.Log().Info("Persisted application state"))); ret.Add(item.IsResuming.Merge(item.IsLaunchingNew) - .SelectMany(_ => driver.LoadState()) - .LoggedCatch( - item, - Observable.Defer(() => Observable.Return(item.CreateNewAppState?.Invoke())), - "Failed to restore app state from storage, creating from scratch") - .Subscribe(x => item.AppState = x ?? item.CreateNewAppState?.Invoke())); + .Do(_ => Interlocked.Exchange(ref ensureLoadAppStateFunc, null)?.Invoke()) + .Subscribe()); return ret; } -} \ No newline at end of file + + /// + /// Ensures one time app state load from storage. + /// + /// The suspension host. + /// The suspension driver. + /// A completed observable. + private static IObservable EnsureLoadAppState(this ISuspensionHost item, ISuspensionDriver? driver = null) + { + if (item.AppState is not null) + { + return Observable.Return(Unit.Default); + } + + suspensionDriver ??= driver ?? Locator.Current.GetService(); + + if (suspensionDriver is null) + { + item.Log().Error("Could not find a valid driver and therefore cannot load app state."); + return Observable.Return(Unit.Default); + } + + try + { + item.AppState = suspensionDriver.LoadState().Wait(); + } + catch (Exception ex) + { + item.Log().Warn(ex, "Failed to restore app state from storage, creating from scratch"); + item.AppState = item.CreateNewAppState?.Invoke(); + } + + return Observable.Return(Unit.Default); + } +}