diff --git a/flyteadmin/pkg/manager/impl/execution_manager.go b/flyteadmin/pkg/manager/impl/execution_manager.go index e857842633..cfe757b053 100644 --- a/flyteadmin/pkg/manager/impl/execution_manager.go +++ b/flyteadmin/pkg/manager/impl/execution_manager.go @@ -1173,7 +1173,7 @@ func (m *ExecutionManager) GetExecutionData( } maxDataSize := m.config.ApplicationConfiguration().GetRemoteDataConfig().MaxSizeInBytes remoteDataScheme := m.config.ApplicationConfiguration().GetRemoteDataConfig().Scheme - if remoteDataScheme == common.Local || inputsURLBlob.Bytes < maxDataSize { + if util.ShouldFetchData(m.config.ApplicationConfiguration().GetRemoteDataConfig(), inputsURLBlob) { var fullInputs core.LiteralMap err := m.storageClient.ReadProtobuf(ctx, executionModel.InputsURI, &fullInputs) if err != nil { @@ -1181,7 +1181,7 @@ func (m *ExecutionManager) GetExecutionData( } response.FullInputs = &fullInputs } - if remoteDataScheme == common.Local || (signedOutputsURLBlob.Bytes < maxDataSize && execution.Closure.GetOutputs() != nil) { + if remoteDataScheme == common.Local || remoteDataScheme == common.None || (signedOutputsURLBlob.Bytes < maxDataSize && execution.Closure.GetOutputs() != nil) { var fullOutputs core.LiteralMap outputsURI := execution.Closure.GetOutputs().GetUri() err := m.storageClient.ReadProtobuf(ctx, storage.DataReference(outputsURI), &fullOutputs) diff --git a/flyteadmin/pkg/manager/impl/node_execution_manager.go b/flyteadmin/pkg/manager/impl/node_execution_manager.go index f2799b2ade..41f09f728b 100644 --- a/flyteadmin/pkg/manager/impl/node_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/node_execution_manager.go @@ -452,9 +452,7 @@ func (m *NodeExecutionManager) GetNodeExecutionData( Inputs: &signedInputsURLBlob, Outputs: &signedOutputsURLBlob, } - maxDataSize := m.config.ApplicationConfiguration().GetRemoteDataConfig().MaxSizeInBytes - remoteDataScheme := m.config.ApplicationConfiguration().GetRemoteDataConfig().Scheme - if remoteDataScheme == common.Local || signedInputsURLBlob.Bytes < maxDataSize { + if util.ShouldFetchData(m.config.ApplicationConfiguration().GetRemoteDataConfig(), signedInputsURLBlob) { var fullInputs core.LiteralMap err := m.storageClient.ReadProtobuf(ctx, storage.DataReference(nodeExecution.InputUri), &fullInputs) if err != nil { @@ -462,7 +460,8 @@ func (m *NodeExecutionManager) GetNodeExecutionData( } response.FullInputs = &fullInputs } - if remoteDataScheme == common.Local || (signedOutputsURLBlob.Bytes < maxDataSize && len(nodeExecution.Closure.GetOutputUri()) > 0) { + if util.ShouldFetchOutputData(m.config.ApplicationConfiguration().GetRemoteDataConfig(), signedOutputsURLBlob, + nodeExecution.Closure.GetOutputUri()) { var fullOutputs core.LiteralMap err := m.storageClient.ReadProtobuf(ctx, storage.DataReference(nodeExecution.Closure.GetOutputUri()), &fullOutputs) if err != nil { diff --git a/flyteadmin/pkg/manager/impl/task_execution_manager.go b/flyteadmin/pkg/manager/impl/task_execution_manager.go index 425efaf638..28166b9547 100644 --- a/flyteadmin/pkg/manager/impl/task_execution_manager.go +++ b/flyteadmin/pkg/manager/impl/task_execution_manager.go @@ -302,9 +302,7 @@ func (m *TaskExecutionManager) GetTaskExecutionData( Inputs: &signedInputsURLBlob, Outputs: &signedOutputsURLBlob, } - maxDataSize := m.config.ApplicationConfiguration().GetRemoteDataConfig().MaxSizeInBytes - remoteDataScheme := m.config.ApplicationConfiguration().GetRemoteDataConfig().Scheme - if remoteDataScheme == common.Local || signedInputsURLBlob.Bytes < maxDataSize { + if util.ShouldFetchData(m.config.ApplicationConfiguration().GetRemoteDataConfig(), signedInputsURLBlob) { var fullInputs core.LiteralMap err := m.storageClient.ReadProtobuf(ctx, storage.DataReference(taskExecution.InputUri), &fullInputs) if err != nil { @@ -312,7 +310,8 @@ func (m *TaskExecutionManager) GetTaskExecutionData( } response.FullInputs = &fullInputs } - if remoteDataScheme == common.Local || (signedOutputsURLBlob.Bytes < maxDataSize && len(taskExecution.Closure.GetOutputUri()) > 0) { + if util.ShouldFetchOutputData(m.config.ApplicationConfiguration().GetRemoteDataConfig(), signedOutputsURLBlob, + taskExecution.Closure.GetOutputUri()) { var fullOutputs core.LiteralMap err := m.storageClient.ReadProtobuf(ctx, storage.DataReference(taskExecution.Closure.GetOutputUri()), &fullOutputs) if err != nil { diff --git a/flyteadmin/pkg/manager/impl/util/data.go b/flyteadmin/pkg/manager/impl/util/data.go new file mode 100644 index 0000000000..0fb65e55b0 --- /dev/null +++ b/flyteadmin/pkg/manager/impl/util/data.go @@ -0,0 +1,15 @@ +package util + +import ( + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" +) + +func ShouldFetchData(config *interfaces.RemoteDataConfig, urlBlob admin.UrlBlob) bool { + return config.Scheme == common.Local || config.Scheme == common.None || urlBlob.Bytes < config.MaxSizeInBytes +} + +func ShouldFetchOutputData(config *interfaces.RemoteDataConfig, urlBlob admin.UrlBlob, outputURI string) bool { + return ShouldFetchData(config, urlBlob) && len(outputURI) > 0 +} diff --git a/flyteadmin/pkg/manager/impl/util/data_test.go b/flyteadmin/pkg/manager/impl/util/data_test.go new file mode 100644 index 0000000000..d9c8a6b9dd --- /dev/null +++ b/flyteadmin/pkg/manager/impl/util/data_test.go @@ -0,0 +1,72 @@ +package util + +import ( + "testing" + + "github.com/flyteorg/flyteadmin/pkg/common" + "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/stretchr/testify/assert" +) + +func TestShouldFetchData(t *testing.T) { + t.Run("local config", func(t *testing.T) { + assert.True(t, ShouldFetchData(&interfaces.RemoteDataConfig{ + Scheme: common.Local, + MaxSizeInBytes: 100, + }, admin.UrlBlob{ + Bytes: 200, + })) + }) + t.Run("no config", func(t *testing.T) { + assert.True(t, ShouldFetchData(&interfaces.RemoteDataConfig{ + Scheme: common.None, + MaxSizeInBytes: 100, + }, admin.UrlBlob{ + Bytes: 200, + })) + }) + t.Run("max size under limit", func(t *testing.T) { + assert.True(t, ShouldFetchData(&interfaces.RemoteDataConfig{ + Scheme: common.AWS, + MaxSizeInBytes: 1000, + }, admin.UrlBlob{ + Bytes: 200, + })) + }) + t.Run("max size over limit", func(t *testing.T) { + assert.False(t, ShouldFetchData(&interfaces.RemoteDataConfig{ + Scheme: common.AWS, + MaxSizeInBytes: 100, + }, admin.UrlBlob{ + Bytes: 200, + })) + }) +} + +func TestShouldFetchOutputData(t *testing.T) { + t.Run("local config", func(t *testing.T) { + assert.True(t, ShouldFetchOutputData(&interfaces.RemoteDataConfig{ + Scheme: common.Local, + MaxSizeInBytes: 100, + }, admin.UrlBlob{ + Bytes: 200, + }, "s3://foo/bar.txt")) + }) + t.Run("max size under limit", func(t *testing.T) { + assert.True(t, ShouldFetchOutputData(&interfaces.RemoteDataConfig{ + Scheme: common.AWS, + MaxSizeInBytes: 1000, + }, admin.UrlBlob{ + Bytes: 200, + }, "s3://foo/bar.txt")) + }) + t.Run("output uri empty", func(t *testing.T) { + assert.False(t, ShouldFetchOutputData(&interfaces.RemoteDataConfig{ + Scheme: common.AWS, + MaxSizeInBytes: 1000, + }, admin.UrlBlob{ + Bytes: 200, + }, "")) + }) +}