diff --git a/src/ReactiveUI/Suspension/SuspensionHostExtensions.cs b/src/ReactiveUI/Suspension/SuspensionHostExtensions.cs
index 7e6098cfc2..7bc1270b36 100644
--- a/src/ReactiveUI/Suspension/SuspensionHostExtensions.cs
+++ b/src/ReactiveUI/Suspension/SuspensionHostExtensions.cs
@@ -4,8 +4,10 @@
// See the LICENSE file in the project root for full license information.
using System;
+using System.Reactive;
using System.Reactive.Disposables;
using System.Reactive.Linq;
+using System.Threading;
using Splat;
namespace ReactiveUI;
@@ -16,38 +18,50 @@ namespace ReactiveUI;
public static class SuspensionHostExtensions
{
///
- /// Observe changes to the AppState of a class derived from ISuspensionHost.
+ /// Func used to load app state exactly once.
///
- /// The observable type.
+ private static Func>? ensureLoadAppStateFunc;
+
+ ///
+ /// Supsension driver reference field to prevent introducing breaking change.
+ ///
+ private static ISuspensionDriver? suspensionDriver;
+
+ ///
+ /// Get the current App State of a class derived from ISuspensionHost.
+ ///
+ /// The app state type.
/// The suspension host.
- /// An observable of the app state.
- public static IObservable ObserveAppState(this ISuspensionHost item)
- where T : class
+ /// The app state.
+ public static T GetAppState(this ISuspensionHost item)
{
if (item is null)
{
throw new ArgumentNullException(nameof(item));
}
- return item.WhenAny(suspensionHost => suspensionHost.AppState, observedChange => observedChange.Value)
- .WhereNotNull()
- .Cast();
+ Interlocked.Exchange(ref ensureLoadAppStateFunc, null)?.Invoke();
+
+ return (T)item.AppState!;
}
///
- /// Get the current App State of a class derived from ISuspensionHost.
+ /// Observe changes to the AppState of a class derived from ISuspensionHost.
///
- /// The app state type.
+ /// The observable type.
/// The suspension host.
- /// The app state.
- public static T GetAppState(this ISuspensionHost item)
+ /// An observable of the app state.
+ public static IObservable ObserveAppState(this ISuspensionHost item)
+ where T : class
{
if (item is null)
{
throw new ArgumentNullException(nameof(item));
}
- return (T)item.AppState!;
+ return item.WhenAny(suspensionHost => suspensionHost.AppState, observedChange => observedChange.Value)
+ .WhereNotNull()
+ .Cast();
}
///
@@ -65,32 +79,64 @@ public static IDisposable SetupDefaultSuspendResume(this ISuspensionHost item, I
}
var ret = new CompositeDisposable();
- driver ??= Locator.Current.GetService();
+ suspensionDriver ??= driver ?? Locator.Current.GetService();
- if (driver is null)
+ if (suspensionDriver is null)
{
item.Log().Error("Could not find a valid driver and therefore cannot setup Suspend/Resume.");
return Disposable.Empty;
}
+ ensureLoadAppStateFunc = () => EnsureLoadAppState(item, suspensionDriver);
+
ret.Add(item.ShouldInvalidateState
- .SelectMany(_ => driver.InvalidateState())
+ .SelectMany(_ => suspensionDriver.InvalidateState())
.LoggedCatch(item, Observables.Unit, "Tried to invalidate app state")
.Subscribe(_ => item.Log().Info("Invalidated app state")));
ret.Add(item.ShouldPersistState
- .SelectMany(x => driver.SaveState(item.AppState!).Finally(x.Dispose))
+ .SelectMany(x => suspensionDriver.SaveState(item.AppState!).Finally(x.Dispose))
.LoggedCatch(item, Observables.Unit, "Tried to persist app state")
.Subscribe(_ => item.Log().Info("Persisted application state")));
ret.Add(item.IsResuming.Merge(item.IsLaunchingNew)
- .SelectMany(_ => driver.LoadState())
- .LoggedCatch(
- item,
- Observable.Defer(() => Observable.Return(item.CreateNewAppState?.Invoke())),
- "Failed to restore app state from storage, creating from scratch")
- .Subscribe(x => item.AppState = x ?? item.CreateNewAppState?.Invoke()));
+ .Do(_ => Interlocked.Exchange(ref ensureLoadAppStateFunc, null)?.Invoke())
+ .Subscribe());
return ret;
}
-}
\ No newline at end of file
+
+ ///
+ /// Ensures one time app state load from storage.
+ ///
+ /// The suspension host.
+ /// The suspension driver.
+ /// A completed observable.
+ private static IObservable EnsureLoadAppState(this ISuspensionHost item, ISuspensionDriver? driver = null)
+ {
+ if (item.AppState is not null)
+ {
+ return Observable.Return(Unit.Default);
+ }
+
+ suspensionDriver ??= driver ?? Locator.Current.GetService();
+
+ if (suspensionDriver is null)
+ {
+ item.Log().Error("Could not find a valid driver and therefore cannot load app state.");
+ return Observable.Return(Unit.Default);
+ }
+
+ try
+ {
+ item.AppState = suspensionDriver.LoadState().Wait();
+ }
+ catch (Exception ex)
+ {
+ item.Log().Warn(ex, "Failed to restore app state from storage, creating from scratch");
+ item.AppState = item.CreateNewAppState?.Invoke();
+ }
+
+ return Observable.Return(Unit.Default);
+ }
+}