diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs index c4dc167979..253c86aa5b 100644 --- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs +++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs @@ -17,6 +17,7 @@ namespace Microsoft.ML.Runtime.Api /// /// A helper class to create data views based on the user-provided types. /// + [BestFriend] internal static class DataViewConstructionUtils { public static IDataView CreateFromList(IHostEnvironment env, IList data, diff --git a/src/Microsoft.ML.Api/PredictionEngine.cs b/src/Microsoft.ML.Api/PredictionEngine.cs index 9c4522722d..c8d6826610 100644 --- a/src/Microsoft.ML.Api/PredictionEngine.cs +++ b/src/Microsoft.ML.Api/PredictionEngine.cs @@ -122,6 +122,33 @@ public void Reset() } } + public sealed class PredictionEngine : PredictionEngineBase + where TSrc : class + where TDst : class, new() + { + internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns, + SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) + : base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition) + { + } + + /// + /// Run prediction pipeline on one example. + /// + /// The example to run on. + /// The object to store the prediction in. If it's null, a new one will be created, otherwise the old one + /// is reused. + public override void Predict(TSrc example, ref TDst prediction) + { + Contracts.CheckValue(example, nameof(example)); + ExtractValues(example); + if (prediction == null) + prediction = new TDst(); + + FillValues(prediction); + } + } + /// /// A class that runs the previously trained model (and the preceding transform pipeline) on the /// in-memory data, one example at a time. @@ -130,14 +157,17 @@ public void Reset() /// /// The user-defined type that holds the example. /// The user-defined type that holds the prediction. - public sealed class PredictionEngine + public abstract class PredictionEngineBase where TSrc : class where TDst : class, new() { private readonly DataViewConstructionUtils.InputRow _inputRow; private readonly IRowReadableAs _outputRow; private readonly Action _disposer; + [BestFriend] + private protected ITransformer Transformer { get; } + [BestFriend] private static Func StreamChecker(IHostEnvironment env, Stream modelStream) { env.CheckValue(modelStream, nameof(modelStream)); @@ -150,29 +180,37 @@ private static Func StreamChecker(IHostEnvironment env, }; } - internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns, + [BestFriend] + private protected PredictionEngineBase(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) { Contracts.CheckValue(env, nameof(env)); env.AssertValue(transformer); + Transformer = transformer; var makeMapper = TransformerChecker(env, transformer); env.AssertValue(makeMapper); - _inputRow = DataViewConstructionUtils.CreateInputRow(env, inputSchemaDefinition); - var mapper = makeMapper(_inputRow.Schema); + PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, out _disposer, out _outputRow); + } + + internal virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, + SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow) + { var cursorable = TypedCursorable.Create(env, new EmptyDataView(env, mapper.Schema), ignoreMissingColumns, outputSchemaDefinition); - var outputRow = mapper.GetRow(_inputRow, col => true, out _disposer); - _outputRow = cursorable.GetRow(outputRow); + var outputRowLocal = mapper.GetRow(_inputRow, col => true, out disposer); + outputRow = cursorable.GetRow(outputRowLocal); } - private static Func TransformerChecker(IExceptionContext ectx, ITransformer transformer) + protected virtual Func TransformerChecker(IExceptionContext ectx, ITransformer transformer) { ectx.CheckValue(transformer, nameof(transformer)); ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper"); return transformer.GetRowToRowMapper; } - ~PredictionEngine() + internal virtual ITransformer ProcessTransformer(ITransformer transformer) => transformer; + + ~PredictionEngineBase() { _disposer?.Invoke(); } @@ -189,19 +227,16 @@ public TDst Predict(TSrc example) return result; } + protected void ExtractValues(TSrc example) => _inputRow.ExtractValues(example); + + protected void FillValues(TDst prediction) => _outputRow.FillValues(prediction); + /// /// Run prediction pipeline on one example. /// /// The example to run on. /// The object to store the prediction in. If it's null, a new one will be created, otherwise the old one /// is reused. - public void Predict(TSrc example, ref TDst prediction) - { - Contracts.CheckValue(example, nameof(example)); - _inputRow.ExtractValues(example); - if (prediction == null) - prediction = new TDst(); - _outputRow.FillValues(prediction); - } + public abstract void Predict(TSrc example, ref TDst prediction); } } \ No newline at end of file diff --git a/src/Microsoft.ML.Api/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Api/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..05db89ef9f --- /dev/null +++ b/src/Microsoft.ML.Api/Properties/AssemblyInfo.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; +using Microsoft.ML; + +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)] + +[assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs index 38f4965cea..c7944c7704 100644 --- a/src/Microsoft.ML.Api/TypedCursor.cs +++ b/src/Microsoft.ML.Api/TypedCursor.cs @@ -77,6 +77,7 @@ public interface ICursorable /// Similarly to the 'DataView{T}, this class uses IL generation to create the 'poke' methods that /// write directly into the fields of the user-defined type. /// + [BestFriend] internal sealed class TypedCursorable : ICursorable where TRow : class { diff --git a/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs b/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs index bdcea0c0ae..1efb2f7e18 100644 --- a/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs +++ b/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs @@ -13,12 +13,16 @@ namespace Microsoft.ML.Runtime.Internal.Utilities /// an item will result in discarding the least recently added item. /// [BestFriend] - internal sealed class FixedSizeQueue + internal sealed class FixedSizeQueue : ICloneable { private readonly T[] _array; private int _startIndex; private int _count; + public int StartIndex => _startIndex; + + public T[] Buffer => _array; + public FixedSizeQueue(int capacity) { Contracts.Assert(capacity > 0, "Array capacity should be greater than zero"); @@ -26,6 +30,13 @@ public FixedSizeQueue(int capacity) AssertValid(); } + public FixedSizeQueue(int capacity, int startIndex, T[] buffer) : this(capacity) + { + _startIndex = startIndex; + for (int index = 0; index < capacity; index++) + _array[index] = buffer[index]; + } + [Conditional("DEBUG")] private void AssertValid() { @@ -136,5 +147,8 @@ public void Clear() _count = 0; AssertValid(); } + + public object Clone() => new FixedSizeQueue(Capacity, StartIndex, Buffer); + } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs index f6c0a4bfd4..73c9be64ba 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -36,20 +36,33 @@ public enum TransformerScope Everything = Training | Testing | Scoring } + [BestFriend] + internal interface ITransformerAccessor + { + ITransformer[] Transformers { get; } + TransformerScope[] Scopes { get; } + } + /// /// A chain of transformers (possibly empty) that end with a . /// For an empty chain, is always . /// - public sealed class TransformerChain : ITransformer, ICanSaveModel, IEnumerable + public sealed class TransformerChain : ITransformer, ICanSaveModel, IEnumerable, ITransformerAccessor where TLastTransformer : class, ITransformer { - private readonly ITransformer[] _transformers; - private readonly TransformerScope[] _scopes; + [BestFriend] + internal readonly ITransformer[] Transformers; + [BestFriend] + internal readonly TransformerScope[] Scopes; public readonly TLastTransformer LastTransformer; private const string TransformDirTemplate = "Transform_{0:000}"; - public bool IsRowToRowMapper => _transformers.All(t => t.IsRowToRowMapper); + public bool IsRowToRowMapper => Transformers.All(t => t.IsRowToRowMapper); + + ITransformer[] ITransformerAccessor.Transformers => Transformers; + + TransformerScope[] ITransformerAccessor.Scopes => Scopes; private static VersionInfo GetVersionInfo() { @@ -72,12 +85,12 @@ public TransformerChain(IEnumerable transformers, IEnumerable 0) == (LastTransformer != null)); - Contracts.Check(_transformers.Length == _scopes.Length); + Contracts.Check((Transformers.Length > 0) == (LastTransformer != null)); + Contracts.Check(Transformers.Length == Scopes.Length); } /// @@ -91,14 +104,14 @@ public TransformerChain(params ITransformer[] transformers) if (Utils.Size(transformers) == 0) { - _transformers = new ITransformer[0]; - _scopes = new TransformerScope[0]; + Transformers = new ITransformer[0]; + Scopes = new TransformerScope[0]; LastTransformer = null; } else { - _transformers = transformers.ToArray(); - _scopes = transformers.Select(x => TransformerScope.Everything).ToArray(); + Transformers = transformers.ToArray(); + Scopes = transformers.Select(x => TransformerScope.Everything).ToArray(); LastTransformer = transformers.Last() as TLastTransformer; Contracts.Check(LastTransformer != null); } @@ -109,7 +122,7 @@ public Schema GetOutputSchema(Schema inputSchema) Contracts.CheckValue(inputSchema, nameof(inputSchema)); var s = inputSchema; - foreach (var xf in _transformers) + foreach (var xf in Transformers) s = xf.GetOutputSchema(s); return s; } @@ -123,7 +136,7 @@ public IDataView Transform(IDataView input) GetOutputSchema(input.Schema); var dv = input; - foreach (var xf in _transformers) + foreach (var xf in Transformers) dv = xf.Transform(dv); return dv; } @@ -132,12 +145,12 @@ public TransformerChain GetModelFor(TransformerScope scopeFilter) { var xfs = new List(); var scopes = new List(); - for (int i = 0; i < _transformers.Length; i++) + for (int i = 0; i < Transformers.Length; i++) { - if ((_scopes[i] & scopeFilter) != TransformerScope.None) + if ((Scopes[i] & scopeFilter) != TransformerScope.None) { - xfs.Add(_transformers[i]); - scopes.Add(_scopes[i]); + xfs.Add(Transformers[i]); + scopes.Add(Scopes[i]); } } return new TransformerChain(xfs.ToArray(), scopes.ToArray()); @@ -147,7 +160,7 @@ public TransformerChain Append(TNewLast transformer, Transfo where TNewLast : class, ITransformer { Contracts.CheckValue(transformer, nameof(transformer)); - return new TransformerChain(_transformers.AppendElement(transformer), _scopes.AppendElement(scope)); + return new TransformerChain(Transformers.AppendElement(transformer), Scopes.AppendElement(scope)); } public void Save(ModelSaveContext ctx) @@ -155,13 +168,13 @@ public void Save(ModelSaveContext ctx) ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); - ctx.Writer.Write(_transformers.Length); + ctx.Writer.Write(Transformers.Length); - for (int i = 0; i < _transformers.Length; i++) + for (int i = 0; i < Transformers.Length; i++) { - ctx.Writer.Write((int)_scopes[i]); + ctx.Writer.Write((int)Scopes[i]); var dirName = string.Format(TransformDirTemplate, i); - ctx.SaveModel(_transformers[i], dirName); + ctx.SaveModel(Transformers[i], dirName); } } @@ -171,16 +184,16 @@ public void Save(ModelSaveContext ctx) internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx) { int len = ctx.Reader.ReadInt32(); - _transformers = new ITransformer[len]; - _scopes = new TransformerScope[len]; + Transformers = new ITransformer[len]; + Scopes = new TransformerScope[len]; for (int i = 0; i < len; i++) { - _scopes[i] = (TransformerScope)(ctx.Reader.ReadInt32()); + Scopes[i] = (TransformerScope)(ctx.Reader.ReadInt32()); var dirName = string.Format(TransformDirTemplate, i); - ctx.LoadModel(env, out _transformers[i], dirName); + ctx.LoadModel(env, out Transformers[i], dirName); } if (len > 0) - LastTransformer = _transformers[len - 1] as TLastTransformer; + LastTransformer = Transformers[len - 1] as TLastTransformer; else LastTransformer = null; } @@ -198,7 +211,7 @@ public void SaveTo(IHostEnvironment env, Stream outputStream) } } - public IEnumerator GetEnumerator() => ((IEnumerable)_transformers).GetEnumerator(); + public IEnumerator GetEnumerator() => ((IEnumerable)Transformers).GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); @@ -207,11 +220,11 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) Contracts.CheckValue(inputSchema, nameof(inputSchema)); Contracts.Check(IsRowToRowMapper, nameof(GetRowToRowMapper) + " method called despite " + nameof(IsRowToRowMapper) + " being false."); - IRowToRowMapper[] mappers = new IRowToRowMapper[_transformers.Length]; + IRowToRowMapper[] mappers = new IRowToRowMapper[Transformers.Length]; Schema schema = inputSchema; for (int i = 0; i < mappers.Length; ++i) { - mappers[i] = _transformers[i].GetRowToRowMapper(schema); + mappers[i] = Transformers[i].GetRowToRowMapper(schema); schema = mappers[i].Schema; } return new CompositeRowToRowMapper(inputSchema, mappers); diff --git a/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs b/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs index f07595e52a..914dcf2a95 100644 --- a/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs +++ b/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -19,6 +19,9 @@ public sealed class CompositeRowToRowMapper : IRowToRowMapper public Schema InputSchema { get; } public Schema Schema { get; } + [BestFriend] + internal IRowToRowMapper[] InnerMappers => _innerMappers; + /// /// Out of a series of mappers, construct a seemingly unitary mapper that is able to apply them in sequence. /// @@ -85,6 +88,7 @@ public IRow GetRow(IRow input, Func active, out Action disposer) // We want the last disposer to be called first, so the order of the addition here is important. } } + return result; } diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index c60717bce1..d98cc8e89b 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -454,6 +454,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo _wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign); int i = 0; _wTrans.CopyFrom(tempArray, ref i); + _y = new CpuAlignedVector(_rank, SseUtils.CbAlign); } _buffer = new FixedSizeQueue(_seriesLength); diff --git a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs index f5f4af3104..bfc5e7b874 100644 --- a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs @@ -20,12 +20,16 @@ public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment : base(args, name, env) { InitialWindowSize = 0; + StateRef = new State(); + StateRef.InitState(WindowSize, InitialWindowSize, this, Host); } public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name) : base(env, ctx, name) { Host.CheckDecode(InitialWindowSize == 0); + StateRef = new State(ctx); + StateRef.InitState(this, Host); } public override Schema GetOutputSchema(Schema inputSchema) @@ -46,15 +50,52 @@ public override void Save(ModelSaveContext ctx) { ctx.CheckAtModel(); Host.Assert(InitialWindowSize == 0); + base.Save(ctx); // *** Binary format *** // - - base.Save(ctx); + // State: StateRef + StateRef.Save(ctx); } public sealed class State : AnomalyDetectionStateBase { + public State() + { + + } + + public State(ModelLoadContext ctx) : base(ctx) + { + WindowedBuffer = new FixedSizeQueue( + ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadSingleArray()); + + InitialWindowedBuffer = new FixedSizeQueue( + ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadSingleArray()); + } + + public override void Save(ModelSaveContext ctx) + { + base.Save(ctx); + + ctx.Writer.Write(WindowedBuffer.Capacity); + ctx.Writer.Write(WindowedBuffer.StartIndex); + ctx.Writer.WriteSingleArray(WindowedBuffer.Buffer); + + ctx.Writer.Write(InitialWindowedBuffer.Capacity); + ctx.Writer.Write(InitialWindowedBuffer.StartIndex); + ctx.Writer.WriteSingleArray(InitialWindowedBuffer.Buffer); + } + + public override void CloneCore(StateBase state) + { + base.CloneCore(state); + Contracts.Assert(state is State); + var stateLocal = state as State; + stateLocal.WindowedBuffer = (FixedSizeQueue)WindowedBuffer.Clone(); + stateLocal.InitialWindowedBuffer = (FixedSizeQueue)InitialWindowedBuffer.Clone(); + } + private protected override void LearnStateFromDataCore(FixedSizeQueue data) { // This method is empty because there is no need for initial tuning for this transform. @@ -70,6 +111,10 @@ private protected override double ComputeRawAnomalyScore(ref Single input, Fixed // This transform treats the input sequenence as the raw anomaly score. return (double)input; } + + public override void Consume(float value) + { + } } } } diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs index 1fb5b2e62e..741bba0601 100644 --- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TimeSeriesProcessing; +using Microsoft.ML.TimeSeries; using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; [assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform), @@ -110,6 +111,14 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData return new IidChangePointDetector(env, args).MakeDataTransform(input); } + public override IStatefulTransformer Clone() + { + var clone = (IidChangePointDetector)MemberwiseClone(); + clone.StateRef = (State)clone.StateRef.Clone(); + clone.StateRef.InitState(clone, Host); + return clone; + } + internal IidChangePointDetector(IHostEnvironment env, Arguments args) : base(new BaseArguments(args), LoaderSignature, env) { diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs index 9ccb302eb6..5e93116f42 100644 --- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs @@ -11,6 +11,7 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TimeSeriesProcessing; +using Microsoft.ML.TimeSeries; using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; [assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), typeof(IidSpikeDetector.Arguments), typeof(SignatureDataTransform), @@ -106,6 +107,14 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData return new IidSpikeDetector(env, args).MakeDataTransform(input); } + public override IStatefulTransformer Clone() + { + var clone = (IidSpikeDetector)MemberwiseClone(); + clone.StateRef = (State)clone.StateRef.Clone(); + clone.StateRef.InitState(clone, Host); + return clone; + } + internal IidSpikeDetector(IHostEnvironment env, Arguments args) : base(new BaseArguments(args), LoaderSignature, env) { diff --git a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs new file mode 100644 index 0000000000..ac502a9e31 --- /dev/null +++ b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs @@ -0,0 +1,263 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; + +namespace Microsoft.ML.TimeSeries +{ + public interface IStatefulRowToRowMapper : IRowToRowMapper + { + } + + public interface IStatefulTransformer : ITransformer + { + IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema); + + IStatefulTransformer Clone(); + } + + public interface IStatefulRow : IRow + { + Action GetPinger(); + } + + public interface IStatefulRowMapper : IRowMapper + { + void CloneState(); + + Action CreatePinger(IRow input, Func activeOutput, out Action disposer); + } + + /// + /// A class that runs the previously trained model (and the preceding transform pipeline) on the + /// in-memory data, one example at a time. + /// This can also be used with trained pipelines that do not end with a predictor: in this case, the + /// 'prediction' will be just the outcome of all the transformations. + /// + /// The user-defined type that holds the example. + /// The user-defined type that holds the prediction. + public sealed class TimeSeriesPredictionFunction : PredictionEngineBase + where TSrc : class + where TDst : class, new() + { + private Action _pinger; + private long _rowPosition; + private ITransformer InputTransformer { get; set; } + + public void CheckPoint(IHostEnvironment env, string modelPath) + { + using (var file = File.Create(modelPath)) + if (Transformer is ITransformerAccessor) + { + + new TransformerChain + (((ITransformerAccessor)Transformer).Transformers, + ((ITransformerAccessor)Transformer).Scopes).SaveTo(env, file); + } + else + Transformer.SaveTo(env, file); + } + + private static ITransformer CloneTransformers(ITransformer transformer) + { + ITransformer[] transformersClone = null; + TransformerScope[] scopeClone = null; + if (transformer is ITransformerAccessor) + { + ITransformerAccessor accessor = (ITransformerAccessor)transformer; + transformersClone = accessor.Transformers.Select(x => x).ToArray(); + scopeClone = accessor.Scopes.Select(x => x).ToArray(); + int index = 0; + foreach (var xf in transformersClone) + transformersClone[index++] = xf is IStatefulTransformer ? ((IStatefulTransformer)xf).Clone() : xf; + + return new TransformerChain(transformersClone, scopeClone); + } + else + return transformer is IStatefulTransformer ? ((IStatefulTransformer)transformer).Clone() : transformer; + } + + public TimeSeriesPredictionFunction(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns, + SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) : + base(env, CloneTransformers(transformer), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition) + { + } + + internal override ITransformer ProcessTransformer(ITransformer transformer) => CloneTransformers(transformer); + + public IRow GetStatefulRows(IRow input, IRowToRowMapper mapper, Func active, + List rows, out Action disposer) + { + Contracts.CheckValue(input, nameof(input)); + Contracts.CheckValue(active, nameof(active)); + + disposer = null; + IRowToRowMapper[] innerMappers = new IRowToRowMapper[0]; + if (mapper is CompositeRowToRowMapper) + innerMappers = ((CompositeRowToRowMapper)mapper).InnerMappers; + + if (innerMappers.Length == 0) + { + bool differentActive = false; + for (int c = 0; c < input.Schema.ColumnCount; ++c) + { + bool wantsActive = active(c); + bool isActive = input.IsColumnActive(c); + differentActive |= wantsActive != isActive; + + if (wantsActive && !isActive) + throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema.GetColumnName(c)}' active but it was not."); + } + + var row = mapper.GetRow(input, active, out disposer); + if (row is IStatefulRow) + rows.Add((IStatefulRow)row); + + return row; + } + + // For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know + // what we need from them. The last one will just have the input, but the rest will need to be + // computed based on the dependencies of the next one in the chain. + var deps = new Func[innerMappers.Length]; + deps[deps.Length - 1] = active; + for (int i = deps.Length - 1; i >= 1; --i) + deps[i - 1] = innerMappers[i].GetDependencies(deps[i]); + + IRow result = input; + for (int i = 0; i < innerMappers.Length; ++i) + { + Action localDisp; + result = GetStatefulRows(result, innerMappers[i], deps[i], rows, out localDisp); + if (result is IStatefulRow) + rows.Add((IStatefulRow)result); + + if (localDisp != null) + { + if (disposer == null) + disposer = localDisp; + else + disposer = localDisp + disposer; + // We want the last disposer to be called first, so the order of the addition here is important. + } + } + + return result; + } + + private Action CreatePinger(List rows) + { + Action[] pingers = new Action[rows.Count]; + int index = 0; + foreach (var row in rows) + pingers[index++] = row.GetPinger(); + + return (long position) => + { + foreach (var ping in pingers) + ping(position); + }; + } + + internal override void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns, + SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow) + { + List rows = new List(); + IRow outputRowLocal = outputRowLocal = GetStatefulRows(inputRow, mapper, col => true, rows, out disposer); + var cursorable = TypedCursorable.Create(env, new EmptyDataView(env, mapper.Schema), ignoreMissingColumns, outputSchemaDefinition); + _pinger = CreatePinger(rows); + outputRow = cursorable.GetRow(outputRowLocal); + } + + private bool IsRowToRowMapper(ITransformer transformer) + { + if (transformer is ITransformerAccessor) + return ((ITransformerAccessor)transformer).Transformers.All(t => t.IsRowToRowMapper || t is IStatefulTransformer); + else + return transformer.IsRowToRowMapper || transformer is IStatefulTransformer; + } + + private IRowToRowMapper GetRowToRowMapper(Schema inputSchema) + { + Contracts.CheckValue(inputSchema, nameof(inputSchema)); + Contracts.Check(IsRowToRowMapper(InputTransformer), nameof(GetRowToRowMapper) + + " method called despite " + nameof(IsRowToRowMapper) + " being false. or transformer not being " + nameof(IStatefulTransformer)); + + if (!(InputTransformer is ITransformerAccessor)) + if (InputTransformer is IStatefulTransformer) + return ((IStatefulTransformer)InputTransformer).GetStatefulRowToRowMapper(inputSchema); + else + return InputTransformer.GetRowToRowMapper(inputSchema); + + Contracts.Check(InputTransformer is ITransformerAccessor); + + var transformers = ((ITransformerAccessor)InputTransformer).Transformers; + IRowToRowMapper[] mappers = new IRowToRowMapper[transformers.Length]; + Schema schema = inputSchema; + for (int i = 0; i < mappers.Length; ++i) + { + if (transformers[i] is IStatefulTransformer) + mappers[i] = ((IStatefulTransformer)transformers[i]).GetStatefulRowToRowMapper(schema); + else + mappers[i] = transformers[i].GetRowToRowMapper(schema); + + schema = mappers[i].Schema; + } + return new CompositeRowToRowMapper(inputSchema, mappers); + } + + protected override Func TransformerChecker(IExceptionContext ectx, ITransformer transformer) + { + ectx.CheckValue(transformer, nameof(transformer)); + ectx.CheckParam(IsRowToRowMapper(transformer), nameof(transformer), "Must be a row to row mapper or " + nameof(IStatefulTransformer)); + InputTransformer = transformer; + return GetRowToRowMapper; + } + + /// + /// Run prediction pipeline on one example. + /// + /// The example to run on. + /// The object to store the prediction in. If it's null, a new one will be created, otherwise the old one + /// is reused. + public override void Predict(TSrc example, ref TDst prediction) + { + Contracts.CheckValue(example, nameof(example)); + ExtractValues(example); + if (prediction == null) + prediction = new TDst(); + + // Update state. + _pinger(_rowPosition); + + // Predict. + FillValues(prediction); + + _rowPosition++; + } + } + + public static class PredictionFunctionExtensions + { + public static TimeSeriesPredictionFunction CreateTimeSeriesPredictionFunction(this ITransformer transformer, IHostEnvironment env, + bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) + where TSrc : class + where TDst : class, new() + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(transformer, nameof(transformer)); + env.CheckValueOrNull(inputSchemaDefinition); + env.CheckValueOrNull(outputSchemaDefinition); + return new TimeSeriesPredictionFunction(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition); + } + } +} diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs index 398fb86f19..b28ce37329 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs @@ -1,15 +1,17 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System; using System.Collections.Generic; +using System.Threading; using Microsoft.ML.Data; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.CpuMath; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; +using Microsoft.ML.TimeSeries; namespace Microsoft.ML.Runtime.TimeSeriesProcessing { @@ -306,10 +308,10 @@ public abstract class AnomalyDetectionStateBase : StateBase protected SequentialAnomalyDetectionTransformBase Parent; // A windowed buffer to cache the update values to the martingale score in the log scale. - private FixedSizeQueue _logMartingaleUpdateBuffer; + private FixedSizeQueue LogMartingaleUpdateBuffer { get; set; } // A windowed buffer to cache the raw anomaly scores for p-value calculation. - private FixedSizeQueue _rawScoreBuffer; + private FixedSizeQueue RawScoreBuffer { get; set; } // The current martingale score in the log scale. private Double _logMartingaleValue; @@ -322,6 +324,45 @@ public abstract class AnomalyDetectionStateBase : StateBase protected Double LatestMartingaleScore => Math.Exp(_logMartingaleValue); + public override void CloneCore(StateBase state) + { + base.CloneCore(state); + Contracts.Assert(state is AnomalyDetectionStateBase); + var stateLocal = state as AnomalyDetectionStateBase; + stateLocal.LogMartingaleUpdateBuffer = (FixedSizeQueue)LogMartingaleUpdateBuffer.Clone(); + stateLocal.RawScoreBuffer = (FixedSizeQueue)RawScoreBuffer.Clone(); + } + + public AnomalyDetectionStateBase(ModelLoadContext ctx) : base(ctx) + { + LogMartingaleUpdateBuffer = new FixedSizeQueue( + ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadDoubleArray()); + + RawScoreBuffer = new FixedSizeQueue( + ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadSingleArray()); + + _logMartingaleValue = ctx.Reader.ReadDouble(); + _sumSquaredDist = ctx.Reader.ReadDouble(); + _martingaleAlertCounter = ctx.Reader.ReadInt32(); + } + + public override void Save(ModelSaveContext ctx) + { + base.Save(ctx); + + ctx.Writer.Write(LogMartingaleUpdateBuffer.Capacity); + ctx.Writer.Write(LogMartingaleUpdateBuffer.StartIndex); + ctx.Writer.WriteDoubleArray(LogMartingaleUpdateBuffer.Buffer); + + ctx.Writer.Write(RawScoreBuffer.Capacity); + ctx.Writer.Write(RawScoreBuffer.StartIndex); + ctx.Writer.WriteSingleArray(RawScoreBuffer.Buffer); + + ctx.Writer.Write(_logMartingaleValue); + ctx.Writer.Write(_sumSquaredDist); + ctx.Writer.Write(_martingaleAlertCounter); + } + private protected AnomalyDetectionStateBase() : base() { } @@ -329,7 +370,7 @@ private protected AnomalyDetectionStateBase() : base() private Double ComputeKernelPValue(Double rawScore) { int i; - int n = _rawScoreBuffer.Count; + int n = RawScoreBuffer.Count; if (n == 0) return 0.5; @@ -341,21 +382,21 @@ private Double ComputeKernelPValue(Double rawScore) Double diff; for (i = 0; i < n; ++i) { - diff = rawScore - _rawScoreBuffer[i]; + diff = rawScore - RawScoreBuffer[i]; pValue -= ProbabilityFunctions.Erf(diff / bandWidth); _sumSquaredDist += diff * diff; } pValue = 0.5 + pValue / (2 * n); - if (_rawScoreBuffer.IsFull) + if (RawScoreBuffer.IsFull) { for (i = 1; i < n; ++i) { - diff = _rawScoreBuffer[0] - _rawScoreBuffer[i]; + diff = RawScoreBuffer[0] - RawScoreBuffer[i]; _sumSquaredDist -= diff * diff; } - diff = _rawScoreBuffer[0] - rawScore; + diff = RawScoreBuffer[0] - rawScore; _sumSquaredDist -= diff * diff; } @@ -435,7 +476,7 @@ private protected override sealed void TransformCore(ref TInput input, FixedSize else if (result.Values[2] > MaxPValue) result.Values[2] = MaxPValue; - _rawScoreBuffer.AddLast(rawScore); + RawScoreBuffer.AddLast(rawScore); // Step 3: Computing the martingale value if (Parent.Martingale != MartingaleType.None && Parent.ThresholdScore == AlertingScore.MartingaleScore) @@ -452,17 +493,17 @@ private protected override sealed void TransformCore(ref TInput input, FixedSize break; } - if (_logMartingaleUpdateBuffer.Count == 0) + if (LogMartingaleUpdateBuffer.Count == 0) { - for (int i = 0; i < _logMartingaleUpdateBuffer.Capacity; ++i) - _logMartingaleUpdateBuffer.AddLast(martingaleUpdate); - _logMartingaleValue += _logMartingaleUpdateBuffer.Capacity * martingaleUpdate; + for (int i = 0; i < LogMartingaleUpdateBuffer.Capacity; ++i) + LogMartingaleUpdateBuffer.AddLast(martingaleUpdate); + _logMartingaleValue += LogMartingaleUpdateBuffer.Capacity * martingaleUpdate; } else { _logMartingaleValue += martingaleUpdate; - _logMartingaleValue -= _logMartingaleUpdateBuffer.PeekFirst(); - _logMartingaleUpdateBuffer.AddLast(martingaleUpdate); + _logMartingaleValue -= LogMartingaleUpdateBuffer.PeekFirst(); + LogMartingaleUpdateBuffer.AddLast(martingaleUpdate); } result.Values[3] = Math.Exp(_logMartingaleValue); @@ -473,7 +514,7 @@ private protected override sealed void TransformCore(ref TInput input, FixedSize // Generating alert bool alert = false; - if (_rawScoreBuffer.IsFull) // No alert until the buffer is completely full. + if (RawScoreBuffer.IsFull) // No alert until the buffer is completely full. { switch (Parent.ThresholdScore) { @@ -506,17 +547,23 @@ private protected override sealed void TransformCore(ref TInput input, FixedSize dst = result.Commit(); } - private protected override sealed void InitializeStateCore() + private protected override sealed void InitializeStateCore(bool disk = false) { Parent = (SequentialAnomalyDetectionTransformBase)ParentTransform; Host.Assert(WindowSize >= 0); - if (Parent.Martingale != MartingaleType.None) - _logMartingaleUpdateBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize); + if (disk == false) + { + if (Parent.Martingale != MartingaleType.None) + LogMartingaleUpdateBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize); + else + LogMartingaleUpdateBuffer = new FixedSizeQueue(1); + + RawScoreBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize); - _rawScoreBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize); + _logMartingaleValue = 0; + } - _logMartingaleValue = 0; InitializeAnomalyDetector(); } @@ -536,15 +583,16 @@ private protected override sealed void InitializeStateCore() private protected abstract Double ComputeRawAnomalyScore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration); } - protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(Host, this, schema); + protected override IStatefulRowMapper MakeRowMapper(ISchema schema) => new Mapper(Host, this, schema); - private sealed class Mapper : IRowMapper + private sealed class Mapper : IStatefulRowMapper { private readonly IHost _host; private readonly SequentialAnomalyDetectionTransformBase _parent; private readonly ISchema _parentSchema; private readonly int _inputColumnIndex; private readonly VBuffer> _slotNames; + private TState State { get; set; } public Mapper(IHostEnvironment env, SequentialAnomalyDetectionTransformBase parent, ISchema inputSchema) { @@ -564,6 +612,8 @@ public Mapper(IHostEnvironment env, SequentialAnomalyDetectionTransformBase>(4, new[] { "Alert".AsMemory(), "Raw Score".AsMemory(), "P-Value Score".AsMemory(), "Martingale Score".AsMemory() }); + + State = _parent.StateRef; } public Schema.DetachedColumn[] GetOutputColumns() @@ -592,11 +642,8 @@ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Ac disposer = null; var getters = new Delegate[1]; if (activeOutput(0)) - { - TState state = new TState(); - state.InitState(_parent.WindowSize, _parent.InitialWindowSize, _parent, _host); - getters[0] = MakeGetter(input, state); - } + getters[0] = MakeGetter(input, State); + return getters; } @@ -608,15 +655,46 @@ private Delegate MakeGetter(IRow input, TState state) var srcGetter = input.GetGetter(_inputColumnIndex); ProcessData processData = _parent.WindowSize > 0 ? (ProcessData)state.Process : state.ProcessWithoutBuffer; + ValueGetter> valueGetter = (ref VBuffer dst) => - { + { TInput src = default; srcGetter(ref src); processData(ref src, ref dst); - }; - + }; return valueGetter; } + + public Action CreatePinger(IRow input, Func activeOutput, out Action disposer) + { + disposer = null; + Action pinger = null; + if (activeOutput(0)) + pinger = MakePinger(input, State); + + return pinger; + } + + private Action MakePinger(IRow input, TState state) + { + _host.AssertValue(input); + var srcGetter = input.GetGetter(_inputColumnIndex); + Action pinger = (long rowPosition) => + { + TInput src = default; + srcGetter(ref src); + state.UpdateState(ref src, rowPosition, _parent.WindowSize > 0); + }; + return pinger; + } + + public void CloneState() + { + if (Interlocked.Increment(ref _parent.StateRefCount) > 1) + { + State = (TState)_parent.StateRef.Clone(); + } + } } } } diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs index 1c78af4d27..56e27749eb 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs @@ -8,7 +8,10 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Core.Data; +using Microsoft.ML.TimeSeries; +using Microsoft.ML.Runtime.Model.Onnx; +using Microsoft.ML.Runtime.Model.Pfa; +using System.Linq; using Microsoft.ML.Data; namespace Microsoft.ML.Runtime.TimeSeriesProcessing @@ -20,13 +23,13 @@ namespace Microsoft.ML.Runtime.TimeSeriesProcessing /// The input type of the sequential processing. /// The dst type of the sequential processing. /// The state type of the sequential processing. Must be a class inherited from StateBase - public abstract class SequentialTransformerBase : ITransformer, ICanSaveModel + public abstract class SequentialTransformerBase : IStatefulTransformer, ICanSaveModel where TState : SequentialTransformerBase.StateBase, new() { /// /// The base class for encapsulating the State object for sequential processing. This class implements a windowed buffer. /// - public abstract class StateBase + public abstract class StateBase : ICanSaveModel, ICloneable { // Ideally this class should be private. However, due to the current constraints with the LambdaTransform, we need to have // access to the state class when inheriting from SequentialTransformerBase. @@ -40,12 +43,12 @@ public abstract class StateBase /// /// The internal windowed buffer for buffering the values in the input sequence. /// - private protected FixedSizeQueue WindowedBuffer; + private protected FixedSizeQueue WindowedBuffer { get; set; } /// /// The buffer used to buffer the training data points. /// - private protected FixedSizeQueue InitialWindowedBuffer; + private protected FixedSizeQueue InitialWindowedBuffer { get; set; } private protected int WindowSize { get; private set; } @@ -66,7 +69,19 @@ protected long IncrementRowCounter() return RowCounter; } - private bool _isIniatilized; + protected long PreviousPosition; + + public StateBase(ModelLoadContext ctx) + { + WindowSize = ctx.Reader.ReadInt32(); + InitialWindowSize = ctx.Reader.ReadInt32(); + } + + public virtual void Save(ModelSaveContext ctx) + { + ctx.Writer.Write(WindowSize); + ctx.Writer.Write(InitialWindowSize); + } /// /// This method sets the window size and initializes the buffer only once. @@ -76,10 +91,10 @@ protected long IncrementRowCounter() /// The size of the windowed initial buffer used for training /// The parent transform of this state object /// The host - public void InitState(int windowSize, int initialWindowSize, SequentialTransformerBase parentTransform, IHost host) + public void InitState(int windowSize, int initialWindowSize, SequentialTransformerBase parentTransform, + IHost host) { Contracts.CheckValue(host, nameof(host), "The host cannot be null."); - host.Check(!_isIniatilized, "The window size can be set only once."); host.CheckValue(parentTransform, nameof(parentTransform)); host.CheckParam(windowSize >= 0, nameof(windowSize), "Must be non-negative."); host.CheckParam(initialWindowSize >= 0, nameof(initialWindowSize), "Must be non-negative."); @@ -93,7 +108,19 @@ public void InitState(int windowSize, int initialWindowSize, SequentialTransform RowCounter = 0; InitializeStateCore(); - _isIniatilized = true; + PreviousPosition = -1; + } + + public void InitState(SequentialTransformerBase parentTransform, IHost host) + { + Contracts.CheckValue(host, nameof(host), "The host cannot be null."); + host.CheckValue(parentTransform, nameof(parentTransform)); + + Host = host; + ParentTransform = parentTransform; + RowCounter = 0; + InitializeStateCore(true); + PreviousPosition = -1; } /// @@ -101,51 +128,74 @@ public void InitState(int windowSize, int initialWindowSize, SequentialTransform /// public virtual void Reset() { - Host.Assert(_isIniatilized); Host.Assert(WindowedBuffer != null); Host.Assert(InitialWindowedBuffer != null); RowCounter = 0; WindowedBuffer.Clear(); InitialWindowedBuffer.Clear(); + PreviousPosition = -1; } - public void Process(ref TInput input, ref TOutput output) + public void UpdateState(ref TInput input, long rowPosition, bool buffer = true) + { + if (rowPosition > PreviousPosition) + { + PreviousPosition = rowPosition; + UpdateStateCore(ref input, buffer); + Consume(input); + } + } + + public void UpdateStateCore(ref TInput input, bool buffer = true) { if (InitialWindowedBuffer.Count < InitialWindowSize) { InitialWindowedBuffer.AddLast(input); - SetNaOutput(ref output); - - if (InitialWindowedBuffer.Count >= InitialWindowSize - WindowSize) + if (InitialWindowedBuffer.Count >= InitialWindowSize - WindowSize && buffer) + WindowedBuffer.AddLast(input); + } + else + { + if (buffer) WindowedBuffer.AddLast(input); + IncrementRowCounter(); + } + } + + public void Process(ref TInput input, ref TOutput output) + { + if (PreviousPosition == -1) + UpdateStateCore(ref input); + + if (InitialWindowedBuffer.Count < InitialWindowSize) + { + SetNaOutput(ref output); + if (InitialWindowedBuffer.Count == InitialWindowSize) LearnStateFromDataCore(InitialWindowedBuffer); } else { TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output); - WindowedBuffer.AddLast(input); - IncrementRowCounter(); } } public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) { + if (PreviousPosition == -1) + UpdateStateCore(ref input, false); + if (InitialWindowedBuffer.Count < InitialWindowSize) { - InitialWindowedBuffer.AddLast(input); SetNaOutput(ref output); if (InitialWindowedBuffer.Count == InitialWindowSize) LearnStateFromDataCore(InitialWindowedBuffer); } else - { TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output); - IncrementRowCounter(); - } } /// @@ -166,13 +216,28 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) /// /// The abstract method that realizes the logic for initializing the state object. /// - private protected abstract void InitializeStateCore(); + private protected abstract void InitializeStateCore(bool disk = false); /// /// The abstract method that realizes the logic for learning the parameters and the initial state object from data. /// /// A queue of data points used for training private protected abstract void LearnStateFromDataCore(FixedSizeQueue data); + + public abstract void Consume(TInput value); + + public object Clone() + { + var clone = (StateBase)MemberwiseClone(); + CloneCore(clone); + return clone; + } + + public virtual void CloneCore(StateBase state) + { + state.WindowedBuffer = (FixedSizeQueue)WindowedBuffer.Clone(); + state.InitialWindowedBuffer = (FixedSizeQueue)InitialWindowedBuffer.Clone(); + } } private protected readonly IHost Host; @@ -193,6 +258,10 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) public bool IsRowToRowMapper => false; + public TState StateRef{ get; set; } + + public int StateRefCount; + /// /// The main constructor for the sequential transform /// @@ -273,7 +342,7 @@ public virtual void Save(ModelSaveContext ctx) public abstract Schema GetOutputSchema(Schema inputSchema); - protected abstract IRowMapper MakeRowMapper(ISchema schema); + protected abstract IStatefulRowMapper MakeRowMapper(ISchema schema); protected SequentialDataTransform MakeDataTransform(IDataView input) { @@ -288,16 +357,28 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) throw new InvalidOperationException("Not a RowToRowMapper."); } - public sealed class SequentialDataTransform : TransformBase, ITransformTemplate + public IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema) { - private readonly IRowMapper _mapper; + Host.CheckValue(inputSchema, nameof(inputSchema)); + return new TimeSeriesRowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema)); + } + + public virtual IStatefulTransformer Clone() => (SequentialTransformerBase)MemberwiseClone(); + + public sealed class SequentialDataTransform : TransformBase, ITransformTemplate, IRowToRowMapper + { + private readonly IStatefulRowMapper _mapper; private readonly SequentialTransformerBase _parent; private readonly IDataTransform _transform; private readonly ColumnBindings _bindings; - public SequentialDataTransform(IHost host, SequentialTransformerBase parent, IDataView input, IRowMapper mapper) - :base(parent.Host, input) + private MetadataDispatcher Metadata { get; } + + public SequentialDataTransform(IHost host, SequentialTransformerBase parent, + IDataView input, IStatefulRowMapper mapper) + : base(parent.Host, input) { + Metadata = new MetadataDispatcher(1); _parent = parent; _transform = CreateLambdaTransform(_parent.Host, input, _parent.InputColumnName, _parent.OutputColumnName, InitFunction, _parent.WindowSize > 0, _parent.OutputColumnType); @@ -305,6 +386,8 @@ public SequentialDataTransform(IHost host, SequentialTransformerBase _mapper.CloneState(); + private static IDataTransform CreateLambdaTransform(IHost host, IDataView input, string inputColumnName, string outputColumnName, Action initFunction, bool hasBuffer, ColumnType outputColTypeOverride) { @@ -346,7 +429,9 @@ private void InitFunction(TState state) protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) { var srcCursor = _transform.GetRowCursor(predicate, rand); - return new Cursor(Host, this, srcCursor); + var clone = (SequentialDataTransform)MemberwiseClone(); + clone.CloneStateInMapper(); + return new Cursor(Host, clone, srcCursor); } protected override bool? ShouldUseParallelCursors(Func predicate) @@ -377,6 +462,71 @@ public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource) { return new SequentialDataTransform(Contracts.CheckRef(env, nameof(env)).Register("SequentialDataTransform"), _parent, newSource, _mapper); } + + public Schema InputSchema => Source.Schema; + public Func GetDependencies(Func predicate) + { + for (int i = 0; i < Schema.ColumnCount; i++) + { + if (predicate(i)) + return col => true; + } + return col => false; + } + + public IRow GetRow(IRow input, Func active, out Action disposer) => + new Row(_bindings.Schema, input, _mapper.CreateGetters(input, active, out disposer), + _mapper.CreatePinger(input, active, out disposer)); + + } + + public sealed class Row : IStatefulRow + { + private readonly Schema _schema; + private readonly IRow _input; + private readonly Delegate[] _getters; + private readonly Action _pinger; + + public Schema Schema { get { return _schema; } } + + public long Position { get { return _input.Position; } } + + public long Batch { get { return _input.Batch; } } + + public Row(Schema schema, IRow input, Delegate[] getters, Action pinger) + { + Contracts.CheckValue(schema, nameof(schema)); + Contracts.CheckValue(input, nameof(input)); + Contracts.Check(Utils.Size(getters) == schema.ColumnCount); + _schema = schema; + _input = input; + _getters = getters ?? new Delegate[0]; + _pinger = pinger; + } + + public ValueGetter GetIdGetter() + { + return _input.GetIdGetter(); + } + + public ValueGetter GetGetter(int col) + { + Contracts.CheckParam(0 <= col && col < _getters.Length, nameof(col), "Invalid col value in GetGetter"); + Contracts.Check(IsColumnActive(col)); + var fn = _getters[col] as ValueGetter; + if (fn == null) + throw Contracts.Except("Unexpected TValue in GetGetter"); + return fn; + } + + public Action GetPinger() => + _pinger as Action ?? throw Contracts.Except("Invalid TValue in GetPinger: '{0}'", typeof(long)); + + public bool IsColumnActive(int col) + { + Contracts.Check(0 <= col && col < _getters.Length); + return _getters[col] != null; + } } /// @@ -408,4 +558,310 @@ public ValueGetter GetGetter(int col) } } } + + /// + /// This class is a transform that can add any number of output columns, that depend on any number of input columns. + /// It does so with the help of an , that is given a schema in its constructor, and has methods + /// to get the dependencies on input columns and the getters for the output columns, given an active set of output columns. + /// + + public sealed class TimeSeriesRowToRowMapperTransform : RowToRowTransformBase, IStatefulRowToRowMapper, + ITransformCanSaveOnnx, ITransformCanSavePfa + { + private readonly IStatefulRowMapper _mapper; + private readonly ColumnBindings _bindings; + public const string RegistrationName = "TimeSeriesRowToRowMapperTransform"; + public const string LoaderSignature = "TimeSeriesRowToRowMapper"; + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "TS ROW MPPR", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TimeSeriesRowToRowMapperTransform).Assembly.FullName); + } + + public override Schema Schema => _bindings.Schema; + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false; + + bool ICanSavePfa.CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false; + + public TimeSeriesRowToRowMapperTransform(IHostEnvironment env, IDataView input, IStatefulRowMapper mapper) + : base(env, RegistrationName, input) + { + Contracts.CheckValue(mapper, nameof(mapper)); + _mapper = mapper; + _bindings = new ColumnBindings(Schema.Create(input.Schema), mapper.GetOutputColumns()); + } + + public static Schema GetOutputSchema(ISchema inputSchema, IRowMapper mapper) + { + Contracts.CheckValue(inputSchema, nameof(inputSchema)); + Contracts.CheckValue(mapper, nameof(mapper)); + return new ColumnBindings(Schema.Create(inputSchema), mapper.GetOutputColumns()).Schema; + } + + private TimeSeriesRowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input) + : base(host, input) + { + // *** Binary format *** + // _mapper + + ctx.LoadModel(host, out _mapper, "Mapper", input.Schema); + _bindings = new ColumnBindings(Schema.Create(input.Schema), _mapper.GetOutputColumns()); + } + + public static TimeSeriesRowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + var h = env.Register(RegistrationName); + h.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + h.CheckValue(input, nameof(input)); + return h.Apply("Loading Model", ch => new TimeSeriesRowToRowMapperTransform(h, ctx, input)); + } + + public override void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // _mapper + + ctx.SaveModel(_mapper, "Mapper"); + } + + /// + /// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount), + /// a predicate for the needed active input columns, and a predicate for the needed active + /// output columns. + /// + private bool[] GetActive(Func predicate, out Func predicateInput) + { + int n = _bindings.Schema.ColumnCount; + var active = Utils.BuildArray(n, predicate); + Contracts.Assert(active.Length == n); + + var activeInput = _bindings.GetActiveInput(predicate); + Contracts.Assert(activeInput.Length == _bindings.InputSchema.ColumnCount); + + // Get a predicate that determines which outputs are active. + var predicateOut = GetActiveOutputColumns(active); + + // Now map those to active input columns. + var predicateIn = _mapper.GetDependencies(predicateOut); + + // Combine the two sets of input columns. + predicateInput = + col => 0 <= col && col < activeInput.Length && (activeInput[col] || predicateIn(col)); + + return active; + } + + private Func GetActiveOutputColumns(bool[] active) + { + Contracts.AssertValue(active); + Contracts.Assert(active.Length == _bindings.Schema.ColumnCount); + + return + col => + { + Contracts.Assert(0 <= col && col < _bindings.AddedColumnIndices.Count); + return 0 <= col && col < _bindings.AddedColumnIndices.Count && active[_bindings.AddedColumnIndices[col]]; + }; + } + + protected override bool? ShouldUseParallelCursors(Func predicate) + { + Host.AssertValue(predicate, "predicate"); + if (_bindings.AddedColumnIndices.Any(predicate)) + return true; + return null; + } + + protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null) + { + Func predicateInput; + var active = GetActive(predicate, out predicateInput); + return new RowCursor(Host, Source.GetRowCursor(predicateInput, rand), this, active); + } + + public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null) + { + Host.CheckValue(predicate, nameof(predicate)); + Host.CheckValueOrNull(rand); + + Func predicateInput; + var active = GetActive(predicate, out predicateInput); + + var inputs = Source.GetRowCursorSet(out consolidator, predicateInput, n, rand); + Host.AssertNonEmpty(inputs); + + if (inputs.Length == 1 && n > 1 && _bindings.AddedColumnIndices.Any(predicate)) + inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n); + Host.AssertNonEmpty(inputs); + + var cursors = new IRowCursor[inputs.Length]; + for (int i = 0; i < inputs.Length; i++) + cursors[i] = new RowCursor(Host, inputs[i], this, active); + return cursors; + } + + void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + if (_mapper is ISaveAsOnnx onnx) + { + Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX."); + onnx.SaveAsOnnx(ctx); + } + } + + void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + if (_mapper is ISaveAsPfa pfa) + { + Host.Check(pfa.CanSavePfa, "Cannot be saved as PFA."); + pfa.SaveAsPfa(ctx); + } + } + + public Func GetDependencies(Func predicate) + { + Func predicateInput; + GetActive(predicate, out predicateInput); + return predicateInput; + } + + Schema IRowToRowMapper.InputSchema => Source.Schema; + + public IRow GetRow(IRow input, Func active, out Action disposer) + { + Host.CheckValue(input, nameof(input)); + Host.CheckValue(active, nameof(active)); + Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to"); + + disposer = null; + using (var ch = Host.Start("GetEntireRow")) + { + Action disp; + var activeArr = new bool[Schema.ColumnCount]; + for (int i = 0; i < Schema.ColumnCount; i++) + activeArr[i] = active(i); + var pred = GetActiveOutputColumns(activeArr); + var getters = _mapper.CreateGetters(input, pred, out disp); + disposer += disp; + return new StatefulRow(input, this, Schema, getters, + _mapper.CreatePinger(input, pred, out disp)); + } + } + + private sealed class StatefulRow : IStatefulRow + { + private readonly IRow _input; + private readonly Delegate[] _getters; + private readonly Action _pinger; + + private readonly TimeSeriesRowToRowMapperTransform _parent; + + public long Batch { get { return _input.Batch; } } + + public long Position { get { return _input.Position; } } + + public Schema Schema { get; } + + public StatefulRow(IRow input, TimeSeriesRowToRowMapperTransform parent, + Schema schema, Delegate[] getters, Action pinger) + { + _input = input; + _parent = parent; + Schema = schema; + _getters = getters; + _pinger = pinger; + } + + public ValueGetter GetGetter(int col) + { + bool isSrc; + int index = _parent._bindings.MapColumnIndex(out isSrc, col); + if (isSrc) + return _input.GetGetter(index); + + Contracts.Assert(_getters[index] != null); + var fn = _getters[index] as ValueGetter; + if (fn == null) + throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); + return fn; + } + + public Action GetPinger() => + _pinger as Action ?? throw Contracts.Except("Invalid TValue in GetPinger: '{0}'", typeof(long)); + + public ValueGetter GetIdGetter() => _input.GetIdGetter(); + + public bool IsColumnActive(int col) + { + bool isSrc; + int index = _parent._bindings.MapColumnIndex(out isSrc, col); + if (isSrc) + return _input.IsColumnActive((index)); + return _getters[index] != null; + } + } + + private sealed class RowCursor : SynchronizedCursorBase, IRowCursor + { + private readonly Delegate[] _getters; + private readonly bool[] _active; + private readonly ColumnBindings _bindings; + private readonly Action _disposer; + + public Schema Schema => _bindings.Schema; + + public RowCursor(IChannelProvider provider, IRowCursor input, TimeSeriesRowToRowMapperTransform parent, bool[] active) + : base(provider, input) + { + var pred = parent.GetActiveOutputColumns(active); + _getters = parent._mapper.CreateGetters(input, pred, out _disposer); + _active = active; + _bindings = parent._bindings; + } + + public bool IsColumnActive(int col) + { + Ch.Check(0 <= col && col < _bindings.Schema.ColumnCount); + return _active[col]; + } + + public ValueGetter GetGetter(int col) + { + Ch.Check(IsColumnActive(col)); + + bool isSrc; + int index = _bindings.MapColumnIndex(out isSrc, col); + if (isSrc) + return Input.GetGetter(index); + + Ch.AssertValue(_getters); + var getter = _getters[index]; + Ch.Assert(getter != null); + var fn = getter as ValueGetter; + if (fn == null) + throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); + return fn; + } + + public override void Dispose() + { + _disposer?.Invoke(); + base.Dispose(); + } + } + } + } diff --git a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs index bc8401ef83..e0f96ed74b 100644 --- a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs +++ b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs @@ -105,7 +105,7 @@ public abstract class SsaArguments : ArgumentsBase protected readonly bool IsAdaptive; protected readonly ErrorFunctionUtils.ErrorFunction ErrorFunction; protected readonly Func ErrorFunc; - protected readonly SequenceModelerBase Model; + protected SequenceModelerBase Model; public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment env) : base(args.WindowSize, 0, args.Source, args.Name, name, env, args.Side, args.Martingale, args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold) @@ -122,6 +122,9 @@ public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment // Creating the master SSA model Model = new AdaptiveSingularSpectrumSequenceModeler(Host, args.InitialWindowSize, SeasonalWindowSize + 1, SeasonalWindowSize, DiscountFactor, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalWindowSize / 2, false, false); + + StateRef = new State(); + StateRef.InitState(WindowSize, InitialWindowSize, this, Host); } public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name) @@ -150,9 +153,11 @@ public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, strin ErrorFunc = ErrorFunctionUtils.GetErrorFunction(ErrorFunction); IsAdaptive = ctx.Reader.ReadBoolean(); + StateRef = new State(ctx); ctx.LoadModel, SignatureLoadModel>(env, out Model, "SSA"); Host.CheckDecode(Model != null); + StateRef.InitState(this, Host); } public override Schema GetOutputSchema(Schema inputSchema) @@ -186,6 +191,7 @@ public override void Save(ModelSaveContext ctx) // float: _discountFactor // byte: _errorFunction // bool: _isAdaptive + // State: StateRef // AdaptiveSingularSpectrumSequenceModeler: _model base.Save(ctx); @@ -193,6 +199,8 @@ public override void Save(ModelSaveContext ctx) ctx.Writer.Write(DiscountFactor); ctx.Writer.Write((byte)ErrorFunction); ctx.Writer.Write(IsAdaptive); + StateRef.Save(ctx); + ctx.SaveModel(Model, "SSA"); } @@ -201,6 +209,47 @@ public sealed class State : AnomalyDetectionStateBase private SequenceModelerBase _model; private SsaAnomalyDetectionBase _parentAnomalyDetector; + public State() + { + + } + + public State(ModelLoadContext ctx) : base(ctx) + { + WindowedBuffer = new FixedSizeQueue( + ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadSingleArray()); + + InitialWindowedBuffer = new FixedSizeQueue( + ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadSingleArray()); + } + + public override void Save(ModelSaveContext ctx) + { + base.Save(ctx); + + ctx.Writer.Write(WindowedBuffer.Capacity); + ctx.Writer.Write(WindowedBuffer.StartIndex); + ctx.Writer.WriteSingleArray(WindowedBuffer.Buffer); + + ctx.Writer.Write(InitialWindowedBuffer.Capacity); + ctx.Writer.Write(InitialWindowedBuffer.StartIndex); + ctx.Writer.WriteSingleArray(InitialWindowedBuffer.Buffer); + } + + public override void CloneCore(StateBase state) + { + base.CloneCore(state); + Contracts.Assert(state is State); + var stateLocal = state as State; + stateLocal.WindowedBuffer = (FixedSizeQueue)WindowedBuffer.Clone(); + stateLocal.InitialWindowedBuffer = (FixedSizeQueue)InitialWindowedBuffer.Clone(); + if (_model != null) + { + _parentAnomalyDetector.Model = _parentAnomalyDetector.Model.Clone(); + _model = _parentAnomalyDetector.Model; + } + } + private protected override void LearnStateFromDataCore(FixedSizeQueue data) { // This method is empty because there is no need to implement a training logic here. @@ -209,7 +258,7 @@ private protected override void LearnStateFromDataCore(FixedSizeQueue da private protected override void InitializeAnomalyDetector() { _parentAnomalyDetector = (SsaAnomalyDetectionBase)Parent; - _model = _parentAnomalyDetector.Model.Clone(); + _model = _parentAnomalyDetector.Model; } private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration) @@ -218,12 +267,15 @@ private protected override double ComputeRawAnomalyScore(ref Single input, Fixed Single expectedValue = 0; _model.PredictNext(ref expectedValue); - // Feed the current point to the model - _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive); + if (PreviousPosition == -1) + // Feed the current point to the model + _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive); // Return the error as the raw anomaly score return _parentAnomalyDetector.ErrorFunc(input, expectedValue); } + + public override void Consume(Single input) => _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive); } } } diff --git a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs index 32f52ca786..a359aeccfc 100644 --- a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TimeSeriesProcessing; +using Microsoft.ML.TimeSeries; using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; [assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(IDataTransform), typeof(SsaChangePointDetector), typeof(SsaChangePointDetector.Arguments), typeof(SignatureDataTransform), @@ -120,6 +121,15 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData return new SsaChangePointDetector(env, args, input).MakeDataTransform(input); } + public override IStatefulTransformer Clone() + { + var clone = (SsaChangePointDetector)MemberwiseClone(); + clone.Model = clone.Model.Clone(); + clone.StateRef = (State)clone.StateRef.Clone(); + clone.StateRef.InitState(clone, Host); + return clone; + } + internal SsaChangePointDetector(IHostEnvironment env, Arguments args) : base(new BaseArguments(args), LoaderSignature, env) { diff --git a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs index 1fbc36b49d..6bccd51d7a 100644 --- a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs @@ -11,6 +11,7 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.Model; using Microsoft.ML.Runtime.TimeSeriesProcessing; +using Microsoft.ML.TimeSeries; using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase; [assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(IDataTransform), typeof(SsaSpikeDetector), typeof(SsaSpikeDetector.Arguments), typeof(SignatureDataTransform), @@ -133,6 +134,15 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, return new SsaSpikeDetector(env, ctx).MakeDataTransform(input); } + public override IStatefulTransformer Clone() + { + var clone = (SsaSpikeDetector)MemberwiseClone(); + clone.Model = clone.Model.Clone(); + clone.StateRef = (State)clone.StateRef.Clone(); + clone.StateRef.InitState(clone, Host); + return clone; + } + // Factory method for SignatureLoadModel. private static SsaSpikeDetector Create(IHostEnvironment env, ModelLoadContext ctx) { diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index 834e3f669f..d81738b476 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -1,11 +1,16 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; +using System.IO; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.TimeSeriesProcessing; +using Microsoft.ML.TimeSeries; using Xunit; namespace Microsoft.ML.Tests @@ -16,16 +21,28 @@ public sealed class TimeSeries private sealed class Prediction { #pragma warning disable CS0649 - [VectorType(4)] + [VectorType(4)] public double[] Change; -#pragma warning restore CS0649 +#pragma warning restore CS0649 + } + + public class Prediction1 + { + public float Random; } private sealed class Data { + public string Text; + public float Random; public float Value; - public Data(float value) => Value = value; + public Data(float value) + { + Text = "random123value"; + Random = -1; + Value = value; + } } [Fact] @@ -118,5 +135,111 @@ public void ChangePointDetectionWithSeasonality() Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score } } + + [Fact] + public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() + { + const int ChangeHistorySize = 10; + const int SeasonalitySize = 10; + const int NumberOfSeasonsInTraining = 5; + const int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize; + + List data = new List(); + + var ml = new MLContext(seed: 1, conc: 1); + var dataView = ml.CreateStreamingDataView(data); + + for (int j = 0; j < NumberOfSeasonsInTraining; j++) + for (int i = 0; i < SeasonalitySize; i++) + data.Add(new Data(i)); + + for (int i = 0; i < ChangeHistorySize; i++) + data.Add(new Data(i * 100)); + + + // Pipeline. + var pipeline = ml.Transforms.Text.FeaturizeText("Text", "Text_Featurized") + .Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Arguments() + { + Confidence = 95, + Source = "Value", + Name = "Change", + ChangeHistoryLength = ChangeHistorySize, + TrainingWindowSize = MaxTrainingSize, + SeasonalWindowSize = SeasonalitySize + })); + + // Train. + var model = pipeline.Fit(dataView); + + //Predict. + var engine = model.CreateTimeSeriesPredictionFunction(ml); + //Even though time series column is not requested it will pass the observation through time series transform. + var prediction = engine.Predict(new Data(1)); + Assert.Equal(-1, prediction.Random); + prediction = engine.Predict(new Data(2)); + Assert.Equal(-1, prediction.Random); + } + + [Fact] + public void ChangePointDetectionWithSeasonalityPredictionEngine() + { + const int ChangeHistorySize = 10; + const int SeasonalitySize = 10; + const int NumberOfSeasonsInTraining = 5; + const int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize; + + List data = new List(); + + var ml = new MLContext(seed: 1, conc: 1); + var dataView = ml.CreateStreamingDataView(data); + + for (int j = 0; j < NumberOfSeasonsInTraining; j++) + for (int i = 0; i < SeasonalitySize; i++) + data.Add(new Data(i)); + + for (int i = 0; i < ChangeHistorySize; i++) + data.Add(new Data(i * 100)); + + + // Pipeline. + var pipeline = ml.Transforms.Text.FeaturizeText("Text", "Text_Featurized") + .Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Arguments() + { + Confidence = 95, + Source = "Value", + Name = "Change", + ChangeHistoryLength = ChangeHistorySize, + TrainingWindowSize = MaxTrainingSize, + SeasonalWindowSize = SeasonalitySize + })); + + // Train. + var model = pipeline.Fit(dataView); + //Predict. + var engine = model.CreateTimeSeriesPredictionFunction(ml); + var prediction = engine.Predict(new Data(1)); + Assert.Equal(0, prediction.Change[0], precision: 7); // Alert + Assert.Equal(1.1661833524703979, prediction.Change[1], precision: 7); // Raw score + Assert.Equal(0.5, prediction.Change[2], precision: 7); // P-Value score + Assert.Equal(5.1200000000000114E-08, prediction.Change[3], precision: 7); // Martingale score + + //Checkpoint. + var modelPath = "temp.zip"; + engine.CheckPoint(ml, modelPath); + + // Load model. + ITransformer model2 = null; + using (var file = File.OpenRead(modelPath)) + model2 = TransformerChain.LoadFrom(ml, file); + + //Predict and expect different result for the same input. + engine = model2.CreateTimeSeriesPredictionFunction(ml); + prediction = engine.Predict(new Data(1)); + Assert.Equal(0, prediction.Change[0], precision: 7); // Alert + Assert.Equal(-0.12883400917053223, prediction.Change[1], precision: 7); // Raw score + Assert.Equal(0.5, prediction.Change[2], precision: 7); // P-Value score + Assert.Equal(2.6214400000000113E-15, prediction.Change[3], precision: 7); // Martingale score + } } }