diff --git a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
index c4dc167979..253c86aa5b 100644
--- a/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
+++ b/src/Microsoft.ML.Api/DataViewConstructionUtils.cs
@@ -17,6 +17,7 @@ namespace Microsoft.ML.Runtime.Api
///
/// A helper class to create data views based on the user-provided types.
///
+ [BestFriend]
internal static class DataViewConstructionUtils
{
public static IDataView CreateFromList(IHostEnvironment env, IList data,
diff --git a/src/Microsoft.ML.Api/PredictionEngine.cs b/src/Microsoft.ML.Api/PredictionEngine.cs
index 9c4522722d..c8d6826610 100644
--- a/src/Microsoft.ML.Api/PredictionEngine.cs
+++ b/src/Microsoft.ML.Api/PredictionEngine.cs
@@ -122,6 +122,33 @@ public void Reset()
}
}
+ public sealed class PredictionEngine : PredictionEngineBase
+ where TSrc : class
+ where TDst : class, new()
+ {
+ internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
+ SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
+ : base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
+ {
+ }
+
+ ///
+ /// Run prediction pipeline on one example.
+ ///
+ /// The example to run on.
+ /// The object to store the prediction in. If it's null, a new one will be created, otherwise the old one
+ /// is reused.
+ public override void Predict(TSrc example, ref TDst prediction)
+ {
+ Contracts.CheckValue(example, nameof(example));
+ ExtractValues(example);
+ if (prediction == null)
+ prediction = new TDst();
+
+ FillValues(prediction);
+ }
+ }
+
///
/// A class that runs the previously trained model (and the preceding transform pipeline) on the
/// in-memory data, one example at a time.
@@ -130,14 +157,17 @@ public void Reset()
///
/// The user-defined type that holds the example.
/// The user-defined type that holds the prediction.
- public sealed class PredictionEngine
+ public abstract class PredictionEngineBase
where TSrc : class
where TDst : class, new()
{
private readonly DataViewConstructionUtils.InputRow _inputRow;
private readonly IRowReadableAs _outputRow;
private readonly Action _disposer;
+ [BestFriend]
+ private protected ITransformer Transformer { get; }
+ [BestFriend]
private static Func StreamChecker(IHostEnvironment env, Stream modelStream)
{
env.CheckValue(modelStream, nameof(modelStream));
@@ -150,29 +180,37 @@ private static Func StreamChecker(IHostEnvironment env,
};
}
- internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
+ [BestFriend]
+ private protected PredictionEngineBase(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
{
Contracts.CheckValue(env, nameof(env));
env.AssertValue(transformer);
+ Transformer = transformer;
var makeMapper = TransformerChecker(env, transformer);
env.AssertValue(makeMapper);
-
_inputRow = DataViewConstructionUtils.CreateInputRow(env, inputSchemaDefinition);
- var mapper = makeMapper(_inputRow.Schema);
+ PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, out _disposer, out _outputRow);
+ }
+
+ internal virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns,
+ SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow)
+ {
var cursorable = TypedCursorable.Create(env, new EmptyDataView(env, mapper.Schema), ignoreMissingColumns, outputSchemaDefinition);
- var outputRow = mapper.GetRow(_inputRow, col => true, out _disposer);
- _outputRow = cursorable.GetRow(outputRow);
+ var outputRowLocal = mapper.GetRow(_inputRow, col => true, out disposer);
+ outputRow = cursorable.GetRow(outputRowLocal);
}
- private static Func TransformerChecker(IExceptionContext ectx, ITransformer transformer)
+ protected virtual Func TransformerChecker(IExceptionContext ectx, ITransformer transformer)
{
ectx.CheckValue(transformer, nameof(transformer));
ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
return transformer.GetRowToRowMapper;
}
- ~PredictionEngine()
+ internal virtual ITransformer ProcessTransformer(ITransformer transformer) => transformer;
+
+ ~PredictionEngineBase()
{
_disposer?.Invoke();
}
@@ -189,19 +227,16 @@ public TDst Predict(TSrc example)
return result;
}
+ protected void ExtractValues(TSrc example) => _inputRow.ExtractValues(example);
+
+ protected void FillValues(TDst prediction) => _outputRow.FillValues(prediction);
+
///
/// Run prediction pipeline on one example.
///
/// The example to run on.
/// The object to store the prediction in. If it's null, a new one will be created, otherwise the old one
/// is reused.
- public void Predict(TSrc example, ref TDst prediction)
- {
- Contracts.CheckValue(example, nameof(example));
- _inputRow.ExtractValues(example);
- if (prediction == null)
- prediction = new TDst();
- _outputRow.FillValues(prediction);
- }
+ public abstract void Predict(TSrc example, ref TDst prediction);
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Api/Properties/AssemblyInfo.cs b/src/Microsoft.ML.Api/Properties/AssemblyInfo.cs
new file mode 100644
index 0000000000..05db89ef9f
--- /dev/null
+++ b/src/Microsoft.ML.Api/Properties/AssemblyInfo.cs
@@ -0,0 +1,10 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+using Microsoft.ML;
+
+[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]
+
+[assembly: WantsToBeBestFriends]
diff --git a/src/Microsoft.ML.Api/TypedCursor.cs b/src/Microsoft.ML.Api/TypedCursor.cs
index 38f4965cea..c7944c7704 100644
--- a/src/Microsoft.ML.Api/TypedCursor.cs
+++ b/src/Microsoft.ML.Api/TypedCursor.cs
@@ -77,6 +77,7 @@ public interface ICursorable
/// Similarly to the 'DataView{T}, this class uses IL generation to create the 'poke' methods that
/// write directly into the fields of the user-defined type.
///
+ [BestFriend]
internal sealed class TypedCursorable : ICursorable
where TRow : class
{
diff --git a/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs b/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs
index bdcea0c0ae..1efb2f7e18 100644
--- a/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs
+++ b/src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs
@@ -13,12 +13,16 @@ namespace Microsoft.ML.Runtime.Internal.Utilities
/// an item will result in discarding the least recently added item.
///
[BestFriend]
- internal sealed class FixedSizeQueue
+ internal sealed class FixedSizeQueue : ICloneable
{
private readonly T[] _array;
private int _startIndex;
private int _count;
+ public int StartIndex => _startIndex;
+
+ public T[] Buffer => _array;
+
public FixedSizeQueue(int capacity)
{
Contracts.Assert(capacity > 0, "Array capacity should be greater than zero");
@@ -26,6 +30,13 @@ public FixedSizeQueue(int capacity)
AssertValid();
}
+ public FixedSizeQueue(int capacity, int startIndex, T[] buffer) : this(capacity)
+ {
+ _startIndex = startIndex;
+ for (int index = 0; index < capacity; index++)
+ _array[index] = buffer[index];
+ }
+
[Conditional("DEBUG")]
private void AssertValid()
{
@@ -136,5 +147,8 @@ public void Clear()
_count = 0;
AssertValid();
}
+
+ public object Clone() => new FixedSizeQueue(Capacity, StartIndex, Buffer);
+
}
}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
index f6c0a4bfd4..73c9be64ba 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
@@ -36,20 +36,33 @@ public enum TransformerScope
Everything = Training | Testing | Scoring
}
+ [BestFriend]
+ internal interface ITransformerAccessor
+ {
+ ITransformer[] Transformers { get; }
+ TransformerScope[] Scopes { get; }
+ }
+
///
/// A chain of transformers (possibly empty) that end with a .
/// For an empty chain, is always .
///
- public sealed class TransformerChain : ITransformer, ICanSaveModel, IEnumerable
+ public sealed class TransformerChain : ITransformer, ICanSaveModel, IEnumerable, ITransformerAccessor
where TLastTransformer : class, ITransformer
{
- private readonly ITransformer[] _transformers;
- private readonly TransformerScope[] _scopes;
+ [BestFriend]
+ internal readonly ITransformer[] Transformers;
+ [BestFriend]
+ internal readonly TransformerScope[] Scopes;
public readonly TLastTransformer LastTransformer;
private const string TransformDirTemplate = "Transform_{0:000}";
- public bool IsRowToRowMapper => _transformers.All(t => t.IsRowToRowMapper);
+ public bool IsRowToRowMapper => Transformers.All(t => t.IsRowToRowMapper);
+
+ ITransformer[] ITransformerAccessor.Transformers => Transformers;
+
+ TransformerScope[] ITransformerAccessor.Scopes => Scopes;
private static VersionInfo GetVersionInfo()
{
@@ -72,12 +85,12 @@ public TransformerChain(IEnumerable transformers, IEnumerable 0) == (LastTransformer != null));
- Contracts.Check(_transformers.Length == _scopes.Length);
+ Contracts.Check((Transformers.Length > 0) == (LastTransformer != null));
+ Contracts.Check(Transformers.Length == Scopes.Length);
}
///
@@ -91,14 +104,14 @@ public TransformerChain(params ITransformer[] transformers)
if (Utils.Size(transformers) == 0)
{
- _transformers = new ITransformer[0];
- _scopes = new TransformerScope[0];
+ Transformers = new ITransformer[0];
+ Scopes = new TransformerScope[0];
LastTransformer = null;
}
else
{
- _transformers = transformers.ToArray();
- _scopes = transformers.Select(x => TransformerScope.Everything).ToArray();
+ Transformers = transformers.ToArray();
+ Scopes = transformers.Select(x => TransformerScope.Everything).ToArray();
LastTransformer = transformers.Last() as TLastTransformer;
Contracts.Check(LastTransformer != null);
}
@@ -109,7 +122,7 @@ public Schema GetOutputSchema(Schema inputSchema)
Contracts.CheckValue(inputSchema, nameof(inputSchema));
var s = inputSchema;
- foreach (var xf in _transformers)
+ foreach (var xf in Transformers)
s = xf.GetOutputSchema(s);
return s;
}
@@ -123,7 +136,7 @@ public IDataView Transform(IDataView input)
GetOutputSchema(input.Schema);
var dv = input;
- foreach (var xf in _transformers)
+ foreach (var xf in Transformers)
dv = xf.Transform(dv);
return dv;
}
@@ -132,12 +145,12 @@ public TransformerChain GetModelFor(TransformerScope scopeFilter)
{
var xfs = new List();
var scopes = new List();
- for (int i = 0; i < _transformers.Length; i++)
+ for (int i = 0; i < Transformers.Length; i++)
{
- if ((_scopes[i] & scopeFilter) != TransformerScope.None)
+ if ((Scopes[i] & scopeFilter) != TransformerScope.None)
{
- xfs.Add(_transformers[i]);
- scopes.Add(_scopes[i]);
+ xfs.Add(Transformers[i]);
+ scopes.Add(Scopes[i]);
}
}
return new TransformerChain(xfs.ToArray(), scopes.ToArray());
@@ -147,7 +160,7 @@ public TransformerChain Append(TNewLast transformer, Transfo
where TNewLast : class, ITransformer
{
Contracts.CheckValue(transformer, nameof(transformer));
- return new TransformerChain(_transformers.AppendElement(transformer), _scopes.AppendElement(scope));
+ return new TransformerChain(Transformers.AppendElement(transformer), Scopes.AppendElement(scope));
}
public void Save(ModelSaveContext ctx)
@@ -155,13 +168,13 @@ public void Save(ModelSaveContext ctx)
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
- ctx.Writer.Write(_transformers.Length);
+ ctx.Writer.Write(Transformers.Length);
- for (int i = 0; i < _transformers.Length; i++)
+ for (int i = 0; i < Transformers.Length; i++)
{
- ctx.Writer.Write((int)_scopes[i]);
+ ctx.Writer.Write((int)Scopes[i]);
var dirName = string.Format(TransformDirTemplate, i);
- ctx.SaveModel(_transformers[i], dirName);
+ ctx.SaveModel(Transformers[i], dirName);
}
}
@@ -171,16 +184,16 @@ public void Save(ModelSaveContext ctx)
internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx)
{
int len = ctx.Reader.ReadInt32();
- _transformers = new ITransformer[len];
- _scopes = new TransformerScope[len];
+ Transformers = new ITransformer[len];
+ Scopes = new TransformerScope[len];
for (int i = 0; i < len; i++)
{
- _scopes[i] = (TransformerScope)(ctx.Reader.ReadInt32());
+ Scopes[i] = (TransformerScope)(ctx.Reader.ReadInt32());
var dirName = string.Format(TransformDirTemplate, i);
- ctx.LoadModel(env, out _transformers[i], dirName);
+ ctx.LoadModel(env, out Transformers[i], dirName);
}
if (len > 0)
- LastTransformer = _transformers[len - 1] as TLastTransformer;
+ LastTransformer = Transformers[len - 1] as TLastTransformer;
else
LastTransformer = null;
}
@@ -198,7 +211,7 @@ public void SaveTo(IHostEnvironment env, Stream outputStream)
}
}
- public IEnumerator GetEnumerator() => ((IEnumerable)_transformers).GetEnumerator();
+ public IEnumerator GetEnumerator() => ((IEnumerable)Transformers).GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
@@ -207,11 +220,11 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema)
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Contracts.Check(IsRowToRowMapper, nameof(GetRowToRowMapper) + " method called despite " + nameof(IsRowToRowMapper) + " being false.");
- IRowToRowMapper[] mappers = new IRowToRowMapper[_transformers.Length];
+ IRowToRowMapper[] mappers = new IRowToRowMapper[Transformers.Length];
Schema schema = inputSchema;
for (int i = 0; i < mappers.Length; ++i)
{
- mappers[i] = _transformers[i].GetRowToRowMapper(schema);
+ mappers[i] = Transformers[i].GetRowToRowMapper(schema);
schema = mappers[i].Schema;
}
return new CompositeRowToRowMapper(inputSchema, mappers);
diff --git a/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs b/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs
index f07595e52a..914dcf2a95 100644
--- a/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs
+++ b/src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs
@@ -1,4 +1,4 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
@@ -19,6 +19,9 @@ public sealed class CompositeRowToRowMapper : IRowToRowMapper
public Schema InputSchema { get; }
public Schema Schema { get; }
+ [BestFriend]
+ internal IRowToRowMapper[] InnerMappers => _innerMappers;
+
///
/// Out of a series of mappers, construct a seemingly unitary mapper that is able to apply them in sequence.
///
@@ -85,6 +88,7 @@ public IRow GetRow(IRow input, Func active, out Action disposer)
// We want the last disposer to be called first, so the order of the addition here is important.
}
}
+
return result;
}
diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs
index c60717bce1..d98cc8e89b 100644
--- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs
+++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs
@@ -454,6 +454,7 @@ public AdaptiveSingularSpectrumSequenceModeler(IHostEnvironment env, ModelLoadCo
_wTrans = new CpuAlignedMatrixRow(_rank, _windowSize, SseUtils.CbAlign);
int i = 0;
_wTrans.CopyFrom(tempArray, ref i);
+ _y = new CpuAlignedVector(_rank, SseUtils.CbAlign);
}
_buffer = new FixedSizeQueue(_seriesLength);
diff --git a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs
index f5f4af3104..bfc5e7b874 100644
--- a/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs
+++ b/src/Microsoft.ML.TimeSeries/IidAnomalyDetectionBase.cs
@@ -20,12 +20,16 @@ public IidAnomalyDetectionBase(ArgumentsBase args, string name, IHostEnvironment
: base(args, name, env)
{
InitialWindowSize = 0;
+ StateRef = new State();
+ StateRef.InitState(WindowSize, InitialWindowSize, this, Host);
}
public IidAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name)
: base(env, ctx, name)
{
Host.CheckDecode(InitialWindowSize == 0);
+ StateRef = new State(ctx);
+ StateRef.InitState(this, Host);
}
public override Schema GetOutputSchema(Schema inputSchema)
@@ -46,15 +50,52 @@ public override void Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
Host.Assert(InitialWindowSize == 0);
+ base.Save(ctx);
// *** Binary format ***
//
-
- base.Save(ctx);
+ // State: StateRef
+ StateRef.Save(ctx);
}
public sealed class State : AnomalyDetectionStateBase
{
+ public State()
+ {
+
+ }
+
+ public State(ModelLoadContext ctx) : base(ctx)
+ {
+ WindowedBuffer = new FixedSizeQueue(
+ ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadSingleArray());
+
+ InitialWindowedBuffer = new FixedSizeQueue(
+ ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadSingleArray());
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ base.Save(ctx);
+
+ ctx.Writer.Write(WindowedBuffer.Capacity);
+ ctx.Writer.Write(WindowedBuffer.StartIndex);
+ ctx.Writer.WriteSingleArray(WindowedBuffer.Buffer);
+
+ ctx.Writer.Write(InitialWindowedBuffer.Capacity);
+ ctx.Writer.Write(InitialWindowedBuffer.StartIndex);
+ ctx.Writer.WriteSingleArray(InitialWindowedBuffer.Buffer);
+ }
+
+ public override void CloneCore(StateBase state)
+ {
+ base.CloneCore(state);
+ Contracts.Assert(state is State);
+ var stateLocal = state as State;
+ stateLocal.WindowedBuffer = (FixedSizeQueue)WindowedBuffer.Clone();
+ stateLocal.InitialWindowedBuffer = (FixedSizeQueue)InitialWindowedBuffer.Clone();
+ }
+
private protected override void LearnStateFromDataCore(FixedSizeQueue data)
{
// This method is empty because there is no need for initial tuning for this transform.
@@ -70,6 +111,10 @@ private protected override double ComputeRawAnomalyScore(ref Single input, Fixed
// This transform treats the input sequenence as the raw anomaly score.
return (double)input;
}
+
+ public override void Consume(float value)
+ {
+ }
}
}
}
diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
index 1fb5b2e62e..741bba0601 100644
--- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs
@@ -12,6 +12,7 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
+using Microsoft.ML.TimeSeries;
using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
[assembly: LoadableClass(IidChangePointDetector.Summary, typeof(IDataTransform), typeof(IidChangePointDetector), typeof(IidChangePointDetector.Arguments), typeof(SignatureDataTransform),
@@ -110,6 +111,14 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData
return new IidChangePointDetector(env, args).MakeDataTransform(input);
}
+ public override IStatefulTransformer Clone()
+ {
+ var clone = (IidChangePointDetector)MemberwiseClone();
+ clone.StateRef = (State)clone.StateRef.Clone();
+ clone.StateRef.InitState(clone, Host);
+ return clone;
+ }
+
internal IidChangePointDetector(IHostEnvironment env, Arguments args)
: base(new BaseArguments(args), LoaderSignature, env)
{
diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
index 9ccb302eb6..5e93116f42 100644
--- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs
@@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
+using Microsoft.ML.TimeSeries;
using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
[assembly: LoadableClass(IidSpikeDetector.Summary, typeof(IDataTransform), typeof(IidSpikeDetector), typeof(IidSpikeDetector.Arguments), typeof(SignatureDataTransform),
@@ -106,6 +107,14 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData
return new IidSpikeDetector(env, args).MakeDataTransform(input);
}
+ public override IStatefulTransformer Clone()
+ {
+ var clone = (IidSpikeDetector)MemberwiseClone();
+ clone.StateRef = (State)clone.StateRef.Clone();
+ clone.StateRef.InitState(clone, Host);
+ return clone;
+ }
+
internal IidSpikeDetector(IHostEnvironment env, Arguments args)
: base(new BaseArguments(args), LoaderSignature, env)
{
diff --git a/src/Microsoft.ML.TimeSeries/PredictionFunction.cs b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs
new file mode 100644
index 0000000000..ac502a9e31
--- /dev/null
+++ b/src/Microsoft.ML.TimeSeries/PredictionFunction.cs
@@ -0,0 +1,263 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Data;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+
+namespace Microsoft.ML.TimeSeries
+{
+ public interface IStatefulRowToRowMapper : IRowToRowMapper
+ {
+ }
+
+ public interface IStatefulTransformer : ITransformer
+ {
+ IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema);
+
+ IStatefulTransformer Clone();
+ }
+
+ public interface IStatefulRow : IRow
+ {
+ Action GetPinger();
+ }
+
+ public interface IStatefulRowMapper : IRowMapper
+ {
+ void CloneState();
+
+ Action CreatePinger(IRow input, Func activeOutput, out Action disposer);
+ }
+
+ ///
+ /// A class that runs the previously trained model (and the preceding transform pipeline) on the
+ /// in-memory data, one example at a time.
+ /// This can also be used with trained pipelines that do not end with a predictor: in this case, the
+ /// 'prediction' will be just the outcome of all the transformations.
+ ///
+ /// The user-defined type that holds the example.
+ /// The user-defined type that holds the prediction.
+ public sealed class TimeSeriesPredictionFunction : PredictionEngineBase
+ where TSrc : class
+ where TDst : class, new()
+ {
+ private Action _pinger;
+ private long _rowPosition;
+ private ITransformer InputTransformer { get; set; }
+
+ public void CheckPoint(IHostEnvironment env, string modelPath)
+ {
+ using (var file = File.Create(modelPath))
+ if (Transformer is ITransformerAccessor)
+ {
+
+ new TransformerChain
+ (((ITransformerAccessor)Transformer).Transformers,
+ ((ITransformerAccessor)Transformer).Scopes).SaveTo(env, file);
+ }
+ else
+ Transformer.SaveTo(env, file);
+ }
+
+ private static ITransformer CloneTransformers(ITransformer transformer)
+ {
+ ITransformer[] transformersClone = null;
+ TransformerScope[] scopeClone = null;
+ if (transformer is ITransformerAccessor)
+ {
+ ITransformerAccessor accessor = (ITransformerAccessor)transformer;
+ transformersClone = accessor.Transformers.Select(x => x).ToArray();
+ scopeClone = accessor.Scopes.Select(x => x).ToArray();
+ int index = 0;
+ foreach (var xf in transformersClone)
+ transformersClone[index++] = xf is IStatefulTransformer ? ((IStatefulTransformer)xf).Clone() : xf;
+
+ return new TransformerChain(transformersClone, scopeClone);
+ }
+ else
+ return transformer is IStatefulTransformer ? ((IStatefulTransformer)transformer).Clone() : transformer;
+ }
+
+ public TimeSeriesPredictionFunction(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
+ SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null) :
+ base(env, CloneTransformers(transformer), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
+ {
+ }
+
+ internal override ITransformer ProcessTransformer(ITransformer transformer) => CloneTransformers(transformer);
+
+ public IRow GetStatefulRows(IRow input, IRowToRowMapper mapper, Func active,
+ List rows, out Action disposer)
+ {
+ Contracts.CheckValue(input, nameof(input));
+ Contracts.CheckValue(active, nameof(active));
+
+ disposer = null;
+ IRowToRowMapper[] innerMappers = new IRowToRowMapper[0];
+ if (mapper is CompositeRowToRowMapper)
+ innerMappers = ((CompositeRowToRowMapper)mapper).InnerMappers;
+
+ if (innerMappers.Length == 0)
+ {
+ bool differentActive = false;
+ for (int c = 0; c < input.Schema.ColumnCount; ++c)
+ {
+ bool wantsActive = active(c);
+ bool isActive = input.IsColumnActive(c);
+ differentActive |= wantsActive != isActive;
+
+ if (wantsActive && !isActive)
+ throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema.GetColumnName(c)}' active but it was not.");
+ }
+
+ var row = mapper.GetRow(input, active, out disposer);
+ if (row is IStatefulRow)
+ rows.Add((IStatefulRow)row);
+
+ return row;
+ }
+
+ // For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know
+ // what we need from them. The last one will just have the input, but the rest will need to be
+ // computed based on the dependencies of the next one in the chain.
+ var deps = new Func[innerMappers.Length];
+ deps[deps.Length - 1] = active;
+ for (int i = deps.Length - 1; i >= 1; --i)
+ deps[i - 1] = innerMappers[i].GetDependencies(deps[i]);
+
+ IRow result = input;
+ for (int i = 0; i < innerMappers.Length; ++i)
+ {
+ Action localDisp;
+ result = GetStatefulRows(result, innerMappers[i], deps[i], rows, out localDisp);
+ if (result is IStatefulRow)
+ rows.Add((IStatefulRow)result);
+
+ if (localDisp != null)
+ {
+ if (disposer == null)
+ disposer = localDisp;
+ else
+ disposer = localDisp + disposer;
+ // We want the last disposer to be called first, so the order of the addition here is important.
+ }
+ }
+
+ return result;
+ }
+
+ private Action CreatePinger(List rows)
+ {
+ Action[] pingers = new Action[rows.Count];
+ int index = 0;
+ foreach (var row in rows)
+ pingers[index++] = row.GetPinger();
+
+ return (long position) =>
+ {
+ foreach (var ping in pingers)
+ ping(position);
+ };
+ }
+
+ internal override void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns,
+ SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs outputRow)
+ {
+ List rows = new List();
+ IRow outputRowLocal = outputRowLocal = GetStatefulRows(inputRow, mapper, col => true, rows, out disposer);
+ var cursorable = TypedCursorable.Create(env, new EmptyDataView(env, mapper.Schema), ignoreMissingColumns, outputSchemaDefinition);
+ _pinger = CreatePinger(rows);
+ outputRow = cursorable.GetRow(outputRowLocal);
+ }
+
+ private bool IsRowToRowMapper(ITransformer transformer)
+ {
+ if (transformer is ITransformerAccessor)
+ return ((ITransformerAccessor)transformer).Transformers.All(t => t.IsRowToRowMapper || t is IStatefulTransformer);
+ else
+ return transformer.IsRowToRowMapper || transformer is IStatefulTransformer;
+ }
+
+ private IRowToRowMapper GetRowToRowMapper(Schema inputSchema)
+ {
+ Contracts.CheckValue(inputSchema, nameof(inputSchema));
+ Contracts.Check(IsRowToRowMapper(InputTransformer), nameof(GetRowToRowMapper) +
+ " method called despite " + nameof(IsRowToRowMapper) + " being false. or transformer not being " + nameof(IStatefulTransformer));
+
+ if (!(InputTransformer is ITransformerAccessor))
+ if (InputTransformer is IStatefulTransformer)
+ return ((IStatefulTransformer)InputTransformer).GetStatefulRowToRowMapper(inputSchema);
+ else
+ return InputTransformer.GetRowToRowMapper(inputSchema);
+
+ Contracts.Check(InputTransformer is ITransformerAccessor);
+
+ var transformers = ((ITransformerAccessor)InputTransformer).Transformers;
+ IRowToRowMapper[] mappers = new IRowToRowMapper[transformers.Length];
+ Schema schema = inputSchema;
+ for (int i = 0; i < mappers.Length; ++i)
+ {
+ if (transformers[i] is IStatefulTransformer)
+ mappers[i] = ((IStatefulTransformer)transformers[i]).GetStatefulRowToRowMapper(schema);
+ else
+ mappers[i] = transformers[i].GetRowToRowMapper(schema);
+
+ schema = mappers[i].Schema;
+ }
+ return new CompositeRowToRowMapper(inputSchema, mappers);
+ }
+
+ protected override Func TransformerChecker(IExceptionContext ectx, ITransformer transformer)
+ {
+ ectx.CheckValue(transformer, nameof(transformer));
+ ectx.CheckParam(IsRowToRowMapper(transformer), nameof(transformer), "Must be a row to row mapper or " + nameof(IStatefulTransformer));
+ InputTransformer = transformer;
+ return GetRowToRowMapper;
+ }
+
+ ///
+ /// Run prediction pipeline on one example.
+ ///
+ /// The example to run on.
+ /// The object to store the prediction in. If it's null, a new one will be created, otherwise the old one
+ /// is reused.
+ public override void Predict(TSrc example, ref TDst prediction)
+ {
+ Contracts.CheckValue(example, nameof(example));
+ ExtractValues(example);
+ if (prediction == null)
+ prediction = new TDst();
+
+ // Update state.
+ _pinger(_rowPosition);
+
+ // Predict.
+ FillValues(prediction);
+
+ _rowPosition++;
+ }
+ }
+
+ public static class PredictionFunctionExtensions
+ {
+ public static TimeSeriesPredictionFunction CreateTimeSeriesPredictionFunction(this ITransformer transformer, IHostEnvironment env,
+ bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
+ where TSrc : class
+ where TDst : class, new()
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(transformer, nameof(transformer));
+ env.CheckValueOrNull(inputSchemaDefinition);
+ env.CheckValueOrNull(outputSchemaDefinition);
+ return new TimeSeriesPredictionFunction(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs
index 398fb86f19..b28ce37329 100644
--- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs
+++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs
@@ -1,15 +1,17 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections.Generic;
+using System.Threading;
using Microsoft.ML.Data;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Internal.CpuMath;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.TimeSeries;
namespace Microsoft.ML.Runtime.TimeSeriesProcessing
{
@@ -306,10 +308,10 @@ public abstract class AnomalyDetectionStateBase : StateBase
protected SequentialAnomalyDetectionTransformBase Parent;
// A windowed buffer to cache the update values to the martingale score in the log scale.
- private FixedSizeQueue _logMartingaleUpdateBuffer;
+ private FixedSizeQueue LogMartingaleUpdateBuffer { get; set; }
// A windowed buffer to cache the raw anomaly scores for p-value calculation.
- private FixedSizeQueue _rawScoreBuffer;
+ private FixedSizeQueue RawScoreBuffer { get; set; }
// The current martingale score in the log scale.
private Double _logMartingaleValue;
@@ -322,6 +324,45 @@ public abstract class AnomalyDetectionStateBase : StateBase
protected Double LatestMartingaleScore => Math.Exp(_logMartingaleValue);
+ public override void CloneCore(StateBase state)
+ {
+ base.CloneCore(state);
+ Contracts.Assert(state is AnomalyDetectionStateBase);
+ var stateLocal = state as AnomalyDetectionStateBase;
+ stateLocal.LogMartingaleUpdateBuffer = (FixedSizeQueue)LogMartingaleUpdateBuffer.Clone();
+ stateLocal.RawScoreBuffer = (FixedSizeQueue)RawScoreBuffer.Clone();
+ }
+
+ public AnomalyDetectionStateBase(ModelLoadContext ctx) : base(ctx)
+ {
+ LogMartingaleUpdateBuffer = new FixedSizeQueue(
+ ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadDoubleArray());
+
+ RawScoreBuffer = new FixedSizeQueue(
+ ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadSingleArray());
+
+ _logMartingaleValue = ctx.Reader.ReadDouble();
+ _sumSquaredDist = ctx.Reader.ReadDouble();
+ _martingaleAlertCounter = ctx.Reader.ReadInt32();
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ base.Save(ctx);
+
+ ctx.Writer.Write(LogMartingaleUpdateBuffer.Capacity);
+ ctx.Writer.Write(LogMartingaleUpdateBuffer.StartIndex);
+ ctx.Writer.WriteDoubleArray(LogMartingaleUpdateBuffer.Buffer);
+
+ ctx.Writer.Write(RawScoreBuffer.Capacity);
+ ctx.Writer.Write(RawScoreBuffer.StartIndex);
+ ctx.Writer.WriteSingleArray(RawScoreBuffer.Buffer);
+
+ ctx.Writer.Write(_logMartingaleValue);
+ ctx.Writer.Write(_sumSquaredDist);
+ ctx.Writer.Write(_martingaleAlertCounter);
+ }
+
private protected AnomalyDetectionStateBase() : base()
{
}
@@ -329,7 +370,7 @@ private protected AnomalyDetectionStateBase() : base()
private Double ComputeKernelPValue(Double rawScore)
{
int i;
- int n = _rawScoreBuffer.Count;
+ int n = RawScoreBuffer.Count;
if (n == 0)
return 0.5;
@@ -341,21 +382,21 @@ private Double ComputeKernelPValue(Double rawScore)
Double diff;
for (i = 0; i < n; ++i)
{
- diff = rawScore - _rawScoreBuffer[i];
+ diff = rawScore - RawScoreBuffer[i];
pValue -= ProbabilityFunctions.Erf(diff / bandWidth);
_sumSquaredDist += diff * diff;
}
pValue = 0.5 + pValue / (2 * n);
- if (_rawScoreBuffer.IsFull)
+ if (RawScoreBuffer.IsFull)
{
for (i = 1; i < n; ++i)
{
- diff = _rawScoreBuffer[0] - _rawScoreBuffer[i];
+ diff = RawScoreBuffer[0] - RawScoreBuffer[i];
_sumSquaredDist -= diff * diff;
}
- diff = _rawScoreBuffer[0] - rawScore;
+ diff = RawScoreBuffer[0] - rawScore;
_sumSquaredDist -= diff * diff;
}
@@ -435,7 +476,7 @@ private protected override sealed void TransformCore(ref TInput input, FixedSize
else if (result.Values[2] > MaxPValue)
result.Values[2] = MaxPValue;
- _rawScoreBuffer.AddLast(rawScore);
+ RawScoreBuffer.AddLast(rawScore);
// Step 3: Computing the martingale value
if (Parent.Martingale != MartingaleType.None && Parent.ThresholdScore == AlertingScore.MartingaleScore)
@@ -452,17 +493,17 @@ private protected override sealed void TransformCore(ref TInput input, FixedSize
break;
}
- if (_logMartingaleUpdateBuffer.Count == 0)
+ if (LogMartingaleUpdateBuffer.Count == 0)
{
- for (int i = 0; i < _logMartingaleUpdateBuffer.Capacity; ++i)
- _logMartingaleUpdateBuffer.AddLast(martingaleUpdate);
- _logMartingaleValue += _logMartingaleUpdateBuffer.Capacity * martingaleUpdate;
+ for (int i = 0; i < LogMartingaleUpdateBuffer.Capacity; ++i)
+ LogMartingaleUpdateBuffer.AddLast(martingaleUpdate);
+ _logMartingaleValue += LogMartingaleUpdateBuffer.Capacity * martingaleUpdate;
}
else
{
_logMartingaleValue += martingaleUpdate;
- _logMartingaleValue -= _logMartingaleUpdateBuffer.PeekFirst();
- _logMartingaleUpdateBuffer.AddLast(martingaleUpdate);
+ _logMartingaleValue -= LogMartingaleUpdateBuffer.PeekFirst();
+ LogMartingaleUpdateBuffer.AddLast(martingaleUpdate);
}
result.Values[3] = Math.Exp(_logMartingaleValue);
@@ -473,7 +514,7 @@ private protected override sealed void TransformCore(ref TInput input, FixedSize
// Generating alert
bool alert = false;
- if (_rawScoreBuffer.IsFull) // No alert until the buffer is completely full.
+ if (RawScoreBuffer.IsFull) // No alert until the buffer is completely full.
{
switch (Parent.ThresholdScore)
{
@@ -506,17 +547,23 @@ private protected override sealed void TransformCore(ref TInput input, FixedSize
dst = result.Commit();
}
- private protected override sealed void InitializeStateCore()
+ private protected override sealed void InitializeStateCore(bool disk = false)
{
Parent = (SequentialAnomalyDetectionTransformBase)ParentTransform;
Host.Assert(WindowSize >= 0);
- if (Parent.Martingale != MartingaleType.None)
- _logMartingaleUpdateBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize);
+ if (disk == false)
+ {
+ if (Parent.Martingale != MartingaleType.None)
+ LogMartingaleUpdateBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize);
+ else
+ LogMartingaleUpdateBuffer = new FixedSizeQueue(1);
+
+ RawScoreBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize);
- _rawScoreBuffer = new FixedSizeQueue(WindowSize == 0 ? 1 : WindowSize);
+ _logMartingaleValue = 0;
+ }
- _logMartingaleValue = 0;
InitializeAnomalyDetector();
}
@@ -536,15 +583,16 @@ private protected override sealed void InitializeStateCore()
private protected abstract Double ComputeRawAnomalyScore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration);
}
- protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(Host, this, schema);
+ protected override IStatefulRowMapper MakeRowMapper(ISchema schema) => new Mapper(Host, this, schema);
- private sealed class Mapper : IRowMapper
+ private sealed class Mapper : IStatefulRowMapper
{
private readonly IHost _host;
private readonly SequentialAnomalyDetectionTransformBase _parent;
private readonly ISchema _parentSchema;
private readonly int _inputColumnIndex;
private readonly VBuffer> _slotNames;
+ private TState State { get; set; }
public Mapper(IHostEnvironment env, SequentialAnomalyDetectionTransformBase parent, ISchema inputSchema)
{
@@ -564,6 +612,8 @@ public Mapper(IHostEnvironment env, SequentialAnomalyDetectionTransformBase>(4, new[] { "Alert".AsMemory(), "Raw Score".AsMemory(),
"P-Value Score".AsMemory(), "Martingale Score".AsMemory() });
+
+ State = _parent.StateRef;
}
public Schema.DetachedColumn[] GetOutputColumns()
@@ -592,11 +642,8 @@ public Delegate[] CreateGetters(IRow input, Func activeOutput, out Ac
disposer = null;
var getters = new Delegate[1];
if (activeOutput(0))
- {
- TState state = new TState();
- state.InitState(_parent.WindowSize, _parent.InitialWindowSize, _parent, _host);
- getters[0] = MakeGetter(input, state);
- }
+ getters[0] = MakeGetter(input, State);
+
return getters;
}
@@ -608,15 +655,46 @@ private Delegate MakeGetter(IRow input, TState state)
var srcGetter = input.GetGetter(_inputColumnIndex);
ProcessData processData = _parent.WindowSize > 0 ?
(ProcessData)state.Process : state.ProcessWithoutBuffer;
+
ValueGetter> valueGetter = (ref VBuffer dst) =>
- {
+ {
TInput src = default;
srcGetter(ref src);
processData(ref src, ref dst);
- };
-
+ };
return valueGetter;
}
+
+ public Action CreatePinger(IRow input, Func activeOutput, out Action disposer)
+ {
+ disposer = null;
+ Action pinger = null;
+ if (activeOutput(0))
+ pinger = MakePinger(input, State);
+
+ return pinger;
+ }
+
+ private Action MakePinger(IRow input, TState state)
+ {
+ _host.AssertValue(input);
+ var srcGetter = input.GetGetter(_inputColumnIndex);
+ Action pinger = (long rowPosition) =>
+ {
+ TInput src = default;
+ srcGetter(ref src);
+ state.UpdateState(ref src, rowPosition, _parent.WindowSize > 0);
+ };
+ return pinger;
+ }
+
+ public void CloneState()
+ {
+ if (Interlocked.Increment(ref _parent.StateRefCount) > 1)
+ {
+ State = (TState)_parent.StateRef.Clone();
+ }
+ }
}
}
}
diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs
index 1c78af4d27..56e27749eb 100644
--- a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs
+++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs
@@ -8,7 +8,10 @@
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.Api;
-using Microsoft.ML.Core.Data;
+using Microsoft.ML.TimeSeries;
+using Microsoft.ML.Runtime.Model.Onnx;
+using Microsoft.ML.Runtime.Model.Pfa;
+using System.Linq;
using Microsoft.ML.Data;
namespace Microsoft.ML.Runtime.TimeSeriesProcessing
@@ -20,13 +23,13 @@ namespace Microsoft.ML.Runtime.TimeSeriesProcessing
/// The input type of the sequential processing.
/// The dst type of the sequential processing.
/// The state type of the sequential processing. Must be a class inherited from StateBase
- public abstract class SequentialTransformerBase : ITransformer, ICanSaveModel
+ public abstract class SequentialTransformerBase : IStatefulTransformer, ICanSaveModel
where TState : SequentialTransformerBase.StateBase, new()
{
///
/// The base class for encapsulating the State object for sequential processing. This class implements a windowed buffer.
///
- public abstract class StateBase
+ public abstract class StateBase : ICanSaveModel, ICloneable
{
// Ideally this class should be private. However, due to the current constraints with the LambdaTransform, we need to have
// access to the state class when inheriting from SequentialTransformerBase.
@@ -40,12 +43,12 @@ public abstract class StateBase
///
/// The internal windowed buffer for buffering the values in the input sequence.
///
- private protected FixedSizeQueue WindowedBuffer;
+ private protected FixedSizeQueue WindowedBuffer { get; set; }
///
/// The buffer used to buffer the training data points.
///
- private protected FixedSizeQueue InitialWindowedBuffer;
+ private protected FixedSizeQueue InitialWindowedBuffer { get; set; }
private protected int WindowSize { get; private set; }
@@ -66,7 +69,19 @@ protected long IncrementRowCounter()
return RowCounter;
}
- private bool _isIniatilized;
+ protected long PreviousPosition;
+
+ public StateBase(ModelLoadContext ctx)
+ {
+ WindowSize = ctx.Reader.ReadInt32();
+ InitialWindowSize = ctx.Reader.ReadInt32();
+ }
+
+ public virtual void Save(ModelSaveContext ctx)
+ {
+ ctx.Writer.Write(WindowSize);
+ ctx.Writer.Write(InitialWindowSize);
+ }
///
/// This method sets the window size and initializes the buffer only once.
@@ -76,10 +91,10 @@ protected long IncrementRowCounter()
/// The size of the windowed initial buffer used for training
/// The parent transform of this state object
/// The host
- public void InitState(int windowSize, int initialWindowSize, SequentialTransformerBase parentTransform, IHost host)
+ public void InitState(int windowSize, int initialWindowSize, SequentialTransformerBase parentTransform,
+ IHost host)
{
Contracts.CheckValue(host, nameof(host), "The host cannot be null.");
- host.Check(!_isIniatilized, "The window size can be set only once.");
host.CheckValue(parentTransform, nameof(parentTransform));
host.CheckParam(windowSize >= 0, nameof(windowSize), "Must be non-negative.");
host.CheckParam(initialWindowSize >= 0, nameof(initialWindowSize), "Must be non-negative.");
@@ -93,7 +108,19 @@ public void InitState(int windowSize, int initialWindowSize, SequentialTransform
RowCounter = 0;
InitializeStateCore();
- _isIniatilized = true;
+ PreviousPosition = -1;
+ }
+
+ public void InitState(SequentialTransformerBase parentTransform, IHost host)
+ {
+ Contracts.CheckValue(host, nameof(host), "The host cannot be null.");
+ host.CheckValue(parentTransform, nameof(parentTransform));
+
+ Host = host;
+ ParentTransform = parentTransform;
+ RowCounter = 0;
+ InitializeStateCore(true);
+ PreviousPosition = -1;
}
///
@@ -101,51 +128,74 @@ public void InitState(int windowSize, int initialWindowSize, SequentialTransform
///
public virtual void Reset()
{
- Host.Assert(_isIniatilized);
Host.Assert(WindowedBuffer != null);
Host.Assert(InitialWindowedBuffer != null);
RowCounter = 0;
WindowedBuffer.Clear();
InitialWindowedBuffer.Clear();
+ PreviousPosition = -1;
}
- public void Process(ref TInput input, ref TOutput output)
+ public void UpdateState(ref TInput input, long rowPosition, bool buffer = true)
+ {
+ if (rowPosition > PreviousPosition)
+ {
+ PreviousPosition = rowPosition;
+ UpdateStateCore(ref input, buffer);
+ Consume(input);
+ }
+ }
+
+ public void UpdateStateCore(ref TInput input, bool buffer = true)
{
if (InitialWindowedBuffer.Count < InitialWindowSize)
{
InitialWindowedBuffer.AddLast(input);
- SetNaOutput(ref output);
-
- if (InitialWindowedBuffer.Count >= InitialWindowSize - WindowSize)
+ if (InitialWindowedBuffer.Count >= InitialWindowSize - WindowSize && buffer)
+ WindowedBuffer.AddLast(input);
+ }
+ else
+ {
+ if (buffer)
WindowedBuffer.AddLast(input);
+ IncrementRowCounter();
+ }
+ }
+
+ public void Process(ref TInput input, ref TOutput output)
+ {
+ if (PreviousPosition == -1)
+ UpdateStateCore(ref input);
+
+ if (InitialWindowedBuffer.Count < InitialWindowSize)
+ {
+ SetNaOutput(ref output);
+
if (InitialWindowedBuffer.Count == InitialWindowSize)
LearnStateFromDataCore(InitialWindowedBuffer);
}
else
{
TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output);
- WindowedBuffer.AddLast(input);
- IncrementRowCounter();
}
}
public void ProcessWithoutBuffer(ref TInput input, ref TOutput output)
{
+ if (PreviousPosition == -1)
+ UpdateStateCore(ref input, false);
+
if (InitialWindowedBuffer.Count < InitialWindowSize)
{
- InitialWindowedBuffer.AddLast(input);
SetNaOutput(ref output);
if (InitialWindowedBuffer.Count == InitialWindowSize)
LearnStateFromDataCore(InitialWindowedBuffer);
}
else
- {
TransformCore(ref input, WindowedBuffer, RowCounter - InitialWindowSize, ref output);
- IncrementRowCounter();
- }
}
///
@@ -166,13 +216,28 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output)
///
/// The abstract method that realizes the logic for initializing the state object.
///
- private protected abstract void InitializeStateCore();
+ private protected abstract void InitializeStateCore(bool disk = false);
///
/// The abstract method that realizes the logic for learning the parameters and the initial state object from data.
///
/// A queue of data points used for training
private protected abstract void LearnStateFromDataCore(FixedSizeQueue data);
+
+ public abstract void Consume(TInput value);
+
+ public object Clone()
+ {
+ var clone = (StateBase)MemberwiseClone();
+ CloneCore(clone);
+ return clone;
+ }
+
+ public virtual void CloneCore(StateBase state)
+ {
+ state.WindowedBuffer = (FixedSizeQueue)WindowedBuffer.Clone();
+ state.InitialWindowedBuffer = (FixedSizeQueue)InitialWindowedBuffer.Clone();
+ }
}
private protected readonly IHost Host;
@@ -193,6 +258,10 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output)
public bool IsRowToRowMapper => false;
+ public TState StateRef{ get; set; }
+
+ public int StateRefCount;
+
///
/// The main constructor for the sequential transform
///
@@ -273,7 +342,7 @@ public virtual void Save(ModelSaveContext ctx)
public abstract Schema GetOutputSchema(Schema inputSchema);
- protected abstract IRowMapper MakeRowMapper(ISchema schema);
+ protected abstract IStatefulRowMapper MakeRowMapper(ISchema schema);
protected SequentialDataTransform MakeDataTransform(IDataView input)
{
@@ -288,16 +357,28 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema)
throw new InvalidOperationException("Not a RowToRowMapper.");
}
- public sealed class SequentialDataTransform : TransformBase, ITransformTemplate
+ public IRowToRowMapper GetStatefulRowToRowMapper(Schema inputSchema)
{
- private readonly IRowMapper _mapper;
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ return new TimeSeriesRowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema));
+ }
+
+ public virtual IStatefulTransformer Clone() => (SequentialTransformerBase)MemberwiseClone();
+
+ public sealed class SequentialDataTransform : TransformBase, ITransformTemplate, IRowToRowMapper
+ {
+ private readonly IStatefulRowMapper _mapper;
private readonly SequentialTransformerBase _parent;
private readonly IDataTransform _transform;
private readonly ColumnBindings _bindings;
- public SequentialDataTransform(IHost host, SequentialTransformerBase parent, IDataView input, IRowMapper mapper)
- :base(parent.Host, input)
+ private MetadataDispatcher Metadata { get; }
+
+ public SequentialDataTransform(IHost host, SequentialTransformerBase parent,
+ IDataView input, IStatefulRowMapper mapper)
+ : base(parent.Host, input)
{
+ Metadata = new MetadataDispatcher(1);
_parent = parent;
_transform = CreateLambdaTransform(_parent.Host, input, _parent.InputColumnName,
_parent.OutputColumnName, InitFunction, _parent.WindowSize > 0, _parent.OutputColumnType);
@@ -305,6 +386,8 @@ public SequentialDataTransform(IHost host, SequentialTransformerBase _mapper.CloneState();
+
private static IDataTransform CreateLambdaTransform(IHost host, IDataView input, string inputColumnName, string outputColumnName,
Action initFunction, bool hasBuffer, ColumnType outputColTypeOverride)
{
@@ -346,7 +429,9 @@ private void InitFunction(TState state)
protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null)
{
var srcCursor = _transform.GetRowCursor(predicate, rand);
- return new Cursor(Host, this, srcCursor);
+ var clone = (SequentialDataTransform)MemberwiseClone();
+ clone.CloneStateInMapper();
+ return new Cursor(Host, clone, srcCursor);
}
protected override bool? ShouldUseParallelCursors(Func predicate)
@@ -377,6 +462,71 @@ public IDataTransform ApplyToData(IHostEnvironment env, IDataView newSource)
{
return new SequentialDataTransform(Contracts.CheckRef(env, nameof(env)).Register("SequentialDataTransform"), _parent, newSource, _mapper);
}
+
+ public Schema InputSchema => Source.Schema;
+ public Func GetDependencies(Func predicate)
+ {
+ for (int i = 0; i < Schema.ColumnCount; i++)
+ {
+ if (predicate(i))
+ return col => true;
+ }
+ return col => false;
+ }
+
+ public IRow GetRow(IRow input, Func active, out Action disposer) =>
+ new Row(_bindings.Schema, input, _mapper.CreateGetters(input, active, out disposer),
+ _mapper.CreatePinger(input, active, out disposer));
+
+ }
+
+ public sealed class Row : IStatefulRow
+ {
+ private readonly Schema _schema;
+ private readonly IRow _input;
+ private readonly Delegate[] _getters;
+ private readonly Action _pinger;
+
+ public Schema Schema { get { return _schema; } }
+
+ public long Position { get { return _input.Position; } }
+
+ public long Batch { get { return _input.Batch; } }
+
+ public Row(Schema schema, IRow input, Delegate[] getters, Action pinger)
+ {
+ Contracts.CheckValue(schema, nameof(schema));
+ Contracts.CheckValue(input, nameof(input));
+ Contracts.Check(Utils.Size(getters) == schema.ColumnCount);
+ _schema = schema;
+ _input = input;
+ _getters = getters ?? new Delegate[0];
+ _pinger = pinger;
+ }
+
+ public ValueGetter GetIdGetter()
+ {
+ return _input.GetIdGetter();
+ }
+
+ public ValueGetter GetGetter(int col)
+ {
+ Contracts.CheckParam(0 <= col && col < _getters.Length, nameof(col), "Invalid col value in GetGetter");
+ Contracts.Check(IsColumnActive(col));
+ var fn = _getters[col] as ValueGetter;
+ if (fn == null)
+ throw Contracts.Except("Unexpected TValue in GetGetter");
+ return fn;
+ }
+
+ public Action GetPinger() =>
+ _pinger as Action ?? throw Contracts.Except("Invalid TValue in GetPinger: '{0}'", typeof(long));
+
+ public bool IsColumnActive(int col)
+ {
+ Contracts.Check(0 <= col && col < _getters.Length);
+ return _getters[col] != null;
+ }
}
///
@@ -408,4 +558,310 @@ public ValueGetter GetGetter(int col)
}
}
}
+
+ ///
+ /// This class is a transform that can add any number of output columns, that depend on any number of input columns.
+ /// It does so with the help of an , that is given a schema in its constructor, and has methods
+ /// to get the dependencies on input columns and the getters for the output columns, given an active set of output columns.
+ ///
+
+ public sealed class TimeSeriesRowToRowMapperTransform : RowToRowTransformBase, IStatefulRowToRowMapper,
+ ITransformCanSaveOnnx, ITransformCanSavePfa
+ {
+ private readonly IStatefulRowMapper _mapper;
+ private readonly ColumnBindings _bindings;
+ public const string RegistrationName = "TimeSeriesRowToRowMapperTransform";
+ public const string LoaderSignature = "TimeSeriesRowToRowMapper";
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "TS ROW MPPR",
+ verWrittenCur: 0x00010001, // Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TimeSeriesRowToRowMapperTransform).Assembly.FullName);
+ }
+
+ public override Schema Schema => _bindings.Schema;
+ bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;
+
+ bool ICanSavePfa.CanSavePfa => _mapper is ICanSavePfa pfaMapper ? pfaMapper.CanSavePfa : false;
+
+ public TimeSeriesRowToRowMapperTransform(IHostEnvironment env, IDataView input, IStatefulRowMapper mapper)
+ : base(env, RegistrationName, input)
+ {
+ Contracts.CheckValue(mapper, nameof(mapper));
+ _mapper = mapper;
+ _bindings = new ColumnBindings(Schema.Create(input.Schema), mapper.GetOutputColumns());
+ }
+
+ public static Schema GetOutputSchema(ISchema inputSchema, IRowMapper mapper)
+ {
+ Contracts.CheckValue(inputSchema, nameof(inputSchema));
+ Contracts.CheckValue(mapper, nameof(mapper));
+ return new ColumnBindings(Schema.Create(inputSchema), mapper.GetOutputColumns()).Schema;
+ }
+
+ private TimeSeriesRowToRowMapperTransform(IHost host, ModelLoadContext ctx, IDataView input)
+ : base(host, input)
+ {
+ // *** Binary format ***
+ // _mapper
+
+ ctx.LoadModel(host, out _mapper, "Mapper", input.Schema);
+ _bindings = new ColumnBindings(Schema.Create(input.Schema), _mapper.GetOutputColumns());
+ }
+
+ public static TimeSeriesRowToRowMapperTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ var h = env.Register(RegistrationName);
+ h.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel(GetVersionInfo());
+ h.CheckValue(input, nameof(input));
+ return h.Apply("Loading Model", ch => new TimeSeriesRowToRowMapperTransform(h, ctx, input));
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // _mapper
+
+ ctx.SaveModel(_mapper, "Mapper");
+ }
+
+ ///
+ /// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
+ /// a predicate for the needed active input columns, and a predicate for the needed active
+ /// output columns.
+ ///
+ private bool[] GetActive(Func predicate, out Func predicateInput)
+ {
+ int n = _bindings.Schema.ColumnCount;
+ var active = Utils.BuildArray(n, predicate);
+ Contracts.Assert(active.Length == n);
+
+ var activeInput = _bindings.GetActiveInput(predicate);
+ Contracts.Assert(activeInput.Length == _bindings.InputSchema.ColumnCount);
+
+ // Get a predicate that determines which outputs are active.
+ var predicateOut = GetActiveOutputColumns(active);
+
+ // Now map those to active input columns.
+ var predicateIn = _mapper.GetDependencies(predicateOut);
+
+ // Combine the two sets of input columns.
+ predicateInput =
+ col => 0 <= col && col < activeInput.Length && (activeInput[col] || predicateIn(col));
+
+ return active;
+ }
+
+ private Func GetActiveOutputColumns(bool[] active)
+ {
+ Contracts.AssertValue(active);
+ Contracts.Assert(active.Length == _bindings.Schema.ColumnCount);
+
+ return
+ col =>
+ {
+ Contracts.Assert(0 <= col && col < _bindings.AddedColumnIndices.Count);
+ return 0 <= col && col < _bindings.AddedColumnIndices.Count && active[_bindings.AddedColumnIndices[col]];
+ };
+ }
+
+ protected override bool? ShouldUseParallelCursors(Func predicate)
+ {
+ Host.AssertValue(predicate, "predicate");
+ if (_bindings.AddedColumnIndices.Any(predicate))
+ return true;
+ return null;
+ }
+
+ protected override IRowCursor GetRowCursorCore(Func predicate, IRandom rand = null)
+ {
+ Func predicateInput;
+ var active = GetActive(predicate, out predicateInput);
+ return new RowCursor(Host, Source.GetRowCursor(predicateInput, rand), this, active);
+ }
+
+ public override IRowCursor[] GetRowCursorSet(out IRowCursorConsolidator consolidator, Func predicate, int n, IRandom rand = null)
+ {
+ Host.CheckValue(predicate, nameof(predicate));
+ Host.CheckValueOrNull(rand);
+
+ Func predicateInput;
+ var active = GetActive(predicate, out predicateInput);
+
+ var inputs = Source.GetRowCursorSet(out consolidator, predicateInput, n, rand);
+ Host.AssertNonEmpty(inputs);
+
+ if (inputs.Length == 1 && n > 1 && _bindings.AddedColumnIndices.Any(predicate))
+ inputs = DataViewUtils.CreateSplitCursors(out consolidator, Host, inputs[0], n);
+ Host.AssertNonEmpty(inputs);
+
+ var cursors = new IRowCursor[inputs.Length];
+ for (int i = 0; i < inputs.Length; i++)
+ cursors[i] = new RowCursor(Host, inputs[i], this, active);
+ return cursors;
+ }
+
+ void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ if (_mapper is ISaveAsOnnx onnx)
+ {
+ Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
+ onnx.SaveAsOnnx(ctx);
+ }
+ }
+
+ void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx)
+ {
+ Host.CheckValue(ctx, nameof(ctx));
+ if (_mapper is ISaveAsPfa pfa)
+ {
+ Host.Check(pfa.CanSavePfa, "Cannot be saved as PFA.");
+ pfa.SaveAsPfa(ctx);
+ }
+ }
+
+ public Func GetDependencies(Func predicate)
+ {
+ Func predicateInput;
+ GetActive(predicate, out predicateInput);
+ return predicateInput;
+ }
+
+ Schema IRowToRowMapper.InputSchema => Source.Schema;
+
+ public IRow GetRow(IRow input, Func active, out Action disposer)
+ {
+ Host.CheckValue(input, nameof(input));
+ Host.CheckValue(active, nameof(active));
+ Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to");
+
+ disposer = null;
+ using (var ch = Host.Start("GetEntireRow"))
+ {
+ Action disp;
+ var activeArr = new bool[Schema.ColumnCount];
+ for (int i = 0; i < Schema.ColumnCount; i++)
+ activeArr[i] = active(i);
+ var pred = GetActiveOutputColumns(activeArr);
+ var getters = _mapper.CreateGetters(input, pred, out disp);
+ disposer += disp;
+ return new StatefulRow(input, this, Schema, getters,
+ _mapper.CreatePinger(input, pred, out disp));
+ }
+ }
+
+ private sealed class StatefulRow : IStatefulRow
+ {
+ private readonly IRow _input;
+ private readonly Delegate[] _getters;
+ private readonly Action _pinger;
+
+ private readonly TimeSeriesRowToRowMapperTransform _parent;
+
+ public long Batch { get { return _input.Batch; } }
+
+ public long Position { get { return _input.Position; } }
+
+ public Schema Schema { get; }
+
+ public StatefulRow(IRow input, TimeSeriesRowToRowMapperTransform parent,
+ Schema schema, Delegate[] getters, Action pinger)
+ {
+ _input = input;
+ _parent = parent;
+ Schema = schema;
+ _getters = getters;
+ _pinger = pinger;
+ }
+
+ public ValueGetter GetGetter(int col)
+ {
+ bool isSrc;
+ int index = _parent._bindings.MapColumnIndex(out isSrc, col);
+ if (isSrc)
+ return _input.GetGetter(index);
+
+ Contracts.Assert(_getters[index] != null);
+ var fn = _getters[index] as ValueGetter;
+ if (fn == null)
+ throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
+ return fn;
+ }
+
+ public Action GetPinger() =>
+ _pinger as Action ?? throw Contracts.Except("Invalid TValue in GetPinger: '{0}'", typeof(long));
+
+ public ValueGetter GetIdGetter() => _input.GetIdGetter();
+
+ public bool IsColumnActive(int col)
+ {
+ bool isSrc;
+ int index = _parent._bindings.MapColumnIndex(out isSrc, col);
+ if (isSrc)
+ return _input.IsColumnActive((index));
+ return _getters[index] != null;
+ }
+ }
+
+ private sealed class RowCursor : SynchronizedCursorBase, IRowCursor
+ {
+ private readonly Delegate[] _getters;
+ private readonly bool[] _active;
+ private readonly ColumnBindings _bindings;
+ private readonly Action _disposer;
+
+ public Schema Schema => _bindings.Schema;
+
+ public RowCursor(IChannelProvider provider, IRowCursor input, TimeSeriesRowToRowMapperTransform parent, bool[] active)
+ : base(provider, input)
+ {
+ var pred = parent.GetActiveOutputColumns(active);
+ _getters = parent._mapper.CreateGetters(input, pred, out _disposer);
+ _active = active;
+ _bindings = parent._bindings;
+ }
+
+ public bool IsColumnActive(int col)
+ {
+ Ch.Check(0 <= col && col < _bindings.Schema.ColumnCount);
+ return _active[col];
+ }
+
+ public ValueGetter GetGetter(int col)
+ {
+ Ch.Check(IsColumnActive(col));
+
+ bool isSrc;
+ int index = _bindings.MapColumnIndex(out isSrc, col);
+ if (isSrc)
+ return Input.GetGetter(index);
+
+ Ch.AssertValue(_getters);
+ var getter = _getters[index];
+ Ch.Assert(getter != null);
+ var fn = getter as ValueGetter;
+ if (fn == null)
+ throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
+ return fn;
+ }
+
+ public override void Dispose()
+ {
+ _disposer?.Invoke();
+ base.Dispose();
+ }
+ }
+ }
+
}
diff --git a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs
index bc8401ef83..e0f96ed74b 100644
--- a/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs
+++ b/src/Microsoft.ML.TimeSeries/SsaAnomalyDetectionBase.cs
@@ -105,7 +105,7 @@ public abstract class SsaArguments : ArgumentsBase
protected readonly bool IsAdaptive;
protected readonly ErrorFunctionUtils.ErrorFunction ErrorFunction;
protected readonly Func ErrorFunc;
- protected readonly SequenceModelerBase Model;
+ protected SequenceModelerBase Model;
public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment env)
: base(args.WindowSize, 0, args.Source, args.Name, name, env, args.Side, args.Martingale, args.AlertOn, args.PowerMartingaleEpsilon, args.AlertThreshold)
@@ -122,6 +122,9 @@ public SsaAnomalyDetectionBase(SsaArguments args, string name, IHostEnvironment
// Creating the master SSA model
Model = new AdaptiveSingularSpectrumSequenceModeler(Host, args.InitialWindowSize, SeasonalWindowSize + 1, SeasonalWindowSize,
DiscountFactor, AdaptiveSingularSpectrumSequenceModeler.RankSelectionMethod.Exact, null, SeasonalWindowSize / 2, false, false);
+
+ StateRef = new State();
+ StateRef.InitState(WindowSize, InitialWindowSize, this, Host);
}
public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name)
@@ -150,9 +153,11 @@ public SsaAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, strin
ErrorFunc = ErrorFunctionUtils.GetErrorFunction(ErrorFunction);
IsAdaptive = ctx.Reader.ReadBoolean();
+ StateRef = new State(ctx);
ctx.LoadModel, SignatureLoadModel>(env, out Model, "SSA");
Host.CheckDecode(Model != null);
+ StateRef.InitState(this, Host);
}
public override Schema GetOutputSchema(Schema inputSchema)
@@ -186,6 +191,7 @@ public override void Save(ModelSaveContext ctx)
// float: _discountFactor
// byte: _errorFunction
// bool: _isAdaptive
+ // State: StateRef
// AdaptiveSingularSpectrumSequenceModeler: _model
base.Save(ctx);
@@ -193,6 +199,8 @@ public override void Save(ModelSaveContext ctx)
ctx.Writer.Write(DiscountFactor);
ctx.Writer.Write((byte)ErrorFunction);
ctx.Writer.Write(IsAdaptive);
+ StateRef.Save(ctx);
+
ctx.SaveModel(Model, "SSA");
}
@@ -201,6 +209,47 @@ public sealed class State : AnomalyDetectionStateBase
private SequenceModelerBase _model;
private SsaAnomalyDetectionBase _parentAnomalyDetector;
+ public State()
+ {
+
+ }
+
+ public State(ModelLoadContext ctx) : base(ctx)
+ {
+ WindowedBuffer = new FixedSizeQueue(
+ ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadSingleArray());
+
+ InitialWindowedBuffer = new FixedSizeQueue(
+ ctx.Reader.ReadInt32(), ctx.Reader.ReadInt32(), ctx.Reader.ReadSingleArray());
+ }
+
+ public override void Save(ModelSaveContext ctx)
+ {
+ base.Save(ctx);
+
+ ctx.Writer.Write(WindowedBuffer.Capacity);
+ ctx.Writer.Write(WindowedBuffer.StartIndex);
+ ctx.Writer.WriteSingleArray(WindowedBuffer.Buffer);
+
+ ctx.Writer.Write(InitialWindowedBuffer.Capacity);
+ ctx.Writer.Write(InitialWindowedBuffer.StartIndex);
+ ctx.Writer.WriteSingleArray(InitialWindowedBuffer.Buffer);
+ }
+
+ public override void CloneCore(StateBase state)
+ {
+ base.CloneCore(state);
+ Contracts.Assert(state is State);
+ var stateLocal = state as State;
+ stateLocal.WindowedBuffer = (FixedSizeQueue)WindowedBuffer.Clone();
+ stateLocal.InitialWindowedBuffer = (FixedSizeQueue)InitialWindowedBuffer.Clone();
+ if (_model != null)
+ {
+ _parentAnomalyDetector.Model = _parentAnomalyDetector.Model.Clone();
+ _model = _parentAnomalyDetector.Model;
+ }
+ }
+
private protected override void LearnStateFromDataCore(FixedSizeQueue data)
{
// This method is empty because there is no need to implement a training logic here.
@@ -209,7 +258,7 @@ private protected override void LearnStateFromDataCore(FixedSizeQueue da
private protected override void InitializeAnomalyDetector()
{
_parentAnomalyDetector = (SsaAnomalyDetectionBase)Parent;
- _model = _parentAnomalyDetector.Model.Clone();
+ _model = _parentAnomalyDetector.Model;
}
private protected override double ComputeRawAnomalyScore(ref Single input, FixedSizeQueue windowedBuffer, long iteration)
@@ -218,12 +267,15 @@ private protected override double ComputeRawAnomalyScore(ref Single input, Fixed
Single expectedValue = 0;
_model.PredictNext(ref expectedValue);
- // Feed the current point to the model
- _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive);
+ if (PreviousPosition == -1)
+ // Feed the current point to the model
+ _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive);
// Return the error as the raw anomaly score
return _parentAnomalyDetector.ErrorFunc(input, expectedValue);
}
+
+ public override void Consume(Single input) => _model.Consume(ref input, _parentAnomalyDetector.IsAdaptive);
}
}
}
diff --git a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs
index 32f52ca786..a359aeccfc 100644
--- a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs
@@ -12,6 +12,7 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
+using Microsoft.ML.TimeSeries;
using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
[assembly: LoadableClass(SsaChangePointDetector.Summary, typeof(IDataTransform), typeof(SsaChangePointDetector), typeof(SsaChangePointDetector.Arguments), typeof(SignatureDataTransform),
@@ -120,6 +121,15 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData
return new SsaChangePointDetector(env, args, input).MakeDataTransform(input);
}
+ public override IStatefulTransformer Clone()
+ {
+ var clone = (SsaChangePointDetector)MemberwiseClone();
+ clone.Model = clone.Model.Clone();
+ clone.StateRef = (State)clone.StateRef.Clone();
+ clone.StateRef.InitState(clone, Host);
+ return clone;
+ }
+
internal SsaChangePointDetector(IHostEnvironment env, Arguments args)
: base(new BaseArguments(args), LoaderSignature, env)
{
diff --git a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs
index 1fbc36b49d..6bccd51d7a 100644
--- a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs
+++ b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs
@@ -11,6 +11,7 @@
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Model;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
+using Microsoft.ML.TimeSeries;
using static Microsoft.ML.Runtime.TimeSeriesProcessing.SequentialAnomalyDetectionTransformBase;
[assembly: LoadableClass(SsaSpikeDetector.Summary, typeof(IDataTransform), typeof(SsaSpikeDetector), typeof(SsaSpikeDetector.Arguments), typeof(SignatureDataTransform),
@@ -133,6 +134,15 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
return new SsaSpikeDetector(env, ctx).MakeDataTransform(input);
}
+ public override IStatefulTransformer Clone()
+ {
+ var clone = (SsaSpikeDetector)MemberwiseClone();
+ clone.Model = clone.Model.Clone();
+ clone.StateRef = (State)clone.StateRef.Clone();
+ clone.StateRef.InitState(clone, Host);
+ return clone;
+ }
+
// Factory method for SignatureLoadModel.
private static SsaSpikeDetector Create(IHostEnvironment env, ModelLoadContext ctx)
{
diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs
index 834e3f669f..d81738b476 100644
--- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs
+++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs
@@ -1,11 +1,16 @@
-// Licensed to the .NET Foundation under one or more agreements.
+// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using System;
using System.Collections.Generic;
+using System.IO;
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Data;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.TimeSeriesProcessing;
+using Microsoft.ML.TimeSeries;
using Xunit;
namespace Microsoft.ML.Tests
@@ -16,16 +21,28 @@ public sealed class TimeSeries
private sealed class Prediction
{
#pragma warning disable CS0649
- [VectorType(4)]
+ [VectorType(4)]
public double[] Change;
-#pragma warning restore CS0649
+#pragma warning restore CS0649
+ }
+
+ public class Prediction1
+ {
+ public float Random;
}
private sealed class Data
{
+ public string Text;
+ public float Random;
public float Value;
- public Data(float value) => Value = value;
+ public Data(float value)
+ {
+ Text = "random123value";
+ Random = -1;
+ Value = value;
+ }
}
[Fact]
@@ -118,5 +135,111 @@ public void ChangePointDetectionWithSeasonality()
Assert.Equal(expectedValues[index++], row.Change[3], precision: 7); // Martingale score
}
}
+
+ [Fact]
+ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn()
+ {
+ const int ChangeHistorySize = 10;
+ const int SeasonalitySize = 10;
+ const int NumberOfSeasonsInTraining = 5;
+ const int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize;
+
+ List data = new List();
+
+ var ml = new MLContext(seed: 1, conc: 1);
+ var dataView = ml.CreateStreamingDataView(data);
+
+ for (int j = 0; j < NumberOfSeasonsInTraining; j++)
+ for (int i = 0; i < SeasonalitySize; i++)
+ data.Add(new Data(i));
+
+ for (int i = 0; i < ChangeHistorySize; i++)
+ data.Add(new Data(i * 100));
+
+
+ // Pipeline.
+ var pipeline = ml.Transforms.Text.FeaturizeText("Text", "Text_Featurized")
+ .Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Arguments()
+ {
+ Confidence = 95,
+ Source = "Value",
+ Name = "Change",
+ ChangeHistoryLength = ChangeHistorySize,
+ TrainingWindowSize = MaxTrainingSize,
+ SeasonalWindowSize = SeasonalitySize
+ }));
+
+ // Train.
+ var model = pipeline.Fit(dataView);
+
+ //Predict.
+ var engine = model.CreateTimeSeriesPredictionFunction(ml);
+ //Even though time series column is not requested it will pass the observation through time series transform.
+ var prediction = engine.Predict(new Data(1));
+ Assert.Equal(-1, prediction.Random);
+ prediction = engine.Predict(new Data(2));
+ Assert.Equal(-1, prediction.Random);
+ }
+
+ [Fact]
+ public void ChangePointDetectionWithSeasonalityPredictionEngine()
+ {
+ const int ChangeHistorySize = 10;
+ const int SeasonalitySize = 10;
+ const int NumberOfSeasonsInTraining = 5;
+ const int MaxTrainingSize = NumberOfSeasonsInTraining * SeasonalitySize;
+
+ List data = new List();
+
+ var ml = new MLContext(seed: 1, conc: 1);
+ var dataView = ml.CreateStreamingDataView(data);
+
+ for (int j = 0; j < NumberOfSeasonsInTraining; j++)
+ for (int i = 0; i < SeasonalitySize; i++)
+ data.Add(new Data(i));
+
+ for (int i = 0; i < ChangeHistorySize; i++)
+ data.Add(new Data(i * 100));
+
+
+ // Pipeline.
+ var pipeline = ml.Transforms.Text.FeaturizeText("Text", "Text_Featurized")
+ .Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Arguments()
+ {
+ Confidence = 95,
+ Source = "Value",
+ Name = "Change",
+ ChangeHistoryLength = ChangeHistorySize,
+ TrainingWindowSize = MaxTrainingSize,
+ SeasonalWindowSize = SeasonalitySize
+ }));
+
+ // Train.
+ var model = pipeline.Fit(dataView);
+ //Predict.
+ var engine = model.CreateTimeSeriesPredictionFunction(ml);
+ var prediction = engine.Predict(new Data(1));
+ Assert.Equal(0, prediction.Change[0], precision: 7); // Alert
+ Assert.Equal(1.1661833524703979, prediction.Change[1], precision: 7); // Raw score
+ Assert.Equal(0.5, prediction.Change[2], precision: 7); // P-Value score
+ Assert.Equal(5.1200000000000114E-08, prediction.Change[3], precision: 7); // Martingale score
+
+ //Checkpoint.
+ var modelPath = "temp.zip";
+ engine.CheckPoint(ml, modelPath);
+
+ // Load model.
+ ITransformer model2 = null;
+ using (var file = File.OpenRead(modelPath))
+ model2 = TransformerChain.LoadFrom(ml, file);
+
+ //Predict and expect different result for the same input.
+ engine = model2.CreateTimeSeriesPredictionFunction(ml);
+ prediction = engine.Predict(new Data(1));
+ Assert.Equal(0, prediction.Change[0], precision: 7); // Alert
+ Assert.Equal(-0.12883400917053223, prediction.Change[1], precision: 7); // Raw score
+ Assert.Equal(0.5, prediction.Change[2], precision: 7); // P-Value score
+ Assert.Equal(2.6214400000000113E-15, prediction.Change[3], precision: 7); // Martingale score
+ }
}
}