diff --git a/go/tasks/pluginmachinery/core/template/template_test.go b/go/tasks/pluginmachinery/core/template/template_test.go index 3f1f8460a..7541ada57 100644 --- a/go/tasks/pluginmachinery/core/template/template_test.go +++ b/go/tasks/pluginmachinery/core/template/template_test.go @@ -49,6 +49,10 @@ func (d dummyOutputPaths) GetDeckPath() storage.DataReference { panic("should not be called") } +func (d dummyOutputPaths) GetSpanPath() storage.DataReference { + panic("should not be called") +} + func (d dummyOutputPaths) GetPreviousCheckpointsPrefix() storage.DataReference { return d.prevCheckpointPath } diff --git a/go/tasks/pluginmachinery/io/iface.go b/go/tasks/pluginmachinery/io/iface.go index 15d5adefc..4988d19c9 100644 --- a/go/tasks/pluginmachinery/io/iface.go +++ b/go/tasks/pluginmachinery/io/iface.go @@ -43,6 +43,8 @@ type OutputReader interface { Read(ctx context.Context) (*core.LiteralMap, *ExecutionError, error) // DeckExists checks if the deck file has been generated. DeckExists(ctx context.Context) (bool, error) + // SpanExists checks if the span file has been generated. + SpanExists(ctx context.Context) (bool, error) } // CheckpointPaths provides the paths / keys to input Checkpoints directory and an output checkpoints directory. @@ -81,6 +83,8 @@ type OutputFilePaths interface { GetOutputPath() storage.DataReference // GetDeckPath returns a fully qualified path (URN) to where the framework expects the deck.html to exist in the configured storage backend GetDeckPath() storage.DataReference + // GetSpanPath returns a fully qualified path (URN) to where the framework expects the span.html to exist in the configured storage backend + GetSpanPath() storage.DataReference // GetErrorPath returns a fully qualified path (URN) where the error information should be placed as a protobuf core.ErrorDocument. It is not directly // used by the framework, but could be used in the future GetErrorPath() storage.DataReference diff --git a/go/tasks/pluginmachinery/io/mocks/output_file_paths.go b/go/tasks/pluginmachinery/io/mocks/output_file_paths.go index d1018c8f0..a62cf50c5 100644 --- a/go/tasks/pluginmachinery/io/mocks/output_file_paths.go +++ b/go/tasks/pluginmachinery/io/mocks/output_file_paths.go @@ -235,3 +235,35 @@ func (_m *OutputFilePaths) GetRawOutputPrefix() storage.DataReference { return r0 } + +type OutputFilePaths_GetSpanPath struct { + *mock.Call +} + +func (_m OutputFilePaths_GetSpanPath) Return(_a0 storage.DataReference) *OutputFilePaths_GetSpanPath { + return &OutputFilePaths_GetSpanPath{Call: _m.Call.Return(_a0)} +} + +func (_m *OutputFilePaths) OnGetSpanPath() *OutputFilePaths_GetSpanPath { + c_call := _m.On("GetSpanPath") + return &OutputFilePaths_GetSpanPath{Call: c_call} +} + +func (_m *OutputFilePaths) OnGetSpanPathMatch(matchers ...interface{}) *OutputFilePaths_GetSpanPath { + c_call := _m.On("GetSpanPath", matchers...) + return &OutputFilePaths_GetSpanPath{Call: c_call} +} + +// GetSpanPath provides a mock function with given fields: +func (_m *OutputFilePaths) GetSpanPath() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} diff --git a/go/tasks/pluginmachinery/io/mocks/output_reader.go b/go/tasks/pluginmachinery/io/mocks/output_reader.go index ed7e671e2..fff729e2a 100644 --- a/go/tasks/pluginmachinery/io/mocks/output_reader.go +++ b/go/tasks/pluginmachinery/io/mocks/output_reader.go @@ -253,3 +253,42 @@ func (_m *OutputReader) ReadError(ctx context.Context) (io.ExecutionError, error return r0, r1 } + +type OutputReader_SpanExists struct { + *mock.Call +} + +func (_m OutputReader_SpanExists) Return(_a0 bool, _a1 error) *OutputReader_SpanExists { + return &OutputReader_SpanExists{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *OutputReader) OnSpanExists(ctx context.Context) *OutputReader_SpanExists { + c_call := _m.On("SpanExists", ctx) + return &OutputReader_SpanExists{Call: c_call} +} + +func (_m *OutputReader) OnSpanExistsMatch(matchers ...interface{}) *OutputReader_SpanExists { + c_call := _m.On("SpanExists", matchers...) + return &OutputReader_SpanExists{Call: c_call} +} + +// SpanExists provides a mock function with given fields: ctx +func (_m *OutputReader) SpanExists(ctx context.Context) (bool, error) { + ret := _m.Called(ctx) + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context) bool); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/go/tasks/pluginmachinery/io/mocks/output_writer.go b/go/tasks/pluginmachinery/io/mocks/output_writer.go index 038681623..689418fa5 100644 --- a/go/tasks/pluginmachinery/io/mocks/output_writer.go +++ b/go/tasks/pluginmachinery/io/mocks/output_writer.go @@ -240,6 +240,38 @@ func (_m *OutputWriter) GetRawOutputPrefix() storage.DataReference { return r0 } +type OutputWriter_GetSpanPath struct { + *mock.Call +} + +func (_m OutputWriter_GetSpanPath) Return(_a0 storage.DataReference) *OutputWriter_GetSpanPath { + return &OutputWriter_GetSpanPath{Call: _m.Call.Return(_a0)} +} + +func (_m *OutputWriter) OnGetSpanPath() *OutputWriter_GetSpanPath { + c_call := _m.On("GetSpanPath") + return &OutputWriter_GetSpanPath{Call: c_call} +} + +func (_m *OutputWriter) OnGetSpanPathMatch(matchers ...interface{}) *OutputWriter_GetSpanPath { + c_call := _m.On("GetSpanPath", matchers...) + return &OutputWriter_GetSpanPath{Call: c_call} +} + +// GetSpanPath provides a mock function with given fields: +func (_m *OutputWriter) GetSpanPath() storage.DataReference { + ret := _m.Called() + + var r0 storage.DataReference + if rf, ok := ret.Get(0).(func() storage.DataReference); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(storage.DataReference) + } + + return r0 +} + type OutputWriter_Put struct { *mock.Call } diff --git a/go/tasks/pluginmachinery/ioutils/in_memory_output_reader.go b/go/tasks/pluginmachinery/ioutils/in_memory_output_reader.go index 7dd10698d..3d99181e5 100644 --- a/go/tasks/pluginmachinery/ioutils/in_memory_output_reader.go +++ b/go/tasks/pluginmachinery/ioutils/in_memory_output_reader.go @@ -13,6 +13,7 @@ import ( type InMemoryOutputReader struct { literals *core.LiteralMap DeckPath *storage.DataReference + SpanPath *storage.DataReference err *io.ExecutionError } @@ -47,6 +48,11 @@ func (r InMemoryOutputReader) DeckExists(_ context.Context) (bool, error) { return r.DeckPath != nil, nil } +func (r InMemoryOutputReader) SpanExists(_ context.Context) (bool, error) { + return r.SpanPath != nil, nil +} + +// Deprecated: NewInMemoryOutputReader is deprecated. Use NewInMemoryOutputReaderWithSpan instead. func NewInMemoryOutputReader(literals *core.LiteralMap, DeckPath *storage.DataReference, err *io.ExecutionError) InMemoryOutputReader { return InMemoryOutputReader{ literals: literals, @@ -54,3 +60,12 @@ func NewInMemoryOutputReader(literals *core.LiteralMap, DeckPath *storage.DataRe err: err, } } + +func NewInMemoryOutputReaderWithSpan(literals *core.LiteralMap, DeckPath *storage.DataReference, SpanPath *storage.DataReference, err *io.ExecutionError) InMemoryOutputReader { + return InMemoryOutputReader{ + literals: literals, + DeckPath: DeckPath, + SpanPath: SpanPath, + err: err, + } +} diff --git a/go/tasks/pluginmachinery/ioutils/in_memory_output_reader_test.go b/go/tasks/pluginmachinery/ioutils/in_memory_output_reader_test.go index 3cae110fd..050d94d82 100644 --- a/go/tasks/pluginmachinery/ioutils/in_memory_output_reader_test.go +++ b/go/tasks/pluginmachinery/ioutils/in_memory_output_reader_test.go @@ -10,7 +10,8 @@ import ( ) func TestInMemoryOutputReader(t *testing.T) { - deckPath := storage.DataReference("s3://bucket/key") + deckPath := storage.DataReference("s3://bucket/key/deck.html") + spanPath := storage.DataReference("s3://bucket/key/span.pb") lt := map[string]*flyteIdlCore.Literal{ "results": { Value: &flyteIdlCore.Literal_Scalar{ @@ -22,9 +23,10 @@ func TestInMemoryOutputReader(t *testing.T) { }, }, } - or := NewInMemoryOutputReader(&flyteIdlCore.LiteralMap{Literals: lt}, &deckPath, nil) + or := NewInMemoryOutputReaderWithSpan(&flyteIdlCore.LiteralMap{Literals: lt}, &deckPath, &spanPath, nil) assert.Equal(t, &deckPath, or.DeckPath) + assert.Equal(t, &spanPath, or.SpanPath) ctx := context.TODO() ok, err := or.IsError(ctx) diff --git a/go/tasks/pluginmachinery/ioutils/paths.go b/go/tasks/pluginmachinery/ioutils/paths.go index e50aa5484..74dd5b4fb 100644 --- a/go/tasks/pluginmachinery/ioutils/paths.go +++ b/go/tasks/pluginmachinery/ioutils/paths.go @@ -23,6 +23,8 @@ const ( // deckSuffix specifies that deck file are assumed to be written to this "file"/"suffix" under the given prefix // The deck file has a format of HTML deckSuffix = "deck.html" + // spanSuffix specifies that span file are assumed to be written to this "file"/"suffix" under the given prefix + spanSuffix = "span.pb" // ErrorsSuffix specifies that the errors are written to this prefix/file under the given prefix. The Error File // has a format of core.ErrorDocument ErrorsSuffix = "error.pb" diff --git a/go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go b/go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go index 35bd4fd1a..4db622546 100644 --- a/go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go +++ b/go/tasks/pluginmachinery/ioutils/remote_file_output_reader.go @@ -123,6 +123,14 @@ func (r RemoteFileOutputReader) DeckExists(ctx context.Context) (bool, error) { return md.Exists(), nil } +func (r RemoteFileOutputReader) SpanExists(ctx context.Context) (bool, error) { + md, err := r.store.Head(ctx, r.outPath.GetSpanPath()) + if err != nil { + return false, err + } + return md.Exists(), nil +} + func NewRemoteFileOutputReader(_ context.Context, store storage.ComposedProtobufStore, outPaths io.OutputFilePaths, maxDatasetSize int64) RemoteFileOutputReader { return RemoteFileOutputReader{ outPath: outPaths, diff --git a/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go b/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go index ee10638b4..21dac098a 100644 --- a/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go +++ b/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go @@ -32,7 +32,9 @@ func TestReadOrigin(t *testing.T) { opath := &pluginsIOMock.OutputFilePaths{} opath.OnGetErrorPath().Return("") deckPath := "deck.html" + spanPath := "span.pb" opath.OnGetDeckPath().Return(storage.DataReference(deckPath)) + opath.OnGetSpanPath().Return(storage.DataReference(spanPath)) t.Run("user", func(t *testing.T) { errorDoc := &core.ErrorDocument{ @@ -51,7 +53,10 @@ func TestReadOrigin(t *testing.T) { casted.Error = errorDoc.Error }).Return(nil) - store.OnHead(ctx, storage.DataReference("deck.html")).Return(MemoryMetadata{ + store.OnHead(ctx, storage.DataReference(deckPath)).Return(MemoryMetadata{ + exists: true, + }, nil) + store.OnHead(ctx, storage.DataReference(spanPath)).Return(MemoryMetadata{ exists: true, }, nil) @@ -68,6 +73,9 @@ func TestReadOrigin(t *testing.T) { exists, err := r.DeckExists(ctx) assert.NoError(t, err) assert.True(t, exists) + exists, err = r.SpanExists(ctx) + assert.NoError(t, err) + assert.True(t, exists) }) t.Run("system", func(t *testing.T) { diff --git a/go/tasks/pluginmachinery/ioutils/remote_file_output_writer.go b/go/tasks/pluginmachinery/ioutils/remote_file_output_writer.go index d60d0c2b6..02118da7c 100644 --- a/go/tasks/pluginmachinery/ioutils/remote_file_output_writer.go +++ b/go/tasks/pluginmachinery/ioutils/remote_file_output_writer.go @@ -39,6 +39,10 @@ func (w RemoteFileOutputPaths) GetDeckPath() storage.DataReference { return constructPath(w.store, w.outputPrefix, deckSuffix) } +func (w RemoteFileOutputPaths) GetSpanPath() storage.DataReference { + return constructPath(w.store, w.outputPrefix, spanSuffix) +} + func (w RemoteFileOutputPaths) GetErrorPath() storage.DataReference { return constructPath(w.store, w.outputPrefix, ErrorsSuffix) } diff --git a/go/tasks/pluginmachinery/ioutils/remote_file_output_writer_test.go b/go/tasks/pluginmachinery/ioutils/remote_file_output_writer_test.go index ecca892fc..01a18178f 100644 --- a/go/tasks/pluginmachinery/ioutils/remote_file_output_writer_test.go +++ b/go/tasks/pluginmachinery/ioutils/remote_file_output_writer_test.go @@ -33,6 +33,7 @@ func TestRemoteFileOutputWriter(t *testing.T) { assert.Equal(t, constructPath(memStore, rawOutputPrefix, CheckpointPrefix), checkpointPath.GetCheckpointPrefix()) assert.Equal(t, constructPath(memStore, outputPrefix, OutputsSuffix), checkpointPath.GetOutputPath()) assert.Equal(t, constructPath(memStore, outputPrefix, deckSuffix), checkpointPath.GetDeckPath()) + assert.Equal(t, constructPath(memStore, outputPrefix, spanSuffix), checkpointPath.GetSpanPath()) assert.Equal(t, constructPath(memStore, outputPrefix, ErrorsSuffix), checkpointPath.GetErrorPath()) assert.Equal(t, constructPath(memStore, outputPrefix, FuturesSuffix), checkpointPath.GetFuturesPath()) }) @@ -43,6 +44,7 @@ func TestRemoteFileOutputWriter(t *testing.T) { assert.Equal(t, constructPath(memStore, rawOutputPrefix, CheckpointPrefix), p.GetCheckpointPrefix()) assert.Equal(t, constructPath(memStore, outputPrefix, OutputsSuffix), p.GetOutputPath()) assert.Equal(t, constructPath(memStore, outputPrefix, deckSuffix), p.GetDeckPath()) + assert.Equal(t, constructPath(memStore, outputPrefix, spanSuffix), p.GetSpanPath()) assert.Equal(t, constructPath(memStore, outputPrefix, ErrorsSuffix), p.GetErrorPath()) }) } diff --git a/go/tasks/plugins/array/awsbatch/monitor.go b/go/tasks/plugins/array/awsbatch/monitor.go index 7d99f9be3..bec34ab25 100644 --- a/go/tasks/plugins/array/awsbatch/monitor.go +++ b/go/tasks/plugins/array/awsbatch/monitor.go @@ -107,7 +107,7 @@ func CheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionContext, job return nil, err } - if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(nil, nil, &io.ExecutionError{ + if err = ow.Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan(nil, nil, nil, &io.ExecutionError{ ExecutionError: &core2.ExecutionError{ Code: "", Message: subJob.Status.Message, diff --git a/go/tasks/plugins/array/k8s/management.go b/go/tasks/plugins/array/k8s/management.go index 32a3970b4..7d820426d 100644 --- a/go/tasks/plugins/array/k8s/management.go +++ b/go/tasks/plugins/array/k8s/management.go @@ -224,7 +224,7 @@ func LaunchAndCheckSubTasksState(ctx context.Context, tCtx core.TaskExecutionCon return currentState, externalResources, err } - if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(nil, nil, &io.ExecutionError{ + if err = ow.Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan(nil, nil, nil, &io.ExecutionError{ ExecutionError: phaseInfo.Err(), IsRecoverable: phaseInfo.Phase() != core.PhasePermanentFailure, })); err != nil { diff --git a/go/tasks/plugins/array/outputs.go b/go/tasks/plugins/array/outputs.go index ed9b08f39..5107bdde3 100644 --- a/go/tasks/plugins/array/outputs.go +++ b/go/tasks/plugins/array/outputs.go @@ -107,7 +107,7 @@ func (w assembleOutputsWorker) Process(ctx context.Context, workItem workqueue.W } ow := ioutils.NewRemoteFileOutputWriter(ctx, i.dataStore, i.outputPaths) - if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(finalOutputs, nil, nil)); err != nil { + if err = ow.Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan(finalOutputs, nil, nil, nil)); err != nil { return workqueue.WorkStatusNotDone, err } @@ -313,7 +313,7 @@ func (a assembleErrorsWorker) Process(ctx context.Context, workItem workqueue.Wo } ow := ioutils.NewRemoteFileOutputWriter(ctx, w.dataStore, w.outputPaths) - if err = ow.Put(ctx, ioutils.NewInMemoryOutputReader(nil, nil, &io.ExecutionError{ + if err = ow.Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan(nil, nil, nil, &io.ExecutionError{ ExecutionError: &core.ExecutionError{ Code: "", Message: msg, diff --git a/go/tasks/plugins/hive/execution_state.go b/go/tasks/plugins/hive/execution_state.go index d0f86f73d..bc0e637ca 100644 --- a/go/tasks/plugins/hive/execution_state.go +++ b/go/tasks/plugins/hive/execution_state.go @@ -515,7 +515,7 @@ func WriteOutputs(ctx context.Context, tCtx core.TaskExecutionContext, currentSt return currentState, errors.Errorf(errors.BadTaskSpecification, "A non-SchemaType was found [%v]", results.GetType()) } logger.Debugf(ctx, "Writing outputs file for Hive task at [%s]", tCtx.OutputWriter().GetOutputPrefixPath()) - err = tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader( + err = tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan( &idlCore.LiteralMap{ Literals: map[string]*idlCore.Literal{ "results": { @@ -530,7 +530,7 @@ func WriteOutputs(ctx context.Context, tCtx core.TaskExecutionContext, currentSt }, }, }, - }, nil, nil)) + }, nil, nil, nil)) if err != nil { logger.Errorf(ctx, "Error writing outputs file: [%s]", err) return currentState, err diff --git a/go/tasks/plugins/k8s/sagemaker/builtin_training.go b/go/tasks/plugins/k8s/sagemaker/builtin_training.go index 0116b4ebf..571064790 100644 --- a/go/tasks/plugins/k8s/sagemaker/builtin_training.go +++ b/go/tasks/plugins/k8s/sagemaker/builtin_training.go @@ -230,7 +230,7 @@ func (m awsSagemakerPlugin) getTaskPhaseForTrainingJob( return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to create outputs for the task") } // Instantiate a output reader with the literal map, and write the output to the remote location referred to by the OutputWriter - if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(outputLiteralMap, nil, nil)); err != nil { + if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan(outputLiteralMap, nil, nil, nil)); err != nil { return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "Unable to write output to the remote location") } logger.Debugf(ctx, "Successfully produced and returned outputs") diff --git a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go index dc6e02eed..bd4ed3dcb 100644 --- a/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go +++ b/go/tasks/plugins/k8s/sagemaker/hyperparameter_tuning.go @@ -257,7 +257,7 @@ func (m awsSagemakerPlugin) getTaskPhaseForHyperparameterTuningJob( logger.Errorf(ctx, "Failed to create outputs, err: %s", err) return pluginsCore.PhaseInfoUndefined, pluginErrors.Wrapf(pluginErrors.BadTaskSpecification, err, "failed to create outputs for the task") } - if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(out, nil, nil)); err != nil { + if err := pluginContext.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan(out, nil, nil, nil)); err != nil { return pluginsCore.PhaseInfoUndefined, err } logger.Debugf(ctx, "Successfully produced and returned outputs") diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 1803d3ded..c5e823b4e 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -447,7 +447,7 @@ func writeOutput(ctx context.Context, tCtx core.TaskExecutionContext, externalLo results := taskTemplate.Interface.Outputs.Variables["results"] - return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader( + return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan( &pb.LiteralMap{ Literals: map[string]*pb.Literal{ "results": { @@ -462,7 +462,7 @@ func writeOutput(ctx context.Context, tCtx core.TaskExecutionContext, externalLo }, }, }, - }, nil, nil)) + }, nil, nil, nil)) } // The 'PhaseInfoRunning' occurs 15 times (3 for each of the 5 Presto queries that get run for every Presto task) which diff --git a/go/tasks/plugins/webapi/agent/plugin.go b/go/tasks/plugins/webapi/agent/plugin.go index 70a335021..d29022491 100644 --- a/go/tasks/plugins/webapi/agent/plugin.go +++ b/go/tasks/plugins/webapi/agent/plugin.go @@ -134,7 +134,7 @@ func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase return core.PhaseInfoRetryableFailure(pluginErrors.TaskFailedWithError, "failed to run the job", taskInfo), nil case admin.State_SUCCEEDED: if resource.Outputs != nil { - err := taskCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(resource.Outputs, nil, nil)) + err := taskCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan(resource.Outputs, nil, nil, nil)) if err != nil { return core.PhaseInfoUndefined, err } diff --git a/go/tasks/plugins/webapi/athena/utils.go b/go/tasks/plugins/webapi/athena/utils.go index b7f9fd696..13164cbc0 100644 --- a/go/tasks/plugins/webapi/athena/utils.go +++ b/go/tasks/plugins/webapi/athena/utils.go @@ -33,7 +33,7 @@ func writeOutput(ctx context.Context, tCtx webapi.StatusContext, externalLocatio return nil } - return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader( + return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan( &pb.LiteralMap{ Literals: map[string]*pb.Literal{ "results": { @@ -48,7 +48,7 @@ func writeOutput(ctx context.Context, tCtx webapi.StatusContext, externalLocatio }, }, }, - }, nil, nil)) + }, nil, nil, nil)) } type QueryInfo struct { diff --git a/go/tasks/plugins/webapi/athena/utils_test.go b/go/tasks/plugins/webapi/athena/utils_test.go index 368644185..609a2982c 100644 --- a/go/tasks/plugins/webapi/athena/utils_test.go +++ b/go/tasks/plugins/webapi/athena/utils_test.go @@ -93,7 +93,7 @@ func Test_writeOutput(t *testing.T) { ow := &mocks3.OutputWriter{} externalLocation := "s3://my-external-bucket/key" - ow.OnPut(ctx, ioutils.NewInMemoryOutputReader( + ow.OnPut(ctx, ioutils.NewInMemoryOutputReaderWithSpan( &pb.LiteralMap{ Literals: map[string]*pb.Literal{ "results": { @@ -111,7 +111,7 @@ func Test_writeOutput(t *testing.T) { }, }, }, - }, nil, nil)).Return(nil) + }, nil, nil, nil)).Return(nil) statusContext.OnOutputWriter().Return(ow) err = writeOutput(context.Background(), statusContext, externalLocation) diff --git a/go/tasks/plugins/webapi/bigquery/config_flags.go b/go/tasks/plugins/webapi/bigquery/config_flags.go index ca6eaf2f1..765620e2f 100755 --- a/go/tasks/plugins/webapi/bigquery/config_flags.go +++ b/go/tasks/plugins/webapi/bigquery/config_flags.go @@ -58,7 +58,10 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "webApi.caching.resyncInterval"), defaultConfig.WebAPI.Caching.ResyncInterval.String(), "Defines the sync interval.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "webApi.caching.workers"), defaultConfig.WebAPI.Caching.Workers, "Defines the number of workers to start up to process items.") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "webApi.caching.maxSystemFailures"), defaultConfig.WebAPI.Caching.MaxSystemFailures, "Defines the number of failures to fetch a task before failing the task.") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "googleTokenSource.type"), defaultConfig.GoogleTokenSource.Type, "Defines type of TokenSourceFactory, possible values are 'default'") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "googleTokenSource.type"), defaultConfig.GoogleTokenSource.Type, "Defines type of TokenSourceFactory, possible values are 'default' and 'gke-task-workload-identity'") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "googleTokenSource.gke-task-workload-identity.remoteClusterConfig.name"), defaultConfig.GoogleTokenSource.GkeTaskWorkloadIdentityTokenSourceFactoryConfig.RemoteClusterConfig.Name, "Friendly name of the remote cluster") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "googleTokenSource.gke-task-workload-identity.remoteClusterConfig.endpoint"), defaultConfig.GoogleTokenSource.GkeTaskWorkloadIdentityTokenSourceFactoryConfig.RemoteClusterConfig.Endpoint, " Remote K8s cluster endpoint") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "googleTokenSource.gke-task-workload-identity.remoteClusterConfig.enabled"), defaultConfig.GoogleTokenSource.GkeTaskWorkloadIdentityTokenSourceFactoryConfig.RemoteClusterConfig.Enabled, " Boolean flag to enable or disable") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "bigQueryEndpoint"), defaultConfig.bigQueryEndpoint, "") return cmdFlags } diff --git a/go/tasks/plugins/webapi/bigquery/config_flags_test.go b/go/tasks/plugins/webapi/bigquery/config_flags_test.go index fd07a03fb..37f881e64 100755 --- a/go/tasks/plugins/webapi/bigquery/config_flags_test.go +++ b/go/tasks/plugins/webapi/bigquery/config_flags_test.go @@ -225,6 +225,48 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_googleTokenSource.gke-task-workload-identity.remoteClusterConfig.name", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("googleTokenSource.gke-task-workload-identity.remoteClusterConfig.name", testValue) + if vString, err := cmdFlags.GetString("googleTokenSource.gke-task-workload-identity.remoteClusterConfig.name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.GoogleTokenSource.GkeTaskWorkloadIdentityTokenSourceFactoryConfig.RemoteClusterConfig.Name) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_googleTokenSource.gke-task-workload-identity.remoteClusterConfig.endpoint", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("googleTokenSource.gke-task-workload-identity.remoteClusterConfig.endpoint", testValue) + if vString, err := cmdFlags.GetString("googleTokenSource.gke-task-workload-identity.remoteClusterConfig.endpoint"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.GoogleTokenSource.GkeTaskWorkloadIdentityTokenSourceFactoryConfig.RemoteClusterConfig.Endpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_googleTokenSource.gke-task-workload-identity.remoteClusterConfig.enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("googleTokenSource.gke-task-workload-identity.remoteClusterConfig.enabled", testValue) + if vBool, err := cmdFlags.GetBool("googleTokenSource.gke-task-workload-identity.remoteClusterConfig.enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.GoogleTokenSource.GkeTaskWorkloadIdentityTokenSourceFactoryConfig.RemoteClusterConfig.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_bigQueryEndpoint", func(t *testing.T) { t.Run("Override", func(t *testing.T) { diff --git a/go/tasks/plugins/webapi/bigquery/plugin.go b/go/tasks/plugins/webapi/bigquery/plugin.go index bc1f6df83..29fcc2ec1 100644 --- a/go/tasks/plugins/webapi/bigquery/plugin.go +++ b/go/tasks/plugins/webapi/bigquery/plugin.go @@ -333,7 +333,7 @@ func writeOutput(ctx context.Context, tCtx webapi.StatusContext, OutputLocation logger.Infof(ctx, "The task declares no outputs. Skipping writing the outputs.") return nil } - return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader( + return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReaderWithSpan( &flyteIdlCore.LiteralMap{ Literals: map[string]*flyteIdlCore.Literal{ "results": { @@ -351,7 +351,7 @@ func writeOutput(ctx context.Context, tCtx webapi.StatusContext, OutputLocation }, }, }, - }, nil, nil)) + }, nil, nil, nil)) } func handleCreateError(createError *googleapi.Error, taskInfo *core.TaskInfo) core.PhaseInfo {