Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation for a few string fields #5487

Merged
merged 5 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions common/searchattribute/encode_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import (
"fmt"
"time"
"unicode/utf8"

commonpb "go.temporal.io/api/common/v1"
enumspb "go.temporal.io/api/enums/v1"
Expand Down Expand Up @@ -72,16 +73,37 @@
case enumspb.INDEXED_VALUE_TYPE_INT:
return decodeValueTyped[int64](value, allowList)
case enumspb.INDEXED_VALUE_TYPE_KEYWORD:
return decodeValueTyped[string](value, allowList)
return validateStrings(decodeValueTyped[string](value, allowList))
case enumspb.INDEXED_VALUE_TYPE_TEXT:
return decodeValueTyped[string](value, allowList)
return validateStrings(decodeValueTyped[string](value, allowList))
case enumspb.INDEXED_VALUE_TYPE_KEYWORD_LIST:
return decodeValueTyped[[]string](value, false)
return validateStrings(decodeValueTyped[[]string](value, false))
default:
return nil, fmt.Errorf("%w: %v", ErrInvalidType, t)
}
}

func validateStrings(anyValue any, err error) (any, error) {
if err != nil {
return anyValue, err
}

// validate strings
switch value := anyValue.(type) {
case string:
if !utf8.ValidString(value) {
yiminc marked this conversation as resolved.
Show resolved Hide resolved
return nil, fmt.Errorf("%v is not a valid UTF-8 string", value)

Check failure on line 95 in common/searchattribute/encode_value.go

View workflow job for this annotation

GitHub Actions / lint

err113: do not define dynamic errors, use wrapped static errors instead: "fmt.Errorf(\"%v is not a valid UTF-8 string\", value)" (goerr113)
}
case []string:
for _, item := range value {
if !utf8.ValidString(item) {
return nil, fmt.Errorf("%v is not a valid UTF-8 string", item)

Check failure on line 100 in common/searchattribute/encode_value.go

View workflow job for this annotation

GitHub Actions / lint

err113: do not define dynamic errors, use wrapped static errors instead: "fmt.Errorf(\"%v is not a valid UTF-8 string\", item)" (goerr113)
}
}
}
return anyValue, err
}

// decodeValueTyped tries to decode to the given type.
// If the input is a list and allowList is false, then it will return only the first element.
// If the input is a list and allowList is true, then it will return the decoded list.
Expand Down
19 changes: 19 additions & 0 deletions common/searchattribute/encode_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
package searchattribute

import (
"errors"
"testing"
"time"

Expand Down Expand Up @@ -372,3 +373,21 @@
"Datetime Search Attribute is expected to be encoded in RFC 3339 format")
s.Equal("Datetime", string(encodedPayload.Metadata["type"]))
}

func Test_ValidateStrings(t *testing.T) {
_, err := validateStrings("anything here", errors.New("test error"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "test error")

_, err = validateStrings("\x87\x01", nil)
assert.Error(t, err)
assert.Contains(t, err.Error(), "is not a valid UTF-8 string")

value, err := validateStrings("anything here", nil)
assert.Nil(t, err)
assert.Equal(t, "anything here", value)

value, err = validateStrings([]string{"abc", "\x87\x01"}, nil)

Check failure on line 390 in common/searchattribute/encode_value_test.go

View workflow job for this annotation

GitHub Actions / lint

SA4006: this value of `value` is never used (staticcheck)
assert.Error(t, err)
assert.Contains(t, err.Error(), "is not a valid UTF-8 string")
}
8 changes: 8 additions & 0 deletions common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"strings"
"sync"
"time"
"unicode/utf8"

"github.com/dgryski/go-farm"
commonpb "go.temporal.io/api/common/v1"
Expand Down Expand Up @@ -818,3 +819,10 @@ func OverrideWorkflowTaskTimeout(
func CloneProto[T proto.Message](v T) T {
return proto.Clone(v).(T)
}

func ValidateUTF8String(fieldName string, strValue string) error {
if !utf8.ValidString(strValue) {
return serviceerror.NewInvalidArgument(fmt.Sprintf("%s %v is not a valid UTF-8 string", fieldName, strValue))
}
return nil
}
62 changes: 41 additions & 21 deletions service/frontend/workflow_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,10 @@ func (wh *WorkflowHandler) StartWorkflowExecution(ctx context.Context, request *
return nil, errWorkflowTypeTooLong
}

if !utf8.ValidString(request.WorkflowType.GetName()) {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowType %v is not a valid UTF-8 string", request.WorkflowType.GetName()))
}

if err := wh.validateTaskQueue(request.TaskQueue, namespaceName); err != nil {
return nil, err
}
Expand All @@ -385,14 +389,8 @@ func (wh *WorkflowHandler) StartWorkflowExecution(ctx context.Context, request *
return nil, err
}

if request.GetRequestId() == "" {
// For easy direct API use, we default the request ID here but expect all
// SDKs and other auto-retrying clients to set it
request.RequestId = uuid.New()
}

if len(request.GetRequestId()) > wh.config.MaxIDLengthLimit() {
return nil, errRequestIDTooLong
if err := validateRequestId(&request.RequestId, wh.config.MaxIDLengthLimit()); err != nil {
return nil, err
}

sa, err := wh.unaliasedSearchAttributesFrom(request.GetSearchAttributes(), namespaceName)
Expand Down Expand Up @@ -675,6 +673,9 @@ func (wh *WorkflowHandler) RespondWorkflowTaskCompleted(
if len(request.GetIdentity()) > wh.config.MaxIDLengthLimit() {
return nil, errIdentityTooLong
}
if err := common.ValidateUTF8String("Identity", request.GetIdentity()); err != nil {
return nil, err
}
yiminc marked this conversation as resolved.
Show resolved Hide resolved

if err := wh.validateVersioningInfo(
request.Namespace,
Expand Down Expand Up @@ -1686,13 +1687,17 @@ func (wh *WorkflowHandler) SignalWithStartWorkflowExecution(ctx context.Context,
return nil, errWorkflowTypeTooLong
}

if !utf8.ValidString(request.WorkflowType.GetName()) {
return nil, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowType %v is not a valid UTF-8 string", request.WorkflowType.GetName()))
}

namespaceName := namespace.Name(request.GetNamespace())
if err := wh.validateTaskQueue(request.TaskQueue, namespaceName); err != nil {
return nil, err
}

if len(request.GetRequestId()) > wh.config.MaxIDLengthLimit() {
return nil, errRequestIDTooLong
if err := validateRequestId(&request.RequestId, wh.config.MaxIDLengthLimit()); err != nil {
return nil, err
}

if err := wh.validateSignalWithStartWorkflowTimeouts(request); err != nil {
Expand Down Expand Up @@ -3838,6 +3843,9 @@ func (wh *WorkflowHandler) validateTaskQueue(t *taskqueuepb.TaskQueue, namespace
if len(t.GetName()) > wh.config.MaxIDLengthLimit() {
return errTaskQueueTooLong
}
if err := common.ValidateUTF8String("TaskQueue", t.GetName()); err != nil {
return err
}

if t.GetKind() == enumspb.TASK_QUEUE_KIND_UNSPECIFIED {
wh.logger.Warn("Unspecified task queue kind",
Expand Down Expand Up @@ -3881,6 +3889,9 @@ func (wh *WorkflowHandler) validateBuildIdCompatibilityUpdate(
errDeets = append(errDeets, fmt.Sprintf(" Worker build IDs to be no larger than %v characters",
wh.config.WorkerBuildIdSizeLimit()))
}
if err := common.ValidateUTF8String("BuildId", id); err != nil {
errDeets = append(errDeets, err.Error())
}
}

if req.GetNamespace() == "" {
Expand Down Expand Up @@ -4043,6 +4054,24 @@ func (wh *WorkflowHandler) validateRetryPolicy(namespaceName namespace.Name, ret
return common.ValidateRetryPolicy(retryPolicy)
}

func validateRequestId(requestID *string, lenLimit int) error {
if requestID == nil {
// should never happen, but just in case.
return serviceerror.NewInvalidArgument("RequestId is nil")
}
if *requestID == "" {
// For easy direct API use, we default the request ID here but expect all
// SDKs and other auto-retrying clients to set it
*requestID = uuid.New()
}

if len(*requestID) > lenLimit {
return errRequestIDTooLong
}

return common.ValidateUTF8String("RequestId", *requestID)
}

func (wh *WorkflowHandler) validateStartWorkflowTimeouts(
request *workflowservice.StartWorkflowExecutionRequest,
) error {
Expand Down Expand Up @@ -4101,7 +4130,7 @@ func (wh *WorkflowHandler) metricsScope(ctx context.Context) metrics.Handler {
func (wh *WorkflowHandler) validateNamespace(
namespace string,
) error {
if err := wh.validateUTF8String(namespace); err != nil {
if err := common.ValidateUTF8String("Namespace", namespace); err != nil {
return err
}
if len(namespace) > wh.config.MaxIDLengthLimit() {
Expand All @@ -4116,7 +4145,7 @@ func (wh *WorkflowHandler) validateWorkflowID(
if workflowID == "" {
return errWorkflowIDNotSet
}
if err := wh.validateUTF8String(workflowID); err != nil {
if err := common.ValidateUTF8String("WorkflowId", workflowID); err != nil {
return err
}
if len(workflowID) > wh.config.MaxIDLengthLimit() {
Expand All @@ -4125,15 +4154,6 @@ func (wh *WorkflowHandler) validateWorkflowID(
return nil
}

func (wh *WorkflowHandler) validateUTF8String(
str string,
) error {
if !utf8.ValidString(str) {
return serviceerror.NewInvalidArgument(fmt.Sprintf("%v is not a valid UTF-8 string", str))
}
return nil
}

func (wh *WorkflowHandler) canonicalizeScheduleSpec(schedule *schedpb.Schedule) error {
if schedule.Spec == nil {
schedule.Spec = &schedpb.ScheduleSpec{}
Expand Down
12 changes: 12 additions & 0 deletions service/frontend/workflow_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2535,6 +2535,18 @@ func TestContextNearDeadline(t *testing.T) {
assert.False(t, contextNearDeadline(ctx, time.Millisecond))
}

func TestValidateRequestId(t *testing.T) {
req := workflowservice.StartWorkflowExecutionRequest{RequestId: ""}
err := validateRequestId(&req.RequestId, 100)
assert.Nil(t, err)
assert.Len(t, req.RequestId, 36) // new UUID length

req.RequestId = "\x87\x01"
err = validateRequestId(&req.RequestId, 100)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not a valid UTF-8 string")
}

func (s *workflowHandlerSuite) Test_DeleteWorkflowExecution() {
config := s.newConfig()
wh := s.getWorkflowHandler(config)
Expand Down
14 changes: 13 additions & 1 deletion service/history/command_checker.go
yycptt marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ package history
import (
"fmt"
"strings"
"unicode/utf8"

"github.com/pborman/uuid"
commandpb "go.temporal.io/api/command/v1"
Expand Down Expand Up @@ -410,7 +411,10 @@ func (v *commandAttrValidator) validateTimerScheduleAttributes(
return failedCause, serviceerror.NewInvalidArgument("TimerId is not set on StartTimerCommand.")
}
if len(timerID) > v.maxIDLengthLimit {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("TimerID on StartTimerCommand exceeds length limit. TimerId=%s Length=%d Limit=%d", timerID, len(timerID), v.maxIDLengthLimit))
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("TimerId on StartTimerCommand exceeds length limit. TimerId=%s Length=%d Limit=%d", timerID, len(timerID), v.maxIDLengthLimit))
}
if err := common.ValidateUTF8String("TimerId", timerID); err != nil {
return failedCause, err
}
if err := timer.ValidateAndCapTimer(attributes.GetStartToFireTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("An invalid StartToFireTimeout is set on StartTimerCommand: %v. TimerId=%s", err, timerID))
Expand Down Expand Up @@ -748,6 +752,10 @@ func (v *commandAttrValidator) validateStartChildExecutionAttributes(
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowType on StartChildWorkflowExecutionCommand exceeds length limit. WorkflowId=%s WorkflowType=%s Length=%d Limit=%d Namespace=%s", wfID, wfType, len(wfType), v.maxIDLengthLimit, ns))
}

if !utf8.ValidString(wfType) {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("WorkflowType %v is not a valid UTF-8 string", wfType))
}
yiminc marked this conversation as resolved.
Show resolved Hide resolved

if err := timer.ValidateAndCapTimer(attributes.GetWorkflowExecutionTimeout()); err != nil {
return failedCause, serviceerror.NewInvalidArgument(fmt.Sprintf("Invalid WorkflowExecutionTimeout on StartChildWorkflowExecutionCommand: %v. WorkflowId=%s WorkflowType=%s Namespace=%s", err, wfID, wfType, ns))
}
Expand Down Expand Up @@ -813,6 +821,10 @@ func (v *commandAttrValidator) validateTaskQueue(
return taskQueue, serviceerror.NewInvalidArgument(fmt.Sprintf("task queue name exceeds length limit of %v", v.maxIDLengthLimit))
}

if err := common.ValidateUTF8String("TaskQueue", name); err != nil {
return taskQueue, err
}

if strings.HasPrefix(name, reservedTaskQueuePrefix) {
return taskQueue, serviceerror.NewInvalidArgument(fmt.Sprintf("task queue name cannot start with reserved prefix %v", reservedTaskQueuePrefix))
}
Expand Down
Loading