Skip to content

Commit

Permalink
add option to filter based on build ids
Browse files Browse the repository at this point in the history
Signed-off-by: Prajithp <prajithpalakkuda@gmail.com>
  • Loading branch information
Prajithp committed Sep 27, 2024
1 parent f5d7f78 commit 2cfaa31
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 8 deletions.
93 changes: 87 additions & 6 deletions pkg/scalers/temporal.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"strconv"
"strings"
"time"

"github.com/go-logr/logr"
Expand All @@ -22,6 +23,16 @@ const (
temporalDefaultTargetQueueLength = 5
temporalDefaultActivationQueueLength = 0
temporalDefaultNamespace = "default"
temporalDefaultSelectAllActive = true
temporalDefaultSelectUnversioned = true
)

var (
temporalDefauleQueueTypes = []sdk.TaskQueueType{
sdk.TaskQueueTypeActivity,
sdk.TaskQueueTypeWorkflow,
sdk.TaskQueueTypeNexus,
}
)

type temporalScaler struct {
Expand All @@ -38,21 +49,26 @@ type temporalMetadata struct {
triggerIndex int
targetQueueSize int64
queueName string
queueTypes []string
buildIDs []string
allActive bool
unversioned bool
apiKey *string
}

func NewTemporalScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
logger := InitializeLogger(config, "temporal_scaler")

metricType, err := GetMetricTargetType(config)
if err != nil {
return nil, fmt.Errorf("failed to get scaler metric type: %w", err)
}

meta, err := parseTemporalMetadata(config)
meta, err := parseTemporalMetadata(config, logger)
if err != nil {
return nil, fmt.Errorf("failed to parse Temporal metadata: %w", err)
}

logger := InitializeLogger(config, "temporal_scaler")

c, err := getTemporalClient(meta)
if err != nil {
return nil, fmt.Errorf("failed to create Temporal client connection: %w", err)
Expand Down Expand Up @@ -100,9 +116,22 @@ func (s *temporalScaler) GetMetricsAndActivity(ctx context.Context, metricName s
}

func (s *temporalScaler) getQueueSize(ctx context.Context) (int64, error) {
queueType := getQueueTypes(s.metadata.queueTypes)

var selection *sdk.TaskQueueVersionSelection
if s.metadata.allActive || s.metadata.unversioned || len(s.metadata.buildIDs) > 0 {
selection = &sdk.TaskQueueVersionSelection{
AllActive: s.metadata.allActive,
Unversioned: s.metadata.unversioned,
BuildIDs: s.metadata.buildIDs,
}
}

resp, err := s.tcl.DescribeTaskQueueEnhanced(ctx, sdk.DescribeTaskQueueEnhancedOptions{
TaskQueue: s.metadata.queueName,
ReportStats: true,
TaskQueue: s.metadata.queueName,
ReportStats: true,
Versions: selection,
TaskQueueTypes: queueType,
})
if err != nil {
return 0, fmt.Errorf("failed to get Temporal queue size: %w", err)
Expand All @@ -111,6 +140,27 @@ func (s *temporalScaler) getQueueSize(ctx context.Context) (int64, error) {
return getCombinedBacklogCount(resp), nil
}

func getQueueTypes(queueTypes []string) []sdk.TaskQueueType {
var taskQueueTypes []sdk.TaskQueueType
for _, t := range queueTypes {
var taskQueueType sdk.TaskQueueType
switch t {
case "workflow":
taskQueueType = sdk.TaskQueueTypeWorkflow
case "activity":
taskQueueType = sdk.TaskQueueTypeActivity
case "nexus":
taskQueueType = sdk.TaskQueueTypeNexus
}
taskQueueTypes = append(taskQueueTypes, taskQueueType)
}

if len(taskQueueTypes) == 0 {
return temporalDefauleQueueTypes
}
return taskQueueTypes
}

func getCombinedBacklogCount(description sdk.TaskQueueDescription) int64 {
var count int64
for _, versionInfo := range description.VersionsInfo {
Expand Down Expand Up @@ -138,7 +188,7 @@ func getTemporalClient(meta *temporalMetadata) (sdk.Client, error) {
})
}

func parseTemporalMetadata(config *scalersconfig.ScalerConfig) (*temporalMetadata, error) {
func parseTemporalMetadata(config *scalersconfig.ScalerConfig, logger logr.Logger) (*temporalMetadata, error) {
meta := &temporalMetadata{}
meta.activationLagThreshold = temporalDefaultActivationQueueLength
meta.targetQueueSize = temporalDefaultTargetQueueLength
Expand Down Expand Up @@ -176,6 +226,37 @@ func parseTemporalMetadata(config *scalersconfig.ScalerConfig) (*temporalMetadat
return nil, errors.New("no queueName provided")
}

// if buildIds is provided, it will be used to filter the queue and make sure
// selectAllActive and selectUnversioned are set to false to avoid considering
if buildIds, ok := config.TriggerMetadata["buildIds"]; ok && buildIds != "" {
meta.buildIDs = strings.Split(buildIds, ",")
}

if val, ok := config.TriggerMetadata["selectAllActive"]; ok && val != "" {
allActive, err := strconv.ParseBool(val)
if err != nil {
meta.allActive = temporalDefaultSelectAllActive
logger.Error(err, "Error parsing Temoral queue metadata selectAllActive, using default %n", temporalDefaultSelectAllActive)
} else {
meta.allActive = allActive
}
}

if val, ok := config.TriggerMetadata["selectUnversioned"]; ok && val != "" {
unversioned, err := strconv.ParseBool(val)
if err != nil {
meta.unversioned = temporalDefaultSelectUnversioned
logger.Error(err, "Error parsing Temoral queue metadata selectUnversioned, using default %n", temporalDefaultSelectUnversioned)
} else {
meta.unversioned = unversioned
}
}

// optional, valide queueTypes are workflow, activity, nexus
if val, ok := config.TriggerMetadata["queueTypes"]; ok && val != "" {
meta.queueTypes = strings.Split(val, ",")
}

meta.triggerIndex = config.TriggerIndex
return meta, nil
}
7 changes: 5 additions & 2 deletions pkg/scalers/temporal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"testing"

"github.com/go-logr/logr"
"github.com/kedacore/keda/v2/pkg/scalers/scalersconfig"
"github.com/stretchr/testify/assert"
)
Expand All @@ -12,6 +13,8 @@ var (
temporalEndpoint = "localhost:7233"
temporalNamespace = "v2"
temporalQueueName = "default"

logger = logr.Discard()
)

type parseTemporalMetadataTestData struct {
Expand Down Expand Up @@ -51,7 +54,7 @@ func TestTemporalGetMetricSpecForScaling(t *testing.T) {
meta, err := parseTemporalMetadata(&scalersconfig.ScalerConfig{
TriggerMetadata: testData.metadataTestData.metadata,
TriggerIndex: testData.triggerIndex,
})
}, logger)
if err != nil {
t.Fatal("Could not parse metadata:", err)
}
Expand Down Expand Up @@ -127,7 +130,7 @@ func TestParseTemporalMetadata(t *testing.T) {
config := &scalersconfig.ScalerConfig{
TriggerMetadata: c.metadata,
}
meta, err := parseTemporalMetadata(config)
meta, err := parseTemporalMetadata(config, logger)
if c.wantErr == true && err != nil {
t.Log("Expected error, got err")
} else {
Expand Down

0 comments on commit 2cfaa31

Please sign in to comment.