Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction engine for time series. #1618

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a87d836
prediction engine for time series.
codemzs Nov 14, 2018
227adb6
clean up.
codemzs Nov 14, 2018
7a7fe64
clean up.
codemzs Nov 14, 2018
8fc8307
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 15, 2018
cae6f89
refactor code.
codemzs Nov 15, 2018
8381205
refactor code.
codemzs Nov 15, 2018
ccaabd4
refactor code.
codemzs Nov 15, 2018
11b42e4
refactor code.
codemzs Nov 15, 2018
7d21d78
refactor code.
codemzs Nov 16, 2018
878e87d
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 16, 2018
745ff1c
refactor code.
codemzs Nov 16, 2018
6adadee
PR feedback.
codemzs Nov 16, 2018
31b91e7
PR feedback.
codemzs Nov 16, 2018
75ec0b7
PR feedback.
codemzs Nov 16, 2018
0efc46d
checkpoint.
codemzs Nov 21, 2018
a870984
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 21, 2018
b7db101
checkpoint.
codemzs Nov 21, 2018
5cb319a
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 21, 2018
c66940e
checkpointing.
codemzs Nov 21, 2018
237aab4
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 21, 2018
e8d10e5
misc fixes.
codemzs Nov 23, 2018
54593d3
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 23, 2018
7bf4c39
misc fixes.
codemzs Nov 23, 2018
dbf3a20
PR feedback.
codemzs Nov 26, 2018
a6e6d27
cleanup.
codemzs Nov 26, 2018
ffcc045
revert libmf.
codemzs Nov 26, 2018
a28d461
PR feedback.
codemzs Nov 26, 2018
3247537
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 26, 2018
4c27aa1
PR feedback.
codemzs Nov 26, 2018
1a660e4
PR feedback.
codemzs Nov 26, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.ML.Api/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace Microsoft.ML.Runtime.Api
/// <summary>
/// A helper class to create data views based on the user-provided types.
/// </summary>
[BestFriend]
internal static class DataViewConstructionUtils
{
public static IDataView CreateFromList<TRow>(IHostEnvironment env, IList<TRow> data,
Expand Down
67 changes: 51 additions & 16 deletions src/Microsoft.ML.Api/PredictionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,33 @@ public void Reset()
}
}

public sealed class PredictionEngine<TSrc, TDst> : PredictionEngineBase<TSrc, TDst>
where TSrc : class
where TDst : class, new()
{
internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
}

/// <summary>
/// Run prediction pipeline on one example.
/// </summary>
/// <param name="example">The example to run on.</param>
/// <param name="prediction">The object to store the prediction in. If it's <c>null</c>, a new one will be created, otherwise the old one
/// is reused.</param>
public override void Predict(TSrc example, ref TDst prediction)
{
Contracts.CheckValue(example, nameof(example));
ExtractValues(example);
if (prediction == null)
prediction = new TDst();

FillValues(prediction);
}
}

/// <summary>
/// A class that runs the previously trained model (and the preceding transform pipeline) on the
/// in-memory data, one example at a time.
Expand All @@ -130,14 +157,17 @@ public void Reset()
/// </summary>
/// <typeparam name="TSrc">The user-defined type that holds the example.</typeparam>
/// <typeparam name="TDst">The user-defined type that holds the prediction.</typeparam>
public sealed class PredictionEngine<TSrc, TDst>
public abstract class PredictionEngineBase<TSrc, TDst>
where TSrc : class
where TDst : class, new()
{
private readonly DataViewConstructionUtils.InputRow<TSrc> _inputRow;
private readonly IRowReadableAs<TDst> _outputRow;
private readonly Action _disposer;
[BestFriend]
private protected ITransformer Transformer { get; }

[BestFriend]
codemzs marked this conversation as resolved.
Show resolved Hide resolved
private static Func<Schema, IRowToRowMapper> StreamChecker(IHostEnvironment env, Stream modelStream)
{
env.CheckValue(modelStream, nameof(modelStream));
Expand All @@ -150,29 +180,37 @@ private static Func<Schema, IRowToRowMapper> StreamChecker(IHostEnvironment env,
};
}

internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
[BestFriend]
private protected PredictionEngineBase(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
{
Contracts.CheckValue(env, nameof(env));
env.AssertValue(transformer);
Transformer = transformer;
var makeMapper = TransformerChecker(env, transformer);
env.AssertValue(makeMapper);

_inputRow = DataViewConstructionUtils.CreateInputRow<TSrc>(env, inputSchemaDefinition);
var mapper = makeMapper(_inputRow.Schema);
PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, out _disposer, out _outputRow);
}

internal virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow<TSrc> inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs<TDst> outputRow)
{
var cursorable = TypedCursorable<TDst>.Create(env, new EmptyDataView(env, mapper.Schema), ignoreMissingColumns, outputSchemaDefinition);
var outputRow = mapper.GetRow(_inputRow, col => true, out _disposer);
_outputRow = cursorable.GetRow(outputRow);
var outputRowLocal = mapper.GetRow(_inputRow, col => true, out disposer);
outputRow = cursorable.GetRow(outputRowLocal);
}

private static Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
protected virtual Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
{
ectx.CheckValue(transformer, nameof(transformer));
ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
return transformer.GetRowToRowMapper;
}

~PredictionEngine()
internal virtual ITransformer ProcessTransformer(ITransformer transformer) => transformer;

~PredictionEngineBase()
{
_disposer?.Invoke();
}
Expand All @@ -189,19 +227,16 @@ public TDst Predict(TSrc example)
return result;
}

protected void ExtractValues(TSrc example) => _inputRow.ExtractValues(example);

protected void FillValues(TDst prediction) => _outputRow.FillValues(prediction);

/// <summary>
/// Run prediction pipeline on one example.
/// </summary>
/// <param name="example">The example to run on.</param>
/// <param name="prediction">The object to store the prediction in. If it's <c>null</c>, a new one will be created, otherwise the old one
/// is reused.</param>
public void Predict(TSrc example, ref TDst prediction)
{
Contracts.CheckValue(example, nameof(example));
_inputRow.ExtractValues(example);
if (prediction == null)
prediction = new TDst();
_outputRow.FillValues(prediction);
}
public abstract void Predict(TSrc example, ref TDst prediction);
}
}
10 changes: 10 additions & 0 deletions src/Microsoft.ML.Api/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.CompilerServices;
using Microsoft.ML;

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]

[assembly: WantsToBeBestFriends]
1 change: 1 addition & 0 deletions src/Microsoft.ML.Api/TypedCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public interface ICursorable<TRow>
/// Similarly to the 'DataView{T}, this class uses IL generation to create the 'poke' methods that
/// write directly into the fields of the user-defined type.
/// </summary>
[BestFriend]
internal sealed class TypedCursorable<TRow> : ICursorable<TRow>
where TRow : class
{
Expand Down
16 changes: 15 additions & 1 deletion src/Microsoft.ML.Core/Utilities/FixedSizeQueue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,30 @@ namespace Microsoft.ML.Runtime.Internal.Utilities
/// an item will result in discarding the least recently added item.
/// </summary>
[BestFriend]
internal sealed class FixedSizeQueue<T>
internal sealed class FixedSizeQueue<T> : ICloneable
{
private readonly T[] _array;
private int _startIndex;
private int _count;

public int StartIndex => _startIndex;

public T[] Buffer => _array;

public FixedSizeQueue(int capacity)
{
Contracts.Assert(capacity > 0, "Array capacity should be greater than zero");
_array = new T[capacity];
AssertValid();
}

public FixedSizeQueue(int capacity, int startIndex, T[] buffer) : this(capacity)
{
_startIndex = startIndex;
for (int index = 0; index < capacity; index++)
_array[index] = buffer[index];
}

[Conditional("DEBUG")]
private void AssertValid()
{
Expand Down Expand Up @@ -136,5 +147,8 @@ public void Clear()
_count = 0;
AssertValid();
}

public object Clone() => new FixedSizeQueue<T>(Capacity, StartIndex, Buffer);

}
}
75 changes: 44 additions & 31 deletions src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,33 @@ public enum TransformerScope
Everything = Training | Testing | Scoring
}

[BestFriend]
internal interface ITransformerAccessor
{
ITransformer[] Transformers { get; }
TransformerScope[] Scopes { get; }
}

/// <summary>
/// A chain of transformers (possibly empty) that end with a <typeparamref name="TLastTransformer"/>.
/// For an empty chain, <typeparamref name="TLastTransformer"/> is always <see cref="ITransformer"/>.
/// </summary>
public sealed class TransformerChain<TLastTransformer> : ITransformer, ICanSaveModel, IEnumerable<ITransformer>
public sealed class TransformerChain<TLastTransformer> : ITransformer, ICanSaveModel, IEnumerable<ITransformer>, ITransformerAccessor
where TLastTransformer : class, ITransformer
Copy link
Contributor

@Zruty0 Zruty0 Nov 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing. You now can access the inner transformers in 3 ways: via foreach (as an IEnumerable), via the internal field Transformers and via ITransformerAccessor.Transformers. Let's reduce it to one. Maybe make in an IEnumerable of pairs, if you need access to scopes #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or you could make Transformers and Scopes public and IReadOnlyLists


In reply to: 236007659 [](ancestors = 236007659)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created this interface so that I can access transformers and scope from an ITransfomer reference because casting ITransfomer to TransformerChain is messy.


In reply to: 236007746 [](ancestors = 236007746,236007659)

{
private readonly ITransformer[] _transformers;
private readonly TransformerScope[] _scopes;
[BestFriend]
internal readonly ITransformer[] Transformers;
[BestFriend]
internal readonly TransformerScope[] Scopes;
public readonly TLastTransformer LastTransformer;

private const string TransformDirTemplate = "Transform_{0:000}";

public bool IsRowToRowMapper => _transformers.All(t => t.IsRowToRowMapper);
public bool IsRowToRowMapper => Transformers.All(t => t.IsRowToRowMapper);

ITransformer[] ITransformerAccessor.Transformers => Transformers;

TransformerScope[] ITransformerAccessor.Scopes => Scopes;

private static VersionInfo GetVersionInfo()
{
Expand All @@ -72,12 +85,12 @@ public TransformerChain(IEnumerable<ITransformer> transformers, IEnumerable<Tran
Contracts.CheckValueOrNull(transformers);
Contracts.CheckValueOrNull(scopes);

_transformers = transformers?.ToArray() ?? new ITransformer[0];
_scopes = scopes?.ToArray() ?? new TransformerScope[0];
Transformers = transformers?.ToArray() ?? new ITransformer[0];
Scopes = scopes?.ToArray() ?? new TransformerScope[0];
LastTransformer = transformers.LastOrDefault() as TLastTransformer;

Contracts.Check((_transformers.Length > 0) == (LastTransformer != null));
Contracts.Check(_transformers.Length == _scopes.Length);
Contracts.Check((Transformers.Length > 0) == (LastTransformer != null));
Contracts.Check(Transformers.Length == Scopes.Length);
}

/// <summary>
Expand All @@ -91,14 +104,14 @@ public TransformerChain(params ITransformer[] transformers)

if (Utils.Size(transformers) == 0)
{
_transformers = new ITransformer[0];
_scopes = new TransformerScope[0];
Transformers = new ITransformer[0];
Scopes = new TransformerScope[0];
LastTransformer = null;
}
else
{
_transformers = transformers.ToArray();
_scopes = transformers.Select(x => TransformerScope.Everything).ToArray();
Transformers = transformers.ToArray();
Scopes = transformers.Select(x => TransformerScope.Everything).ToArray();
LastTransformer = transformers.Last() as TLastTransformer;
Contracts.Check(LastTransformer != null);
}
Expand All @@ -109,7 +122,7 @@ public Schema GetOutputSchema(Schema inputSchema)
Contracts.CheckValue(inputSchema, nameof(inputSchema));

var s = inputSchema;
foreach (var xf in _transformers)
foreach (var xf in Transformers)
s = xf.GetOutputSchema(s);
return s;
}
Expand All @@ -123,7 +136,7 @@ public IDataView Transform(IDataView input)
GetOutputSchema(input.Schema);

var dv = input;
foreach (var xf in _transformers)
foreach (var xf in Transformers)
dv = xf.Transform(dv);
return dv;
}
Expand All @@ -132,12 +145,12 @@ public TransformerChain<ITransformer> GetModelFor(TransformerScope scopeFilter)
{
var xfs = new List<ITransformer>();
var scopes = new List<TransformerScope>();
for (int i = 0; i < _transformers.Length; i++)
for (int i = 0; i < Transformers.Length; i++)
{
if ((_scopes[i] & scopeFilter) != TransformerScope.None)
if ((Scopes[i] & scopeFilter) != TransformerScope.None)
{
xfs.Add(_transformers[i]);
scopes.Add(_scopes[i]);
xfs.Add(Transformers[i]);
scopes.Add(Scopes[i]);
}
}
return new TransformerChain<ITransformer>(xfs.ToArray(), scopes.ToArray());
Expand All @@ -147,21 +160,21 @@ public TransformerChain<TNewLast> Append<TNewLast>(TNewLast transformer, Transfo
where TNewLast : class, ITransformer
{
Contracts.CheckValue(transformer, nameof(transformer));
return new TransformerChain<TNewLast>(_transformers.AppendElement(transformer), _scopes.AppendElement(scope));
return new TransformerChain<TNewLast>(Transformers.AppendElement(transformer), Scopes.AppendElement(scope));
}

public void Save(ModelSaveContext ctx)
{
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());

ctx.Writer.Write(_transformers.Length);
ctx.Writer.Write(Transformers.Length);

for (int i = 0; i < _transformers.Length; i++)
for (int i = 0; i < Transformers.Length; i++)
{
ctx.Writer.Write((int)_scopes[i]);
ctx.Writer.Write((int)Scopes[i]);
var dirName = string.Format(TransformDirTemplate, i);
ctx.SaveModel(_transformers[i], dirName);
ctx.SaveModel(Transformers[i], dirName);
}
}

Expand All @@ -171,16 +184,16 @@ public void Save(ModelSaveContext ctx)
internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx)
{
int len = ctx.Reader.ReadInt32();
_transformers = new ITransformer[len];
_scopes = new TransformerScope[len];
Transformers = new ITransformer[len];
Scopes = new TransformerScope[len];
for (int i = 0; i < len; i++)
{
_scopes[i] = (TransformerScope)(ctx.Reader.ReadInt32());
Scopes[i] = (TransformerScope)(ctx.Reader.ReadInt32());
var dirName = string.Format(TransformDirTemplate, i);
ctx.LoadModel<ITransformer, SignatureLoadModel>(env, out _transformers[i], dirName);
ctx.LoadModel<ITransformer, SignatureLoadModel>(env, out Transformers[i], dirName);
}
if (len > 0)
LastTransformer = _transformers[len - 1] as TLastTransformer;
LastTransformer = Transformers[len - 1] as TLastTransformer;
else
LastTransformer = null;
}
Expand All @@ -198,7 +211,7 @@ public void SaveTo(IHostEnvironment env, Stream outputStream)
}
}

public IEnumerator<ITransformer> GetEnumerator() => ((IEnumerable<ITransformer>)_transformers).GetEnumerator();
public IEnumerator<ITransformer> GetEnumerator() => ((IEnumerable<ITransformer>)Transformers).GetEnumerator();

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

Expand All @@ -207,11 +220,11 @@ public IRowToRowMapper GetRowToRowMapper(Schema inputSchema)
Contracts.CheckValue(inputSchema, nameof(inputSchema));
Contracts.Check(IsRowToRowMapper, nameof(GetRowToRowMapper) + " method called despite " + nameof(IsRowToRowMapper) + " being false.");

IRowToRowMapper[] mappers = new IRowToRowMapper[_transformers.Length];
IRowToRowMapper[] mappers = new IRowToRowMapper[Transformers.Length];
Schema schema = inputSchema;
for (int i = 0; i < mappers.Length; ++i)
{
mappers[i] = _transformers[i].GetRowToRowMapper(schema);
mappers[i] = Transformers[i].GetRowToRowMapper(schema);
schema = mappers[i].Schema;
}
return new CompositeRowToRowMapper(inputSchema, mappers);
Expand Down
6 changes: 5 additions & 1 deletion src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand All @@ -19,6 +19,9 @@ public sealed class CompositeRowToRowMapper : IRowToRowMapper
public Schema InputSchema { get; }
public Schema Schema { get; }

[BestFriend]
internal IRowToRowMapper[] InnerMappers => _innerMappers;

/// <summary>
/// Out of a series of mappers, construct a seemingly unitary mapper that is able to apply them in sequence.
/// </summary>
Expand Down Expand Up @@ -85,6 +88,7 @@ public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
// We want the last disposer to be called first, so the order of the addition here is important.
}
}

return result;
}

Expand Down
Loading