diff --git a/common/searchattribute/encode_value.go b/common/searchattribute/encode_value.go index a49fd99963d..e1e2da9d401 100644 --- a/common/searchattribute/encode_value.go +++ b/common/searchattribute/encode_value.go @@ -25,8 +25,10 @@ package searchattribute import ( + "errors" "fmt" "time" + "unicode/utf8" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" @@ -34,6 +36,8 @@ import ( "go.temporal.io/server/common/payload" ) +var ErrInvalidString = errors.New("SearchAttribute value is not a valid UTF-8 string") + // EncodeValue encodes search attribute value and IndexedValueType to Payload. func EncodeValue(val interface{}, t enumspb.IndexedValueType) (*commonpb.Payload, error) { valPayload, err := payload.Encode(val) @@ -72,16 +76,37 @@ func DecodeValue( case enumspb.INDEXED_VALUE_TYPE_INT: return decodeValueTyped[int64](value, allowList) case enumspb.INDEXED_VALUE_TYPE_KEYWORD: - return decodeValueTyped[string](value, allowList) + return validateStrings(decodeValueTyped[string](value, allowList)) case enumspb.INDEXED_VALUE_TYPE_TEXT: - return decodeValueTyped[string](value, allowList) + return validateStrings(decodeValueTyped[string](value, allowList)) case enumspb.INDEXED_VALUE_TYPE_KEYWORD_LIST: - return decodeValueTyped[[]string](value, false) + return validateStrings(decodeValueTyped[[]string](value, false)) default: return nil, fmt.Errorf("%w: %v", ErrInvalidType, t) } } +func validateStrings(anyValue any, err error) (any, error) { + if err != nil { + return anyValue, err + } + + // validate strings + switch value := anyValue.(type) { + case string: + if !utf8.ValidString(value) { + return nil, fmt.Errorf("%w: %s", ErrInvalidString, value) + } + case []string: + for _, item := range value { + if !utf8.ValidString(item) { + return nil, fmt.Errorf("%w: %s", ErrInvalidString, item) + } + } + } + return anyValue, err +} + // decodeValueTyped tries to decode to the given type. // If the input is a list and allowList is false, then it will return only the first element. // If the input is a list and allowList is true, then it will return the decoded list. diff --git a/common/searchattribute/encode_value_test.go b/common/searchattribute/encode_value_test.go index a3416983171..c8b2d543abb 100644 --- a/common/searchattribute/encode_value_test.go +++ b/common/searchattribute/encode_value_test.go @@ -25,6 +25,7 @@ package searchattribute import ( + "errors" "testing" "time" @@ -372,3 +373,21 @@ func Test_EncodeValue(t *testing.T) { "Datetime Search Attribute is expected to be encoded in RFC 3339 format") s.Equal("Datetime", string(encodedPayload.Metadata["type"])) } + +func Test_ValidateStrings(t *testing.T) { + _, err := validateStrings("anything here", errors.New("test error")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "test error") + + _, err = validateStrings("\x87\x01", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "is not a valid UTF-8 string") + + value, err := validateStrings("anything here", nil) + assert.Nil(t, err) + assert.Equal(t, "anything here", value) + + _, err = validateStrings([]string{"abc", "\x87\x01"}, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "is not a valid UTF-8 string") +} diff --git a/common/util.go b/common/util.go index 1c1e6419909..843d8160b36 100644 --- a/common/util.go +++ b/common/util.go @@ -32,6 +32,7 @@ import ( "strings" "sync" "time" + "unicode/utf8" "github.com/dgryski/go-farm" commonpb "go.temporal.io/api/common/v1" @@ -818,3 +819,10 @@ func OverrideWorkflowTaskTimeout( func CloneProto[T proto.Message](v T) T { return proto.Clone(v).(T) } + +func ValidateUTF8String(fieldName string, strValue string) error { + if !utf8.ValidString(strValue) { + return serviceerror.NewInvalidArgument(fmt.Sprintf("%s %v is not a valid UTF-8 string", fieldName, strValue)) + } + return nil +} diff --git a/service/frontend/workflow_handler.go b/service/frontend/workflow_handler.go index b2e3f37caef..88c6a373c6f 100644 --- a/service/frontend/workflow_handler.go +++ b/service/frontend/workflow_handler.go @@ -32,7 +32,6 @@ import ( "strings" "sync/atomic" "time" - "unicode/utf8" "github.com/pborman/uuid" batchpb "go.temporal.io/api/batch/v1" @@ -377,22 +376,20 @@ func (wh *WorkflowHandler) StartWorkflowExecution(ctx context.Context, request * return nil, errWorkflowTypeTooLong } - if err := wh.validateTaskQueue(request.TaskQueue, namespaceName); err != nil { + if err := common.ValidateUTF8String("WorkflowType", request.WorkflowType.GetName()); err != nil { return nil, err } - if err := wh.validateStartWorkflowTimeouts(request); err != nil { + if err := wh.validateTaskQueue(request.TaskQueue, namespaceName); err != nil { return nil, err } - if request.GetRequestId() == "" { - // For easy direct API use, we default the request ID here but expect all - // SDKs and other auto-retrying clients to set it - request.RequestId = uuid.New() + if err := wh.validateStartWorkflowTimeouts(request); err != nil { + return nil, err } - if len(request.GetRequestId()) > wh.config.MaxIDLengthLimit() { - return nil, errRequestIDTooLong + if err := validateRequestId(&request.RequestId, wh.config.MaxIDLengthLimit()); err != nil { + return nil, err } sa, err := wh.unaliasedSearchAttributesFrom(request.GetSearchAttributes(), namespaceName) @@ -1686,13 +1683,16 @@ func (wh *WorkflowHandler) SignalWithStartWorkflowExecution(ctx context.Context, return nil, errWorkflowTypeTooLong } + if err := common.ValidateUTF8String("WorkflowType", request.WorkflowType.GetName()); err != nil { + return nil, err + } namespaceName := namespace.Name(request.GetNamespace()) if err := wh.validateTaskQueue(request.TaskQueue, namespaceName); err != nil { return nil, err } - if len(request.GetRequestId()) > wh.config.MaxIDLengthLimit() { - return nil, errRequestIDTooLong + if err := validateRequestId(&request.RequestId, wh.config.MaxIDLengthLimit()); err != nil { + return nil, err } if err := wh.validateSignalWithStartWorkflowTimeouts(request); err != nil { @@ -3838,6 +3838,9 @@ func (wh *WorkflowHandler) validateTaskQueue(t *taskqueuepb.TaskQueue, namespace if len(t.GetName()) > wh.config.MaxIDLengthLimit() { return errTaskQueueTooLong } + if err := common.ValidateUTF8String("TaskQueue", t.GetName()); err != nil { + return err + } if t.GetKind() == enumspb.TASK_QUEUE_KIND_UNSPECIFIED { wh.logger.Warn("Unspecified task queue kind", @@ -3881,6 +3884,9 @@ func (wh *WorkflowHandler) validateBuildIdCompatibilityUpdate( errDeets = append(errDeets, fmt.Sprintf(" Worker build IDs to be no larger than %v characters", wh.config.WorkerBuildIdSizeLimit())) } + if err := common.ValidateUTF8String("BuildId", id); err != nil { + errDeets = append(errDeets, err.Error()) + } } if req.GetNamespace() == "" { @@ -4043,6 +4049,24 @@ func (wh *WorkflowHandler) validateRetryPolicy(namespaceName namespace.Name, ret return common.ValidateRetryPolicy(retryPolicy) } +func validateRequestId(requestID *string, lenLimit int) error { + if requestID == nil { + // should never happen, but just in case. + return serviceerror.NewInvalidArgument("RequestId is nil") + } + if *requestID == "" { + // For easy direct API use, we default the request ID here but expect all + // SDKs and other auto-retrying clients to set it + *requestID = uuid.New() + } + + if len(*requestID) > lenLimit { + return errRequestIDTooLong + } + + return common.ValidateUTF8String("RequestId", *requestID) +} + func (wh *WorkflowHandler) validateStartWorkflowTimeouts( request *workflowservice.StartWorkflowExecutionRequest, ) error { @@ -4101,7 +4125,7 @@ func (wh *WorkflowHandler) metricsScope(ctx context.Context) metrics.Handler { func (wh *WorkflowHandler) validateNamespace( namespace string, ) error { - if err := wh.validateUTF8String(namespace); err != nil { + if err := common.ValidateUTF8String("Namespace", namespace); err != nil { return err } if len(namespace) > wh.config.MaxIDLengthLimit() { @@ -4116,7 +4140,7 @@ func (wh *WorkflowHandler) validateWorkflowID( if workflowID == "" { return errWorkflowIDNotSet } - if err := wh.validateUTF8String(workflowID); err != nil { + if err := common.ValidateUTF8String("WorkflowId", workflowID); err != nil { return err } if len(workflowID) > wh.config.MaxIDLengthLimit() { @@ -4125,15 +4149,6 @@ func (wh *WorkflowHandler) validateWorkflowID( return nil } -func (wh *WorkflowHandler) validateUTF8String( - str string, -) error { - if !utf8.ValidString(str) { - return serviceerror.NewInvalidArgument(fmt.Sprintf("%v is not a valid UTF-8 string", str)) - } - return nil -} - func (wh *WorkflowHandler) canonicalizeScheduleSpec(schedule *schedpb.Schedule) error { if schedule.Spec == nil { schedule.Spec = &schedpb.ScheduleSpec{} diff --git a/service/frontend/workflow_handler_test.go b/service/frontend/workflow_handler_test.go index c219fe22211..e9755ed7712 100644 --- a/service/frontend/workflow_handler_test.go +++ b/service/frontend/workflow_handler_test.go @@ -2535,6 +2535,18 @@ func TestContextNearDeadline(t *testing.T) { assert.False(t, contextNearDeadline(ctx, time.Millisecond)) } +func TestValidateRequestId(t *testing.T) { + req := workflowservice.StartWorkflowExecutionRequest{RequestId: ""} + err := validateRequestId(&req.RequestId, 100) + assert.Nil(t, err) + assert.Len(t, req.RequestId, 36) // new UUID length + + req.RequestId = "\x87\x01" + err = validateRequestId(&req.RequestId, 100) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not a valid UTF-8 string") +} + func (s *workflowHandlerSuite) Test_DeleteWorkflowExecution() { config := s.newConfig() wh := s.getWorkflowHandler(config) diff --git a/service/history/command_checker.go b/service/history/command_checker.go index 6118d1d0d9d..bd42b888a35 100644 --- a/service/history/command_checker.go +++ b/service/history/command_checker.go @@ -410,7 +410,10 @@ func (v *commandAttrValidator) validateTimerScheduleAttributes( return failedCause, serviceerror.NewInvalidArgument("TimerId is not set on StartTimerCommand.") } if len(timerID) > v.maxIDLengthLimit { - return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("TimerID on StartTimerCommand exceeds length limit. TimerId=%s Length=%d Limit=%d", timerID, len(timerID), v.maxIDLengthLimit)) + return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("TimerId on StartTimerCommand exceeds length limit. TimerId=%s Length=%d Limit=%d", timerID, len(timerID), v.maxIDLengthLimit)) + } + if err := common.ValidateUTF8String("TimerId", timerID); err != nil { + return failedCause, err } if err := timer.ValidateAndCapTimer(attributes.GetStartToFireTimeout()); err != nil { return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("An invalid StartToFireTimeout is set on StartTimerCommand: %v. TimerId=%s", err, timerID)) @@ -540,6 +543,9 @@ func (v *commandAttrValidator) validateCancelExternalWorkflowExecutionAttributes if len(workflowID) > v.maxIDLengthLimit { return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowId on RequestCancelExternalWorkflowExecutionCommand exceeds length limit. WorkflowId=%s Length=%d Limit=%d RunId=%s Namespace=%s", workflowID, len(workflowID), v.maxIDLengthLimit, runID, ns)) } + if err := common.ValidateUTF8String("WorkflowId", workflowID); err != nil { + return failedCause, err + } if runID != "" && uuid.Parse(runID) == nil { return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid RunId set on RequestCancelExternalWorkflowExecutionCommand. WorkflowId=%s RunId=%s Namespace=%s", workflowID, runID, ns)) } @@ -585,6 +591,9 @@ func (v *commandAttrValidator) validateSignalExternalWorkflowExecutionAttributes if len(workflowID) > v.maxIDLengthLimit { return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowId on SignalExternalWorkflowExecutionCommand exceeds length limit. WorkflowId=%s Length=%d Limit=%d Namespace=%s RunId=%s SignalName=%s", workflowID, len(workflowID), v.maxIDLengthLimit, ns, targetRunID, signalName)) } + if err := common.ValidateUTF8String("WorkflowId", workflowID); err != nil { + return failedCause, err + } if targetRunID != "" && uuid.Parse(targetRunID) == nil { return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid RunId set on SignalExternalWorkflowExecutionCommand. WorkflowId=%s Namespace=%s RunId=%s SignalName=%s", workflowID, ns, targetRunID, signalName)) } @@ -748,6 +757,14 @@ func (v *commandAttrValidator) validateStartChildExecutionAttributes( return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowType on StartChildWorkflowExecutionCommand exceeds length limit. WorkflowId=%s WorkflowType=%s Length=%d Limit=%d Namespace=%s", wfID, wfType, len(wfType), v.maxIDLengthLimit, ns)) } + if err := common.ValidateUTF8String("WorkflowId", wfID); err != nil { + return failedCause, err + } + + if err := common.ValidateUTF8String("WorkflowType", wfType); err != nil { + return failedCause, err + } + if err := timer.ValidateAndCapTimer(attributes.GetWorkflowExecutionTimeout()); err != nil { return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid WorkflowExecutionTimeout on StartChildWorkflowExecutionCommand: %v. WorkflowId=%s WorkflowType=%s Namespace=%s", err, wfID, wfType, ns)) } @@ -813,6 +830,10 @@ func (v *commandAttrValidator) validateTaskQueue( return taskQueue, serviceerror.NewInvalidArgument(fmt.Sprintf("task queue name exceeds length limit of %v", v.maxIDLengthLimit)) } + if err := common.ValidateUTF8String("TaskQueue", name); err != nil { + return taskQueue, err + } + if strings.HasPrefix(name, reservedTaskQueuePrefix) { return taskQueue, serviceerror.NewInvalidArgument(fmt.Sprintf("task queue name cannot start with reserved prefix %v", reservedTaskQueuePrefix)) }