Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction engine for time series. #1618

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a87d836
prediction engine for time series.
codemzs Nov 14, 2018
227adb6
clean up.
codemzs Nov 14, 2018
7a7fe64
clean up.
codemzs Nov 14, 2018
8fc8307
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 15, 2018
cae6f89
refactor code.
codemzs Nov 15, 2018
8381205
refactor code.
codemzs Nov 15, 2018
ccaabd4
refactor code.
codemzs Nov 15, 2018
11b42e4
refactor code.
codemzs Nov 15, 2018
7d21d78
refactor code.
codemzs Nov 16, 2018
878e87d
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 16, 2018
745ff1c
refactor code.
codemzs Nov 16, 2018
6adadee
PR feedback.
codemzs Nov 16, 2018
31b91e7
PR feedback.
codemzs Nov 16, 2018
75ec0b7
PR feedback.
codemzs Nov 16, 2018
0efc46d
checkpoint.
codemzs Nov 21, 2018
a870984
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 21, 2018
b7db101
checkpoint.
codemzs Nov 21, 2018
5cb319a
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 21, 2018
c66940e
checkpointing.
codemzs Nov 21, 2018
237aab4
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 21, 2018
e8d10e5
misc fixes.
codemzs Nov 23, 2018
54593d3
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 23, 2018
7bf4c39
misc fixes.
codemzs Nov 23, 2018
dbf3a20
PR feedback.
codemzs Nov 26, 2018
a6e6d27
cleanup.
codemzs Nov 26, 2018
ffcc045
revert libmf.
codemzs Nov 26, 2018
a28d461
PR feedback.
codemzs Nov 26, 2018
3247537
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 26, 2018
4c27aa1
PR feedback.
codemzs Nov 26, 2018
1a660e4
PR feedback.
codemzs Nov 26, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Microsoft.ML.Api/ApiUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

namespace Microsoft.ML.Runtime.Api
{
[BestFriend]
internal delegate void Peek<in TRow, TValue>(TRow row, long position, ref TValue value);

[BestFriend]
internal delegate void Poke<TRow, TValue>(TRow dst, TValue src);

[BestFriend]
internal static class ApiUtils
{
private static OpCode GetAssignmentOpCode(Type t)
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Api/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.ML.Runtime.Api
/// <summary>
/// A helper class to create data views based on the user-provided types.
/// </summary>
[BestFriend]
internal static class DataViewConstructionUtils
{
public static IDataView CreateFromList<TRow>(IHostEnvironment env, IList<TRow> data,
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.ML.Api/InternalSchemaDefinition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.ML.Runtime.Api
/// <summary>
/// An internal class that holds the (already validated) mapping between a custom type and an IDataView schema.
/// </summary>
[BestFriend]
internal sealed class InternalSchemaDefinition
codemzs marked this conversation as resolved.
Show resolved Hide resolved
{
public readonly Column[] Columns;
Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.ML.Api/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.CompilerServices;
using Microsoft.ML;

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]

[assembly: WantsToBeBestFriends]
2 changes: 2 additions & 0 deletions src/Microsoft.ML.Core/Data/MetadataUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ public static class Kinds
/// slots within that column.
/// </summary>
public const string CategoricalSlotRanges = "CategoricalSlotRanges";

public const string TimeSeriesColumn = "TimeSeriesColumn";
codemzs marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
Expand Down
4 changes: 4 additions & 0 deletions src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.Runtime.Internal.Utilities;

namespace Microsoft.ML.Runtime.Data
Expand All @@ -18,6 +19,8 @@ public sealed class CompositeRowToRowMapper : IRowToRowMapper
public Schema InputSchema { get; }
public Schema Schema { get; }

public IRowToRowMapper[] InnerMappers => _innerMappers;
codemzs marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Out of a series of mappers, construct a seemingly unitary mapper that is able to apply them in sequence.
/// </summary>
Expand Down Expand Up @@ -84,6 +87,7 @@ public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
// We want the last disposer to be called first, so the order of the addition here is important.
}
}

return result;
}

Expand Down
253 changes: 253 additions & 0 deletions src/Microsoft.ML.TimeSeries/PredictionEngine.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;

namespace Microsoft.ML.TimeSeries
{
/// <summary>
/// A class that runs the previously trained model (and the preceding transform pipeline) on the
/// in-memory data, one example at a time.
/// This can also be used with trained pipelines that do not end with a predictor: in this case, the
/// 'prediction' will be just the outcome of all the transformations.
/// </summary>
/// <typeparam name="TSrc">The user-defined type that holds the example.</typeparam>
/// <typeparam name="TDst">The user-defined type that holds the prediction.</typeparam>
public sealed class PredictionEngine<TSrc, TDst>
codemzs marked this conversation as resolved.
Show resolved Hide resolved
where TSrc : class
where TDst : class, new()
{
private readonly DataViewConstructionUtils.InputRow<TSrc> _inputRow;
private readonly IRowReadableAs<TDst> _outputRow;
private readonly IStatefulRowReadableAs<TDst>[] _statefulRows;
private readonly Action _disposer;

internal PredictionEngine(IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: this(env, StreamChecker(env, modelStream), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
}

private static Func<Schema, IRowToRowMapper> StreamChecker(IHostEnvironment env, Stream modelStream)
{
env.CheckValue(modelStream, nameof(modelStream));
return schema =>
{
var pipe = DataViewConstructionUtils.LoadPipeWithPredictor(env, modelStream, new EmptyDataView(env, schema));
var transformer = new TransformWrapper(env, pipe);
env.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
return transformer.GetRowToRowMapper(schema);
};
}

internal PredictionEngine(IHostEnvironment env, IDataView dataPipe, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: this(env, new TransformWrapper(env, env.CheckRef(dataPipe, nameof(dataPipe))), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
}

internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: this(env, TransformerChecker(env, transformer), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
}

private static Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionContext ectx, ITransformer transformer)
{
ectx.CheckValue(transformer, nameof(transformer));
ectx.CheckParam(transformer.IsRowToRowMapper, nameof(transformer), "Must be a row to row mapper");
return transformer.GetRowToRowMapper;
}

public void GetStatefulRows(IRow input, IRowToRowMapper mapper, Func<int, bool> active,
List<IStatefulRow> rows, out Action disposer, out IRow outRow)
{
Contracts.CheckValue(input, nameof(input));
Contracts.CheckValue(active, nameof(active));

disposer = null;
IRowToRowMapper[] innerMappers = new IRowToRowMapper[0];
if (mapper is CompositeRowToRowMapper)
innerMappers = ((CompositeRowToRowMapper)mapper).InnerMappers;

if (innerMappers.Length == 0)
{
bool differentActive = false;
for (int c = 0; c < input.Schema.ColumnCount; ++c)
{
bool wantsActive = active(c);
bool isActive = input.IsColumnActive(c);
differentActive |= wantsActive != isActive;

if (wantsActive && !isActive)
throw Contracts.ExceptParam(nameof(input), $"Mapper required column '{input.Schema.GetColumnName(c)}' active but it was not.");
}

outRow = input;
}

// For each of the inner mappers, we will be calling their GetRow method, but to do so we need to know
// what we need from them. The last one will just have the input, but the rest will need to be
// computed based on the dependencies of the next one in the chain.
var deps = new Func<int, bool>[innerMappers.Length];
deps[deps.Length - 1] = active;
for (int i = deps.Length - 1; i >= 1; --i)
deps[i - 1] = innerMappers[i].GetDependencies(deps[i]);

IRow result = input;
for (int i = 0; i < innerMappers.Length; ++i)
{
Action localDisp;
if (innerMappers[i] is CompositeRowToRowMapper)
GetStatefulRows(result, innerMappers[i], deps[i], rows, out localDisp, out result);
else
result = innerMappers[i].GetRow(result, deps[i], out localDisp);

if (result is IStatefulRow)
rows.Add((IStatefulRow)input);
codemzs marked this conversation as resolved.
Show resolved Hide resolved

if (localDisp != null)
{
if (disposer == null)
disposer = localDisp;
else
disposer = localDisp + disposer;
// We want the last disposer to be called first, so the order of the addition here is important.
}
}

outRow = input;
}

private PredictionEngine(IHostEnvironment env, Func<Schema, IRowToRowMapper> makeMapper, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition)
{
Copy link
Contributor

@Zruty0 Zruty0 Nov 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{ [](start = 8, length = 1)

I think you can just merge them all onto one delegate, there's no need to have arrays of arrays of delegates #Resolved

Contracts.CheckValue(env, nameof(env));
env.AssertValue(makeMapper);

_inputRow = DataViewConstructionUtils.CreateInputRow<TSrc>(env, inputSchemaDefinition);
var mapper = makeMapper(_inputRow.Schema);

List<IStatefulRow> rows = new List<IStatefulRow>();
if (mapper is CompositeRowToRowMapper)
GetStatefulRows(_inputRow, mapper, col => true, rows, out _disposer, out var outRow);

var cursorable = TypedCursorable<TDst>.Create(env, new EmptyDataView(env, mapper.Schema), ignoreMissingColumns, outputSchemaDefinition);
var outputRow = mapper.GetRow(_inputRow, col => true, out _disposer);

if (rows.Count == 0 && outputRow is IStatefulRow)
rows.Add((IStatefulRow)outputRow);

Copy link
Contributor

@Zruty0 Zruty0 Nov 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like an unroll of 1 level of recursion. Can't GetStatefulRows also handle case where mapper is not a CompositeRowToRowMapper? #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still active


In reply to: 233642937 [](ancestors = 233642937)

_statefulRows = rows.Select(row => cursorable.GetRow(row)).ToArray();
codemzs marked this conversation as resolved.
Show resolved Hide resolved

_outputRow = cursorable.GetRow(outputRow);
}

~PredictionEngine()
{
_disposer?.Invoke();
}

/// <summary>
/// Run prediction pipeline on one example.
/// </summary>
/// <param name="example">The example to run on.</param>
/// <returns>The result of prediction. A new object is created for every call.</returns>
public TDst Predict(TSrc example)
{
var result = new TDst();
Predict(example, ref result);
return result;
}

/// <summary>
/// Run prediction pipeline on one example.
/// </summary>
/// <param name="example">The example to run on.</param>
/// <param name="prediction">The object to store the prediction in. If it's <c>null</c>, a new one will be created, otherwise the old one
/// is reused.</param>
public void Predict(TSrc example, ref TDst prediction)
{
Contracts.CheckValue(example, nameof(example));
_inputRow.ExtractValues(example);
if (prediction == null)
prediction = new TDst();

foreach (var row in _statefulRows)
row.PingValues(prediction);

_outputRow.FillValues(prediction);
}
}

/// <summary>
/// A prediction engine class, that takes instances of <typeparamref name="TSrc"/> through
/// the transformer pipeline and produces instances of <typeparamref name="TDst"/> as outputs.
/// </summary>
public sealed class PredictionFunction<TSrc, TDst>
where TSrc : class
where TDst : class, new()
{
private readonly PredictionEngine<TSrc, TDst> _engine;

/// <summary>
/// Create an instance of <see cref="PredictionFunction{TSrc, TDst}"/>.
/// </summary>
/// <param name="env">The host environment.</param>
/// <param name="transformer">The model (transformer) to use for prediction.</param>
public PredictionFunction(IHostEnvironment env, ITransformer transformer)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(transformer, nameof(transformer));

IDataView dv = env.CreateDataView(new TSrc[0]);
_engine = env.CreateTimeSeriesPredictionEngine<TSrc, TDst>(transformer);
}

/// <summary>
/// Perform one prediction using the model.
/// </summary>
/// <param name="example">The object that holds values to predict from.</param>
/// <returns>The object populated with prediction results.</returns>
public TDst Predict(TSrc example) => _engine.Predict(example);

/// <summary>
/// Perform one prediction using the model.
/// Reuses the provided prediction object, which is more efficient in high-load scenarios.
/// </summary>
/// <param name="example">The object that holds values to predict from.</param>
/// <param name="prediction">The object to store the predictions in. If it's <c>null</c>, a new object is created,
/// otherwise the provided object is used.</param>
public void Predict(TSrc example, ref TDst prediction) => _engine.Predict(example, ref prediction);
}

public static class PredictionFunctionExtensions
{
public static PredictionEngine<TSrc, TDst> CreateTimeSeriesPredictionEngine<TSrc, TDst>(this IHostEnvironment env, ITransformer transformer,
bool ignoreMissingColumns = false, SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
where TSrc : class
where TDst : class, new()
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(transformer, nameof(transformer));
env.CheckValueOrNull(inputSchemaDefinition);
env.CheckValueOrNull(outputSchemaDefinition);
return new PredictionEngine<TSrc, TDst>(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition);
}

/// <summary>
/// Create an instance of the 'prediction function', or 'prediction machine', from a model
/// denoted by <paramref name="transformer"/>.
/// It will be accepting instances of <typeparamref name="TSrc"/> as input, and produce
/// instances of <typeparamref name="TDst"/> as output.
/// </summary>
public static PredictionFunction<TSrc, TDst> MakeTimeSeriesPredictionFunction<TSrc, TDst>(this ITransformer transformer, IHostEnvironment env)
where TSrc : class
where TDst : class, new()
=> new PredictionFunction<TSrc, TDst>(env, transformer);
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand Down Expand Up @@ -574,6 +574,11 @@ public Schema.Column[] GetOutputColumns()
{
var meta = new Schema.Metadata.Builder();
meta.AddSlotNames(_parent._outputLength, GetSlotNames);
ValueGetter<bool> getter = (ref bool dst) =>
{
dst = true;
};
meta.Add(new Schema.Column(MetadataUtils.Kinds.TimeSeriesColumn, BoolType.Instance, null), getter);
var info = new Schema.Column[1];
info[0] = new Schema.Column(_parent.OutputColumnName, new VectorType(NumberType.R8, _parent._outputLength), meta.GetMetadata());
return info;
Expand Down Expand Up @@ -612,13 +617,14 @@ private Delegate MakeGetter(IRow input, TState state)
var srcGetter = input.GetGetter<TInput>(_inputColumnIndex);
ProcessData processData = _parent.WindowSize > 0 ?
(ProcessData)state.Process : state.ProcessWithoutBuffer;
ValueGetter<VBuffer<double>> valueGetter = (ref VBuffer<double> dst) =>
{
TInput src = default;
srcGetter(ref src);
processData(ref src, ref dst);
};

state.Row = input;
ValueGetter <VBuffer<double>> valueGetter = (ref VBuffer<double> dst) =>
{
TInput src = default;
srcGetter(ref src);
processData(ref src, ref dst);
};
return valueGetter;
}
}
Expand Down
Loading