Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prediction engine for time series. #1618

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a87d836
prediction engine for time series.
codemzs Nov 14, 2018
227adb6
clean up.
codemzs Nov 14, 2018
7a7fe64
clean up.
codemzs Nov 14, 2018
8fc8307
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 15, 2018
cae6f89
refactor code.
codemzs Nov 15, 2018
8381205
refactor code.
codemzs Nov 15, 2018
ccaabd4
refactor code.
codemzs Nov 15, 2018
11b42e4
refactor code.
codemzs Nov 15, 2018
7d21d78
refactor code.
codemzs Nov 16, 2018
878e87d
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 16, 2018
745ff1c
refactor code.
codemzs Nov 16, 2018
6adadee
PR feedback.
codemzs Nov 16, 2018
31b91e7
PR feedback.
codemzs Nov 16, 2018
75ec0b7
PR feedback.
codemzs Nov 16, 2018
0efc46d
checkpoint.
codemzs Nov 21, 2018
a870984
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 21, 2018
b7db101
checkpoint.
codemzs Nov 21, 2018
5cb319a
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 21, 2018
c66940e
checkpointing.
codemzs Nov 21, 2018
237aab4
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 21, 2018
e8d10e5
misc fixes.
codemzs Nov 23, 2018
54593d3
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 23, 2018
7bf4c39
misc fixes.
codemzs Nov 23, 2018
dbf3a20
PR feedback.
codemzs Nov 26, 2018
a6e6d27
cleanup.
codemzs Nov 26, 2018
ffcc045
revert libmf.
codemzs Nov 26, 2018
a28d461
PR feedback.
codemzs Nov 26, 2018
3247537
Merge branch 'master' of https://github.com/dotnet/machinelearning in…
codemzs Nov 26, 2018
4c27aa1
PR feedback.
codemzs Nov 26, 2018
1a660e4
PR feedback.
codemzs Nov 26, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.ML.Api/DataViewConstructionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.ML.Runtime.Api
/// <summary>
/// A helper class to create data views based on the user-provided types.
/// </summary>
[BestFriend]
internal static class DataViewConstructionUtils
{
public static IDataView CreateFromList<TRow>(IHostEnvironment env, IList<TRow> data,
Expand Down
63 changes: 50 additions & 13 deletions src/Microsoft.ML.Api/PredictionEngine.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,29 @@ public void Reset()
}
}

public sealed class PredictionEngine<TSrc, TDst> : PredictionEngineBase<TSrc, TDst>
where TSrc : class
where TDst : class, new()
{
internal PredictionEngine(IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: base(env, modelStream, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
}

internal PredictionEngine(IHostEnvironment env, IDataView dataPipe, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: base(env, dataPipe, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
}

internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
codemzs marked this conversation as resolved.
Show resolved Hide resolved
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: base(env, transformer, ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
}
}

/// <summary>
/// A class that runs the previously trained model (and the preceding transform pipeline) on the
/// in-memory data, one example at a time.
Expand All @@ -132,20 +155,22 @@ public void Reset()
/// </summary>
/// <typeparam name="TSrc">The user-defined type that holds the example.</typeparam>
/// <typeparam name="TDst">The user-defined type that holds the prediction.</typeparam>
public sealed class PredictionEngine<TSrc, TDst>
public abstract class PredictionEngineBase<TSrc, TDst>
where TSrc : class
where TDst : class, new()
{
private readonly DataViewConstructionUtils.InputRow<TSrc> _inputRow;
private readonly IRowReadableAs<TDst> _outputRow;
private readonly Action _disposer;

internal PredictionEngine(IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns,
[BestFriend]
internal PredictionEngineBase(IHostEnvironment env, Stream modelStream, bool ignoreMissingColumns,
codemzs marked this conversation as resolved.
Show resolved Hide resolved
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: this(env, StreamChecker(env, modelStream), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
}

[BestFriend]
codemzs marked this conversation as resolved.
Show resolved Hide resolved
private static Func<Schema, IRowToRowMapper> StreamChecker(IHostEnvironment env, Stream modelStream)
{
env.CheckValue(modelStream, nameof(modelStream));
Expand All @@ -158,13 +183,15 @@ private static Func<Schema, IRowToRowMapper> StreamChecker(IHostEnvironment env,
};
}

internal PredictionEngine(IHostEnvironment env, IDataView dataPipe, bool ignoreMissingColumns,
[BestFriend]
internal PredictionEngineBase(IHostEnvironment env, IDataView dataPipe, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: this(env, new TransformWrapper(env, env.CheckRef(dataPipe, nameof(dataPipe))), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
}

internal PredictionEngine(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
[BestFriend]
internal PredictionEngineBase(IHostEnvironment env, ITransformer transformer, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition = null, SchemaDefinition outputSchemaDefinition = null)
: this(env, TransformerChecker(env, transformer), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition)
{
Expand All @@ -177,20 +204,25 @@ private static Func<Schema, IRowToRowMapper> TransformerChecker(IExceptionContex
return transformer.GetRowToRowMapper;
}

private PredictionEngine(IHostEnvironment env, Func<Schema, IRowToRowMapper> makeMapper, bool ignoreMissingColumns,
internal PredictionEngineBase(IHostEnvironment env, Func<Schema, IRowToRowMapper> makeMapper, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition)
{
Contracts.CheckValue(env, nameof(env));
env.AssertValue(makeMapper);

_inputRow = DataViewConstructionUtils.CreateInputRow<TSrc>(env, inputSchemaDefinition);
var mapper = makeMapper(_inputRow.Schema);

PredictionEngineCore(env, _inputRow, makeMapper(_inputRow.Schema), ignoreMissingColumns, inputSchemaDefinition, outputSchemaDefinition, out _disposer, out _outputRow);
}

internal virtual void PredictionEngineCore(IHostEnvironment env, DataViewConstructionUtils.InputRow<TSrc> inputRow, IRowToRowMapper mapper, bool ignoreMissingColumns,
SchemaDefinition inputSchemaDefinition, SchemaDefinition outputSchemaDefinition, out Action disposer, out IRowReadableAs<TDst> outputRow)
{
var cursorable = TypedCursorable<TDst>.Create(env, new EmptyDataView(env, mapper.Schema), ignoreMissingColumns, outputSchemaDefinition);
var outputRow = mapper.GetRow(_inputRow, col => true, out _disposer);
_outputRow = cursorable.GetRow(outputRow);
var outputRowLocal = mapper.GetRow(_inputRow, col => true, out disposer);
outputRow = cursorable.GetRow(outputRowLocal);
}

~PredictionEngine()
~PredictionEngineBase()
{
_disposer?.Invoke();
}
Expand All @@ -207,19 +239,24 @@ public TDst Predict(TSrc example)
return result;
}

public void ExtractValues(TSrc example) => _inputRow.ExtractValues(example);
Copy link
Contributor

@Zruty0 Zruty0 Nov 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

that's weird, why do you need it public? #Resolved


public void FillValues(TDst prediction) => _outputRow.FillValues(prediction);

/// <summary>
/// Run prediction pipeline on one example.
/// </summary>
/// <param name="example">The example to run on.</param>
/// <param name="prediction">The object to store the prediction in. If it's <c>null</c>, a new one will be created, otherwise the old one
/// is reused.</param>
public void Predict(TSrc example, ref TDst prediction)
public virtual void Predict(TSrc example, ref TDst prediction)
Copy link
Contributor

@Zruty0 Zruty0 Nov 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

virtual [](start = 15, length = 7)

If 2 existing implementations are dissimilar, what's the value in having a virtual method? Make it abstract #Resolved

{
Contracts.CheckValue(example, nameof(example));
_inputRow.ExtractValues(example);
ExtractValues(example);
if (prediction == null)
prediction = new TDst();
_outputRow.FillValues(prediction);

FillValues(prediction);
}
}

Expand Down
18 changes: 14 additions & 4 deletions src/Microsoft.ML.Api/PredictionFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,40 @@

namespace Microsoft.ML.Runtime.Data
{
public sealed class PredictionFunction<TSrc, TDst> : PredictionFunctionBase<TSrc, TDst>
Copy link
Contributor

@Zruty0 Zruty0 Nov 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PredictionFunction [](start = 24, length = 18)

I honestly don't understand the purpose of PredictionFunction. Why introduce a hierarchy of useless classes? I would much rather rename PredictionEngine to PredictionFunction and be done with it.

It doesn't have to be done in this PR, but could you just use PredictionEngine only, to simplify that future change? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm with you 100% on this. I have the same thought when I was making this change.


In reply to: 236007234 [](ancestors = 236007234)

where TSrc : class
where TDst : class, new()
{
public PredictionFunction(IHostEnvironment env, ITransformer transformer) : base(env, transformer) { }
}

/// <summary>
/// A prediction engine class, that takes instances of <typeparamref name="TSrc"/> through
/// the transformer pipeline and produces instances of <typeparamref name="TDst"/> as outputs.
/// </summary>
public sealed class PredictionFunction<TSrc, TDst>
public abstract class PredictionFunctionBase<TSrc, TDst>
where TSrc : class
where TDst : class, new()
{
private readonly PredictionEngine<TSrc, TDst> _engine;
private readonly PredictionEngineBase<TSrc, TDst> _engine;

/// <summary>
/// Create an instance of <see cref="PredictionFunction{TSrc, TDst}"/>.
/// </summary>
/// <param name="env">The host environment.</param>
/// <param name="transformer">The model (transformer) to use for prediction.</param>
public PredictionFunction(IHostEnvironment env, ITransformer transformer)
public PredictionFunctionBase(IHostEnvironment env, ITransformer transformer)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(transformer, nameof(transformer));

IDataView dv = env.CreateDataView(new TSrc[0]);
_engine = env.CreatePredictionEngine<TSrc, TDst>(transformer);
CreatePredictionEngine(env, transformer, out _engine);
}

internal virtual void CreatePredictionEngine(IHostEnvironment env, ITransformer transformer, out PredictionEngineBase<TSrc, TDst> engine) =>
Copy link
Contributor

@Zruty0 Zruty0 Nov 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CreatePredictionEngine [](start = 30, length = 22)

use return value instead of out parameter #Resolved

engine = env.CreatePredictionEngine<TSrc, TDst>(transformer);

/// <summary>
/// Perform one prediction using the model.
/// </summary>
Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.ML.Api/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Runtime.CompilerServices;
using Microsoft.ML;

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries" + PublicKey.Value)]

[assembly: WantsToBeBestFriends]
1 change: 1 addition & 0 deletions src/Microsoft.ML.Api/TypedCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ public interface ICursorable<TRow>
/// Similarly to the 'DataView{T}, this class uses IL generation to create the 'poke' methods that
/// write directly into the fields of the user-defined type.
/// </summary>
[BestFriend]
internal sealed class TypedCursorable<TRow> : ICursorable<TRow>
where TRow : class
{
Expand Down
5 changes: 5 additions & 0 deletions src/Microsoft.ML.Data/DataView/CompositeRowToRowMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using Microsoft.ML.Runtime.Internal.Utilities;

namespace Microsoft.ML.Runtime.Data
Expand All @@ -18,6 +19,9 @@ public sealed class CompositeRowToRowMapper : IRowToRowMapper
public Schema InputSchema { get; }
public Schema Schema { get; }

[BestFriend]
internal IRowToRowMapper[] InnerMappers => _innerMappers;

/// <summary>
/// Out of a series of mappers, construct a seemingly unitary mapper that is able to apply them in sequence.
/// </summary>
Expand Down Expand Up @@ -84,6 +88,7 @@ public IRow GetRow(IRow input, Func<int, bool> active, out Action disposer)
// We want the last disposer to be called first, so the order of the addition here is important.
}
}

return result;
}

Expand Down
Loading