diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/AzPredictorServiceTests.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/AzPredictorServiceTests.cs index cd209ec4ba84..262279c87903 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/AzPredictorServiceTests.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/AzPredictorServiceTests.cs @@ -1,4 +1,4 @@ -// ---------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------- // // Copyright Microsoft Corporation // Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,7 +13,10 @@ // ---------------------------------------------------------------------------------- using Microsoft.Azure.PowerShell.Tools.AzPredictor.Test.Mocks; +using System; +using System.Collections.Generic; using System.Linq; +using System.Management.Automation.Language; using System.Management.Automation.Subsystem; using System.Threading; using Xunit; @@ -26,10 +29,36 @@ namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Test [Collection("Model collection")] public class AzPredictorServiceTests { + private class PredictiveSuggestionComparer : EqualityComparer + { + public override bool Equals(PredictiveSuggestion first, PredictiveSuggestion second) + { + if ((first == null) && (second == null)) + { + return true; + } + else if ((first == null) || (second == null)) + { + return false; + } + + return string.Equals(first.SuggestionText, second.SuggestionText, StringComparison.Ordinal); + } + + public override int GetHashCode(PredictiveSuggestion suggestion) + { + return suggestion.SuggestionText.GetHashCode(); + } + } + private readonly ModelFixture _fixture; private readonly AzPredictorService _service; - private readonly Predictor _suggestionsPredictor; - private readonly Predictor _commandsPredictor; + private readonly CommandLinePredictor _commandBasedPredictor; + private readonly CommandLinePredictor _fallbackPredictor; + + private readonly AzPredictorService _noFallbackPredictorService; + private readonly AzPredictorService _noCommandBasedPredictorService; + private readonly AzPredictorService _noPredictorService; /// /// Constructs a new instance of @@ -39,15 +68,36 @@ public AzPredictorServiceTests(ModelFixture fixture) { this._fixture = fixture; var startHistory = $"{AzPredictorConstants.CommandPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandPlaceholder}"; - this._suggestionsPredictor = new Predictor(this._fixture.PredictionCollection[startHistory], null); - this._commandsPredictor = new Predictor(this._fixture.CommandCollection, null); + this._commandBasedPredictor = new CommandLinePredictor(this._fixture.PredictionCollection[startHistory], null); + this._fallbackPredictor = new CommandLinePredictor(this._fixture.CommandCollection, null); this._service = new MockAzPredictorService(startHistory, this._fixture.PredictionCollection[startHistory], this._fixture.CommandCollection); + + this._noFallbackPredictorService = new MockAzPredictorService(startHistory, this._fixture.PredictionCollection[startHistory], null); + this._noCommandBasedPredictorService = new MockAzPredictorService(null, null, this._fixture.CommandCollection); + this._noPredictorService = new MockAzPredictorService(null, null, null); } + /// + /// Verify the method checks parameter values. + /// + [Fact] + public void VerifyParameterValues() + { + var predictionContext = PredictionContext.Create("Get-AzContext"); + + Action actual = () => this._service.GetSuggestion(null, 1, 1, CancellationToken.None); + Assert.Throws(actual); + + actual = () => this._service.GetSuggestion(predictionContext.InputAst, 0, 1, CancellationToken.None); + Assert.Throws(actual); + + actual = () => this._service.GetSuggestion(predictionContext.InputAst, 1, 0, CancellationToken.None); + Assert.Throws(actual); + } /// - /// Verifies that the prediction comes from the suggestions list, not the command list. + /// Verifies that the prediction comes from the command based list, not the fallback list. /// [Theory] [InlineData("CONNECT-AZACCOUNT")] @@ -59,52 +109,126 @@ public AzPredictorServiceTests(ModelFixture fixture) [InlineData("new-azresourcegroup -name hello")] [InlineData("Get-AzContext -Name")] [InlineData("Get-AzContext -ErrorAction")] - public void VerifyUsingSuggestion(string userInput) + public void VerifyUsingCommandBasedPredictor(string userInput) { var predictionContext = PredictionContext.Create(userInput); - var presentCommands = new System.Collections.Generic.Dictionary(); - var expected = this._suggestionsPredictor.Query(predictionContext.InputAst, presentCommands, 1, 1, CancellationToken.None); + var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; + var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; + var inputParameterSet = new ParameterSet(commandAst); + var rawUserInput = predictionContext.InputAst.Extent.Text; + var presentCommands = new Dictionary(); + var expected = this._commandBasedPredictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + 1, + 1, + CancellationToken.None); + var actual = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None); - Assert.NotEmpty(actual); - Assert.NotNull(actual.First().Item1); - Assert.Equal(expected.Item1.First().Key, actual.First().Item1); - Assert.Equal(PredictionSource.CurrentCommand, actual.First().Item3); + Assert.NotNull(actual); + Assert.True(actual.Count > 0); + Assert.NotNull(actual.PredictiveSuggestions.First()); + Assert.NotNull(actual.PredictiveSuggestions.First().SuggestionText); + Assert.Equal(expected.Count, actual.Count); + Assert.Equal(expected.PredictiveSuggestions, actual.PredictiveSuggestions, new PredictiveSuggestionComparer()); + Assert.Equal(expected.SourceTexts, actual.SourceTexts); + Assert.All(actual.SuggestionSources, (source) => Assert.Equal(SuggestionSource.CurrentCommand, source)); + + actual = this._noFallbackPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None); + Assert.NotNull(actual); + Assert.True(actual.Count > 0); + Assert.NotNull(actual.PredictiveSuggestions.First()); + Assert.NotNull(actual.PredictiveSuggestions.First().SuggestionText); + Assert.Equal(expected.Count, actual.Count); + Assert.Equal(expected.PredictiveSuggestions, actual.PredictiveSuggestions, new PredictiveSuggestionComparer()); + Assert.Equal(expected.SourceTexts, actual.SourceTexts); + Assert.All(actual.SuggestionSources, (source) => Assert.Equal(SuggestionSource.CurrentCommand, source)); } /// - /// Verifies that when no prediction is in the suggestion list, we'll use the command list. + /// Verifies that when no prediction is in the command based list, we'll use the fallback list. /// [Theory] [InlineData("Get-AzResource -Name hello -Pre")] [InlineData("Get-AzADServicePrincipal -ApplicationObject")] - public void VerifyUsingCommand(string userInput) + public void VerifyUsingFallbackPredictor(string userInput) { var predictionContext = PredictionContext.Create(userInput); - var presentCommands = new System.Collections.Generic.Dictionary(); - var expected = this._commandsPredictor.Query(predictionContext.InputAst, presentCommands, 1, 1, CancellationToken.None); + var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; + var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; + var inputParameterSet = new ParameterSet(commandAst); + var rawUserInput = predictionContext.InputAst.Extent.Text; + var presentCommands = new Dictionary(); + var expected = this._fallbackPredictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + 1, + 1, + CancellationToken.None); + var actual = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None); - Assert.NotEmpty(actual); - Assert.NotNull(actual.First().Item1); - Assert.Equal(expected.Item1.First().Key, actual.First().Item1); - Assert.Equal(PredictionSource.StaticCommands, actual.First().Item3); + Assert.NotNull(actual); + Assert.True(actual.Count > 0); + Assert.NotNull(actual.PredictiveSuggestions.First()); + Assert.NotNull(actual.PredictiveSuggestions.First().SuggestionText); + Assert.Equal(expected.Count, actual.Count); + Assert.Equal(expected.PredictiveSuggestions, actual.PredictiveSuggestions, new PredictiveSuggestionComparer()); + Assert.Equal(expected.SourceTexts, actual.SourceTexts); + Assert.All(actual.SuggestionSources, (source) => Assert.Equal(SuggestionSource.StaticCommands, source)); + + actual = this._noCommandBasedPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None); + Assert.NotNull(actual); + Assert.True(actual.Count > 0); + Assert.NotNull(actual.PredictiveSuggestions.First()); + Assert.NotNull(actual.PredictiveSuggestions.First().SuggestionText); + Assert.Equal(expected.Count, actual.Count); + Assert.Equal(expected.PredictiveSuggestions, actual.PredictiveSuggestions, new PredictiveSuggestionComparer()); + Assert.Equal(expected.SourceTexts, actual.SourceTexts); + Assert.All(actual.SuggestionSources, (source) => Assert.Equal(SuggestionSource.StaticCommands, source)); } /// - /// Verify that no prediction for the user input, meaning it's not in the prediction list or the command list. + /// Verify that no prediction for the user input, meaning it's not in the command based list or the fallback list. /// [Theory] [InlineData(AzPredictorConstants.CommandPlaceholder)] - [InlineData("git status")] [InlineData("Get-ChildItem")] [InlineData("new-azresourcegroup -NoExistingParam")] [InlineData("get-azaccount ")] - [InlineData("Get-AzContext Name")] [InlineData("NEW-AZCONTEXT")] public void VerifyNoPrediction(string userInput) { var predictionContext = PredictionContext.Create(userInput); var actual = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None); - Assert.Empty(actual); + Assert.Equal(0, actual.Count); + + actual = this._noFallbackPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None); + Assert.Equal(0, actual.Count); + + actual = this._noCommandBasedPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None); + Assert.Equal(0, actual.Count); + + actual = this._noPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None); + Assert.Null(actual); + } + + /// + /// Verify when we cannot parse the user input correctly. + /// + /// + /// When we can parse them correctly, please move the InlineData to the corresponding test methods, for example, "git status" + /// doesn't have any prediction so it should move to . + /// + [Theory] + [InlineData("git status")] + [InlineData("Get-AzContext Name")] + public void VerifyMalFormattedCommandLine(string userInput) + { + var predictionContext = PredictionContext.Create(userInput); + Action actual = () => this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None); + _ = Assert.Throws(actual); } } } diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/AzPredictorTests.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/AzPredictorTests.cs index b8dd8a854d6c..4032b733f235 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/AzPredictorTests.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/AzPredictorTests.cs @@ -45,6 +45,7 @@ public AzPredictorTests(ModelFixture modelFixture) this._azPredictor = new AzPredictor(this._service, this._telemetryClient, new Settings() { SuggestionCount = 1, + MaxAllowedCommandDuplicate = 1, }, null); } @@ -134,7 +135,6 @@ public void VerifySupportedCommandMasked() /// Verifies AzPredictor returns the same value as AzPredictorService for the prediction. /// [Theory] - [InlineData("git status")] [InlineData("new-azresourcegroup -name hello")] [InlineData("Get-AzContext -Name")] [InlineData("Get-AzContext -ErrorAction")] @@ -145,7 +145,8 @@ public void VerifySuggestion(string userInput) var expected = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None); var actual = this._azPredictor.GetSuggestion(predictionContext, CancellationToken.None); - Assert.Equal(expected.Select(e => e.Item1), actual.Select(a => a.SuggestionText)); + Assert.Equal(expected.Count, actual.Count); + Assert.Equal(expected.PredictiveSuggestions.First().SuggestionText, actual.First().SuggestionText); } /// @@ -158,6 +159,7 @@ public void VerifySuggestionOnIncompleteCommand() var localAzPredictor = new AzPredictor(this._service, this._telemetryClient, new Settings() { SuggestionCount = 7, + MaxAllowedCommandDuplicate = 1, }, null); @@ -169,5 +171,23 @@ public void VerifySuggestionOnIncompleteCommand() Assert.Equal(expected, actual.First().SuggestionText); } + + + /// + /// Verify when we cannot parse the user input correctly. + /// + /// + /// When we can parse them correctly, please move the InlineData to the corresponding test methods, for example, "git status" + /// can be moved to . + /// + [Theory] + [InlineData("git status")] + public void VerifyMalFormattedCommandLine(string userInput) + { + var predictionContext = PredictionContext.Create(userInput); + var actual = this._azPredictor.GetSuggestion(predictionContext, CancellationToken.None); + + Assert.Empty(actual); + } } } diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/CommandLinePredictorTests.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/CommandLinePredictorTests.cs new file mode 100644 index 000000000000..dd2bd9a6c32d --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/CommandLinePredictorTests.cs @@ -0,0 +1,266 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Management.Automation.Language; +using System.Management.Automation.Subsystem; +using System.Threading; +using Xunit; + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Test +{ + /// + /// Test cases for + /// + [Collection("Model collection")] + public class CommandLinePredictorTests + { + private readonly ModelFixture _fixture; + private readonly CommandLinePredictor _predictor; + + /// + /// Constructs a new instance of + /// + public CommandLinePredictorTests(ModelFixture fixture) + { + this._fixture = fixture; + var startHistory = $"{AzPredictorConstants.CommandPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandPlaceholder}"; + this._predictor = new CommandLinePredictor(this._fixture.PredictionCollection[startHistory], null); + } + + /// + /// Verify the method checks parameter values. + /// + [Fact] + public void VerifyParameterValues() + { + var predictionContext = PredictionContext.Create("Get-AzContext"); + var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; + var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; + var inputParameterSet = new ParameterSet(commandAst); + var rawUserInput = predictionContext.InputAst.Extent.Text; + var presentCommands = new Dictionary(); + + Action actual = () => this._predictor.GetSuggestion(null, + inputParameterSet, + rawUserInput, + presentCommands, + 1, + 1, + CancellationToken.None); + Assert.Throws(actual); + + actual = () => this._predictor.GetSuggestion(commandName, + null, + rawUserInput, + presentCommands, + 1, + 1, + CancellationToken.None); + Assert.Throws(actual); + + actual = () => this._predictor.GetSuggestion(commandName, + inputParameterSet, + null, + presentCommands, + 1, + 1, + CancellationToken.None); + Assert.Throws(actual); + + actual = () => this._predictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + null, + 1, + 1, + CancellationToken.None); + Assert.Throws(actual); + + actual = () => this._predictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + 0, + 1, + CancellationToken.None); + Assert.Throws(actual); + + actual = () => this._predictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + 1, + 0, + CancellationToken.None); + Assert.Throws(actual); + } + + /// + /// Tests in the case there is no prediction for the user input or the user input matches exact what we have in the model. + /// + [Theory] + [InlineData("NEW-AZCONTEXT")] + [InlineData("get-azaccount ")] + [InlineData(AzPredictorConstants.CommandPlaceholder)] + [InlineData("Get-ChildItem")] + public void GetNoPredictionWithCommandName(string userInput) + { + var predictionContext = PredictionContext.Create(userInput); + var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; + var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; + var inputParameterSet = new ParameterSet(commandAst); + var rawUserInput = predictionContext.InputAst.Extent.Text; + var presentCommands = new Dictionary(); + var result = this._predictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + 1, + 1, + CancellationToken.None); + Assert.Equal(0, result.Count); + } + + /// + /// Tests in the case there are no az commands in the history. + /// + [Theory] + [InlineData("New-AzKeyVault ")] + [InlineData("CONNECT-AZACCOUNT")] + [InlineData("set-azstorageaccount ")] + [InlineData("Get-AzResourceG")] + [InlineData("Get-AzStorageAcco")] // an imcomplete command and there is a record "Get-AzStorageAccount" in the model. + public void GetPredictionWithCommandName(string userInput) + { + var predictionContext = PredictionContext.Create(userInput); + var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; + var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; + var inputParameterSet = new ParameterSet(commandAst); + var rawUserInput = predictionContext.InputAst.Extent.Text; + var presentCommands = new Dictionary(); + var result = this._predictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + 1, + 1, + CancellationToken.None); + Assert.True(result.Count > 0); + } + + /// + /// Tests in the case when the user inputs the command name and parameters. + /// + [Theory] + [InlineData("Get-AzKeyVault -VaultName")] + [InlineData("GET-AZSTORAGEACCOUNTKEY -NAME ")] + [InlineData("new-azresourcegroup -name hello")] + [InlineData("Get-AzContext -Name")] + public void GetPredictionWithCommandNameParameters(string userInput) + { + var predictionContext = PredictionContext.Create(userInput); + var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; + var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; + var inputParameterSet = new ParameterSet(commandAst); + var rawUserInput = predictionContext.InputAst.Extent.Text; + var presentCommands = new Dictionary(); + var result = this._predictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + 1, + 1, + CancellationToken.None); + Assert.True(result.Count > 0); + } + + /// + /// Tests in the case when the user inputs the command name and parameters. + /// + [Theory] + [InlineData("Get-AzResource -Name hello -Pre")] + [InlineData("Get-AzADServicePrincipal -ApplicationObject")] // Doesn't exist + [InlineData("new-azresourcegroup -NoExistingParam")] + [InlineData("Set-StorageAccount -WhatIf")] + // Enable "git status" and "Get-AzContext Name" when ParameterSet can parse this format of command + // [InlineData("git status")] + // [InlineData("Get-AzContext Name")] // a wrong command + public void GetNoPredictionWithCommandNameParameters(string userInput) + { + var predictionContext = PredictionContext.Create(userInput); + var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; + var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; + var inputParameterSet = new ParameterSet(commandAst); + var rawUserInput = predictionContext.InputAst.Extent.Text; + var presentCommands = new Dictionary(); + var result = this._predictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + 1, + 1, + CancellationToken.None); + Assert.Equal(0, result.Count); + } + + /// + /// Verify that the prediction for the command (without parameter) has the right parameters. + /// + [Fact] + public void VerifyPredictionForCommand() + { + var predictionContext = PredictionContext.Create("Connect-AzAccount"); + var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; + var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; + var inputParameterSet = new ParameterSet(commandAst); + var rawUserInput = predictionContext.InputAst.Extent.Text; + var presentCommands = new Dictionary(); + var result = this._predictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + 1, + 1, + CancellationToken.None); + + Assert.Equal("Connect-AzAccount -Credential -ServicePrincipal -Tenant <>", result.PredictiveSuggestions.First().SuggestionText); + } + + /// + /// Verify that the prediction for the command (with parameter) has the right parameters. + /// + [Fact] + public void VerifyPredictionForCommandAndParameters() + { + var predictionContext = PredictionContext.Create("GET-AZSTORAGEACCOUNTKEY -NAME"); + var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; + var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; + var inputParameterSet = new ParameterSet(commandAst); + var rawUserInput = predictionContext.InputAst.Extent.Text; + var presentCommands = new Dictionary(); + var result = this._predictor.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + 1, + 1, + CancellationToken.None); + + Assert.Equal("Get-AzStorageAccountKey -Name 'ContosoStorage' -ResourceGroupName 'ContosoGroup02'", result.PredictiveSuggestions.First().SuggestionText); + } + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/Mocks/MockAzPredictorService.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/Mocks/MockAzPredictorService.cs index 8c9fd8e5ff27..334aa5d01ed6 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/Mocks/MockAzPredictorService.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/Mocks/MockAzPredictorService.cs @@ -34,9 +34,20 @@ sealed class MockAzPredictorService : AzPredictorService /// The commands collection public MockAzPredictorService(string history, IList suggestions, IList commands) { - SetPredictionCommand(history); - SetCommandsPredictor(commands); - SetSuggestionPredictor(history, suggestions); + if (history != null) + { + SetCommandToRequestPrediction(history); + + if (suggestions != null) + { + SetCommandBasedPreditor(history, suggestions); + } + } + + if (commands != null) + { + SetFallbackPredictor(commands); + } } /// @@ -46,7 +57,7 @@ public override void RequestPredictions(IEnumerable history) } /// - protected override void RequestCommands() + protected override void RequestAllPredictiveCommands() { // Do nothing since we've set the command and suggestion predictors. } diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/Mocks/MockAzPredictorTelemetryClient.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/Mocks/MockAzPredictorTelemetryClient.cs index cd51a155ebe8..1728ac6ea0fc 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/Mocks/MockAzPredictorTelemetryClient.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/Mocks/MockAzPredictorTelemetryClient.cs @@ -13,7 +13,7 @@ // ---------------------------------------------------------------------------------- using System; -using System.Collections.Generic; +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry; namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Test.Mocks { @@ -34,38 +34,29 @@ public class RecordedSuggestionForHistory public int SuggestionAccepted { get; set; } /// - public void OnHistory(string historyLine) + public void OnHistory(HistoryTelemetryData telemetryData) { this.RecordedSuggestion = new RecordedSuggestionForHistory() { - HistoryLine = historyLine, + HistoryLine = telemetryData.Command, }; } /// - public void OnRequestPrediction(string command) + public void OnRequestPrediction(RequestPredictionTelemetryData telemetryData) { } /// - public void OnRequestPredictionError(string command, Exception e) - { - } - - /// - public void OnSuggestionAccepted(string acceptedSuggestion) + public void OnSuggestionAccepted(SuggestionAcceptedTelemetryData telemetryData) { ++this.SuggestionAccepted; } /// - public void OnGetSuggestion(string maskedUserInput, IEnumerable> suggestions, bool isCancelled) + public void OnGetSuggestion(GetSuggestionTelemetryData telemetryData) { } - /// - public void OnGetSuggestionError(Exception e) - { - } } } diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/ModelFixture.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/ModelFixture.cs index 3f029d6d6272..74538f27a8f6 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/ModelFixture.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/ModelFixture.cs @@ -12,11 +12,12 @@ // limitations under the License. // ---------------------------------------------------------------------------------- -using Newtonsoft.Json; +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities; using System; using System.Collections.Generic; using System.IO; using System.IO.Compression; +using System.Text.Json; using Xunit; namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Test @@ -54,8 +55,8 @@ public ModelFixture() var commandsModel = ModelFixture.ReadZipEntry(Path.Join(dataDirectory, ModelFixture.CommandsModelZip), ModelFixture.CommandsModelJson); var predictionsModel = ModelFixture.ReadZipEntry(Path.Join(dataDirectory, ModelFixture.PredictionsModelZip), ModelFixture.PredictionsModelJson); - this.CommandCollection = JsonConvert.DeserializeObject>(commandsModel); - this.PredictionCollection = JsonConvert.DeserializeObject>>(predictionsModel); + this.CommandCollection = JsonSerializer.Deserialize>(commandsModel, JsonUtilities.DefaultSerializerOptions); + this.PredictionCollection = JsonSerializer.Deserialize>>(predictionsModel, JsonUtilities.DefaultSerializerOptions); } /// diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/PredictorTests.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/PredictorTests.cs deleted file mode 100644 index 37f7862ed376..000000000000 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/PredictorTests.cs +++ /dev/null @@ -1,134 +0,0 @@ -// ---------------------------------------------------------------------------------- -// -// Copyright Microsoft Corporation -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ---------------------------------------------------------------------------------- - -using System.Linq; -using System.Management.Automation.Subsystem; -using System.Threading; -using Xunit; - -namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Test -{ - /// - /// Test cases for - /// - [Collection("Model collection")] - public class PredictorTests - { - private readonly ModelFixture _fixture; - private readonly Predictor _predictor; - - /// - /// Constructs a new instance of - /// - public PredictorTests(ModelFixture fixture) - { - this._fixture = fixture; - var startHistory = $"{AzPredictorConstants.CommandPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandPlaceholder}"; - this._predictor = new Predictor(this._fixture.PredictionCollection[startHistory], null); - } - - /// - /// Tests in the case there is no prediction for the user input or the user input matches exact what we have in the model. - /// - [Theory] - [InlineData("NEW-AZCONTEXT")] - [InlineData("get-azaccount ")] - [InlineData(AzPredictorConstants.CommandPlaceholder)] - [InlineData("git status")] - [InlineData("Get-ChildItem")] - public void GetNullPredictionWithCommandName(string userInput) - { - var predictionContext = PredictionContext.Create(userInput); - var presentCommands = new System.Collections.Generic.Dictionary(); - var result = this._predictor.Query(predictionContext.InputAst, presentCommands, 1, 1, CancellationToken.None); - Assert.Empty(result.Item1); - } - - /// - /// Tests in the case there are no az commands in the history. - /// - [Theory] - [InlineData("New-AzKeyVault ")] - [InlineData("CONNECT-AZACCOUNT")] - [InlineData("set-azstorageaccount ")] - [InlineData("Get-AzResourceG")] - [InlineData("Get-AzStorageAcco")] // an imcomplete command and there is a record "Get-AzStorageAccount" in the model. - public void GetPredictionWithCommandName(string userInput) - { - var predictionContext = PredictionContext.Create(userInput); - var presentCommands = new System.Collections.Generic.Dictionary(); - var result = this._predictor.Query(predictionContext.InputAst, presentCommands, 1, 1, CancellationToken.None); - Assert.NotEmpty(result.Item1); - } - - /// - /// Tests in the case when the user inputs the command name and parameters. - /// - [Theory] - [InlineData("Get-AzKeyVault -VaultName")] - [InlineData("GET-AZSTORAGEACCOUNTKEY -NAME ")] - [InlineData("new-azresourcegroup -name hello")] - [InlineData("Get-AzContext -Name")] - public void GetPredictionWithCommandNameParameters(string userInput) - { - var predictionContext = PredictionContext.Create(userInput); - var presentCommands = new System.Collections.Generic.Dictionary(); - var result = this._predictor.Query(predictionContext.InputAst, presentCommands, 1, 1, CancellationToken.None); - Assert.NotEmpty(result.Item1); - } - - /// - /// Tests in the case when the user inputs the command name and parameters. - /// - [Theory] - [InlineData("Get-AzResource -Name hello -Pre")] - [InlineData("Get-AzADServicePrincipal -ApplicationObject")] // Doesn't exist - [InlineData("new-azresourcegroup -NoExistingParam")] - [InlineData("Set-StorageAccount -WhatIf")] - [InlineData("Get-AzContext Name")] // a wrong command - public void GetNullPredictionWithCommandNameParameters(string userInput) - { - var predictionContext = PredictionContext.Create(userInput); - var presentCommands = new System.Collections.Generic.Dictionary(); - var result = this._predictor.Query(predictionContext.InputAst, presentCommands, 1, 1, CancellationToken.None); - Assert.Empty(result.Item1); - } - - /// - /// Verify that the prediction for the command (without parameter) has the right parameters. - /// - [Fact] - public void VerifyPredictionForCommand() - { - var predictionContext = PredictionContext.Create("Connect-AzAccount"); - var presentCommands = new System.Collections.Generic.Dictionary(); - var result = this._predictor.Query(predictionContext.InputAst, presentCommands, 1, 1, CancellationToken.None); - - Assert.Equal("Connect-AzAccount -Credential -ServicePrincipal -Tenant <>", result.Item1.First().Key); - } - - /// - /// Verify that the prediction for the command (with parameter) has the right parameters. - /// - [Fact] - public void VerifyPredictionForCommandAndParameters() - { - var predictionContext = PredictionContext.Create("GET-AZSTORAGEACCOUNTKEY -NAME"); - var presentCommands = new System.Collections.Generic.Dictionary(); - var result = this._predictor.Query(predictionContext.InputAst, presentCommands, 1, 1, CancellationToken.None); - - Assert.Equal("Get-AzStorageAccountKey -Name 'ContosoStorage' -ResourceGroupName 'ContosoGroup02'", result.Item1.First().Key); - } - } -} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzContext.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzContext.cs index b2e929c5be71..a363280c66be 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzContext.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzContext.cs @@ -171,9 +171,9 @@ private static string GenerateSha256HashString(string originInput) } /// - /// Get the MAC address of the default NIC, or null if none can be found + /// Get the MAC address of the default NIC, or null if none can be found. /// - /// The MAC address of the defautl nic, or null if none is found + /// The MAC address of the defautl nic, or null if none is found. private static string GetMACAddress() { return NetworkInterface.GetAllNetworkInterfaces()? diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictor.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictor.cs index dd5be6e10b80..a03eed3391a2 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictor.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictor.cs @@ -1,4 +1,4 @@ -// ---------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------- // // Copyright Microsoft Corporation // Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,6 +12,8 @@ // limitations under the License. // ---------------------------------------------------------------------------------- +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry; +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities; using System; using System.Collections.Generic; using System.Linq; @@ -19,7 +21,6 @@ using System.Management.Automation.Language; using System.Management.Automation.Subsystem; using System.Runtime.CompilerServices; -using System.Text; using System.Threading; [assembly:InternalsVisibleTo("Microsoft.Azure.PowerShell.Tools.AzPredictor.Test")] @@ -27,7 +28,7 @@ namespace Microsoft.Azure.PowerShell.Tools.AzPredictor { /// - /// The implementation of a to provide suggestion in PSReadLine. + /// The implementation of a to provide suggestions in PSReadLine. /// internal sealed class AzPredictor : ICommandPredictor { @@ -49,8 +50,6 @@ internal sealed class AzPredictor : ICommandPredictor internal static readonly Guid Identifier = new Guid("599d1760-4ee1-4ed2-806e-f2a1b1a0ba4d"); private const int SuggestionCountForTelemetry = 5; - private const string ParameterValueMask = "***"; - private const char ParameterValueSeperator = ':'; private static readonly string[] CommonParameters = new string[] { "location" }; @@ -61,33 +60,26 @@ internal sealed class AzPredictor : ICommandPredictor private Queue _lastTwoMaskedCommands = new Queue(AzPredictorConstants.CommandHistoryCountToProcess); - // This contains the user modified texts and the original suggestion. - private Dictionary _userAcceptedAndSuggestion = new Dictionary(); - /// - /// Constructs a new instance of + /// Constructs a new instance of . /// - /// The service that provides the suggestion - /// The client to collect telemetry - /// The settings of the service - /// The Az context which this module runs with + /// The service that provides the suggestion. + /// The client to collect telemetry. + /// The settings for . + /// The Az context which this module runs in. public AzPredictor(IAzPredictorService service, ITelemetryClient telemetryClient, Settings settings, IAzContext azContext) { - this._service = service; - this._telemetryClient = telemetryClient; - this._settings = settings; - this._azContext = azContext; + _service = service; + _telemetryClient = telemetryClient; + _settings = settings; + _azContext = azContext; } /// public void StartEarlyProcessing(IReadOnlyList history) { // The context only changes when the user executes the corresponding command. - this._azContext?.UpdateContext(); - lock (_userAcceptedAndSuggestion) - { - _userAcceptedAndSuggestion.Clear(); - } + _azContext?.UpdateContext(); if (history.Count > 0) { @@ -125,7 +117,7 @@ public void StartEarlyProcessing(IReadOnlyList history) _service.RecordHistory(lastCommand.Item1); } - _telemetryClient.OnHistory(lastCommand.Item2); + _telemetryClient.OnHistory(new HistoryTelemetryData(lastCommand.Item2)); _service.RequestPredictions(_lastTwoMaskedCommands); } @@ -140,7 +132,7 @@ ValueTuple GetAstAndMaskedCommandLine(string commandLine) if (_service.IsSupportedCommand(commandName)) { - maskedCommandLine = AzPredictor.MaskCommandLine(commandAst); + maskedCommandLine = CommandLineUtilities.MaskCommandLine(commandAst); } else { @@ -154,129 +146,42 @@ ValueTuple GetAstAndMaskedCommandLine(string commandLine) /// public void OnSuggestionAccepted(string acceptedSuggestion) { - IDictionary localSuggestedTexts = null; - lock (_userAcceptedAndSuggestion) - { - localSuggestedTexts = _userAcceptedAndSuggestion; - } - - if (localSuggestedTexts.TryGetValue(acceptedSuggestion, out var suggestedText)) - { - _telemetryClient.OnSuggestionAccepted(suggestedText); - } - else - { - _telemetryClient.OnSuggestionAccepted("NoRecord"); - } + _telemetryClient.OnSuggestionAccepted(new SuggestionAcceptedTelemetryData(acceptedSuggestion)); } /// public List GetSuggestion(PredictionContext context, CancellationToken cancellationToken) { - var localCancellationToken = Settings.ContinueOnTimeout ? CancellationToken.None : cancellationToken; + if (_settings.SuggestionCount.Value <= 0) + { + return new List(); + } - IEnumerable> suggestions = Enumerable.Empty>(); - string maskedUserInput = string.Empty; - // This is the list of records of the source suggestion and the prediction source. - var telemetryData = new List>(); + Exception exception = null; + CommandLineSuggestion suggestions = null; try { - maskedUserInput = AzPredictor.MaskCommandLine(context.InputAst.FindAll((ast) => ast is CommandAst, true).LastOrDefault() as CommandAst); + var localCancellationToken = Settings.ContinueOnTimeout ? CancellationToken.None : cancellationToken; suggestions = _service.GetSuggestion(context.InputAst, _settings.SuggestionCount.Value, _settings.MaxAllowedCommandDuplicate.Value, localCancellationToken); - localCancellationToken.ThrowIfCancellationRequested(); - - var userAcceptedAndSuggestion = new Dictionary(); - - foreach (var s in suggestions) - { - telemetryData.Add(ValueTuple.Create(s.Item2, s.Item3)); - userAcceptedAndSuggestion[s.Item1] = s.Item2; - } - - lock (_userAcceptedAndSuggestion) - { - foreach (var u in userAcceptedAndSuggestion) - { - _userAcceptedAndSuggestion[u.Key] = u.Value; - } - } - - localCancellationToken.ThrowIfCancellationRequested(); - - var returnedValue = suggestions.Select((r, index) => - { - return new PredictiveSuggestion(r.Item1); - }) - .ToList(); - - _telemetryClient.OnGetSuggestion(maskedUserInput, - telemetryData, - cancellationToken.IsCancellationRequested); - - return returnedValue; - + var returnedValue = suggestions?.PredictiveSuggestions?.ToList(); + return returnedValue ?? new List(); } catch (Exception e) when (!(e is OperationCanceledException)) { - this._telemetryClient.OnGetSuggestionError(e); - } - - return new List(); - } - - /// - /// Masks the user input of any data, like names and locations. - /// Also alphabetizes the parameters to normalize them before sending - /// them to the model. - /// e.g., Get-AzContext -Name Hello -Location 'EastUS' => Get-AzContext -Location *** -Name *** - /// - /// The last user input command - private static string MaskCommandLine(CommandAst cmdAst) - { - var commandElements = cmdAst?.CommandElements; - - if (commandElements == null) - { - return null; + exception = e; + return new List(); } - - if (commandElements.Count == 1) + finally { - return cmdAst.Extent.Text; - } - - var sb = new StringBuilder(cmdAst.Extent.Text.Length); - _ = sb.Append(commandElements[0].ToString()); - var parameters = commandElements - .Skip(1) - .Where(element => element is CommandParameterAst) - .Cast() - .OrderBy(ast => ast.ParameterName); - foreach (CommandParameterAst param in parameters) - { - _ = sb.Append(AzPredictorConstants.CommandParameterSeperator); - if (param.Argument != null) - { - // Parameter is in the form of `-Name:name` - _ = sb.Append(AzPredictorConstants.ParameterIndicator) - .Append(param.ParameterName) - .Append(AzPredictor.ParameterValueSeperator) - .Append(AzPredictor.ParameterValueMask); - } - else - { - // Parameter is in the form of `-Name` - _ = sb.Append(AzPredictorConstants.ParameterIndicator) - .Append(param.ParameterName) - .Append(AzPredictorConstants.CommandParameterSeperator) - .Append(AzPredictor.ParameterValueMask); - } + _telemetryClient.OnGetSuggestion(new GetSuggestionTelemetryData(context.InputAst, + suggestions, + cancellationToken.IsCancellationRequested, + exception)); } - return sb.ToString(); } } diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorConstants.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorConstants.cs index 98582e59a5b4..84537e52ea46 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorConstants.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorConstants.cs @@ -59,6 +59,16 @@ internal static class AzPredictorConstants /// public const char ParameterIndicator = '-'; + /// + /// The seperator used in parameter name and value pair which is in the form -Name:Value. + /// + public const char ParameterValueSeperator = ':'; + + /// + /// The substitute for the parameter value. + /// + public const string ParameterValueMask = "***"; + /// /// The setting file name. /// diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorService.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorService.cs index 6be87ae33847..59605d701768 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorService.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorService.cs @@ -12,8 +12,8 @@ // limitations under the License. // ---------------------------------------------------------------------------------- -using Newtonsoft.Json; -using Newtonsoft.Json.Serialization; +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry; +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities; using System; using System.Collections.Generic; using System.Linq; @@ -21,19 +21,19 @@ using System.Net.Http; using System.Net.Http.Headers; using System.Text; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; namespace Microsoft.Azure.PowerShell.Tools.AzPredictor { /// - /// A service that talk to Aladdin endpoints to get the commands and predictions. + /// A service that connects to Aladdin endpoints to get the model and provides suggestions to PSReadLine. /// internal class AzPredictorService : IAzPredictorService, IDisposable { private const string ClientType = "AzurePowerShell"; - [JsonObject(NamingStrategyType = typeof(CamelCaseNamingStrategy))] private sealed class PredictionRequestBody { public sealed class RequestContext @@ -48,10 +48,9 @@ public sealed class RequestContext public string ClientType { get; set; } = AzPredictorService.ClientType; public RequestContext Context { get; set; } = new RequestContext(); - public PredictionRequestBody(string command) => this.History = command; + public PredictionRequestBody(string command) => History = command; }; - [JsonObject(NamingStrategyType = typeof(CamelCaseNamingStrategy))] private sealed class CommandRequestContext { public Version VersionNumber{ get; set; } = new Version(0, 0); @@ -61,10 +60,26 @@ private sealed class CommandRequestContext private readonly HttpClient _client; private readonly string _commandsEndpoint; private readonly string _predictionsEndpoint; - private volatile Tuple _commandSuggestions; // The command and the prediction for that. - private volatile Predictor _commands; - private volatile string _commandForPrediction; - private HashSet _commandSet; + + /// + /// The history command line and the predictor based on that. + /// + private volatile Tuple _commandBasedPredictor; + + /// + /// The predictor to used when doesn't return enough suggestions. + /// + private volatile CommandLinePredictor _fallbackPredictor; + + /// + /// The history command line that we request prediction for. + /// + private volatile string _commandToRequestPrediction; + + /// + /// All the command lines we can provide as suggestions. + /// + private HashSet _allPredictiveCommands; private CancellationTokenSource _predictionRequestCancellationSource; private readonly ParameterValuePredictor _parameterValuePredictor = new ParameterValuePredictor(); @@ -72,31 +87,33 @@ private sealed class CommandRequestContext private readonly IAzContext _azContext; /// - /// The AzPredictor service interacts with the Aladdin service specified in serviceUri. - /// At initialization, it requests a list of the popular commands. + /// Creates a new instance of . /// /// The URI of the Aladdin service. /// The telemetry client. - /// The Az context which this module runs with + /// The Az context which this module runs in. public AzPredictorService(string serviceUri, ITelemetryClient telemetryClient, IAzContext azContext) { - this._commandsEndpoint = $"{serviceUri}{AzPredictorConstants.CommandsEndpoint}?clientType={AzPredictorService.ClientType}&context={JsonConvert.SerializeObject(new CommandRequestContext())}"; - this._predictionsEndpoint = serviceUri + AzPredictorConstants.PredictionsEndpoint; - this._telemetryClient = telemetryClient; - this._azContext = azContext; + Validation.CheckArgument(!string.IsNullOrWhiteSpace(serviceUri), $"{nameof(serviceUri)} cannot be null or whitespace."); + Validation.CheckArgument(telemetryClient, $"{nameof(telemetryClient)} cannot be null."); + Validation.CheckArgument(azContext, $"{nameof(azContext)} cannot be null."); - this._client = new HttpClient(); - this._client.DefaultRequestHeaders?.Add(AzPredictorService.ThrottleByIdHeader, this._azContext.UserId); + _commandsEndpoint = $"{serviceUri}{AzPredictorConstants.CommandsEndpoint}?clientType={AzPredictorService.ClientType}&context={JsonSerializer.Serialize(new CommandRequestContext(), JsonUtilities.DefaultSerializerOptions)}"; + _predictionsEndpoint = serviceUri + AzPredictorConstants.PredictionsEndpoint; + _telemetryClient = telemetryClient; + _azContext = azContext; - RequestCommands(); + _client = new HttpClient(); + + RequestAllPredictiveCommands(); } /// - /// A default constructor for the derived class. + /// A default constructor for the derived class. This is used in test cases. /// protected AzPredictorService() { - RequestCommands(); + RequestAllPredictiveCommands(); } /// @@ -106,160 +123,230 @@ public void Dispose() } /// - /// Dispose the object + /// Dispose the object. /// - /// Indicate if this is called from + /// Indicate if this is called from . protected virtual void Dispose(bool disposing) { if (disposing) { - if (this._predictionRequestCancellationSource != null) + if (_predictionRequestCancellationSource != null) { - this._predictionRequestCancellationSource.Dispose(); - this._predictionRequestCancellationSource = null; + _predictionRequestCancellationSource.Dispose(); + _predictionRequestCancellationSource = null; } } } /// /// - /// Queries the Predictor with the user input if predictions are available, otherwise uses commands + /// Tries to get the suggestions for the user input from the command history. If that doesn't find + /// suggestions, it'll fallback to find the suggestion regardless of command history. /// - public IEnumerable> GetSuggestion(Ast input, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken) + public CommandLineSuggestion GetSuggestion(Ast input, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken) { - var commandSuggestions = this._commandSuggestions; - var command = this._commandForPrediction; + Validation.CheckArgument(input, $"{nameof(input)} cannot be null"); + Validation.CheckArgument(suggestionCount > 0, $"{nameof(suggestionCount)} must be larger than 0."); + Validation.CheckArgument(maxAllowedCommandDuplicate > 0, $"{nameof(maxAllowedCommandDuplicate)} must be larger than 0."); - IList> results = new List>(); - var presentCommands = new System.Collections.Generic.Dictionary(); - var resultsFromSuggestionTuple = commandSuggestions?.Item2?.Query(input, presentCommands, suggestionCount, maxAllowedCommandDuplicate, cancellationToken); - var resultsFromSuggestion = resultsFromSuggestionTuple.Item1; - presentCommands = resultsFromSuggestionTuple.Item2.ToDictionary(kvp => kvp.Key, kvp => kvp.Value); + var commandAst = input.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; + var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; + if (string.IsNullOrWhiteSpace(commandName)) + { + return null; + } - if (resultsFromSuggestion != null) + var inputParameterSet = new ParameterSet(commandAst); + var rawUserInput = input.Extent.Text; + var presentCommands = new Dictionary(); + var commandBasedPredictor = _commandBasedPredictor; + var commandToRequestPrediction = _commandToRequestPrediction; + + var result = commandBasedPredictor?.Item2?.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + suggestionCount, + maxAllowedCommandDuplicate, + cancellationToken); + + if ((result != null) && (result.Count > 0)) { - var predictionSource = PredictionSource.None; + var suggestionSource = SuggestionSource.PreviousCommand; - if (string.Equals(command, commandSuggestions?.Item1, StringComparison.Ordinal)) + if (string.Equals(commandToRequestPrediction, commandBasedPredictor?.Item1, StringComparison.Ordinal)) { - predictionSource = PredictionSource.CurrentCommand; - } - else - { - predictionSource = PredictionSource.PreviousCommand; + suggestionSource = SuggestionSource.CurrentCommand; } - if (resultsFromSuggestion != null) + for (var i = 0; i < result.Count; ++i) { - foreach (var r in resultsFromSuggestion) - { - results.Add(ValueTuple.Create(r.Key, r.Value, predictionSource)); - } + result.UpdateSuggestionSource(i, suggestionSource); } } - if ((resultsFromSuggestion == null) || (resultsFromSuggestion.Count() < suggestionCount)) + if ((result == null) || (result.Count < suggestionCount)) { - var commands = this._commands; - var resultsFromCommandsTuple = commands?.Query(input, presentCommands,suggestionCount - resultsFromSuggestion.Count(), maxAllowedCommandDuplicate, cancellationToken); - var resultsFromCommands = resultsFromCommandsTuple.Item1; - presentCommands = resultsFromCommandsTuple.Item2.ToDictionary(kvp => kvp.Key, kvp => kvp.Value); + var fallbackPredictor = _fallbackPredictor; + var suggestionCountToRequest = (result == null) ? suggestionCount : suggestionCount - result.Count; + var resultsFromFallback = fallbackPredictor?.GetSuggestion(commandName, + inputParameterSet, + rawUserInput, + presentCommands, + suggestionCountToRequest, + maxAllowedCommandDuplicate, + cancellationToken); + + if ((result == null) && (resultsFromFallback != null)) + { + result = resultsFromFallback; - if (resultsFromCommands != null) + for (var i = 0; i < result.Count; ++i) + { + result.UpdateSuggestionSource(i, SuggestionSource.StaticCommands); + } + } + else if ((resultsFromFallback != null) && (resultsFromFallback.Count > 0)) { - foreach (var r in resultsFromCommands) + for (var i = 0; i < resultsFromFallback.Count; ++i) { - if (resultsFromSuggestion?.ContainsKey(r.Key) == true) + if (result.SourceTexts.Contains(resultsFromFallback.SourceTexts[i])) { continue; } - results.Add(ValueTuple.Create(r.Key, r.Value, PredictionSource.StaticCommands)); + result.AddSuggestion(resultsFromFallback.PredictiveSuggestions[i], resultsFromFallback.SourceTexts[i], SuggestionSource.StaticCommands); } } } - return results; + return result; } /// public virtual void RequestPredictions(IEnumerable commands) { - AzPredictorService.ReplaceThrottleUserIdToHeader(this._client?.DefaultRequestHeaders, this._azContext.UserId); + Validation.CheckArgument(commands, $"{nameof(commands)} cannot be null."); + var localCommands= string.Join(AzPredictorConstants.CommandConcatenator, commands); - this._telemetryClient.OnRequestPrediction(localCommands); + bool postSuccess = false; + Exception exception = null; + bool startRequestTask = false; - if (string.Equals(localCommands, this._commandForPrediction, StringComparison.Ordinal)) + try { - // It's the same history we've already requested the prediction for last time, skip it. - return; - } - else - { - this.SetPredictionCommand(localCommands); + if (string.Equals(localCommands, _commandToRequestPrediction, StringComparison.Ordinal)) + { + // It's the same history we've already requested the prediction for last time, skip it. + return; + } - // When it's called multiple times, we only need to keep the one for the latest command. + if (commands.Any()) + { + SetCommandToRequestPrediction(localCommands); - this._predictionRequestCancellationSource?.Cancel(); - this._predictionRequestCancellationSource = new CancellationTokenSource(); + // When it's called multiple times, we only need to keep the one for the latest command. - var cancellationToken = this._predictionRequestCancellationSource.Token; + _predictionRequestCancellationSource?.Cancel(); + _predictionRequestCancellationSource = new CancellationTokenSource(); - // We don't need to block on the task. We send the HTTP request and update prediction list at the background. - Task.Run(async () => { - try - { - var requestContext = new PredictionRequestBody.RequestContext() - { - SessionId = this._telemetryClient.SessionId, - CorrelationId = this._telemetryClient.CorrelationId, - }; - var requestBody = new PredictionRequestBody(localCommands) + var cancellationToken = _predictionRequestCancellationSource.Token; + + // We don't need to block on the task. We send the HTTP request and update prediction list at the background. + startRequestTask = true; + Task.Run(async () => { + try { - Context = requestContext, - }; + AzPredictorService.ReplaceThrottleUserIdToHeader(_client?.DefaultRequestHeaders, _azContext.UserId); - var requestBodyString = JsonConvert.SerializeObject(requestBody); - var httpResponseMessage = await _client.PostAsync(this._predictionsEndpoint, new StringContent(requestBodyString, Encoding.UTF8, "application/json"), cancellationToken); + var requestContext = new PredictionRequestBody.RequestContext() + { + SessionId = _telemetryClient.SessionId, + CorrelationId = _telemetryClient.CorrelationId, + }; - var reply = await httpResponseMessage.Content.ReadAsStringAsync(cancellationToken); - var suggestionsList = JsonConvert.DeserializeObject>(reply); + var requestBody = new PredictionRequestBody(localCommands) + { + Context = requestContext, + }; - this.SetSuggestionPredictor(localCommands, suggestionsList); - } - catch (Exception e) when (!(e is OperationCanceledException)) - { - this._telemetryClient.OnRequestPredictionError(localCommands, e); - } - }, - cancellationToken); + var requestBodyString = JsonSerializer.Serialize(requestBody, JsonUtilities.DefaultSerializerOptions); + var httpResponseMessage = await _client.PostAsync(_predictionsEndpoint, new StringContent(requestBodyString, Encoding.UTF8, "application/json"), cancellationToken); + postSuccess = true; + + httpResponseMessage.EnsureSuccessStatusCode(); + var reply = await httpResponseMessage.Content.ReadAsStreamAsync(cancellationToken); + var suggestionsList = await JsonSerializer.DeserializeAsync>(reply, JsonUtilities.DefaultSerializerOptions); + + SetCommandBasedPreditor(localCommands, suggestionsList); + } + catch (Exception e) when (!(e is OperationCanceledException)) + { + exception = e; + } + finally + { + _telemetryClient.OnRequestPrediction(new RequestPredictionTelemetryData(localCommands, postSuccess, exception)); + } + }, + cancellationToken); + } + } + catch (Exception e) + { + exception = e; + } + finally + { + if (!startRequestTask) + { + _telemetryClient.OnRequestPrediction(new RequestPredictionTelemetryData(localCommands, hasSentHttpRequest: false, exception: exception)); + } } } /// public virtual void RecordHistory(CommandAst history) { - this._parameterValuePredictor.ProcessHistoryCommand(history); + Validation.CheckArgument(history, $"{nameof(history)} cannot be null."); + + _parameterValuePredictor.ProcessHistoryCommand(history); } /// - public bool IsSupportedCommand(string cmd) => !string.IsNullOrWhiteSpace(cmd) && (_commandSet?.Contains(cmd) == true); + public bool IsSupportedCommand(string cmd) => !string.IsNullOrWhiteSpace(cmd) && (_allPredictiveCommands?.Contains(cmd) == true); /// - /// Requests a list of popular commands from service. These commands are used as fallback suggestion + /// Requests a list of popular commands from service. These commands are used as fall back suggestion /// if none of the predictions fit for the current input. This method should be called once per session. /// - protected virtual void RequestCommands() + protected virtual void RequestAllPredictiveCommands() { // We don't need to block on the task. We send the HTTP request and update commands and predictions list at the background. Task.Run(async () => { - var httpResponseMessage = await this._client.GetAsync(this._commandsEndpoint); + Exception exception = null; - var reply = await httpResponseMessage.Content.ReadAsStringAsync(); - var commands_reply = JsonConvert.DeserializeObject>(reply); - this.SetCommandsPredictor(commands_reply); + try + { + _client.DefaultRequestHeaders?.Add(AzPredictorService.ThrottleByIdHeader, _azContext.UserId); + + var httpResponseMessage = await _client.GetAsync(_commandsEndpoint); + + httpResponseMessage.EnsureSuccessStatusCode(); + var reply = await httpResponseMessage.Content.ReadAsStringAsync(); + var commandsReply = JsonSerializer.Deserialize>(reply, JsonUtilities.DefaultSerializerOptions); + SetFallbackPredictor(commandsReply); + } + catch (Exception e) + { + exception = e; + } + finally + { + _telemetryClient.OnRequestPrediction(new RequestPredictionTelemetryData("request_commands", hasSentHttpRequest: true, exception: exception)); + } // Initialize predictions RequestPredictions(new string[] { @@ -269,32 +356,39 @@ protected virtual void RequestCommands() } /// - /// Sets the commands predictor. + /// Sets the fallback predictor. /// /// The command collection to set the predictor - protected void SetCommandsPredictor(IList commands) + protected void SetFallbackPredictor(IList commands) { - this._commands = new Predictor(commands, this._parameterValuePredictor); - this._commandSet = commands.Select(x => AzPredictorService.GetCommandName(x)).ToHashSet(StringComparer.OrdinalIgnoreCase); // this could be slow + Validation.CheckArgument(commands, $"{nameof(commands)} cannot be null."); + + _fallbackPredictor = new CommandLinePredictor(commands, _parameterValuePredictor); + _allPredictiveCommands = commands.Select(x => AzPredictorService.GetCommandName(x)).ToHashSet(StringComparer.OrdinalIgnoreCase); // this could be slow } /// - /// Sets the suggestiosn predictor. + /// Sets the predictor based on the command history. /// /// The commands that the suggestions are for /// The suggestion collection to set the predictor - protected void SetSuggestionPredictor(string commands, IList suggestions) + protected void SetCommandBasedPreditor(string commands, IList suggestions) { - this._commandSuggestions = Tuple.Create(commands, new Predictor(suggestions, this._parameterValuePredictor)); + Validation.CheckArgument(!string.IsNullOrWhiteSpace(commands), $"{nameof(commands)} cannot be null or whitespace."); + Validation.CheckArgument(suggestions, $"{nameof(suggestions)} cannot be null."); + + _commandBasedPredictor = Tuple.Create(commands, new CommandLinePredictor(suggestions, _parameterValuePredictor)); } /// /// Updates the command for prediction. /// /// The command for the new prediction - protected void SetPredictionCommand(string command) + protected void SetCommandToRequestPrediction(string command) { - this._commandForPrediction = command; + Validation.CheckArgument(!string.IsNullOrWhiteSpace(command), $"{nameof(command)} cannot be null or whitespace."); + + _commandToRequestPrediction = command; } /// diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorTelemetryClient.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorTelemetryClient.cs deleted file mode 100644 index d8c627a43202..000000000000 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorTelemetryClient.cs +++ /dev/null @@ -1,202 +0,0 @@ -// ---------------------------------------------------------------------------------- -// -// Copyright Microsoft Corporation -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ---------------------------------------------------------------------------------- - -using Microsoft.ApplicationInsights; -using Microsoft.ApplicationInsights.Extensibility; -using Microsoft.Azure.PowerShell.Tools.AzPredictor.Profile; -using Newtonsoft.Json; -using System; -using System.Collections.Generic; -using System.Globalization; -using System.Linq; - -namespace Microsoft.Azure.PowerShell.Tools.AzPredictor -{ - /// - /// A telemetry client implementation to collect the telemetry data for AzPredictor - /// - sealed class AzPredictorTelemetryClient : ITelemetryClient - { - private const string TelemetryEventPrefix = "Az.Tools.Predictor"; - - /// - public string SessionId { get; } = Guid.NewGuid().ToString(); - - /// - public string CorrelationId { get; private set; } = Guid.NewGuid().ToString(); - - private readonly TelemetryClient _telemetryClient; - private readonly IAzContext _azContext; - private Tuple, string> _cachedAzModulesVersions = Tuple.Create, string>(null, null); - - /// - /// Constructs a new instance of - /// - /// The Az context which this module runs with - public AzPredictorTelemetryClient(IAzContext azContext) - { - TelemetryConfiguration configuration = TelemetryConfiguration.CreateDefault(); - configuration.InstrumentationKey = "7df6ff70-8353-4672-80d6-568517fed090"; // Use Azuer-PowerShell instrumentation key. see https://github.com/Azure/azure-powershell-common/blob/master/src/Common/AzurePSCmdlet.cs - _telemetryClient = new TelemetryClient(configuration); - _telemetryClient.Context.Location.Ip = "0.0.0.0"; - _telemetryClient.Context.Cloud.RoleInstance = "placeholderdon'tuse"; - _telemetryClient.Context.Cloud.RoleName = "placeholderdon'tuse"; - _azContext = azContext; - } - - /// - public void OnHistory(string historyLine) - { - if (!IsDataCollectionAllowed()) - { - return; - } - - var properties = CreateProperties(); - properties.Add("History", historyLine); - - _telemetryClient.TrackEvent($"{AzPredictorTelemetryClient.TelemetryEventPrefix}/CommandHistory", properties); - -#if TELEMETRY_TRACE && DEBUG - Console.WriteLine("Recording CommandHistory"); -#endif - } - - /// - public void OnRequestPrediction(string command) - { - if (!IsDataCollectionAllowed()) - { - return; - } - - CorrelationId = Guid.NewGuid().ToString(); - - var properties = CreateProperties(); - properties.Add("Command", command); - - _telemetryClient.TrackEvent($"{AzPredictorTelemetryClient.TelemetryEventPrefix}/RequestPrediction", properties); - -#if TELEMETRY_TRACE && DEBUG - Console.WriteLine("Recording RequestPrediction"); -#endif - } - - /// - public void OnRequestPredictionError(string command, Exception e) - { - if (!IsDataCollectionAllowed()) - { - return; - } - - var properties = CreateProperties(); - properties.Add("Command", command); - properties.Add("Exception", e.ToString()); - - _telemetryClient.TrackEvent($"{AzPredictorTelemetryClient.TelemetryEventPrefix}/RequestPredictionError", properties); - -#if TELEMETRY_TRACE && DEBUG - Console.WriteLine("Recording RequestPredictionError"); -#endif - } - - /// - public void OnSuggestionAccepted(string acceptedSuggestion) - { - if (!IsDataCollectionAllowed()) - { - return; - } - - var properties = CreateProperties(); - properties.Add("AcceptedSuggestion", acceptedSuggestion); - - _telemetryClient.TrackEvent($"{AzPredictorTelemetryClient.TelemetryEventPrefix}/AcceptSuggestion", properties); - -#if TELEMETRY_TRACE && DEBUG - Console.WriteLine("Recording AcceptSuggestion"); -#endif - } - - /// - public void OnGetSuggestion(string maskedUserInput, IEnumerable> suggestions, bool isCancelled) - { - if (!IsDataCollectionAllowed()) - { - return; - } - - var properties = CreateProperties(); - properties.Add("UserInput", maskedUserInput); - properties.Add("Suggestion", JsonConvert.SerializeObject(suggestions)); - properties.Add("IsCancelled", isCancelled.ToString(CultureInfo.InvariantCulture)); - - _telemetryClient.TrackEvent($"{AzPredictorTelemetryClient.TelemetryEventPrefix}/GetSuggestion", properties); - -#if TELEMETRY_TRACE && DEBUG - Console.WriteLine("Recording GetSuggestion"); -#endif - } - - /// - public void OnGetSuggestionError(Exception e) - { - if (!IsDataCollectionAllowed()) - { - return; - } - - var properties = CreateProperties(); - properties.Add("Exception", e.ToString()); - - _telemetryClient.TrackEvent($"{AzPredictorTelemetryClient.TelemetryEventPrefix}/GetSuggestionError", properties); - -#if TELEMETRY_TRACE && DEBUG - Console.WriteLine("Recording GetSuggestioinError"); -#endif - } - - /// - /// Check whether the data collection is opted in from user - /// - /// true if allowed - private bool IsDataCollectionAllowed() - { - if (AzurePSDataCollectionProfile.Instance.EnableAzureDataCollection == true) - { - return true; - } - - return false; - } - - /// - /// Add the common properties to the telemetry event. - /// - private IDictionary CreateProperties() - { - return new Dictionary() - { - { "SessionId", SessionId }, - { "CorrelationId", CorrelationId }, - { "UserId", _azContext.UserId }, - { "HashMacAddress", _azContext.MacAddress }, - { "PowerShellVersion", _azContext.PowerShellVersion.ToString() }, - { "ModuleVersion", _azContext.ModuleVersion.ToString() }, - { "OS", _azContext.OSVersion }, - }; - } - } -} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Prediction.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/CommandLine.cs similarity index 52% rename from tools/Az.Tools.Predictor/Az.Tools.Predictor/Prediction.cs rename to tools/Az.Tools.Predictor/Az.Tools.Predictor/CommandLine.cs index da065032efc1..991249765eb0 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Prediction.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/CommandLine.cs @@ -12,38 +12,37 @@ // limitations under the License. // ---------------------------------------------------------------------------------- -using System.Collections.Generic; +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities; namespace Microsoft.Azure.PowerShell.Tools.AzPredictor { /// - /// A prediction candidate consists of the command name and list of parameter sets, - /// where each parameter set is a set of parameters (order independent) that go along with the command. + /// A command line consists of the command name and the parameter set, + /// where the parameter set is a set of parameters (order independent) that go along with the command. /// - sealed class Prediction + sealed class CommandLine { /// - /// Gets the command name + /// Gets the command name. /// - public string Command { get; } + public string Name { get; } /// - /// Gets the list of + /// Gets the . /// - public IList ParameterSets { get; } + public ParameterSet ParameterSet { get; } /// - /// Create a new instance of with the command and parameter set. + /// Create a new instance of with the command name and parameter set. /// - /// The command name - /// The parameter set - public Prediction(string command, ParameterSet parameters) + /// The command name. + /// The parameter set. + public CommandLine(string name, ParameterSet parameterSet) { - this.Command = command; - ParameterSets = new List - { - parameters - }; + Validation.CheckArgument(!string.IsNullOrWhiteSpace(name), $"{nameof(name)} must not be null or whitespace."); + + Name = name; + ParameterSet = parameterSet; } } } diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/CommandLinePredictor.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/CommandLinePredictor.cs new file mode 100644 index 000000000000..2a193f134ada --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/CommandLinePredictor.cs @@ -0,0 +1,285 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Management.Automation.Language; +using System.Management.Automation.Subsystem; +using System.Text; +using System.Threading; + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor +{ + /// + /// This class query the command model from Aladdin service and return suggestions based on user input, for example, + /// when the user inputs "Connec", it returns "Connect-AzAccount". + /// + /// + /// The suggestion returned to PSReadLine may not be the same as the model to generate the suggestion. The suggestion may + /// be adjusted based on user input. + /// + internal sealed class CommandLinePredictor + { + private readonly IList _commandLinePredictions = new List(); + private readonly ParameterValuePredictor _parameterValuePredictor; + + /// + /// Creates a new instance of . + /// + /// List of suggestions from the model, sorted by frequency (most to least). + /// Provide the prediction to the parameter values. + public CommandLinePredictor(IList modelPredictions, ParameterValuePredictor parameterValuePredictor) + { + Validation.CheckArgument(modelPredictions, $"{nameof(modelPredictions)} cannot be null."); + + _parameterValuePredictor = parameterValuePredictor; + var commnadLines = new List(); + + foreach (var predictionTextRaw in modelPredictions ?? Enumerable.Empty()) + { + var predictionText = CommandLineUtilities.EscapePredictionText(predictionTextRaw); + Ast ast = Parser.ParseInput(predictionText, out Token[] tokens, out _); + var commandAst = (ast.Find((ast) => ast is CommandAst, searchNestedScriptBlocks: false) as CommandAst); + + if (commandAst?.CommandElements[0] is StringConstantExpressionAst commandName) + { + var parameterSet = new ParameterSet(commandAst); + this._commandLinePredictions.Add(new CommandLine(commandName.Value, parameterSet)); + } + } + } + + /// + /// Returns suggestions given the user input. + /// + /// The command name extracted from the user input. + /// The parameter set extracted from the user input. + /// The string format of the command line from user input. + /// Commands already present. Contents may be added to this collection. + /// The number of suggestions to return. + /// The maximum amount of the same commands in the list of predictions. + /// The cancellation token + /// The collections of suggestions. + public CommandLineSuggestion GetSuggestion(string inputCommandName, + ParameterSet inputParameterSet, + string rawUserInput, + IDictionary presentCommands, + int suggestionCount, + int maxAllowedCommandDuplicate, + CancellationToken cancellationToken) + { + Validation.CheckArgument(!string.IsNullOrWhiteSpace(inputCommandName), $"{nameof(inputCommandName)} cannot be null or whitespace."); + Validation.CheckArgument(inputParameterSet, $"{nameof(inputParameterSet)} cannot be null."); + Validation.CheckArgument(!string.IsNullOrWhiteSpace(rawUserInput), $"{nameof(rawUserInput)} cannot be null or whitespace."); + Validation.CheckArgument(presentCommands, $"{nameof(presentCommands)} cannot be null."); + Validation.CheckArgument(suggestionCount > 0, $"{nameof(suggestionCount)} must be larger than 0."); + Validation.CheckArgument(maxAllowedCommandDuplicate > 0, $"{nameof(maxAllowedCommandDuplicate)} must be larger than 0."); + + const int commandCollectionCapacity = 10; + CommandLineSuggestion result = new(); + var resultsTemp = new Dictionary(commandCollectionCapacity, StringComparer.OrdinalIgnoreCase); + + var isCommandNameComplete = inputParameterSet.Parameters.Any() || rawUserInput.EndsWith(' '); + + Func commandNameQuery = (command) => command.Equals(inputCommandName, StringComparison.OrdinalIgnoreCase); + if (!isCommandNameComplete) + { + commandNameQuery = (command) => command.StartsWith(inputCommandName, StringComparison.OrdinalIgnoreCase); + } + + // Try to find the matching command and arrange the parameters in the order of the input. + // + // Predictions should be flexible, e.g. if "Command -Name N -Location L" is a possibility, + // then "Command -Location L -Name N" should also be possible. + // + // resultBuilder and usedParams are used to store the information to construct the result. + // We want to avoid too much heap allocation for the performance purpose. + + const int parameterCollectionCapacity = 10; + var resultBuilder = new StringBuilder(); + var usedParams = new HashSet(parameterCollectionCapacity); + var sourceBuilder = new StringBuilder(); + + for (var i = 0; i < _commandLinePredictions.Count && result.Count < suggestionCount; ++i) + { + if (commandNameQuery(_commandLinePredictions[i].Name)) + { + cancellationToken.ThrowIfCancellationRequested(); + + resultBuilder.Clear(); + resultBuilder.Append(_commandLinePredictions[i].Name); + usedParams.Clear(); + + if (DoesPredictionParameterSetMatchInput(resultBuilder, inputParameterSet, _commandLinePredictions[i].ParameterSet, usedParams)) + { + PredictRestOfParameters(resultBuilder, _commandLinePredictions[i].ParameterSet.Parameters, usedParams); + + if (resultBuilder.Length <= rawUserInput.Length) + { + continue; + } + + var prediction = resultBuilder.ToString(); + + sourceBuilder.Clear(); + sourceBuilder.Append(_commandLinePredictions[i].Name); + + foreach (var p in _commandLinePredictions[i].ParameterSet.Parameters) + { + AppendParameterNameAndValue(sourceBuilder, p.Name, p.Value); + } + + if (!presentCommands.ContainsKey(_commandLinePredictions[i].Name)) + { + result.AddSuggestion(new PredictiveSuggestion(prediction), sourceBuilder.ToString()); + presentCommands.Add(_commandLinePredictions[i].Name, 1); + } + else if (presentCommands[_commandLinePredictions[i].Name] < maxAllowedCommandDuplicate) + { + result.AddSuggestion(new PredictiveSuggestion(prediction), sourceBuilder.ToString()); + presentCommands[_commandLinePredictions[i].Name] += 1; + } + else + { + _ = resultsTemp.TryAdd(prediction, sourceBuilder.ToString()); + } + } + } + } + + var resultCount = result.Count; + + if ((resultCount < suggestionCount) && (resultsTemp.Count > 0)) + { + foreach (var temp in resultsTemp.Take(suggestionCount - resultCount)) + { + result.AddSuggestion(new PredictiveSuggestion(temp.Key), temp.Value); + } + } + + return result; + } + + /// + /// Appends unused parameters to the builder. + /// + /// StringBuilder that aggregates the prediction text output. + /// Chosen prediction parameters. + /// Set of used parameters for set. + private void PredictRestOfParameters(StringBuilder builder, IReadOnlyList parameters, HashSet usedParams) + { + for (var j = 0; j < parameters.Count; j++) + { + if (!usedParams.Contains(j)) + { + BuildParameterValue(builder, parameters[j]); + } + } + } + + /// + /// Determines if parameter set contains all of the parameters of the input. + /// + /// StringBuilder that aggregates the prediction text output. + /// Parsed ParameterSet from the user input AST. + /// Candidate prediction parameter set. + /// Set of used parameters for set. + private bool DoesPredictionParameterSetMatchInput(StringBuilder builder, ParameterSet inputParameters, ParameterSet predictionParameters, HashSet usedParams) + { + foreach (var inputParameter in inputParameters.Parameters) + { + var matchIndex = FindParameterPositionInSet(inputParameter, predictionParameters, usedParams); + if (matchIndex == -1) + { + return false; + } + else + { + usedParams.Add(matchIndex); + if (inputParameter.Value != null) + { + AppendParameterNameAndValue(builder, predictionParameters.Parameters[matchIndex].Name, inputParameter.Value); + } + else + { + BuildParameterValue(builder, predictionParameters.Parameters[matchIndex]); + } + } + } + + return true; + } + + /// + /// Create the parameter values from the history commandlines. + /// + /// For example: + /// history command line + /// > New-AzVM -Name "TestVM" ... + /// prediction: + /// > Get-AzVM -VMName <TestVM> + /// "TestVM" is predicted for Get-AzVM. + /// + /// The string builder to create the whole predicted command line. + /// The parameter name and value from prediction. + private void BuildParameterValue(StringBuilder builder, Parameter parameter) + { + var parameterName = parameter.Name; + var parameterValue = this._parameterValuePredictor?.GetParameterValueFromAzCommand(parameterName); + + if (string.IsNullOrWhiteSpace(parameterValue)) + { + parameterValue = parameter.Value; + } + + AppendParameterNameAndValue(builder, parameterName, parameterValue); + } + + /// + /// Determines the index of the given parameter in the parameter set. + /// + /// The parameter name and its value. + /// Prediction parameter set to find parameter position in. + /// Set of used parameters for set. + private static int FindParameterPositionInSet(Parameter parameter, ParameterSet predictionSet, HashSet usedParams) + { + for (var k = 0; k < predictionSet.Parameters.Count; k++) + { + var isPrefixed = predictionSet.Parameters[k].Name.StartsWith(parameter.Name, StringComparison.OrdinalIgnoreCase); + var hasNotBeenUsed = !usedParams.Contains(k); + if (isPrefixed && hasNotBeenUsed) + { + return k; + } + } + + return -1; + } + + private static void AppendParameterNameAndValue(StringBuilder builder, string name, string value) + { + _ = builder.Append(AzPredictorConstants.CommandParameterSeperator); + _ = builder.Append(AzPredictorConstants.ParameterIndicator); + _ = builder.Append(name); + + if (!string.IsNullOrWhiteSpace(value)) + { + _ = builder.Append(AzPredictorConstants.CommandParameterSeperator); + _ = builder.Append(value); + } + } + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/CommandLineSuggestion.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/CommandLineSuggestion.cs new file mode 100644 index 000000000000..fcb3afc7b3a2 --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/CommandLineSuggestion.cs @@ -0,0 +1,108 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities; +using System; +using System.Collections.Generic; +using System.Management.Automation.Subsystem; + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor +{ + /// + /// Represents the suggestions to show to the user and the related information of the suggestions. + /// + /// + /// Because the performance requirement in , + /// it contains lists of each piece of information, for example, a collection of predictive suggestion and a list of + /// suggestion sources. Note that the count of each list should be the same. And each element in the list corresonds to + /// the element in other list at the same index. + /// + public sealed class CommandLineSuggestion + { + /// + /// Since PSReadLine can accept at most 10 suggestions, we pre-allocate that many items in the collection to avoid + /// re-allocation when we try to find the suggestion to return. + /// + private const int CollectionDefaultCapacity = 10; + + private readonly List _predictiveSuggestions = new List(CommandLineSuggestion.CollectionDefaultCapacity); + + /// + /// Gets the suggestions returned to show to the user. This can be adjusted from based on + /// the user input. + /// + public IReadOnlyList PredictiveSuggestions { get { return _predictiveSuggestions; } } + + private readonly List _sourceTexts = new List(CommandLineSuggestion.CollectionDefaultCapacity); + /// + /// Gets the texts that is based on. + /// + public IReadOnlyList SourceTexts { get { return _sourceTexts; } } + + private readonly List _suggestionSources = new List(CommandLineSuggestion.CollectionDefaultCapacity); + /// + /// Gets or sets the sources where the text is from. + /// + public IReadOnlyList SuggestionSources { get { return _suggestionSources; } } + + /// + /// Gets the number of suggestions. + /// + public int Count { get { return _suggestionSources.Count; } } + + /// + /// Adds a new suggestion. + /// + /// The suggestion to show to the user. + /// The text that used to construct . + public void AddSuggestion(PredictiveSuggestion predictiveSuggestion, string sourceText) => AddSuggestion(predictiveSuggestion, sourceText, SuggestionSource.None); + + /// + /// Adds a new suggestion. + /// + /// The suggestion to show to the user. + /// The text that used to construct . + /// The source where the suggestion is from. + public void AddSuggestion(PredictiveSuggestion predictiveSuggestion, string sourceText, SuggestionSource suggestionSource) + { + Validation.CheckArgument(predictiveSuggestion, $"{nameof(predictiveSuggestion)} cannot be null."); + Validation.CheckArgument(!string.IsNullOrWhiteSpace(predictiveSuggestion.SuggestionText), $"{nameof(predictiveSuggestion)} cannot have a null or whitespace suggestion text."); + Validation.CheckArgument(!string.IsNullOrWhiteSpace(sourceText), $"{nameof(sourceText)} cannot be null or whitespace."); + + _predictiveSuggestions.Add(predictiveSuggestion); + _sourceTexts.Add(sourceText); + _suggestionSources.Add(suggestionSource); + + CheckObjectInvariant(); + } + + /// + /// Updates the suggestion source of a suggestion. + /// + /// The index of a suggestion. + /// The new suggestion source. + public void UpdateSuggestionSource(int index, SuggestionSource suggestionSource) + { + Validation.CheckArgument((index >= 0) && (index < _suggestionSources.Count), $"{nameof(index)} is out of range."); + + _suggestionSources[index] = suggestionSource; + CheckObjectInvariant(); + } + + private void CheckObjectInvariant() + { + Validation.CheckInvariant(_predictiveSuggestions.Count == _sourceTexts.Count && _predictiveSuggestions.Count == _suggestionSources.Count); + } + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/IAzPredictorService.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/IAzPredictorService.cs index 1a8c25c0e1c8..038935825aa0 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/IAzPredictorService.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/IAzPredictorService.cs @@ -12,7 +12,6 @@ // limitations under the License. // ---------------------------------------------------------------------------------- -using System; using System.Collections.Generic; using System.Management.Automation.Language; using System.Threading; @@ -27,23 +26,23 @@ public interface IAzPredictorService /// /// Gest the suggestions for the user input. /// - /// User input from PSReadLine + /// User input from PSReadLine. /// The number of suggestion to return. - /// The maximum amount of the same commnds in the list of predictions. /// The cancellation token - /// The list of suggestions for and the source that create the suggestion. The maximum number of suggestion is - public IEnumerable> GetSuggestion(Ast input, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken); + /// The maximum amount of the same commnds in the list of predictions. + /// The suggestions for . The maximum number of suggestions is . + public CommandLineSuggestion GetSuggestion(Ast input, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken); /// /// Requests predictions, given a command string. /// - /// A list of commands + /// A list of commands. public void RequestPredictions(IEnumerable commands); /// /// Record the history from PSReadLine. /// - /// The last command in history + /// The last command in history. public void RecordHistory(CommandAst history); /// diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Parameter.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Parameter.cs new file mode 100644 index 000000000000..112d31c7f910 --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Parameter.cs @@ -0,0 +1,48 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities; + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor +{ + /// + /// The class represents a name-value pair of a parameter. + /// + struct Parameter + { + /// + /// Gets the name of the parameter. + /// + public string Name { get; } + + /// + /// Gets or sets the valus of the parameter. + /// null if there is no valud is expected or set for this parameter. + /// + public string Value { get; set; } + + /// + /// Creates a new instance of + /// + /// The name of the parameter + /// The value of the parameter. If the parameter is a switch parameter, it's null. + public Parameter(string name, string value) + { + Validation.CheckArgument(!string.IsNullOrWhiteSpace(name), $"{nameof(name)} cannot be null or whitespace"); + + Name = name; + Value = value; + } + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterSet.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterSet.cs index ad0de6d019e1..57f0f82c91e8 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterSet.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterSet.cs @@ -12,7 +12,7 @@ // limitations under the License. // ---------------------------------------------------------------------------------- -using System; +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities; using System.Collections.Generic; using System.Linq; using System.Management.Automation.Language; @@ -26,23 +26,25 @@ namespace Microsoft.Azure.PowerShell.Tools.AzPredictor /// sealed class ParameterSet { - public IList> Parameters { get; } + /// + /// Gets the list of the parameters with their names and values. + /// + public IReadOnlyList Parameters { get; } public ParameterSet(CommandAst commandAst) { - Parameters = new List>(); + Validation.CheckArgument(commandAst, $"{nameof(commandAst)} cannot be null."); + + var parameters = new List(); var elements = commandAst.CommandElements.Skip(1); - Ast param = null; + CommandParameterAst param = null; Ast arg = null; foreach (Ast elem in elements) { - if (elem is CommandParameterAst) + if (elem is CommandParameterAst p) { - if (param != null) - { - Parameters.Add(new Tuple(param.ToString(), arg?.ToString())); - } - param = elem; + AddParameter(param, arg); + param = p; arg = null; } else if (AzPredictorConstants.ParameterIndicator == elem?.ToString().Trim().FirstOrDefault()) @@ -50,11 +52,7 @@ public ParameterSet(CommandAst commandAst) // We have an incomplete command line such as // `New-AzResourceGroup -Name ResourceGroup01 -Location WestUS -` // We'll ignore the incomplete parameter. - if (param != null) - { - Parameters.Add(new Tuple(param.ToString(), arg?.ToString())); - } - + AddParameter(param, arg); param = null; arg = null; } @@ -64,14 +62,18 @@ public ParameterSet(CommandAst commandAst) } } + Validation.CheckInvariant((param != null) || (arg == null)); - if (param != null) - { - Parameters.Add(new Tuple(param.ToString(), arg?.ToString())); - } - else if (arg != null) + AddParameter(param, arg); + + Parameters = parameters; + + void AddParameter(CommandParameterAst parameterName, Ast parameterValue) { - throw new InvalidOperationException(); + if (parameterName != null) + { + parameters.Add(new Parameter(parameterName.ParameterName, (parameterValue == null) ? null : CommandLineUtilities.UnescapePredictionText(parameterValue.ToString()))); + } } } } diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterValuePredictor.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterValuePredictor.cs index d42710a69205..703f85641cc1 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterValuePredictor.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterValuePredictor.cs @@ -42,7 +42,6 @@ public void ProcessHistoryCommand(CommandAst command) /// The parameter value from the history command. Null if that is not available. public string GetParameterValueFromAzCommand(string parameterName) { - parameterName = parameterName.TrimStart(AzPredictorConstants.ParameterIndicator); if (_localParameterValues.TryGetValue(parameterName.ToUpper(), out var value)) { return value; @@ -100,9 +99,9 @@ private void ExtractLocalParameters(System.Collections.ObjectModel.ReadOnlyColle for (int i = 2; i < command.Count; i += 2) { - if (command[i - 1] is CommandParameterAst && command[i] is StringConstantExpressionAst) + if (command[i - 1] is CommandParameterAst parameterAst && command[i] is StringConstantExpressionAst) { - var parameterName = command[i - 1].ToString().TrimStart(AzPredictorConstants.ParameterIndicator); + var parameterName = parameterAst.ParameterName; var key = ParameterValuePredictor.GetLocalParameterKey(commandNoun, parameterName); var parameterValue = command[i].ToString(); this._localParameterValues.AddOrUpdate(key, parameterValue, (k, v) => parameterValue); diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Predictor.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Predictor.cs deleted file mode 100644 index 3c85c1c8e602..000000000000 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Predictor.cs +++ /dev/null @@ -1,302 +0,0 @@ -// ---------------------------------------------------------------------------------- -// -// Copyright Microsoft Corporation -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ---------------------------------------------------------------------------------- - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Management.Automation.Language; -using System.Text; -using System.Threading; - -namespace Microsoft.Azure.PowerShell.Tools.AzPredictor -{ - /// - /// Caches predictions from Aladdin service, queries user input, e.g. "Connec" and returns autocompleted version, or null. - /// - internal sealed class Predictor - { - private readonly IList _predictions; - private readonly ParameterValuePredictor _parameterValuePredictor; - - /// - /// Predictor must be initialized with a list of string suggestions. - /// - /// List of suggestions from the model, sorted by frequency (most to least) - /// Provide the prediction to the parameter values. - public Predictor(IList modelPredictions, ParameterValuePredictor parameterValuePredictor) - { - this._parameterValuePredictor = parameterValuePredictor; - this._predictions = new List(); - - foreach (var predictionTextRaw in modelPredictions ?? Enumerable.Empty()) - { - var predictionText = EscapePredictionText(predictionTextRaw); - Ast ast = Parser.ParseInput(predictionText, out Token[] tokens, out _); - var commandAst = (ast.Find((ast) => ast is CommandAst, searchNestedScriptBlocks: false) as CommandAst); - - if (commandAst?.CommandElements[0] is StringConstantExpressionAst commandName) - { - var parameterSet = new ParameterSet(commandAst); - this._predictions.Add(new Prediction(commandName.Value, parameterSet)); - } - } - } - - /// - /// Given a user input PowerShell AST, returns prediction text. - /// - /// PowerShell AST input of the user, generated by PSReadLine - /// Commands already present. - /// The number of suggestion to return. - /// The maximum amount of the same commnds in the list of predictions. - /// The cancellation token - /// The collection of suggestions. The key is the predicted text adjusted based on . The - /// value is the original text to create the adjusted text. - public Tuple, IDictionary> Query(Ast input, IDictionary presentCommands, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken) - { - if (suggestionCount <= 0) - { - return null; - } - - var commandAst = input.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst; - var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value; - - if (string.IsNullOrWhiteSpace(commandName)) - { - return null; - } - - var results = new Dictionary(StringComparer.OrdinalIgnoreCase); - var resultsTemp = new Dictionary(StringComparer.OrdinalIgnoreCase); - - try - { - var inputParameterSet = new ParameterSet(commandAst); - - - var isCommandNameComplete = (((commandAst?.CommandElements != null) && (commandAst.CommandElements.Count > 1)) || ((input as ScriptBlockAst)?.Extent?.Text?.EndsWith(' ') == true)); - - Func commandNameQuery = (command) => command.Equals(commandName, StringComparison.OrdinalIgnoreCase); - if (!isCommandNameComplete) - { - commandNameQuery = (command) => command.StartsWith(commandName, StringComparison.OrdinalIgnoreCase); - } - - // Try to find the matching command and arrange the parameters in the order of the input. - // - // Predictions should be flexible, e.g. if "Command -Name N -Location L" is a possibility, - // then "Command -Location L -Name N" should also be possible. - // - // resultBuilder and usedParams are used to store the information to construct the result. - // We want to avoid too much heap allocation for the performance purpose. - - var resultBuilder = new StringBuilder(); - var usedParams = new HashSet(); - var sourceBuilder = new StringBuilder(); - - for (var i = 0; i < _predictions.Count && results.Count < suggestionCount; ++i) - { - if (commandNameQuery(_predictions[i].Command)) - { - foreach (var parameterSet in _predictions[i].ParameterSets) - { - cancellationToken.ThrowIfCancellationRequested(); - - resultBuilder.Clear(); - resultBuilder.Append(_predictions[i].Command); - usedParams.Clear(); - - if (DoesPredictionParameterSetMatchInput(resultBuilder, inputParameterSet, parameterSet, usedParams)) - { - PredictRestOfParameters(resultBuilder, parameterSet.Parameters, usedParams); - var prediction = UnescapePredictionText(resultBuilder); - - if (prediction.Length <= input.Extent.Text.Length) - { - continue; - } - - sourceBuilder.Clear(); - sourceBuilder.Append(_predictions[i].Command); - - foreach (var p in parameterSet.Parameters) - { - _ = sourceBuilder.Append(AzPredictorConstants.CommandParameterSeperator); - _ = sourceBuilder.Append(p.Item1); - - if (!string.IsNullOrWhiteSpace(p.Item2)) - { - _ = sourceBuilder.Append(AzPredictorConstants.CommandParameterSeperator); - _ = sourceBuilder.Append(p.Item2); - } - } - - - if (!presentCommands.ContainsKey(_predictions[i].Command)) - { - results.Add(prediction.ToString(), sourceBuilder.ToString()); - presentCommands.Add(_predictions[i].Command, 1); - } - else if (presentCommands[_predictions[i].Command] < maxAllowedCommandDuplicate) - { - results.Add(prediction.ToString(), sourceBuilder.ToString()); - presentCommands[_predictions[i].Command] += 1; - } - else - { - resultsTemp.Add(prediction.ToString(), sourceBuilder.ToString()); - } - - if (results.Count == suggestionCount) - { - break; - } - } - } - } - } - } - catch - { - } - if ((results.Count < suggestionCount) && (resultsTemp.Count >0)) - { - resultsTemp.ToList().GetRange(0, suggestionCount - results.Count).ForEach(x => results.Add(x.Key,x.Value)); - } - return new Tuple, IDictionary>(results, presentCommands); - } - - /// - /// Appends unused parameters to the builder. - /// - /// StringBuilder that aggregates the prediction text output - /// Chosen prediction parameters. - /// Set of used parameters for set. - private void PredictRestOfParameters(StringBuilder builder, IList> parameters, HashSet usedParams) - { - for (var j = 0; j < parameters.Count; j++) - { - if (!usedParams.Contains(j)) - { - BuildParameterValue(builder, parameters[j]); - } - } - } - - /// - /// Determines if parameter set contains all of the parameters of the input. - /// - /// StringBuilder that aggregates the prediction text output - /// Parsed ParameterSet from the user input AST - /// Candidate prediction parameter set. - /// Set of used parameters for set. - private bool DoesPredictionParameterSetMatchInput(StringBuilder builder, ParameterSet inputParameters, ParameterSet predictionParameters, HashSet usedParams) - { - foreach (var inputParameter in inputParameters.Parameters) - { - var matchIndex = FindParameterPositionInSet(inputParameter, predictionParameters, usedParams); - if (matchIndex == -1) - { - return false; - } - else - { - usedParams.Add(matchIndex); - if (inputParameter.Item2 != null) - { - _ = builder.Append(AzPredictorConstants.CommandParameterSeperator); - _ = builder.Append(predictionParameters.Parameters[matchIndex].Item1); - - _ = builder.Append(AzPredictorConstants.CommandParameterSeperator); - _ = builder.Append(inputParameter.Item2); - } - else - { - BuildParameterValue(builder, predictionParameters.Parameters[matchIndex]); - } - } - } - return true; - } - - /// - /// Create the parameter values from the history commandlines. - /// - /// For example: - /// history command line - /// > New-AzVM -Name "TestVM" ... - /// prediction: - /// > Get-AzVM -VMName <TestVM> - /// "TestVM" is predicted for Get-AzVM. - /// - /// The string builder to create the whole predicted command line. - /// The parameter name and vlaue from prediction - private void BuildParameterValue(StringBuilder builder, Tuple parameter) - { - var parameterName = parameter.Item1; - _ = builder.Append(AzPredictorConstants.CommandParameterSeperator); - _ = builder.Append(parameterName); - - string parameterValue = this._parameterValuePredictor?.GetParameterValueFromAzCommand(parameterName); - - if (string.IsNullOrWhiteSpace(parameterValue)) - { - parameterValue = parameter.Item2; - } - - if (!string.IsNullOrWhiteSpace(parameterValue)) - { - _ = builder.Append(AzPredictorConstants.CommandParameterSeperator); - _ = builder.Append(parameterValue); - } - } - - /// - /// Determines the index of the given parameter in the parameter set. - /// - /// A tuple, parameter AST, and argument AST (or null), representing the parameter. - /// Prediction parameter setto find parameter position in. - /// Set of used parameters for set. - private static int FindParameterPositionInSet(Tuple parameter, ParameterSet predictionSet, HashSet usedParams) - { - for (var k = 0; k < predictionSet.Parameters.Count; k++) - { - var isPrefixed = predictionSet.Parameters[k].Item1.StartsWith(parameter.Item1, StringComparison.OrdinalIgnoreCase); - var hasNotBeenUsed = !usedParams.Contains(k); - if (isPrefixed && hasNotBeenUsed) - { - return k; - } - } - - return -1; - } - - /// - /// Escaping the prediction text is necessary because KnowledgeBase predicted suggestions - /// such as "<PSSubnetConfig>" are incorrectly identified as pipe operators - /// - /// The text to escape. - private static string EscapePredictionText(string text) - { - return text.Replace("<", "'<").Replace(">", ">'"); - } - - private static StringBuilder UnescapePredictionText(StringBuilder text) - { - return text.Replace("'<", "<").Replace(">'", ">"); - } - } -} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Profile/AzurePSDataCollectionProfile.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Profile/AzurePSDataCollectionProfile.cs index a31b7c1d319a..ab4fc91e4aa4 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Profile/AzurePSDataCollectionProfile.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Profile/AzurePSDataCollectionProfile.cs @@ -12,10 +12,10 @@ // limitations under the License. // ---------------------------------------------------------------------------------- -using Microsoft.Azure.PowerShell.Tools.AzPredictor; -using Newtonsoft.Json; +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities; using System; using System.IO; +using System.Text.Json; namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Profile { @@ -24,7 +24,7 @@ namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Profile /// /// The profile about data collection in Azure PowerShell /// - public sealed class AzurePSDataCollectionProfile + internal sealed class AzurePSDataCollectionProfile { private const string EnvironmentVariableName = "Azure_PS_Data_Collection"; private const string DefaultFileName = "AzurePSDataCollectionProfile.json"; @@ -59,10 +59,9 @@ private AzurePSDataCollectionProfile(bool enable) } /// - /// Gets if the data collection is enabled. + /// Gets or sets if the data collection is enabled. /// - [JsonProperty(PropertyName = "enableAzureDataCollection")] - public bool? EnableAzureDataCollection { get; private set; } + public bool? EnableAzureDataCollection { get; set; } private static AzurePSDataCollectionProfile CreateInstance() { @@ -90,7 +89,7 @@ private static AzurePSDataCollectionProfile CreateInstance() if (File.Exists(dataPath)) { string contents = File.ReadAllText(dataPath); - var localResult = JsonConvert.DeserializeObject(contents); + var localResult = JsonSerializer.Deserialize(contents, JsonUtilities.DefaultSerializerOptions); if (localResult != null && localResult.EnableAzureDataCollection.HasValue) { result = localResult; diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Settings.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Settings.cs index 4acf2c409a50..4dfabbe10a9b 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Settings.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Settings.cs @@ -37,7 +37,7 @@ sealed class Settings public string ServiceUri { get; set; } /// - /// The number of suggestions to return to PSReadLine + /// The number of suggestions to return to PSReadLine. /// public int? SuggestionCount { get; set; } public int? MaxAllowedCommandDuplicate { get; set; } @@ -115,12 +115,12 @@ private void OverrideSettingsFromProfile() if (!string.IsNullOrWhiteSpace(profileSettings.ServiceUri)) { - this.ServiceUri = profileSettings.ServiceUri; + ServiceUri = profileSettings.ServiceUri; } if (profileSettings.SuggestionCount.HasValue && (profileSettings.SuggestionCount.Value > 0)) { - this.SuggestionCount = profileSettings.SuggestionCount; + SuggestionCount = profileSettings.SuggestionCount; } if (profileSettings.MaxAllowedCommandDuplicate.HasValue && (profileSettings.MaxAllowedCommandDuplicate.Value > 0)) @@ -137,11 +137,11 @@ private void OverrideSettingsFromProfile() private void OverrideSettingsFromEnv() { - var serviceUri = System.Environment.GetEnvironmentVariable("ServiceUri"); + var serviceUri = System.Environment.GetEnvironmentVariable("AzPredictorServiceUri"); if (!string.IsNullOrWhiteSpace(serviceUri)) { - this.ServiceUri = serviceUri; + ServiceUri = serviceUri; } } } diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/PredictionSource.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/SuggestionSource.cs similarity index 74% rename from tools/Az.Tools.Predictor/Az.Tools.Predictor/PredictionSource.cs rename to tools/Az.Tools.Predictor/Az.Tools.Predictor/SuggestionSource.cs index 801e0796d18c..f4304b99f8f6 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/PredictionSource.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/SuggestionSource.cs @@ -12,16 +12,12 @@ // limitations under the License. // ---------------------------------------------------------------------------------- -using Newtonsoft.Json; -using Newtonsoft.Json.Converters; - namespace Microsoft.Azure.PowerShell.Tools.AzPredictor { /// - /// An enum for the source where we get the prediction. + /// An enum for the source where we get the suggestion. /// - [JsonConverter(typeof(StringEnumConverter))] - public enum PredictionSource + public enum SuggestionSource { /// /// There is no predictions. @@ -29,17 +25,17 @@ public enum PredictionSource None, /// - /// The prediction is from the static command list. + /// The suggestion is from the static command list. This doesn't take command history into account. /// StaticCommands, /// - /// The prediction is from the list for the older command. + /// The suggestion is from the list for outdated command history. /// PreviousCommand, /// - /// The prediction is from the list for the currentc command. + /// The suggestion is from the list for latest command history. /// CurrentCommand } diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/AzPredictorTelemetryClient.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/AzPredictorTelemetryClient.cs new file mode 100644 index 000000000000..52c1378600cc --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/AzPredictorTelemetryClient.cs @@ -0,0 +1,278 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using Microsoft.ApplicationInsights; +using Microsoft.ApplicationInsights.Extensibility; +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Profile; +using Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities; +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Management.Automation.Language; +using System.Text.Json; +using System.Threading.Tasks.Dataflow; + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry +{ + /// + /// A telemetry client implementation to collect the telemetry data for AzPredictor. + /// + sealed class AzPredictorTelemetryClient : ITelemetryClient + { + private const string TelemetryEventPrefix = "Az.Tools.Predictor"; + + /// + public string SessionId { get; } = Guid.NewGuid().ToString(); + + /// + public string CorrelationId { get; private set; } = Guid.NewGuid().ToString(); + + /// + /// The client that sends the telemetry to the server. + /// + private readonly TelemetryClient _telemetryClient; + + private readonly IAzContext _azContext; + + /// + /// The action to handle the in a thread pool. + /// + private readonly ActionBlock _telemetryDispatcher; + + /// + /// The adjusted texts and the source text for the suggestion. + /// + /// + /// We only access it in the thread pool which is to handle the data at the target side of the data flow. + /// We only handle one item at a time so there is no race condition. + /// + private IDictionary _userAcceptedAndSuggestion = new Dictionary(); + + /// + /// Constructs a new instance of . + /// + /// The Az context which this module runs with. + public AzPredictorTelemetryClient(IAzContext azContext) + { + TelemetryConfiguration configuration = TelemetryConfiguration.CreateDefault(); + configuration.InstrumentationKey = "7df6ff70-8353-4672-80d6-568517fed090"; // Use Azuer-PowerShell instrumentation key. see https://github.com/Azure/azure-powershell-common/blob/master/src/Common/AzurePSCmdlet.cs + _telemetryClient = new TelemetryClient(configuration); + _telemetryClient.Context.Location.Ip = "0.0.0.0"; + _telemetryClient.Context.Cloud.RoleInstance = "placeholderdon'tuse"; + _telemetryClient.Context.Cloud.RoleName = "placeholderdon'tuse"; + _azContext = azContext; + _telemetryDispatcher = new ActionBlock( + (telemetryData) => DispatchTelemetryData(telemetryData)); + } + + /// + public void OnHistory(HistoryTelemetryData telemetryData) + { + if (!IsDataCollectionAllowed()) + { + return; + } + + telemetryData.SessionId = SessionId; + telemetryData.CorrelationId = CorrelationId; + + _telemetryDispatcher.Post(telemetryData); + +#if TELEMETRY_TRACE && DEBUG + System.Diagnostics.Trace.WriteLine("Recording CommandHistory"); +#endif + } + + /// + public void OnRequestPrediction(RequestPredictionTelemetryData telemetryData) + { + if (!IsDataCollectionAllowed()) + { + return; + } + + CorrelationId = Guid.NewGuid().ToString(); + + telemetryData.SessionId = SessionId; + telemetryData.CorrelationId = CorrelationId; + + _telemetryDispatcher.Post(telemetryData); + +#if TELEMETRY_TRACE && DEBUG + System.Diagnostics.Trace.WriteLine("Recording RequestPrediction"); +#endif + } + + /// + public void OnSuggestionAccepted(SuggestionAcceptedTelemetryData telemetryData) + { + if (!IsDataCollectionAllowed()) + { + return; + } + + telemetryData.SessionId = SessionId; + telemetryData.CorrelationId = CorrelationId; + + _telemetryDispatcher.Post(telemetryData); + +#if TELEMETRY_TRACE && DEBUG + System.Diagnostics.Trace.WriteLine("Recording AcceptSuggestion"); +#endif + } + + /// + public void OnGetSuggestion(GetSuggestionTelemetryData telemetryData) + { + if (!IsDataCollectionAllowed()) + { + return; + } + + telemetryData.SessionId = SessionId; + telemetryData.CorrelationId = CorrelationId; + + _telemetryDispatcher.Post(telemetryData); + +#if TELEMETRY_TRACE && DEBUG + System.Diagnostics.Trace.WriteLine("Recording GetSuggestion"); +#endif + } + + /// + /// Check whether the data collection is opted in from user. + /// + /// true if allowed. + private static bool IsDataCollectionAllowed() + { + if (AzurePSDataCollectionProfile.Instance.EnableAzureDataCollection == true) + { + return true; + } + + return false; + } + + /// + /// Dispatches according to its implementation. + /// + private void DispatchTelemetryData(ITelemetryData telemetryData) + { + switch (telemetryData) + { + case HistoryTelemetryData history: + SendTelemetry(history); + break; + case RequestPredictionTelemetryData requestPrediction: + SendTelemetry(requestPrediction); + break; + case GetSuggestionTelemetryData getSuggestion: + SendTelemetry(getSuggestion); + break; + case SuggestionAcceptedTelemetryData suggestionAccepted: + SendTelemetry(suggestionAccepted); + break; + default: + throw new NotImplementedException(); + } + } + + /// + /// Sends the telemetry with the command history. + /// + private void SendTelemetry(HistoryTelemetryData telemetryData) + { + var properties = CreateProperties(telemetryData); + properties.Add("History", telemetryData.Command); + + _telemetryClient.TrackEvent($"{AzPredictorTelemetryClient.TelemetryEventPrefix}/CommandHistory", properties); + } + + /// + /// Sends the telemetry with the commands for prediction. + /// + private void SendTelemetry(RequestPredictionTelemetryData telemetryData) + { + _userAcceptedAndSuggestion.Clear(); + + var properties = CreateProperties(telemetryData); + properties.Add("Command", telemetryData.Commands ?? string.Empty); + properties.Add("HttpRequestSent", telemetryData.HasSentHttpRequest.ToString(CultureInfo.InvariantCulture)); + properties.Add("Exception", telemetryData.Exception?.ToString() ?? string.Empty); + + _telemetryClient.TrackEvent($"{AzPredictorTelemetryClient.TelemetryEventPrefix}/RequestPrediction", properties); + } + + /// + /// Sends the telemetry with the suggestion returned to the user. + /// + private void SendTelemetry(GetSuggestionTelemetryData telemetryData) + { + var suggestions = telemetryData.Suggestion?.PredictiveSuggestions; + var suggestionSource = telemetryData.Suggestion?.SuggestionSources; + var sourceTexts = telemetryData.Suggestion?.SourceTexts; + var maskedUserInput = CommandLineUtilities.MaskCommandLine(telemetryData.UserInput?.FindAll((ast) => ast is CommandAst, true).LastOrDefault() as CommandAst); + + if ((suggestions != null) && (sourceTexts != null)) + { + for (int i = 0; i < suggestions.Count; ++i) + { + _userAcceptedAndSuggestion[suggestions[i].SuggestionText] = sourceTexts[i]; + } + } + + var properties = CreateProperties(telemetryData); + properties.Add("UserInput", maskedUserInput ?? string.Empty); + properties.Add("Suggestion", sourceTexts != null ? JsonSerializer.Serialize(sourceTexts.Zip(suggestionSource).Select((s) => Tuple.Create(s.First, s.Second)), JsonUtilities.TelemetrySerializerOptions) : string.Empty); + properties.Add("IsCancelled", telemetryData.IsCancellationRequested.ToString(CultureInfo.InvariantCulture)); + properties.Add("Exception", telemetryData.Exception?.ToString() ?? string.Empty); + + _telemetryClient.TrackEvent($"{AzPredictorTelemetryClient.TelemetryEventPrefix}/GetSuggestion", properties); + } + + /// + /// Sends the telemetry with the suggestion returned to the user. + /// + private void SendTelemetry(SuggestionAcceptedTelemetryData telemetryData) + { + if (!_userAcceptedAndSuggestion.TryGetValue(telemetryData.Suggestion, out var suggestion)) + { + suggestion = "NoRecord"; + } + + var properties = CreateProperties(telemetryData); + properties.Add("AcceptedSuggestion", suggestion); + + _telemetryClient.TrackEvent($"{AzPredictorTelemetryClient.TelemetryEventPrefix}/AcceptSuggestion", properties); + } + + /// + /// Add the common properties to the telemetry event. + /// + private IDictionary CreateProperties(ITelemetryData telemetryData) + { + return new Dictionary() + { + { "SessionId", telemetryData.SessionId }, + { "CorrelationId", telemetryData.CorrelationId }, + { "UserId", _azContext.UserId }, + { "HashMacAddress", _azContext.MacAddress }, + { "PowerShellVersion", _azContext.PowerShellVersion.ToString() }, + { "ModuleVersion", _azContext.ModuleVersion.ToString() }, + { "OS", _azContext.OSVersion }, + }; + } + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/GetSuggestionTelemetryData.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/GetSuggestionTelemetryData.cs new file mode 100644 index 000000000000..9d2908c13690 --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/GetSuggestionTelemetryData.cs @@ -0,0 +1,69 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; +using System.Management.Automation.Language; + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry +{ + /// + /// The data to collect in . + /// + public sealed class GetSuggestionTelemetryData : ITelemetryData + { + /// + public string SessionId { get; internal set; } + + /// + public string CorrelationId { get; internal set; } + + /// + /// Gets the user input. + /// + public Ast UserInput { get; } + + /// + /// Gets the suggestions to return to the user. + /// + public CommandLineSuggestion Suggestion { get; } + + /// + /// Gets whether the cancellation request is already set. + /// + public bool IsCancellationRequested { get; } + + /// + /// Gets the exception if there is an error during the operation. + /// + /// + /// OperationCanceledException isn't considered an error. + /// + public Exception Exception { get; } + + /// + /// Creates a new instance of . + /// + /// The user input that the is for. + /// The suggestions returned for the . + /// Indicates if the cancellation has been requested. + /// The exception that is thrown if there is an error. + public GetSuggestionTelemetryData(Ast userInput, CommandLineSuggestion suggestion, bool isCancellationRequested, Exception exception) + { + UserInput = userInput; + Suggestion = suggestion; + IsCancellationRequested = isCancellationRequested; + Exception = exception; + } + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/HistoryTelemetryData.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/HistoryTelemetryData.cs new file mode 100644 index 000000000000..91b94209be40 --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/HistoryTelemetryData.cs @@ -0,0 +1,39 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry +{ + /// + /// The data to collect in . + /// + public sealed class HistoryTelemetryData : ITelemetryData + { + /// + public string SessionId { get; internal set; } + + /// + public string CorrelationId { get; internal set; } + + /// + /// Gets the history command line. + /// + public string Command { get; } + + /// + /// Creates a new instance of . + /// + /// The history command line. + public HistoryTelemetryData(string command) => Command = command; + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/ITelemetryClient.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/ITelemetryClient.cs similarity index 50% rename from tools/Az.Tools.Predictor/Az.Tools.Predictor/ITelemetryClient.cs rename to tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/ITelemetryClient.cs index 53379b5393d2..037102beaa58 100644 --- a/tools/Az.Tools.Predictor/Az.Tools.Predictor/ITelemetryClient.cs +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/ITelemetryClient.cs @@ -12,10 +12,7 @@ // limitations under the License. // ---------------------------------------------------------------------------------- -using System; -using System.Collections.Generic; - -namespace Microsoft.Azure.PowerShell.Tools.AzPredictor +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry { /// /// The telemetry client that collects and sends the telemetry data. @@ -35,41 +32,25 @@ public interface ITelemetryClient /// /// Collects the event of the history command. /// - /// The history command from PSReadLine. - public void OnHistory(string historyLine); + /// The data to collect. + public void OnHistory(HistoryTelemetryData telemetryData); /// /// Collects the event when a prediction is requested. /// - /// The command to that we request the prediction for. - public void OnRequestPrediction(string command); - - /// - /// Collects the event when we fail to get the prediction for the command - /// - /// The command to that we request the prediction for. - /// The exception - public void OnRequestPredictionError(string command, Exception e); + /// The data to collect. + public void OnRequestPrediction(RequestPredictionTelemetryData telemetryData); /// /// Collects when a suggestion is accepted. /// - /// The suggestion that's accepted by the user. - public void OnSuggestionAccepted(string acceptedSuggestion); + /// The data to collect. + public void OnSuggestionAccepted(SuggestionAcceptedTelemetryData telemetryData); /// /// Collects when we return a suggestion /// - /// The user input that the suggestions are for - /// The list of suggestion and its source - /// Indicates whether the caller has cancelled the call to get suggestion. Usually that's because of time out - public void OnGetSuggestion(string maskedUserInput, IEnumerable> suggestions, bool isCancelled); - - /// - /// Collects when an exception is thrown when we return a suggestion. - /// - /// The exception - - public void OnGetSuggestionError(Exception e); + /// The data to collect. + public void OnGetSuggestion(GetSuggestionTelemetryData telemetryData); } } diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/ITelemetryData.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/ITelemetryData.cs new file mode 100644 index 000000000000..a326a328cbe8 --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/ITelemetryData.cs @@ -0,0 +1,33 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry +{ + /// + /// An interface that all telemetry data class should implement. + /// + public interface ITelemetryData + { + /// + /// Gets the session id. + /// + string SessionId { get; } + + /// + /// Gets the correlation id. + /// + string CorrelationId { get; } + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/RequestPredictionTelemetryData.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/RequestPredictionTelemetryData.cs new file mode 100644 index 000000000000..dec60ac2f21c --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/RequestPredictionTelemetryData.cs @@ -0,0 +1,61 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry +{ + /// + /// The data to collect in . + /// + public sealed class RequestPredictionTelemetryData : ITelemetryData + { + /// + public string SessionId { get; internal set; } + + /// + public string CorrelationId { get; internal set; } + + /// + /// Gets the masked command lines that are used to request prediction. + /// + public string Commands { get; } // "Get-AzContext\nGet-AzVM" /predictions + + /// + /// Gets whether the http request to the service is sent. + /// + public bool HasSentHttpRequest { get; } + + /// + /// Gets the exception if there is an error during the operation. + /// + /// + /// OperationCanceledException isn't considered an error. + /// + public Exception Exception { get; } + + /// + /// Creates an instance of . + /// + /// The commands to request prediction for. + /// The flag to indicate whether the http request is canceled. + /// The exception that may be thrown. + public RequestPredictionTelemetryData(string commands, bool hasSentHttpRequest, Exception exception) + { + Commands = commands; + HasSentHttpRequest = hasSentHttpRequest; + Exception = exception; + } + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/SuggestionAcceptedTelemetryData.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/SuggestionAcceptedTelemetryData.cs new file mode 100644 index 000000000000..133e97edfcbb --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Telemetry/SuggestionAcceptedTelemetryData.cs @@ -0,0 +1,38 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry +{ + /// + /// The data to collect in . + /// + public sealed class SuggestionAcceptedTelemetryData : ITelemetryData + { + /// + public string SessionId { get; internal set; } + + /// + public string CorrelationId { get; internal set; } + + /// + /// Gets the suggestion that's accepted by the user. + /// + public string Suggestion { get; } + + /// + /// Creates a new instance of . + /// + public SuggestionAcceptedTelemetryData(string suggestion) => Suggestion = suggestion; + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Utilities/CommandLineUtilities.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Utilities/CommandLineUtilities.cs new file mode 100644 index 000000000000..06ae716a0b47 --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Utilities/CommandLineUtilities.cs @@ -0,0 +1,98 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System.Linq; +using System.Management.Automation.Language; +using System.Text; + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities +{ + /// + /// A utility class for Az command line. + /// + internal static class CommandLineUtilities + { + /// + /// Masks the user input of any data, like names and locations. + /// Also alphabetizes the parameters to normalize them before sending + /// them to the model. + /// e.g., Get-AzContext -Name Hello -Location 'EastUS' => Get-AzContext -Location *** -Name *** + /// + /// The last user input command. + public static string MaskCommandLine(CommandAst cmdAst) + { + var commandElements = cmdAst?.CommandElements; + + if (commandElements == null) + { + return null; + } + + if (commandElements.Count == 1) + { + return cmdAst.Extent.Text; + } + + var sb = new StringBuilder(cmdAst.Extent.Text.Length); + _ = sb.Append(commandElements[0].ToString()); + var parameters = commandElements + .Skip(1) + .Where(element => element is CommandParameterAst) + .Cast() + .OrderBy(ast => ast.ParameterName); + + foreach (CommandParameterAst param in parameters) + { + _ = sb.Append(AzPredictorConstants.CommandParameterSeperator); + if (param.Argument != null) + { + // Parameter is in the form of `-Name:value` + _ = sb.Append(AzPredictorConstants.ParameterIndicator) + .Append(param.ParameterName) + .Append(AzPredictorConstants.ParameterValueSeperator) + .Append(AzPredictorConstants.ParameterValueMask); + } + else + { + // Parameter is in the form of `-Name value` + _ = sb.Append(AzPredictorConstants.ParameterIndicator) + .Append(param.ParameterName) + .Append(AzPredictorConstants.CommandParameterSeperator) + .Append(AzPredictorConstants.ParameterValueMask); + } + } + return sb.ToString(); + } + + /// + /// Escaping the prediction text is necessary because KnowledgeBase predicted suggestions. + /// such as "<PSSubnetConfig>" are incorrectly identified as pipe operators. + /// + /// The text to escape. + public static string EscapePredictionText(string text) + { + return text.Replace("<", "'<").Replace(">", ">'"); + } + + /// + /// Unescape the prediction text from . + /// We don't want to show the escaped one to the user. + /// + /// The text to unescape. + public static string UnescapePredictionText(string text) + { + return text.Replace("'<", "<").Replace(">'", ">"); + } + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Utilities/JsonUtilities.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Utilities/JsonUtilities.cs new file mode 100644 index 000000000000..2b8f5a0f22c9 --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Utilities/JsonUtilities.cs @@ -0,0 +1,73 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; +using System.Text.Encodings.Web; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities +{ + /// + /// A utility class for json serialization/deserialization. + /// + internal static class JsonUtilities + { + private sealed class VersionConverter : JsonConverter + { + public override Version Read (ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + if (Version.TryParse(reader.GetString(), out var version)) + { + return version; + } + + throw new JsonException(); + } + + public override void Write (Utf8JsonWriter writer, Version value, JsonSerializerOptions options) + { + writer.WriteStringValue(value.ToString()); + } + } + + /// + /// The default serialization options: + /// 1. Use camel case in the naming. + /// 2. Use string instead of number for enums. + /// + public static readonly JsonSerializerOptions DefaultSerializerOptions = new JsonSerializerOptions() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + Converters = + { + new JsonStringEnumConverter(JsonNamingPolicy.CamelCase), + new VersionConverter(), + }, + }; + + /// + /// The serialization options for sending the telemetry. + /// + /// + /// The options are based on except: + /// 1. Uses The result is treated as a string in the + /// telemetry and we don't want to use the default encoder which escape characters such as ', ", <, >, +. + /// + public static readonly JsonSerializerOptions TelemetrySerializerOptions = new JsonSerializerOptions(JsonUtilities.DefaultSerializerOptions) + { + Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, + }; + } +} diff --git a/tools/Az.Tools.Predictor/Az.Tools.Predictor/Utilities/Validation.cs b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Utilities/Validation.cs new file mode 100644 index 000000000000..f90016554830 --- /dev/null +++ b/tools/Az.Tools.Predictor/Az.Tools.Predictor/Utilities/Validation.cs @@ -0,0 +1,77 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using System; + +namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Utilities +{ + internal static class Validation + { + public static void CheckArgument(bool argumentCondition, string message = null) + { + if (!argumentCondition) + { + throw new ArgumentException(message); + } + } + + public static void CheckArgument(T arg, string message = null) where T : class + { + if (arg == null) + { + throw new ArgumentNullException(message); + } + } + + public static void CheckArgument(bool argumentCondition, string message = null) where TException : Exception, new() + { + if (!argumentCondition) + { + Validation.Throw(message); + } + } + + public static void CheckInvariant(bool variantCondition, string message = null) + { + Validation.CheckInvariant(variantCondition, message); + } + + public static void CheckInvariant(bool variantCondition, string message = null) where TException : Exception, new() + { + if (!variantCondition) + { + Validation.Throw(message); + } + } + + private static void Throw(string message) where TException : Exception, new() + { + if (string.IsNullOrEmpty(message)) + { + throw new TException(); + } + else + { + Type exType = typeof(TException); + Exception exception = Activator.CreateInstance(exType, message) as Exception; + + if (exception != null) + { + throw exception; + } + } + } + } +} + diff --git a/tools/Az.Tools.Predictor/MockPSConsole/Microsoft.PowerShell.PSReadLine.Polyfiller.dll b/tools/Az.Tools.Predictor/MockPSConsole/Microsoft.PowerShell.PSReadLine.Polyfiller.dll index 8831829a56e1..b3d88d992a05 100644 Binary files a/tools/Az.Tools.Predictor/MockPSConsole/Microsoft.PowerShell.PSReadLine.Polyfiller.dll and b/tools/Az.Tools.Predictor/MockPSConsole/Microsoft.PowerShell.PSReadLine.Polyfiller.dll differ diff --git a/tools/Az.Tools.Predictor/MockPSConsole/Microsoft.PowerShell.PSReadLine2.dll b/tools/Az.Tools.Predictor/MockPSConsole/Microsoft.PowerShell.PSReadLine2.dll index 3c35b62ccd6c..3b368278919a 100644 Binary files a/tools/Az.Tools.Predictor/MockPSConsole/Microsoft.PowerShell.PSReadLine2.dll and b/tools/Az.Tools.Predictor/MockPSConsole/Microsoft.PowerShell.PSReadLine2.dll differ