diff --git a/.dockerignore b/.dockerignore
index c37dbb09..b65d32c9 100644
--- a/.dockerignore
+++ b/.dockerignore
@@ -1,6 +1,7 @@
# More info: https://docs.docker.com/engine/reference/builder/#dockerignore-file
# Ignore build and test binaries.
bin/
+benchmarks/
charts/
components/
docs/
diff --git a/api/v1/model_types.go b/api/v1/model_types.go
index 5ac9f51c..466e550e 100644
--- a/api/v1/model_types.go
+++ b/api/v1/model_types.go
@@ -117,6 +117,11 @@ type ModelSpec struct {
// DEPRECATED.
// +kubebuilder:validation:Optional
Owner string `json:"owner"`
+
+ // LoadBalancing configuration for the model.
+ // If not specified, a default is used based on the engine and request.
+ // +kubebuilder:default={}
+ LoadBalancing LoadBalancing `json:"loadBalancing,omitempty"`
}
// +kubebuilder:validation:Enum=TextGeneration;TextEmbedding;SpeechToText
@@ -146,6 +151,44 @@ type Adapter struct {
URL string `json:"url"`
}
+type LoadBalancing struct {
+ // +kubebuilder:validation:Optional
+ // +kubebuilder:default=LeastLoad
+ Strategy LoadBalancingStrategy `json:"strategy,omitempty"`
+ // +kubebuilder:validation:Optional
+ // +kubebuilder:default={}
+ PrefixHash PrefixHash `json:"prefixHash,omitempty"`
+}
+
+// +kubebuilder:validation:Enum=LeastLoad;PrefixHash
+type LoadBalancingStrategy string
+
+const (
+ LeastLoadStrategy LoadBalancingStrategy = "LeastLoad"
+ PrefixHashStrategy LoadBalancingStrategy = "PrefixHash"
+)
+
+type PrefixHash struct {
+ // MeanLoadPercentage is the percentage that any given endpoint's load must not exceed
+ // over the mean load of all endpoints in the hash ring. Defaults to 125% which is
+ // a widely accepted value for the Consistent Hashing with Bounded Loads algorithm.
+ // +kubebuilder:default=125
+ // +kubebuilder:validation:Optional
+ // +kubebuilder:validation:Minimum=100
+ MeanLoadPercentage int `json:"meanLoadFactor,omitempty"`
+ // Replication is the number of replicas of each endpoint on the hash ring.
+ // Higher values will result in a more even distribution of load but will
+ // decrease lookup performance.
+ // +kubebuilder:validation:XValidation:rule="self == oldSelf", message="replication is immutable."
+ // +kubebuilder:default=20
+ // +kubebuilder:validation:Optional
+ Replication int `json:"replication,omitempty"`
+ // PrefixCharLength is the number of characters to count when building the prefix to hash.
+ // +kubebuilder:validation:Optional
+ // +kubebuilder:default=100
+ PrefixCharLength int `json:"prefixCharLength,omitempty"`
+}
+
// ModelStatus defines the observed state of Model.
type ModelStatus struct {
Replicas ModelStatusReplicas `json:"replicas,omitempty"`
diff --git a/api/v1/zz_generated.deepcopy.go b/api/v1/zz_generated.deepcopy.go
index 9e61934e..1c7c57f6 100644
--- a/api/v1/zz_generated.deepcopy.go
+++ b/api/v1/zz_generated.deepcopy.go
@@ -39,6 +39,22 @@ func (in *Adapter) DeepCopy() *Adapter {
return out
}
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *LoadBalancing) DeepCopyInto(out *LoadBalancing) {
+ *out = *in
+ out.PrefixHash = in.PrefixHash
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new LoadBalancing.
+func (in *LoadBalancing) DeepCopy() *LoadBalancing {
+ if in == nil {
+ return nil
+ }
+ out := new(LoadBalancing)
+ in.DeepCopyInto(out)
+ return out
+}
+
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Model) DeepCopyInto(out *Model) {
*out = *in
@@ -143,6 +159,7 @@ func (in *ModelSpec) DeepCopyInto(out *ModelSpec) {
*out = new(int64)
**out = **in
}
+ out.LoadBalancing = in.LoadBalancing
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ModelSpec.
@@ -205,3 +222,18 @@ func (in *ModelStatusReplicas) DeepCopy() *ModelStatusReplicas {
in.DeepCopyInto(out)
return out
}
+
+// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
+func (in *PrefixHash) DeepCopyInto(out *PrefixHash) {
+ *out = *in
+}
+
+// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PrefixHash.
+func (in *PrefixHash) DeepCopy() *PrefixHash {
+ if in == nil {
+ return nil
+ }
+ out := new(PrefixHash)
+ in.DeepCopyInto(out)
+ return out
+}
diff --git a/benchmarks/chat/.dockerignore b/benchmarks/chat/.dockerignore
new file mode 100644
index 00000000..f899baac
--- /dev/null
+++ b/benchmarks/chat/.dockerignore
@@ -0,0 +1 @@
+data/ShareGPT_V3_unfiltered_cleaned_split.json
\ No newline at end of file
diff --git a/benchmarks/chat/.gitignore b/benchmarks/chat/.gitignore
new file mode 100644
index 00000000..503304d4
--- /dev/null
+++ b/benchmarks/chat/.gitignore
@@ -0,0 +1 @@
+data/*.json
\ No newline at end of file
diff --git a/benchmarks/chat/Dockerfile b/benchmarks/chat/Dockerfile
new file mode 100644
index 00000000..dda6926d
--- /dev/null
+++ b/benchmarks/chat/Dockerfile
@@ -0,0 +1,14 @@
+FROM ubuntu:20.04
+
+RUN apt-get update && apt-get install -y build-essential make python3 wget vim
+
+# Install k6 binary.
+ENV K6_VERSION=v0.55.0
+RUN wget https://github.com/grafana/k6/releases/download/${K6_VERSION}/k6-${K6_VERSION}-linux-amd64.tar.gz && tar -zxvf k6-${K6_VERSION}-linux-amd64.tar.gz && mv k6-${K6_VERSION}-linux-amd64/k6 /usr/local/bin && rm k6-${K6_VERSION}-linux-amd64.tar.gz
+
+WORKDIR /work
+
+COPY ./k6.js .
+COPY ./Makefile .
+COPY ./data ./data
+COPY ./scenarios ./scenarios
\ No newline at end of file
diff --git a/benchmarks/chat/Makefile b/benchmarks/chat/Makefile
new file mode 100644
index 00000000..b819802b
--- /dev/null
+++ b/benchmarks/chat/Makefile
@@ -0,0 +1,10 @@
+data/ShareGPT_V3_unfiltered_cleaned_split.json:
+ cd data && wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
+
+.PHONY: data
+data: data/ShareGPT_V3_unfiltered_cleaned_split.json
+ cd data && python prepare-message-threads.py
+
+run:
+ ls scenarios/${SCENARIO}
+ CONFIG_DIR=scenarios/${SCENARIO} DATA_DIR=data MODEL_ADDR=kubeai/openai k6 run ./k6.js
\ No newline at end of file
diff --git a/benchmarks/chat/data/prepare-message-threads.py b/benchmarks/chat/data/prepare-message-threads.py
new file mode 100644
index 00000000..7e8b74bf
--- /dev/null
+++ b/benchmarks/chat/data/prepare-message-threads.py
@@ -0,0 +1,43 @@
+import json
+
+
+def main():
+ with open("./ShareGPT_V3_unfiltered_cleaned_split.json", "r") as f:
+ data = json.load(f)
+
+ # Select a subnet the first conversations that start with a human.
+ max = 2000
+ output = []
+ for entry in data:
+ conv = entry.get("conversations")
+ if conv and conv[0]["from"] == "human" and len(conv[0]["value"]) != 0:
+ # Filter the conversation to only include messages from a human using a for loop.
+ # entry["userMessages"] = [c["value"] for c in conv if c["from"] == "human"]
+ totalContentLength = 0
+ userMessages = []
+ for c in conv:
+ if c["from"] == "human":
+ content = c["value"]
+ userMessages.append(content)
+ totalContentLength += len(content)
+
+ if totalContentLength < 2500:
+ continue
+
+ if len(userMessages) < 5:
+ continue
+
+ # Delete the original conversation
+ entry["userMessages"] = userMessages
+ del entry["conversations"]
+ output.append(entry)
+
+ if len(output) >= max:
+ break
+
+ with open("./message-threads.json", "w") as f:
+ data = json.dump(output, f, indent=4)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/benchmarks/chat/k6.js b/benchmarks/chat/k6.js
new file mode 100644
index 00000000..5428b40c
--- /dev/null
+++ b/benchmarks/chat/k6.js
@@ -0,0 +1,71 @@
+import { check } from 'k6';
+import { scenario } from 'k6/execution';
+import http from 'k6/http';
+import { Trend, Counter } from 'k6/metrics';
+
+const model_addr = __ENV.MODEL_ADDR;
+const config_dir = __ENV.CONFIG_DIR;
+const data_dir = __ENV.DATA_DIR;
+
+const timePerToken = new Trend('time_per_token', true);
+const tokens = new Counter('tokens');
+const new_tokens = new Counter('new_tokens');
+const input_tokens = new Counter('input_tokens');
+
+const k6Options = JSON.parse(open(`${config_dir}/k6.json`));
+const baseRequest = JSON.parse(open(`${config_dir}/base-request.json`));
+const messageThreads = JSON.parse(open(`${data_dir}/message-threads.json`))
+
+export const options = k6Options;
+
+export default function run() {
+ const headers = { 'Content-Type': 'application/json' };
+ const msgThread = messageThreads[scenario.iterationInTest % messageThreads.length];
+ var payload = JSON.parse(JSON.stringify(baseRequest));
+
+ // console.log(`Message thread: ${JSON.stringify(msgThread)}`);
+
+ // Iterate over all the messages in the thread, appending the completions to the same payload.
+ for (let i = 0; i < msgThread["userMessages"].length; i++) {
+ payload.messages.push({
+ "role": "user",
+ "content": msgThread["userMessages"][i]
+ });
+ //console.log(`Payload: ${JSON.stringify(payload)}`);
+
+ const res = http.post(`http://${model_addr}/v1/chat/completions`, JSON.stringify(payload), {
+ headers,
+ });
+ if (res.status >= 400 && res.status < 500) {
+ return;
+ }
+
+ check(res, {
+ 'Post status is 200': (res) => res.status === 200,
+ });
+ const duration = res.timings.duration;
+
+ if (res.status === 200) {
+ // console.log(`Status: ${res.status}`);
+ const body = res.json();
+
+ const completion_tokens = body.usage.completion_tokens;
+ const prompt_tokens = body.usage.prompt_tokens;
+ const latency_ms_per_token = duration / completion_tokens;
+
+ new_tokens.add(completion_tokens);
+ input_tokens.add(prompt_tokens);
+ timePerToken.add(latency_ms_per_token);
+ tokens.add(completion_tokens + prompt_tokens);
+
+ const msg0 = body.choices[0].message;
+ payload.messages.push({
+ "role": msg0.role,
+ "content": msg0.content
+ });
+ } else {
+ console.log(`Error Status: ${res.status}`);
+ console.log(`Response: ${res.body}`);
+ }
+ }
+}
diff --git a/benchmarks/chat/scenarios/least-load-vs-prefix-hash/README.md b/benchmarks/chat/scenarios/least-load-vs-prefix-hash/README.md
new file mode 100644
index 00000000..fe068da5
--- /dev/null
+++ b/benchmarks/chat/scenarios/least-load-vs-prefix-hash/README.md
@@ -0,0 +1,131 @@
+# Results
+
+Under specific conditions:
+
+* Restricted GPU memory
+* Low `max_tokens` to be generated
+* Chat threads with decently long user messages
+
+Prefix hashing was shown to have `34%` decrease in average time per token.
+
+`712.11ms (LeastLoad) --> 469.34ms (PrefixHash)`
+
+## Steps taken
+
+```bash
+gcloud container clusters create-auto cluster-1 \
+ --location=us-central1
+skaffold run -f ./skaffold.yaml --tail --port-forward --profile kubeai-only-gke --default-repo us-central1-docker.pkg.dev/substratus-dev
+
+cd ./benchmarks/chat
+make data
+export IMG=us-central1-docker.pkg.dev/substratus-dev/default/kubeai-benchmark-chat:v0.0.2
+docker build -t $IMG . && docker push $IMG
+
+kubectl apply -f ./scenarios/least-load-vs-prefix-hash/model.yaml
+kubectl apply -f ./scenarios/least-load-vs-prefix-hash/pod.yaml
+
+# Run 2x (to ensure both cases start with a preloaded cache)
+kubectl exec -it chat-benchmark -- SCENARIO=least-load-vs-prefix-hash make run
+
+kubectl patch model llama-3.1-8b-instruct-fp8-l4 --type='merge' -p '{"spec": {"loadBalancing": {"strategy": "PrefixHash"}}}'
+kubectl exec -it chat-benchmark -- SCENARIO=least-load-vs-prefix-hash make run
+```
+
+## Next Steps
+
+* Rerun with increased replicas (i.e. 10 instead of 2)
+
+## Benchmark Output
+
+### LeastLoad
+
+```
+ /\ Grafana /‾‾/
+ /\ / \ |\ __ / /
+ / \/ \ | |/ / / ‾‾\
+ / \ | ( | (‾) |
+ / __________ \ |_|\_\ \_____/
+
+ execution: local
+ script: ./k6.js
+ output: -
+
+ scenarios: (100.00%) 1 scenario, 80 max VUs, 10m30s max duration (incl. graceful stop):
+ * chat: 1000 iterations shared among 80 VUs (maxDuration: 10m0s, gracefulStop: 30s)
+
+
+ ✓ Post status is 200
+
+ checks.........................: 100.00% 7341 out of 7341
+ data_received..................: 4.7 MB 7.9 kB/s
+ data_sent......................: 25 MB 42 kB/s
+ http_req_blocked...............: avg=161.4µs min=2.83µs med=5.8µs max=16.67ms p(90)=8.06µs p(95)=10.19µs
+ http_req_connecting............: avg=55.73µs min=0s med=0s max=8.41ms p(90)=0s p(95)=0s
+ http_req_duration..............: avg=6.31s min=165.25ms med=6.66s max=11.65s p(90)=8.55s p(95)=9.07s
+ { expected_response:true }...: avg=6.31s min=165.25ms med=6.66s max=11.65s p(90)=8.55s p(95)=9.07s
+ ✓ http_req_failed................: 0.00% 0 out of 7341
+ http_req_receiving.............: avg=84.64µs min=29.4µs med=74.05µs max=732.69µs p(90)=129.94µs p(95)=154.19µs
+ http_req_sending...............: avg=68µs min=12.1µs med=32.3µs max=1.38ms p(90)=144.04µs p(95)=173.19µs
+ http_req_tls_handshaking.......: avg=0s min=0s med=0s max=0s p(90)=0s p(95)=0s
+ http_req_waiting...............: avg=6.31s min=165.04ms med=6.66s max=11.65s p(90)=8.55s p(95)=9.07s
+ http_reqs......................: 7341 12.422953/s
+ input_tokens...................: 4990223 8444.803735/s
+ iteration_duration.............: avg=46.39s min=6.73s med=41.26s max=4m13s p(90)=1m8s p(95)=1m28s
+ iterations.....................: 1000 1.69227/s
+ new_tokens.....................: 68062 115.179268/s
+ time_per_token.................: avg=712.11ms min=39.56ms med=703.28ms max=2.69s p(90)=928.58ms p(95)=1.09s
+ tokens.........................: 5058285 8559.983003/s
+ vus............................: 1 min=0 max=80
+ vus_max........................: 80 min=21 max=80
+
+
+running (09m50.9s), 00/80 VUs, 1000 complete and 0 interrupted iterations
+chat ✓ [======================================] 80 VUs 09m50.9s/10m0s 1000/1000 shared iters
+```
+
+### PrefixHash
+
+```
+ /\ Grafana /‾‾/
+ /\ / \ |\ __ / /
+ / \/ \ | |/ / / ‾‾\
+ / \ | ( | (‾) |
+ / __________ \ |_|\_\ \_____/
+
+ execution: local
+ script: ./k6.js
+ output: -
+
+ scenarios: (100.00%) 1 scenario, 80 max VUs, 10m30s max duration (incl. graceful stop):
+ * chat: 1000 iterations shared among 80 VUs (maxDuration: 10m0s, gracefulStop: 30s)
+
+
+ ✓ Post status is 200
+
+ checks.........................: 100.00% 7341 out of 7341
+ data_received..................: 4.7 MB 12 kB/s
+ data_sent......................: 25 MB 65 kB/s
+ http_req_blocked...............: avg=268.24µs min=2.94µs med=5.76µs max=28.19ms p(90)=8.17µs p(95)=10.41µs
+ http_req_connecting............: avg=136.33µs min=0s med=0s max=17.7ms p(90)=0s p(95)=0s
+ http_req_duration..............: avg=4.08s min=151.9ms med=2.45s max=12.32s p(90)=9.63s p(95)=10.26s
+ { expected_response:true }...: avg=4.08s min=151.9ms med=2.45s max=12.32s p(90)=9.63s p(95)=10.26s
+ ✓ http_req_failed................: 0.00% 0 out of 7341
+ http_req_receiving.............: avg=81.81µs min=28.68µs med=72.08µs max=786.09µs p(90)=125.04µs p(95)=148.6µs
+ http_req_sending...............: avg=63.61µs min=11.85µs med=31.65µs max=1.59ms p(90)=136.85µs p(95)=161.88µs
+ http_req_tls_handshaking.......: avg=0s min=0s med=0s max=0s p(90)=0s p(95)=0s
+ http_req_waiting...............: avg=4.08s min=151.81ms med=2.45s max=12.32s p(90)=9.63s p(95)=10.26s
+ http_reqs......................: 7341 19.230625/s
+ input_tokens...................: 4990576 13073.409349/s
+ iteration_duration.............: avg=29.98s min=2.37s med=20.29s max=2m53s p(90)=1m1s p(95)=1m18s
+ iterations.....................: 1000 2.619619/s
+ new_tokens.....................: 68218 178.705191/s
+ time_per_token.................: avg=469.34ms min=44.2ms med=257.72ms max=3.86s p(90)=1s p(95)=1.1s
+ tokens.........................: 5058794 13252.11454/s
+ vus............................: 3 min=0 max=80
+ vus_max........................: 80 min=19 max=80
+
+
+running (06m21.7s), 00/80 VUs, 1000 complete and 0 interrupted iterations
+chat ✓ [======================================] 80 VUs 06m21.7s/10m0s 1000/1000 shared iters
+```
diff --git a/benchmarks/chat/scenarios/least-load-vs-prefix-hash/base-request.json b/benchmarks/chat/scenarios/least-load-vs-prefix-hash/base-request.json
new file mode 100644
index 00000000..68bbd9eb
--- /dev/null
+++ b/benchmarks/chat/scenarios/least-load-vs-prefix-hash/base-request.json
@@ -0,0 +1,6 @@
+{
+ "model": "llama-3.1-8b-instruct-fp8-l4",
+ "max_tokens": 10,
+ "temperature": 0,
+ "messages": []
+}
\ No newline at end of file
diff --git a/benchmarks/chat/scenarios/least-load-vs-prefix-hash/k6.json b/benchmarks/chat/scenarios/least-load-vs-prefix-hash/k6.json
new file mode 100644
index 00000000..9b82968f
--- /dev/null
+++ b/benchmarks/chat/scenarios/least-load-vs-prefix-hash/k6.json
@@ -0,0 +1,15 @@
+{
+ "thresholds": {
+ "http_req_failed": [
+ "rate==0"
+ ]
+ },
+ "scenarios": {
+ "chat": {
+ "executor": "shared-iterations",
+ "vus": 80,
+ "iterations": 1000,
+ "maxDuration": "600s"
+ }
+ }
+}
\ No newline at end of file
diff --git a/benchmarks/chat/scenarios/least-load-vs-prefix-hash/model.yaml b/benchmarks/chat/scenarios/least-load-vs-prefix-hash/model.yaml
new file mode 100644
index 00000000..8be83a82
--- /dev/null
+++ b/benchmarks/chat/scenarios/least-load-vs-prefix-hash/model.yaml
@@ -0,0 +1,17 @@
+apiVersion: kubeai.org/v1
+kind: Model
+metadata:
+ name: llama-3.1-8b-instruct-fp8-l4
+spec:
+ features: [TextGeneration]
+ url: hf://neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8
+ engine: VLLM
+ args:
+ - --enable-prefix-caching
+ - --max-model-len=16384
+ - --max-num-batched-token=16384
+ - --gpu-memory-utilization=0.6
+ - --disable-log-requests
+ resourceProfile: nvidia-gpu-l4:1
+ minReplicas: 2
+ maxReplicas: 2
diff --git a/benchmarks/chat/scenarios/least-load-vs-prefix-hash/pod.yaml b/benchmarks/chat/scenarios/least-load-vs-prefix-hash/pod.yaml
new file mode 100644
index 00000000..872c9706
--- /dev/null
+++ b/benchmarks/chat/scenarios/least-load-vs-prefix-hash/pod.yaml
@@ -0,0 +1,19 @@
+apiVersion: v1
+kind: Pod
+metadata:
+ name: chat-benchmark
+spec:
+ restartPolicy: Never
+ containers:
+ - name: bench
+ image: us-central1-docker.pkg.dev/substratus-dev/default/kubeai-benchmark-chat:v0.0.2
+ command: ["sleep", "infinity"]
+ resources:
+ requests:
+ cpu: 6
+ ephemeral-storage: 10Gi
+ memory: 24Gi
+ limits:
+ cpu: 6
+ ephemeral-storage: 10Gi
+ memory: 24Gi
\ No newline at end of file
diff --git a/charts/kubeai/templates/crds/kubeai.org_models.yaml b/charts/kubeai/templates/crds/kubeai.org_models.yaml
index 24803018..b991005b 100644
--- a/charts/kubeai/templates/crds/kubeai.org_models.yaml
+++ b/charts/kubeai/templates/crds/kubeai.org_models.yaml
@@ -106,6 +106,46 @@ spec:
Image to be used for the server process.
Will be set from ResourceProfile + Engine if not specified.
type: string
+ loadBalancing:
+ default: {}
+ description: |-
+ LoadBalancing configuration for the model.
+ If not specified, a default is used based on the engine and request.
+ properties:
+ prefixHash:
+ default: {}
+ properties:
+ meanLoadFactor:
+ default: 125
+ description: |-
+ MeanLoadPercentage is the percentage that any given endpoint's load must not exceed
+ over the mean load of all endpoints in the hash ring. Defaults to 125% which is
+ a widely accepted value for the Consistent Hashing with Bounded Loads algorithm.
+ minimum: 100
+ type: integer
+ prefixCharLength:
+ default: 100
+ description: PrefixCharLength is the number of characters
+ to count when building the prefix to hash.
+ type: integer
+ replication:
+ default: 20
+ description: |-
+ Replication is the number of replicas of each endpoint on the hash ring.
+ Higher values will result in a more even distribution of load but will
+ decrease lookup performance.
+ type: integer
+ x-kubernetes-validations:
+ - message: replication is immutable.
+ rule: self == oldSelf
+ type: object
+ strategy:
+ default: LeastLoad
+ enum:
+ - LeastLoad
+ - PrefixHash
+ type: string
+ type: object
maxReplicas:
description: |-
MaxReplicas is the maximum number of Pod replicas that the model can scale up to.
diff --git a/docs/README.md b/docs/README.md
index 3d00c17f..53713dfe 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -5,7 +5,8 @@ The easiest way to serve ML models in production. Supports LLMs, embeddings, and
✅️ OpenAI API Compatibility: Drop-in replacement for OpenAI
⚖️ Autoscaling: Scale from zero, autoscale based on load
🧠 Serve text generation models with vLLM or Ollama
-🔌 Lora Adapter aware routing
+🔌 Dynamic LoRA adapter loading
+⛕ Inference-optimized load balancing
💬 Speech to Text API with FasterWhisper
🧮 Embedding/Vector API with Infinity
🚀 Multi-platform: CPU, GPU, TPU
diff --git a/docs/concepts/load-balancing.md b/docs/concepts/load-balancing.md
new file mode 100644
index 00000000..665b0542
--- /dev/null
+++ b/docs/concepts/load-balancing.md
@@ -0,0 +1,26 @@
+# Load Balancing
+
+To optimize inference performance and resource utilization, KubeAI supports load balancing strategies specifically tailored for model inference servers such as vLLM. This document explains two primary load balancing strategies available in KubeAI: Least Load and Prefix Hash.
+
+## Least Load
+
+The Least Load strategy distributes inference requests to the model replica that has the least number of in-flight requests. This strategy aims to balance the inference workload evenly across available replicas, reducing the risk of overloading any single server.
+
+## Prefix Hash
+
+The Prefix Hash strategy leverages the Consistent Hashing with With Bounded Loads (CHWBL) algorithm to optimize the performance of engines such as vLLM that support prefix caching. This strategy increases the likelihood of KV cache hits for common prefixes. See vLLM prefix hashing docs for more info.
+
+With this strategy, KubeAI hashes incoming requests based on their prefixes (in addition to a requested LoRA adapter name - if present). Requests with the same hash value are routed to the same replica, except when that replica's in-flight requests exceed the overall average by a configurable percentage.
+
+This strategy has the most benefit for use cases such as chat completion. This is because the entire chat thread is sent in each successive chat requests.
+
+KubeAI supports this strategy for the following endpoints:
+
+```
+/openai/v1/completions
+/openai/v1/chat/completions
+```
+
+## Next
+
+See the [Kubernetes API docs](../reference/kubernetes-api.md) to view how to configure Model load balancing.
\ No newline at end of file
diff --git a/docs/reference/kubernetes-api.md b/docs/reference/kubernetes-api.md
index 4761242a..b049a9f1 100644
--- a/docs/reference/kubernetes-api.md
+++ b/docs/reference/kubernetes-api.md
@@ -30,6 +30,41 @@ _Appears in:_
| `url` _string_ | | | |
+#### LoadBalancing
+
+
+
+
+
+
+
+_Appears in:_
+- [ModelSpec](#modelspec)
+
+| Field | Description | Default | Validation |
+| --- | --- | --- | --- |
+| `strategy` _[LoadBalancingStrategy](#loadbalancingstrategy)_ | | LeastLoad | Enum: [LeastLoad PrefixHash]
Optional: \{\}
|
+| `prefixHash` _[PrefixHash](#prefixhash)_ | | \{ \} | Optional: \{\}
|
+
+
+#### LoadBalancingStrategy
+
+_Underlying type:_ _string_
+
+
+
+_Validation:_
+- Enum: [LeastLoad PrefixHash]
+
+_Appears in:_
+- [LoadBalancing](#loadbalancing)
+
+| Field | Description |
+| --- | --- |
+| `LeastLoad` | |
+| `PrefixHash` | |
+
+
#### Model
@@ -92,6 +127,7 @@ _Appears in:_
| `targetRequests` _integer_ | TargetRequests is average number of active requests that the autoscaler
will try to maintain on model server Pods. | 100 | Minimum: 1
|
| `scaleDownDelaySeconds` _integer_ | ScaleDownDelay is the minimum time before a deployment is scaled down after
the autoscaling algorithm determines that it should be scaled down. | 30 | |
| `owner` _string_ | Owner of the model. Used solely to populate the owner field in the
OpenAI /v1/models endpoint.
DEPRECATED. | | Optional: \{\}
|
+| `loadBalancing` _[LoadBalancing](#loadbalancing)_ | LoadBalancing configuration for the model.
If not specified, a default is used based on the engine and request. | \{ \} | |
#### ModelStatus
@@ -144,3 +180,21 @@ _Appears in:_
| `ready` _integer_ | | | |
+#### PrefixHash
+
+
+
+
+
+
+
+_Appears in:_
+- [LoadBalancing](#loadbalancing)
+
+| Field | Description | Default | Validation |
+| --- | --- | --- | --- |
+| `meanLoadFactor` _integer_ | MeanLoadPercentage is the percentage that any given endpoint's load must not exceed
over the mean load of all endpoints in the hash ring. Defaults to 125% which is
a widely accepted value for the Consistent Hashing with Bounded Loads algorithm. | 125 | Minimum: 100
Optional: \{\}
|
+| `replication` _integer_ | Replication is the number of replicas of each endpoint on the hash ring.
Higher values will result in a more even distribution of load but will
decrease lookup performance. | 20 | Optional: \{\}
|
+| `prefixCharLength` _integer_ | PrefixCharLength is the number of characters to count when building the prefix to hash. | 100 | Optional: \{\}
|
+
+
diff --git a/go.mod b/go.mod
index 8aaf47dd..d6bcfd8a 100644
--- a/go.mod
+++ b/go.mod
@@ -3,6 +3,7 @@ module github.com/substratusai/kubeai
go 1.22.0
require (
+ github.com/cespare/xxhash v1.1.0
github.com/go-playground/validator/v10 v10.22.0
github.com/google/uuid v1.6.0
github.com/onsi/ginkgo/v2 v2.17.1
diff --git a/go.sum b/go.sum
index d516a700..ce8876a6 100644
--- a/go.sum
+++ b/go.sum
@@ -36,6 +36,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/IBM/sarama v1.43.3 h1:Yj6L2IaNvb2mRBop39N7mmJAHBVY3dTPncr3qGVkxPA=
github.com/IBM/sarama v1.43.3/go.mod h1:FVIRaLrhK3Cla/9FfRF5X9Zua2KpS3SYIXxhac1H+FQ=
+github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE=
+github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU=
@@ -73,6 +75,8 @@ github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
+github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
+github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
@@ -297,6 +301,8 @@ github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5X
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
+github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ=
+github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
diff --git a/hack/dev-models/vllm-chat.yaml b/hack/dev-models/vllm-chat.yaml
new file mode 100644
index 00000000..d2a25edf
--- /dev/null
+++ b/hack/dev-models/vllm-chat.yaml
@@ -0,0 +1,18 @@
+apiVersion: kubeai.org/v1
+kind: Model
+metadata:
+ name: tinyllama-chat
+spec:
+ features: [TextGeneration]
+ owner: meta-llama
+ url: hf://TinyLlama/TinyLlama-1.1B-Chat-v1.0
+ #adapters:
+ #- name: foo
+ # url: hf://jashing/tinyllama-colorist-lora
+ #- name: bar
+ # url: s3://substratus-ai-test-0/adapters/jashing/tinyllama-colorist-lora
+ #- name: baz
+ # url: gs://substratus-ai-test-0/adapters/jashing/tinyllama-colorist-lora
+ engine: VLLM
+ resourceProfile: nvidia-gpu-l4:1
+ minReplicas: 1
\ No newline at end of file
diff --git a/internal/apiutils/requests.go b/internal/apiutils/model.go
similarity index 100%
rename from internal/apiutils/requests.go
rename to internal/apiutils/model.go
diff --git a/internal/apiutils/requests_test.go b/internal/apiutils/model_test.go
similarity index 100%
rename from internal/apiutils/requests_test.go
rename to internal/apiutils/model_test.go
diff --git a/internal/apiutils/request.go b/internal/apiutils/request.go
new file mode 100644
index 00000000..9ab5447f
--- /dev/null
+++ b/internal/apiutils/request.go
@@ -0,0 +1,294 @@
+package apiutils
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "io"
+ "mime"
+ "mime/multipart"
+ "net/http"
+
+ "context"
+
+ "github.com/google/uuid"
+ v1 "github.com/substratusai/kubeai/api/v1"
+)
+
+var (
+ ErrBadRequest = fmt.Errorf("bad request")
+ ErrModelNotFound = fmt.Errorf("model not found")
+)
+
+type Request struct {
+ Body []byte
+ bodyPayload map[string]interface{}
+
+ Selectors []string
+
+ ID string
+
+ // RequestedModel is the model name requested by the client.
+ // This might contain the adapter name as well.
+ RequestedModel string
+
+ Model string
+ Adapter string
+
+ LoadBalancing v1.LoadBalancing
+
+ Prefix string
+
+ ContentLength int64
+}
+
+type ModelClient interface {
+ LookupModel(ctx context.Context, model, adapter string, selectors []string) (*v1.Model, error)
+}
+
+func ParseRequest(ctx context.Context, client ModelClient, body io.Reader, path string, headers http.Header) (*Request, error) {
+ r := &Request{
+ ID: uuid.New().String(),
+ }
+
+ r.Selectors = headers.Values("X-Label-Selector")
+
+ // Parse media type (with params - which are used for multipart form data)
+ var (
+ contentType = headers.Get("Content-Type")
+ mediaType string
+ mediaParams map[string]string
+ )
+ if contentType == "" {
+ mediaType = "application/json"
+ mediaParams = map[string]string{}
+ } else {
+ var err error
+ mediaType, mediaParams, err = mime.ParseMediaType(contentType)
+ if err != nil {
+ return nil, fmt.Errorf("%w: parse media type: %w", ErrBadRequest, err)
+ }
+ }
+
+ switch mediaType {
+ // Multipart form data is used for endpoints that accept file uploads:
+ case "multipart/form-data":
+ if err := r.readyMultiPartBody(body, mediaParams); err != nil {
+ return nil, fmt.Errorf("%w: reading multipart form data: %w", ErrBadRequest, err)
+ }
+
+ // Assume "application/json":
+ default:
+ if err := r.readJSONBody(body); err != nil {
+ return nil, fmt.Errorf("%w: reading model from body: %w", ErrBadRequest, err)
+ }
+ }
+
+ if err := r.lookupModel(ctx, client, path); err != nil {
+ return nil, err
+ }
+
+ return r, nil
+}
+
+func (r *Request) readyMultiPartBody(body io.Reader, mediaParams map[string]string) error {
+ boundary := mediaParams["boundary"]
+ if boundary == "" {
+ return fmt.Errorf("no boundary specified in multipart form data")
+ }
+
+ var buf bytes.Buffer
+ mw := multipart.NewWriter(&buf)
+ // Keep the same boundary as the initial request (probably not necessary)
+ mw.SetBoundary(boundary)
+
+ // Iterate over the parts of the multipart form data:
+ // - If the part is named "model", save the value to the proxy request.
+ // - Otherwise, just copy the part to the new multipart writer.
+ mr := multipart.NewReader(body, boundary)
+ for {
+ p, err := mr.NextPart()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return fmt.Errorf("interating over multipart form: %w", err)
+ }
+
+ if p.FormName() == "model" {
+ value, err := io.ReadAll(p)
+ if err != nil {
+ return fmt.Errorf("reading multipart form value: %w", err)
+ }
+ r.Model, r.Adapter = SplitModelAdapter(string(value))
+ r.RequestedModel = string(value)
+ // WORKAROUND ALERT:
+ // Omit the "model" field from the proxy request to avoid FasterWhisper validation issues:
+ // See https://github.com/fedirz/faster-whisper-server/issues/71
+ continue
+ }
+
+ // Copy the part to the new multipart writer.
+ pp, err := mw.CreatePart(p.Header)
+ if err != nil {
+ return fmt.Errorf("creating part: %w", err)
+ }
+ if _, err := io.Copy(pp, p); err != nil {
+ return fmt.Errorf("copying part: %w", err)
+ }
+ }
+
+ // Fully write to buffer.
+ if err := mw.Close(); err != nil {
+ return fmt.Errorf("closing multipart writer: %w", err)
+ }
+ r.Body = buf.Bytes()
+ // Set a new content length based on the new body - which had the "model" field removed.
+ r.ContentLength = int64(len(r.Body))
+
+ return nil
+}
+
+func (r *Request) readJSONBody(body io.Reader) error {
+ var payload map[string]interface{}
+ if err := json.NewDecoder(body).Decode(&payload); err != nil {
+ return fmt.Errorf("decoding: %w", err)
+ }
+
+ modelInf, ok := payload["model"]
+ if !ok {
+ return fmt.Errorf("missing 'model' field")
+ }
+ r.bodyPayload = payload
+
+ modelStr, ok := modelInf.(string)
+ if !ok {
+ return fmt.Errorf("field 'model' should be a string")
+ }
+
+ r.RequestedModel = modelStr
+ r.Model, r.Adapter = SplitModelAdapter(modelStr)
+
+ if r.Adapter != "" {
+ // vLLM expects the adapter to be in the model field.
+ payload["model"] = r.Adapter
+ }
+
+ rewritten, err := json.Marshal(payload)
+ if err != nil {
+ return fmt.Errorf("remarshalling: %w", err)
+ }
+ r.Body = rewritten
+ r.ContentLength = int64(len(r.Body))
+
+ return nil
+}
+
+func (r *Request) lookupModel(ctx context.Context, client ModelClient, path string) error {
+ model, err := client.LookupModel(ctx, r.Model, r.Adapter, r.Selectors)
+ if err != nil {
+ return fmt.Errorf("lookup model: %w", err)
+ }
+ if model == nil {
+ return fmt.Errorf("%w: %q", ErrModelNotFound, r.RequestedModel)
+ }
+
+ r.LoadBalancing = model.Spec.LoadBalancing
+
+ if r.LoadBalancing.Strategy == v1.PrefixHashStrategy && r.bodyPayload != nil {
+ defer func() {
+ r.bodyPayload = nil
+ }()
+ switch path {
+ case "/v1/completions":
+ prefix, err := getPrefixForCompletionRequest(r.bodyPayload, r.LoadBalancing.PrefixHash.PrefixCharLength)
+ if err != nil {
+ return fmt.Errorf("getting prefix for completion request: %w", err)
+ }
+ r.Prefix = prefix
+ case "/v1/chat/completions":
+ prefix, err := getPrefixForChatCompletionRequest(r.bodyPayload, r.LoadBalancing.PrefixHash.PrefixCharLength)
+ if err != nil {
+ return fmt.Errorf("getting prefix for chat completion request: %w", err)
+ }
+ r.Prefix = prefix
+ }
+ }
+
+ return nil
+}
+
+func getPrefixForCompletionRequest(body map[string]interface{}, n int) (string, error) {
+ // Example request body:
+ // {
+ // "model": "gpt-3.5-turbo-instruct",
+ // "prompt": "Say this is a test",
+ // "max_tokens": 7,
+ // "temperature": 0
+ // }
+ promptInf, ok := body["prompt"]
+ if !ok {
+ return "", fmt.Errorf("missing '.prompt' field")
+ }
+ prompt, ok := promptInf.(string)
+ if !ok {
+ return "", fmt.Errorf("'.prompt' field should be a string")
+ }
+ return firstNChars(prompt, n), nil
+}
+
+func getPrefixForChatCompletionRequest(body map[string]interface{}, n int) (string, error) {
+ // Example request body:
+ // {
+ // "model": "gpt-4o",
+ // "messages": [
+ // {
+ // "role": "system",
+ // "content": "You are a helpful assistant."
+ // },
+ // {
+ // "role": "user",
+ // "content": "Hello!"
+ // }
+ // ]
+ // }
+ messagesInf, ok := body["messages"]
+ if !ok {
+ return "", fmt.Errorf("missing '.messages' field")
+ }
+ messages, ok := messagesInf.([]interface{})
+ if !ok {
+ return "", fmt.Errorf("'.messages' field should be an array")
+ }
+ if len(messages) == 0 {
+ return "", fmt.Errorf("empty '.messages' field")
+ }
+
+ // Find the first user request and return the first n characters.
+ for i, msgInf := range messages {
+ msg, ok := msgInf.(map[string]interface{})
+ if !ok {
+ return "", fmt.Errorf("'.messages[i]' should be an object")
+ }
+ if msg["role"] == "user" {
+ textInf, ok := msg["content"]
+ if !ok {
+ return "", fmt.Errorf("missing '.messages[%d].content' field", i)
+ }
+ text, ok := textInf.(string)
+ if !ok {
+ return "", fmt.Errorf("'.messages[%d].content' should be a string", i)
+ }
+ return firstNChars(text, n), nil
+ }
+ }
+
+ return "", fmt.Errorf("no user message found")
+}
+
+// firstNChars returns the first n characters of a string.
+// This function is needed because Go's string indexing is based on bytes, not runes.
+func firstNChars(s string, n int) string {
+ runes := []rune(s)
+ return string(runes[:min(n, len(runes))])
+}
diff --git a/internal/apiutils/request_test.go b/internal/apiutils/request_test.go
new file mode 100644
index 00000000..8e70617c
--- /dev/null
+++ b/internal/apiutils/request_test.go
@@ -0,0 +1,213 @@
+package apiutils
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ v1 "github.com/substratusai/kubeai/api/v1"
+)
+
+func Test_getPrefixForCompletionRequest(t *testing.T) {
+ cases := []struct {
+ input string
+ n int
+ exp string
+ expErrorContains []string
+ }{
+ {`{}`, 0, "", []string{"missing", "prompt"}},
+ {`{}`, 9, "", []string{"missing", "prompt"}},
+ {`{"prompt": "abc"}`, 0, "", nil},
+ {`{"prompt": "abc"}`, 9, "abc", nil},
+ {`{"prompt": "abcefghijk"}`, 9, "abcefghij", nil},
+ {`{"prompt": "世界"}`, 0, "", nil},
+ {`{"prompt": "世界"}`, 1, "世", nil},
+ {`{"prompt": "世界"}`, 2, "世界", nil},
+ {`{"prompt": "世界"}`, 3, "世界", nil},
+ }
+ for _, c := range cases {
+ t.Run(fmt.Sprintf("%q %d", c.input, c.n), func(t *testing.T) {
+ var body map[string]interface{}
+ require.NoError(t, json.Unmarshal([]byte(c.input), &body))
+ out, err := getPrefixForCompletionRequest(body, c.n)
+ if c.expErrorContains != nil {
+ for _, ec := range c.expErrorContains {
+ require.ErrorContains(t, err, ec)
+ }
+ return
+ } else {
+ require.NoError(t, err)
+ }
+ require.Equal(t, c.exp, out)
+ })
+ }
+}
+
+func Test_getPrefixForChatCompletionRequest(t *testing.T) {
+ cases := []struct {
+ input string
+ n int
+ exp string
+ expErrorContains []string
+ }{
+ {`{}`, 0, "", []string{"missing", "messages"}},
+ {`{}`, 0, "", []string{"missing", "messages"}},
+ {`{"messages": []}`, 0, "", []string{"empty"}},
+ {`{"messages": []}`, 9, "", []string{"empty"}},
+ {`{"messages": [{"role": "user", "content": "abc"}]}`, 0, "", nil},
+ {`{"messages": [{"role": "user", "content": "abc"}]}`, 9, "abc", nil},
+ {`{"messages": [{"role": "user", "content": "abcefghijk"}]}`, 9, "abcefghij", nil},
+ {`{"messages": [{"role": "user", "content": "世界"}]}`, 0, "", nil},
+ {`{"messages": [{"role": "user", "content": "世界"}]}`, 1, "世", nil},
+ {`{"messages": [{"role": "user", "content": "世界"}]}`, 2, "世界", nil},
+ {`{"messages": [{"role": "user", "content": "世界"}]}`, 3, "世界", nil},
+ {`{"messages": [{"role": "user", "content": "abc"}, {"role": "user", "content": "xyz"}]}`, 0, "", nil},
+ {`{"messages": [{"role": "user", "content": "abc"}, {"role": "user", "content": "xyz"}]}`, 9, "abc", nil},
+ {`{"messages": [{"role": "system", "content": "abc"}, {"role": "user", "content": "xyz"}]}`, 0, "", nil},
+ {`{"messages": [{"role": "system", "content": "abc"}, {"role": "user", "content": "xyz"}]}`, 9, "xyz", nil},
+ {`{"messages": [{"role": "system", "content": "abc"}]}`, 9, "", []string{"no", "user", "found"}},
+ }
+ for _, c := range cases {
+ t.Run(fmt.Sprintf("%q %d", c.input, c.n), func(t *testing.T) {
+ var body map[string]interface{}
+ require.NoError(t, json.Unmarshal([]byte(c.input), &body))
+ out, err := getPrefixForChatCompletionRequest(body, c.n)
+ if c.expErrorContains != nil {
+ for _, ec := range c.expErrorContains {
+ require.ErrorContains(t, err, ec)
+ }
+ return
+ } else {
+ require.NoError(t, err)
+ }
+ require.Equal(t, c.exp, out)
+ })
+ }
+}
+
+func Test_firstNChars(t *testing.T) {
+ cases := []struct {
+ input string
+ n int
+ exp string
+ }{
+ {"", 0, ""},
+ {"", 1, ""},
+ {"abc", 0, ""},
+ {"abc", 1, "a"},
+ {"abc", 2, "ab"},
+ {"abc", 3, "abc"},
+ {"abc", 4, "abc"},
+ {"世界", 1, "世"},
+ {"世界", 2, "世界"},
+ {"世界", 3, "世界"},
+ }
+ for _, c := range cases {
+ t.Run(fmt.Sprintf("%q %d", c.input, c.n), func(t *testing.T) {
+ require.Equal(t, c.exp, firstNChars(c.input, c.n))
+ })
+ }
+}
+
+func TestParseRequest(t *testing.T) {
+ cases := []struct {
+ name string
+ body string
+ path string
+ headers http.Header
+ expModel string
+ expAdapter string
+ expPrefix string
+ expErrorContains []string
+ }{
+ {
+ name: "empty",
+ body: `{}`,
+ expErrorContains: []string{"bad request"},
+ },
+ {
+ name: "model only",
+ body: `{"model": "test-model"}`,
+ expModel: "test-model",
+ },
+ {
+ name: "model and adapter",
+ body: `{"model": "test-model_test-adapter"}`,
+ expModel: "test-model",
+ expAdapter: "test-adapter",
+ },
+ {
+ name: "openai chat completion missing messages",
+ body: `{"model": "test-model"}`,
+ path: "/v1/chat/completions",
+ expModel: "test-model",
+ expErrorContains: []string{"missing", "messages"},
+ },
+ {
+ name: "openai chat completion missing user message",
+ body: `{"model": "test-model", "messages": [{"role": "system", "content": "test"}]}`,
+ path: "/v1/chat/completions",
+ expModel: "test-model",
+ expErrorContains: []string{"no", "user", "found"},
+ },
+ {
+ name: "openai chat completion",
+ body: `{"model": "test-model", "messages": [{"role": "user", "content": "test-prefix"}]}`,
+ path: "/v1/chat/completions",
+ expModel: "test-model",
+ expPrefix: "test-prefi", // "test-prefix" (max 10) --> "test-prefi"
+ },
+ {
+ name: "openai legacy completion",
+ body: `{"model": "test-model", "prompt": "test-prefix"}`,
+ path: "/v1/completions",
+ expModel: "test-model",
+ expPrefix: "test-prefi", // "test-prefix" (max 10) --> "test-prefi"
+
+ },
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ ctx := context.Background()
+
+ mockClient := &mockModelClient{prefixCharLen: 10}
+
+ req, err := ParseRequest(ctx, mockClient, bytes.NewReader([]byte(c.body)), c.path, c.headers)
+ if c.expErrorContains != nil {
+ for _, ec := range c.expErrorContains {
+ require.ErrorContains(t, err, ec)
+ }
+ return
+ } else {
+ require.NoError(t, err)
+ }
+
+ require.Equal(t, c.expModel, req.Model)
+ require.Equal(t, c.expAdapter, req.Adapter)
+ require.Equal(t, c.expPrefix, req.Prefix)
+ })
+ }
+
+}
+
+type mockModelClient struct {
+ prefixCharLen int
+}
+
+func (m *mockModelClient) LookupModel(ctx context.Context, model, adapter string, selectors []string) (*v1.Model, error) {
+ return &v1.Model{
+ Spec: v1.ModelSpec{
+ LoadBalancing: v1.LoadBalancing{
+ Strategy: v1.PrefixHashStrategy,
+ PrefixHash: v1.PrefixHash{
+ // "test-prefix" --> "test-prefi"
+ PrefixCharLength: m.prefixCharLen,
+ },
+ },
+ },
+ }, nil
+}
diff --git a/internal/endpoints/endpoints.go b/internal/endpoints/endpoints.go
deleted file mode 100644
index 26a55eaf..00000000
--- a/internal/endpoints/endpoints.go
+++ /dev/null
@@ -1,138 +0,0 @@
-package endpoints
-
-import (
- "context"
- "log"
- "sync"
- "sync/atomic"
-)
-
-func newEndpointGroup() *endpointGroup {
- e := &endpointGroup{}
- e.endpoints = make(map[string]endpoint)
- e.bcast = make(chan struct{})
- return e
-}
-
-type endpointGroup struct {
- mtx sync.RWMutex
- endpoints map[string]endpoint
-
- bmtx sync.RWMutex
- bcast chan struct{} // closed when there's a broadcast
-}
-
-func newEndpoint(attrs endpointAttrs) endpoint {
- return endpoint{
- inFlight: &atomic.Int64{},
- endpointAttrs: attrs,
- }
-}
-
-type endpoint struct {
- inFlight *atomic.Int64
- endpointAttrs
-}
-
-// getBestAddr returns the best "IP:Port". It blocks until there are available endpoints
-// in the endpoint group. It selects the host with the minimum in-flight requests
-// among all the available endpoints.
-func (e *endpointGroup) getBestAddr(ctx context.Context, adapter string, awaitChangeEndpoints bool) (string, func(), error) {
- e.mtx.RLock()
- // await endpoints exists
- for awaitChangeEndpoints || len(e.endpoints) == 0 {
- e.mtx.RUnlock()
- select {
- case <-e.awaitEndpoints():
- case <-ctx.Done():
- return "", func() {}, ctx.Err()
- }
- e.mtx.RLock()
- }
- var bestAddr string
- var minInFlight int
- for addr, ep := range e.endpoints {
- if adapter != "" {
- // Skip endpoints that don't have the requested adapter.
- if _, ok := ep.adapters[adapter]; !ok {
- continue
- }
- }
- inFlight := int(e.endpoints[addr].inFlight.Load())
- if bestAddr == "" || inFlight < minInFlight {
- bestAddr = addr
- minInFlight = inFlight
- }
- }
-
- if bestAddr == "" {
- e.mtx.RUnlock()
- return e.getBestAddr(ctx, adapter, true)
- }
-
- ep := e.endpoints[bestAddr]
- ep.inFlight.Add(1)
- decFunc := func() {
- log.Printf("decrementing in-flight count for %s, new in-flight: %v", bestAddr, ep.inFlight.Add(-1))
- }
- e.mtx.RUnlock()
- return bestAddr, decFunc, nil
-}
-
-func (e *endpointGroup) awaitEndpoints() chan struct{} {
- e.bmtx.RLock()
- defer e.bmtx.RUnlock()
- return e.bcast
-}
-
-func (e *endpointGroup) getAllAddrs() []string {
- e.mtx.RLock()
- defer e.mtx.RUnlock()
-
- var hosts []string
- for ip := range e.endpoints {
- hosts = append(hosts, ip)
- }
-
- return hosts
-}
-
-func (g *endpointGroup) lenIPs() int {
- g.mtx.RLock()
- defer g.mtx.RUnlock()
- return len(g.endpoints)
-}
-
-type endpointAttrs struct {
- adapters map[string]struct{}
-}
-
-func (g *endpointGroup) setAddrs(addrs map[string]endpointAttrs) {
- g.mtx.Lock()
- for addr, attrs := range addrs {
- if ep, ok := g.endpoints[addr]; ok {
- ep.adapters = attrs.adapters
- } else {
- g.endpoints[addr] = newEndpoint(attrs)
- }
- }
- for addr := range g.endpoints {
- if _, ok := addrs[addr]; !ok {
- delete(g.endpoints, addr)
- }
- }
- g.mtx.Unlock()
-
- // notify waiting requests
- if len(addrs) > 0 {
- g.broadcastEndpoints()
- }
-}
-
-func (g *endpointGroup) broadcastEndpoints() {
- g.bmtx.Lock()
- defer g.bmtx.Unlock()
-
- close(g.bcast)
- g.bcast = make(chan struct{})
-}
diff --git a/internal/endpoints/resolver_test.go b/internal/endpoints/resolver_test.go
deleted file mode 100644
index b5d3179c..00000000
--- a/internal/endpoints/resolver_test.go
+++ /dev/null
@@ -1,72 +0,0 @@
-package endpoints
-
-import (
- "context"
- "testing"
- "time"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestAwaitBestHost(t *testing.T) {
- const (
- myModel = "my-model"
- myAdapter = "my-adapter"
- myAddrWithoutAdapter = "10.0.0.1:8000"
- myAddrWithAdapter = "10.0.0.2:8000"
- )
-
- manager := &Resolver{endpoints: make(map[string]*endpointGroup, 1)}
-
- testCases := map[string]struct {
- model string
- adapter string
- addrs map[string]endpointAttrs
- expAddr string
- expErr error
- }{
- "model without adapter": {
- model: myModel,
- expAddr: myAddrWithoutAdapter,
- addrs: map[string]endpointAttrs{myAddrWithoutAdapter: {}},
- },
- "model with adapter": {
- model: myModel,
- adapter: myAdapter,
- addrs: map[string]endpointAttrs{
- myAddrWithoutAdapter: {},
- myAddrWithAdapter: {adapters: map[string]struct{}{
- myAdapter: {},
- }},
- },
- expAddr: myAddrWithAdapter,
- },
- "unknown model blocks until timeout": {
- model: "unknown-model",
- addrs: map[string]endpointAttrs{
- myAddrWithoutAdapter: {},
- },
- expErr: context.DeadlineExceeded,
- },
- // not covered: unknown port with multiple ports on entrypoint
- }
-
- for name, spec := range testCases {
- t.Run(name, func(t *testing.T) {
- manager.getEndpoints(myModel).setAddrs(spec.addrs)
-
- ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
- defer cancel()
-
- gotAddr, gotFunc, gotErr := manager.AwaitBestAddress(ctx, spec.model, spec.adapter)
- if spec.expErr != nil {
- require.ErrorIs(t, spec.expErr, gotErr)
- return
- }
- require.NoError(t, gotErr)
- gotFunc()
- assert.Equal(t, spec.expAddr, gotAddr)
- })
- }
-}
diff --git a/internal/loadbalancer/balance_chwbl.go b/internal/loadbalancer/balance_chwbl.go
new file mode 100644
index 00000000..40458092
--- /dev/null
+++ b/internal/loadbalancer/balance_chwbl.go
@@ -0,0 +1,142 @@
+package loadbalancer
+
+import (
+ "context"
+ "fmt"
+ "sort"
+
+ "github.com/cespare/xxhash"
+ "github.com/substratusai/kubeai/internal/metrics"
+)
+
+func (g *group) chwblGetAddr(key string, loadFactor float64, adapter string) (endpoint, bool) {
+ if len(g.chwblHashes) == 0 {
+ return endpoint{}, false
+ }
+
+ h := chwblHash(key)
+ _, i0 := g.chwblSearch(h)
+
+ // The default endpoint is the endpoint that is able to serve the request (has the adapter)
+ // but might not meet the load requirement after all other endpoints have been checked.
+ var defaultEndpoint *endpoint
+
+ i := i0
+ // Avoid an infinite loop by checking if we've checked all the endpoints.
+ for n := 0; n < len(g.chwblSortedHashes); n++ {
+ name := g.chwblHashes[g.chwblSortedHashes[i]]
+ ep, ok := g.endpoints[name]
+ if !ok {
+ panic(fmt.Sprintf("endpoints corrupted, %q should be in map", name))
+ }
+
+ var adapterMatches bool
+ if adapter == "" {
+ adapterMatches = true
+ } else {
+ _, adapterMatches = ep.adapters[adapter]
+ }
+
+ if adapterMatches {
+ if defaultEndpoint == nil {
+ // Save the first endpoint that has the adapter in case no
+ // endpoint is found with acceptable load.
+ defaultEndpoint = &ep
+ }
+ if chwblLoadOK(ep.inFlight.Load(), g.totalInFlight.Load(), len(g.endpoints), loadFactor) {
+ metrics.InferenceRequestsHashLookupIterations.Record(context.Background(), int64(n+1))
+ return ep, true
+ }
+ }
+
+ i++
+ if i >= len(g.chwblSortedHashes) {
+ // wrap around
+ i = 0
+ }
+ }
+
+ if defaultEndpoint != nil {
+ metrics.InferenceRequestsHashLookupIterations.Record(context.Background(), int64(len(g.chwblSortedHashes)))
+ return *defaultEndpoint, true
+ }
+ return endpoint{}, false
+}
+
+func (g *group) chwblAddEndpoint(name string) {
+ for i := 0; i < g.chwblReplication; i++ {
+ h := chwblHashEndpointReplica(name, i)
+ g.chwblHashes[h] = name
+ g.chwblSortedHashes = append(g.chwblSortedHashes, h)
+ }
+
+ // sort hashes in ascending order
+ sort.Slice(g.chwblSortedHashes, func(i int, j int) bool {
+ return g.chwblSortedHashes[i] < g.chwblSortedHashes[j]
+ })
+}
+
+func (g *group) chwblRemoveEndpoint(name string) {
+ for i := 0; i < g.chwblReplication; i++ {
+ h := chwblHashEndpointReplica(name, i)
+ delete(g.chwblHashes, h)
+ g.chwblDeleteSortedHash(h)
+ }
+}
+
+// search returns the hash values and its index.
+func (g *group) chwblSearch(key uint64) (uint64, int) {
+ idx := sort.Search(len(g.chwblSortedHashes), func(i int) bool {
+ return g.chwblSortedHashes[i] >= key
+ })
+
+ if idx >= len(g.chwblSortedHashes) {
+ idx = 0
+ }
+ return g.chwblSortedHashes[idx], idx
+}
+
+func (g *group) chwblDeleteSortedHash(val uint64) {
+ idx := -1
+ left := 0
+ right := len(g.chwblSortedHashes) - 1
+ for left <= right {
+ middle := (left + right) / 2
+ current := g.chwblSortedHashes[middle]
+ if current == val {
+ idx = middle
+ break
+ } else if current < val {
+ left = middle + 1
+ } else if current > val {
+ right = middle - 1
+ }
+ }
+ if idx != -1 {
+ g.chwblSortedHashes = append(g.chwblSortedHashes[:idx], g.chwblSortedHashes[idx+1:]...)
+ }
+}
+
+func chwblHash(s string) uint64 {
+ return xxhash.Sum64([]byte(s))
+}
+
+func chwblHashEndpointReplica(name string, replica int) uint64 {
+ return chwblHash(chwblEndpointReplicaHashInput(name, replica))
+}
+
+func chwblEndpointReplicaHashInput(name string, replica int) string {
+ return fmt.Sprintf("%s%d", name, replica)
+}
+
+func chwblLoadOK(load, totalLoad int64, n int, loadFactor float64) bool {
+ if totalLoad == 0 {
+ return true
+ }
+
+ // The "+1"s are to simulate the load of the new request.
+ avgLoad := float64(totalLoad+1) / float64(n)
+ threshold := avgLoad * loadFactor
+ ok := float64(load)+1 <= threshold
+ return ok
+}
diff --git a/internal/loadbalancer/balance_least_load.go b/internal/loadbalancer/balance_least_load.go
new file mode 100644
index 00000000..84ddd7ab
--- /dev/null
+++ b/internal/loadbalancer/balance_least_load.go
@@ -0,0 +1,23 @@
+package loadbalancer
+
+func (g *group) getAddrLeastLoad(adapter string) (endpoint, bool) {
+ var bestEp endpoint
+ var found bool
+ var minInFlight int
+ for _, ep := range g.endpoints {
+ if adapter != "" {
+ // Skip endpoints that don't have the requested adapter.
+ if _, ok := ep.adapters[adapter]; !ok {
+ continue
+ }
+ }
+ inFlight := int(ep.inFlight.Load())
+ if !found || inFlight < minInFlight {
+ bestEp = ep
+ found = true
+ minInFlight = inFlight
+ }
+ }
+
+ return bestEp, found
+}
diff --git a/internal/loadbalancer/group.go b/internal/loadbalancer/group.go
new file mode 100644
index 00000000..c126bfec
--- /dev/null
+++ b/internal/loadbalancer/group.go
@@ -0,0 +1,149 @@
+package loadbalancer
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "sync/atomic"
+
+ v1 "github.com/substratusai/kubeai/api/v1"
+ "github.com/substratusai/kubeai/internal/apiutils"
+)
+
+func newEndpointGroup() *group {
+ g := &group{
+ endpoints: make(map[string]endpoint),
+ totalInFlight: &atomic.Int64{},
+ chwblReplication: 100,
+ chwblHashes: map[uint64]string{},
+ chwblSortedHashes: []uint64{},
+ bcast: make(chan struct{}),
+ }
+ return g
+}
+
+type group struct {
+ mtx sync.RWMutex
+
+ endpoints map[string]endpoint
+
+ totalInFlight *atomic.Int64
+
+ // the number of times an endpoint is replicated on the hash ring
+ chwblReplication int
+ // map of hash to endpoint
+ chwblHashes map[uint64]string
+ // sorted list of hashed node-replicas
+ chwblSortedHashes []uint64
+
+ bmtx sync.RWMutex
+ bcast chan struct{} // closed when there's a broadcast
+}
+
+type endpoint struct {
+ address string
+
+ inFlight *atomic.Int64
+
+ adapters map[string]struct{}
+}
+
+// getBestAddr returns the best "IP:Port". It blocks until there are available endpoints
+// in the endpoint group.
+func (g *group) getBestAddr(ctx context.Context, req *apiutils.Request, awaitChangeEndpoints bool) (string, func(), error) {
+ g.mtx.RLock()
+ // await endpoints exists
+ for awaitChangeEndpoints || len(g.endpoints) == 0 {
+ g.mtx.RUnlock()
+ select {
+ case <-g.awaitEndpoints():
+ case <-ctx.Done():
+ return "", func() {}, ctx.Err()
+ }
+ g.mtx.RLock()
+ }
+
+ var ep endpoint
+ var found bool
+ switch req.LoadBalancing.Strategy {
+ case v1.PrefixHashStrategy:
+ ep, found = g.chwblGetAddr(req.Adapter+req.Prefix, float64(req.LoadBalancing.PrefixHash.MeanLoadPercentage)/100, req.Adapter)
+ case v1.LeastLoadStrategy:
+ ep, found = g.getAddrLeastLoad(req.Adapter)
+ default:
+ return "", func() {}, fmt.Errorf("unknown load balancing strategy: %v", req.LoadBalancing.Strategy)
+ }
+
+ if !found {
+ g.mtx.RUnlock()
+ return g.getBestAddr(ctx, req, true)
+ }
+
+ g.addInFlight(ep.inFlight, 1)
+ decFunc := func() {
+ g.addInFlight(ep.inFlight, -1)
+ }
+ g.mtx.RUnlock()
+ return ep.address, decFunc, nil
+}
+
+func (g *group) awaitEndpoints() chan struct{} {
+ g.bmtx.RLock()
+ defer g.bmtx.RUnlock()
+ return g.bcast
+}
+
+func (g *group) getAllAddrs() []string {
+ g.mtx.RLock()
+ defer g.mtx.RUnlock()
+
+ var hosts []string
+ for _, ep := range g.endpoints {
+ hosts = append(hosts, ep.address)
+ }
+
+ return hosts
+}
+
+func (g *group) reconcileEndpoints(observed map[string]endpoint) {
+ g.mtx.Lock()
+ for name, observedEp := range observed {
+ if currentEp, ok := g.endpoints[name]; ok {
+ currentEp.adapters = observedEp.adapters
+ g.endpoints[name] = currentEp
+ } else {
+ g.endpoints[name] = endpoint{
+ inFlight: &atomic.Int64{},
+ address: observedEp.address,
+ adapters: observedEp.adapters,
+ }
+ g.chwblAddEndpoint(name)
+ }
+ }
+ for name, ep := range g.endpoints {
+ if _, ok := observed[name]; !ok {
+ g.totalInFlight.Add(-ep.inFlight.Load())
+ g.chwblRemoveEndpoint(name)
+ delete(g.endpoints, name)
+ }
+ }
+ g.mtx.Unlock()
+
+ // notify waiting requests
+ if len(observed) > 0 {
+ g.broadcastEndpoints()
+ }
+}
+
+func (g *group) broadcastEndpoints() {
+ g.bmtx.Lock()
+ defer g.bmtx.Unlock()
+
+ close(g.bcast)
+ g.bcast = make(chan struct{})
+}
+
+func (g *group) addInFlight(endpointInFlight *atomic.Int64, add int64) int64 {
+ g.totalInFlight.Add(add)
+ return endpointInFlight.Add(add)
+}
diff --git a/internal/endpoints/endpoints_bench_test.go b/internal/loadbalancer/group_bench_test.go
similarity index 50%
rename from internal/endpoints/endpoints_bench_test.go
rename to internal/loadbalancer/group_bench_test.go
index db083aa2..32e08cb3 100644
--- a/internal/endpoints/endpoints_bench_test.go
+++ b/internal/loadbalancer/group_bench_test.go
@@ -1,17 +1,19 @@
-package endpoints
+package loadbalancer
import (
"context"
"testing"
+
+ "github.com/substratusai/kubeai/internal/apiutils"
)
func BenchmarkEndpointGroup(b *testing.B) {
e := newEndpointGroup()
- e.setAddrs(map[string]endpointAttrs{"10.0.0.1": {}})
+ e.reconcileEndpoints(map[string]endpoint{"pod1": {address: "10.0.0.1:8000"}})
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
- _, f, err := e.getBestAddr(context.Background(), "", false)
+ _, f, err := e.getBestAddr(context.Background(), &apiutils.Request{}, false)
if err != nil {
b.Fatal(err)
}
diff --git a/internal/endpoints/endpoints_test.go b/internal/loadbalancer/group_test.go
similarity index 56%
rename from internal/endpoints/endpoints_test.go
rename to internal/loadbalancer/group_test.go
index 13e50df5..20dcac92 100644
--- a/internal/endpoints/endpoints_test.go
+++ b/internal/loadbalancer/group_test.go
@@ -1,4 +1,4 @@
-package endpoints
+package loadbalancer
import (
"context"
@@ -8,32 +8,47 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ v1 "github.com/substratusai/kubeai/api/v1"
+ "github.com/substratusai/kubeai/internal/apiutils"
"k8s.io/apimachinery/pkg/util/rand"
)
func TestConcurrentAccess(t *testing.T) {
- const myModel = "myModel"
+ const (
+ myModel = "myModel"
+ myAddr = "10.0.0.1:8000"
+ )
testCases := map[string]struct {
readerCount int
writerCount int
}{
- "lot of reader": {readerCount: 1_000, writerCount: 1},
- "lot of writer": {readerCount: 1, writerCount: 1_000},
- "lot of both": {readerCount: 1_000, writerCount: 1_000},
+ "one reader_one_writer": {readerCount: 1, writerCount: 1},
+ "lot of reader": {readerCount: 1_000, writerCount: 1},
+ "lot of writer": {readerCount: 1, writerCount: 1_000},
+ "lot of both": {readerCount: 1_000, writerCount: 1_000},
}
for name, spec := range testCases {
- randomReadFn := []func(g *endpointGroup){
- func(g *endpointGroup) { g.getBestAddr(context.Background(), "", false) },
- func(g *endpointGroup) { g.getAllAddrs() },
- func(g *endpointGroup) { g.lenIPs() },
+ randomReadFn := []func(g *group){
+ func(g *group) {
+ ip, f, err := g.getBestAddr(context.Background(), &apiutils.Request{
+ Model: myModel,
+ LoadBalancing: v1.LoadBalancing{
+ Strategy: v1.LeastLoadStrategy,
+ },
+ }, false)
+ require.NoError(t, err)
+ defer f()
+ assert.Equal(t, myAddr, ip)
+ },
+ func(g *group) { g.getAllAddrs() },
}
t.Run(name, func(t *testing.T) {
- // setup endpoint with one service so that requests are not waiting
- endpoint := newEndpointGroup()
- endpoint.setAddrs(
- map[string]endpointAttrs{myModel: {}},
+ // setup endpoint with one endpoint so that requests are not waiting
+ group := newEndpointGroup()
+ group.reconcileEndpoints(
+ map[string]endpoint{myModel: {address: myAddr}},
)
var startWg, doneWg sync.WaitGroup
@@ -50,10 +65,10 @@ func TestConcurrentAccess(t *testing.T) {
}
}
// when
- startTogether(spec.readerCount, func() { randomReadFn[rand.Intn(len(randomReadFn)-1)](endpoint) })
+ startTogether(spec.readerCount, func() { randomReadFn[rand.Intn(len(randomReadFn)-1)](group) })
startTogether(spec.writerCount, func() {
- endpoint.setAddrs(
- map[string]endpointAttrs{rand.String(1): {}},
+ group.reconcileEndpoints(
+ map[string]endpoint{myModel: {address: myAddr}},
)
})
doneWg.Wait()
@@ -77,16 +92,16 @@ func TestBlockAndWaitForEndpoints(t *testing.T) {
}()
}
}
- endpoint := newEndpointGroup()
+ group := newEndpointGroup()
ctx := context.TODO()
startTogether(100, func() {
- endpoint.getBestAddr(ctx, "", false)
+ group.getBestAddr(ctx, &apiutils.Request{}, false)
})
startWg.Wait()
// when broadcast triggered
- endpoint.setAddrs(
- map[string]endpointAttrs{rand.String(4): {}},
+ group.reconcileEndpoints(
+ map[string]endpoint{rand.String(4): {}},
)
// then
doneWg.Wait()
@@ -102,7 +117,7 @@ func TestAbortOnCtxCancel(t *testing.T) {
go func(t *testing.T) {
startWg.Wait()
endpoint := newEndpointGroup()
- _, f, err := endpoint.getBestAddr(ctx, "", false)
+ _, f, err := endpoint.getBestAddr(ctx, &apiutils.Request{}, false)
defer f()
require.Error(t, err)
doneWg.Done()
diff --git a/internal/endpoints/resolver.go b/internal/loadbalancer/load_balancer.go
similarity index 63%
rename from internal/endpoints/resolver.go
rename to internal/loadbalancer/load_balancer.go
index ab50c68e..fac8428f 100644
--- a/internal/endpoints/resolver.go
+++ b/internal/loadbalancer/load_balancer.go
@@ -1,4 +1,4 @@
-package endpoints
+package loadbalancer
import (
"context"
@@ -7,7 +7,8 @@ import (
"strings"
"sync"
- kubeaiv1 "github.com/substratusai/kubeai/api/v1"
+ v1 "github.com/substratusai/kubeai/api/v1"
+ "github.com/substratusai/kubeai/internal/apiutils"
"github.com/substratusai/kubeai/internal/k8sutils"
corev1 "k8s.io/api/core/v1"
"k8s.io/utils/ptr"
@@ -17,10 +18,10 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
)
-func NewResolver(mgr ctrl.Manager) (*Resolver, error) {
- r := &Resolver{}
+func New(mgr ctrl.Manager) (*LoadBalancer, error) {
+ r := &LoadBalancer{}
r.Client = mgr.GetClient()
- r.endpoints = map[string]*endpointGroup{}
+ r.groups = map[string]*group{}
r.ExcludePods = map[string]struct{}{}
if err := r.SetupWithManager(mgr); err != nil {
return nil, err
@@ -28,12 +29,12 @@ func NewResolver(mgr ctrl.Manager) (*Resolver, error) {
return r, nil
}
-type Resolver struct {
+type LoadBalancer struct {
client.Client
endpointsMtx sync.Mutex
// map[]endpointGroup
- endpoints map[string]*endpointGroup
+ groups map[string]*group
selfIPsMtx sync.RWMutex
selfIPs []string
@@ -41,14 +42,14 @@ type Resolver struct {
ExcludePods map[string]struct{}
}
-func (r *Resolver) SetupWithManager(mgr ctrl.Manager) error {
+func (r *LoadBalancer) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
WithOptions(controller.Options{NeedLeaderElection: ptr.To(false)}).
For(&corev1.Pod{}).
Complete(r)
}
-func (r *Resolver) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
+func (r *LoadBalancer) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
var pod corev1.Pod
if err := r.Get(ctx, req.NamespacedName, &pod); err != nil {
return ctrl.Result{}, client.IgnoreNotFound(err)
@@ -80,17 +81,17 @@ func (r *Resolver) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result
return ctrl.Result{}, nil
}
- modelName, ok := labels[kubeaiv1.PodModelLabel]
+ modelName, ok := labels[v1.PodModelLabel]
if !ok {
return ctrl.Result{}, nil
}
var podList corev1.PodList
- if err := r.List(ctx, &podList, client.InNamespace(pod.Namespace), client.MatchingLabels{kubeaiv1.PodModelLabel: modelName}); err != nil {
+ if err := r.List(ctx, &podList, client.InNamespace(pod.Namespace), client.MatchingLabels{v1.PodModelLabel: modelName}); err != nil {
return ctrl.Result{}, fmt.Errorf("listing matching pods: %w", err)
}
- addrs := map[string]endpointAttrs{}
+ observedEndpoints := map[string]endpoint{}
for _, pod := range podList.Items {
if _, exclude := r.ExcludePods[pod.Name]; exclude {
continue
@@ -101,14 +102,14 @@ func (r *Resolver) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result
// The Model controller should always set the port annotation in the Pods it creates
// to communicate the port that the given backend listens on.
- port := getPodAnnotation(pod, kubeaiv1.ModelPodPortAnnotation)
+ port := getPodAnnotation(pod, v1.ModelPodPortAnnotation)
if port == "" {
- log.Printf("ERROR: No port annotation %q found for pod %s, skipping", kubeaiv1.ModelPodPortAnnotation, pod.Name)
+ log.Printf("ERROR: No port annotation %q found for pod %s, skipping", v1.ModelPodPortAnnotation, pod.Name)
continue
}
// Allow overriding the IP address of the pod.
- ip := getPodAnnotation(pod, kubeaiv1.ModelPodIPAnnotation)
+ ip := getPodAnnotation(pod, v1.ModelPodIPAnnotation)
if ip == "" {
ip = pod.Status.PodIP
}
@@ -118,26 +119,27 @@ func (r *Resolver) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result
continue
}
- addrs[ip+":"+port] = getEndpointAttrs(pod)
+ observedEndpoints[pod.Namespace+"/"+pod.Name] = endpoint{
+ address: ip + ":" + port,
+ adapters: getEndpointAdapters(pod),
+ }
}
- r.getEndpoints(modelName).setAddrs(addrs)
+ r.getEndpoints(modelName).reconcileEndpoints(observedEndpoints)
return ctrl.Result{}, nil
}
-func getEndpointAttrs(pod corev1.Pod) endpointAttrs {
- attrs := endpointAttrs{
- adapters: map[string]struct{}{},
- }
+func getEndpointAdapters(pod corev1.Pod) map[string]struct{} {
+ adapters := map[string]struct{}{}
for k := range pod.GetLabels() {
- if strings.HasPrefix(k, kubeaiv1.PodAdapterLabelPrefix) {
- attrs.adapters[strings.TrimPrefix(k, kubeaiv1.PodAdapterLabelPrefix)] = struct{}{}
+ if strings.HasPrefix(k, v1.PodAdapterLabelPrefix) {
+ adapters[strings.TrimPrefix(k, v1.PodAdapterLabelPrefix)] = struct{}{}
}
}
- return attrs
+ return adapters
}
func getPodAnnotation(pod corev1.Pod, key string) string {
@@ -147,18 +149,21 @@ func getPodAnnotation(pod corev1.Pod, key string) string {
return ""
}
-func (r *Resolver) getEndpoints(model string) *endpointGroup {
+// getEndpoints returns the endpoint group for the given model.
+// If the group does not exist, it is created.
+// This assumes that the existance of the model is already checked.
+func (r *LoadBalancer) getEndpoints(model string) *group {
r.endpointsMtx.Lock()
- e, ok := r.endpoints[model]
+ g, ok := r.groups[model]
if !ok {
- e = newEndpointGroup()
- r.endpoints[model] = e
+ g = newEndpointGroup()
+ r.groups[model] = g
}
r.endpointsMtx.Unlock()
- return e
+ return g
}
-func (r *Resolver) GetSelfIPs() []string {
+func (r *LoadBalancer) GetSelfIPs() []string {
r.selfIPsMtx.RLock()
defer r.selfIPsMtx.RUnlock()
return r.selfIPs
@@ -167,11 +172,11 @@ func (r *Resolver) GetSelfIPs() []string {
// AwaitBestAddress returns the "IP:Port" with the lowest number of in-flight requests. It will block until an endpoint
// becomes available or the context times out. It returns a function that should be called when the
// request is complete to decrement the in-flight count.
-func (r *Resolver) AwaitBestAddress(ctx context.Context, model, adapter string) (string, func(), error) {
- return r.getEndpoints(model).getBestAddr(ctx, adapter, false)
+func (r *LoadBalancer) AwaitBestAddress(ctx context.Context, req *apiutils.Request) (string, func(), error) {
+ return r.getEndpoints(req.Model).getBestAddr(ctx, req, false)
}
// GetAllHosts retrieves the list of all hosts for a given model.
-func (r *Resolver) GetAllAddresses(model string) []string {
+func (r *LoadBalancer) GetAllAddresses(model string) []string {
return r.getEndpoints(model).getAllAddrs()
}
diff --git a/internal/loadbalancer/load_balancer_test.go b/internal/loadbalancer/load_balancer_test.go
new file mode 100644
index 00000000..e8b05eb1
--- /dev/null
+++ b/internal/loadbalancer/load_balancer_test.go
@@ -0,0 +1,423 @@
+package loadbalancer
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ v1 "github.com/substratusai/kubeai/api/v1"
+ "github.com/substratusai/kubeai/internal/apiutils"
+ "github.com/substratusai/kubeai/internal/metrics/metricstest"
+)
+
+func TestAwaitBestHostBehavior(t *testing.T) {
+ const (
+ myModel = "my-model"
+ myAdapter = "my-adapter"
+ myPodWithoutAdapter = "pod1"
+ myPodWithAdapter = "pod2"
+ myAddrWithoutAdapter = "10.0.0.1:8000"
+ myAddrWithAdapter = "10.0.0.2:8000"
+ )
+
+ testCases := map[string]struct {
+ model string
+ adapter string
+ endpoints map[string]endpoint
+ strategies []v1.LoadBalancingStrategy
+ expAddr string
+ expErr error
+ }{
+ "model only": {
+ model: myModel,
+ strategies: []v1.LoadBalancingStrategy{
+ v1.LeastLoadStrategy,
+ v1.PrefixHashStrategy,
+ },
+ expAddr: myAddrWithoutAdapter,
+ endpoints: map[string]endpoint{
+ myPodWithoutAdapter: {address: myAddrWithoutAdapter},
+ },
+ },
+ "model and adapter": {
+ model: myModel,
+ adapter: myAdapter,
+ endpoints: map[string]endpoint{
+ myPodWithoutAdapter: {
+ address: myAddrWithoutAdapter,
+ },
+ myPodWithAdapter: {
+ address: myAddrWithAdapter,
+ adapters: map[string]struct{}{
+ myAdapter: {},
+ }},
+ },
+ strategies: []v1.LoadBalancingStrategy{
+ v1.LeastLoadStrategy,
+ v1.PrefixHashStrategy,
+ },
+ expAddr: myAddrWithAdapter,
+ },
+ "no matching model blocks until timeout": {
+ model: "unknown-model",
+ endpoints: map[string]endpoint{
+ myPodWithoutAdapter: {address: myAddrWithoutAdapter},
+ },
+ strategies: []v1.LoadBalancingStrategy{
+ v1.LeastLoadStrategy,
+ v1.PrefixHashStrategy,
+ },
+ expErr: context.DeadlineExceeded,
+ },
+ "no matching adapter blocks until timeout": {
+ model: myModel,
+ adapter: "unknown-adapter",
+ endpoints: map[string]endpoint{
+ myPodWithoutAdapter: {address: myAddrWithoutAdapter},
+ },
+ strategies: []v1.LoadBalancingStrategy{
+ v1.LeastLoadStrategy,
+ v1.PrefixHashStrategy,
+ },
+ expErr: context.DeadlineExceeded,
+ },
+ // not covered: unknown port with multiple ports on entrypoint
+ }
+
+ for name, spec := range testCases {
+ for _, strategy := range spec.strategies {
+ t.Run(name+" with "+string(strategy)+" strategy", func(t *testing.T) {
+ metricstest.Init(t)
+
+ manager := &LoadBalancer{
+ groups: map[string]*group{},
+ }
+
+ manager.getEndpoints(myModel).reconcileEndpoints(spec.endpoints)
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
+ defer cancel()
+
+ gotAddr, gotFunc, gotErr := manager.AwaitBestAddress(ctx, &apiutils.Request{
+ Model: spec.model,
+ Adapter: spec.adapter,
+ LoadBalancing: v1.LoadBalancing{
+ Strategy: strategy,
+ PrefixHash: v1.PrefixHash{
+ MeanLoadPercentage: 125,
+ Replication: 1,
+ },
+ },
+ })
+ if spec.expErr != nil {
+ require.ErrorIs(t, spec.expErr, gotErr)
+ return
+ }
+ require.NoError(t, gotErr)
+ gotFunc()
+ assert.Equal(t, spec.expAddr, gotAddr)
+ })
+ }
+ }
+}
+
+func TestLoadBalancingStrategies(t *testing.T) {
+ const (
+ modelA = "model-a"
+ modelB = "model-b"
+
+ adapterA1 = "adapter-a-1"
+ adapterA2 = "adapter-a-2"
+
+ podA1Name = "pod-a-1"
+ podA1Addr = "10.0.0.1:8000"
+
+ podA2Name = "pod-a-2"
+ podA2Addr = "10.0.0.2:8000"
+
+ podB1Name = "pod-b-1"
+ podB1Addr = "10.0.0.3:8000"
+
+ podB2Name = "pod-b-2"
+ podB2Addr = "10.0.0.4:8000"
+ )
+
+ var (
+ podA1Hash = chwblEndpointReplicaHashInput(podA1Name, 0)
+ podA2Hash = chwblEndpointReplicaHashInput(podA2Name, 0)
+ )
+
+ type testStep struct {
+ name string
+
+ requestCount int
+ model string
+ adapter string
+ prefix string
+
+ expectedAddrCounts map[string]int
+ completeForAddrs map[string]int
+ }
+ cases := []struct {
+ name string
+ // map[]map[]
+ modelEndpoints map[string]map[string]endpoint
+ // map[]map[]
+ initialInFlight map[string]map[string]int64
+ loadBalancing v1.LoadBalancing
+ steps []testStep
+ }{
+ {
+ name: "least load strategy",
+ modelEndpoints: map[string]map[string]endpoint{
+ modelA: {
+ podA1Name: {address: podA1Addr, adapters: map[string]struct{}{adapterA1: {}}},
+ podA2Name: {address: podA2Addr, adapters: map[string]struct{}{adapterA2: {}}},
+ },
+ modelB: {
+ podB1Name: {address: podB1Addr},
+ podB2Name: {address: podB2Addr},
+ },
+ },
+ loadBalancing: v1.LoadBalancing{
+ Strategy: v1.LeastLoadStrategy,
+ },
+ steps: []testStep{
+ {
+ name: "first 2 requests to model-a",
+ model: modelA,
+ requestCount: 2,
+ expectedAddrCounts: map[string]int{
+ podA1Addr: 1,
+ podA2Addr: 1,
+ },
+ },
+ {
+ name: "a lot more requests to model-a",
+ model: modelA,
+ requestCount: 100,
+ expectedAddrCounts: map[string]int{
+ podA1Addr: 50,
+ podA2Addr: 50,
+ },
+ },
+ {
+ name: "requests to model-a adapter-a-1",
+ model: modelA,
+ adapter: adapterA1,
+ requestCount: 50,
+ expectedAddrCounts: map[string]int{
+ podA1Addr: 50,
+ },
+ },
+ {
+ name: "requests to model-a without adapter should be distributed to the other pod",
+ model: modelA,
+ requestCount: 52,
+ expectedAddrCounts: map[string]int{
+ podA1Addr: 1,
+ podA2Addr: 51,
+ },
+ },
+ {
+ name: "back to even balance",
+ model: modelA,
+ requestCount: 2,
+ expectedAddrCounts: map[string]int{
+ podA1Addr: 1,
+ podA2Addr: 1,
+ },
+ },
+ {
+ name: "complete some request for pod-a-2",
+ completeForAddrs: map[string]int{
+ podA2Addr: 10,
+ },
+ },
+ {
+ name: "requests to model-a should now be distributed to the other pod",
+ model: modelA,
+ requestCount: 12,
+ expectedAddrCounts: map[string]int{
+ podA1Addr: 1,
+ podA2Addr: 11,
+ },
+ },
+ {
+ name: "first requests to model-b",
+ model: modelB,
+ requestCount: 2,
+ expectedAddrCounts: map[string]int{
+ podB1Addr: 1,
+ podB2Addr: 1,
+ },
+ },
+ },
+ },
+ {
+ name: "prefix hash strategy",
+ modelEndpoints: map[string]map[string]endpoint{
+ modelA: {
+ podA1Name: {address: podA1Addr},
+ podA2Name: {address: podA2Addr},
+ },
+ modelB: {
+ podB1Name: {address: podB1Addr},
+ },
+ },
+ initialInFlight: map[string]map[string]int64{
+ modelA: {
+ podA1Name: 10,
+ podA2Name: 10,
+ },
+ },
+ loadBalancing: v1.LoadBalancing{
+ Strategy: v1.PrefixHashStrategy,
+ PrefixHash: v1.PrefixHash{
+ MeanLoadPercentage: 150,
+ Replication: 1,
+ },
+ },
+ steps: []testStep{
+ {
+ name: "first request to model-a, preferring pod-a-1, each pod has 10 in-flight requests",
+ model: modelA,
+ prefix: podA1Hash,
+ requestCount: 1,
+ expectedAddrCounts: map[string]int{
+ podA1Addr: 1,
+ },
+ },
+ {
+ // load0 load1 1.5*(avg+1) (load0)+1 <= (thres)
+ // 10 10 15.75 TRUE
+ // 11 10 16.5 TRUE
+ // 12 10 17.25 TRUE
+ // 13 10 18 TRUE
+ // 14 10 18.75 TRUE
+ // 15 10 19.5 TRUE
+ // 16 10 20.25 TRUE
+ // 17 10 21 TRUE
+ // 18 10 21.75 TRUE
+ // 19 10 22.5 TRUE
+ // 20 10 23.25 TRUE
+ // 21 10 24 TRUE
+ // 22 10 24.75 TRUE
+ // 23 10 25.5 TRUE
+ // 24 10 26.25 TRUE
+ // 25 10 27 TRUE
+ // 26 10 27.75 TRUE
+ // 27 10 28.5 TRUE
+ // 28 10 29.25 TRUE
+ // 29 10 30 TRUE
+ // 30 10 30.75 FALSE
+ name: "20 more requests preferring pod-a-1",
+ model: modelA,
+ // By making sure that the prefix matches the input used to hash the endpoint (pod-a-1),
+ // we can ensure that the algorithm will prefer pod-a-1.
+ prefix: podA1Hash,
+ requestCount: 20,
+ // See the table above for the expected distribution.
+ expectedAddrCounts: map[string]int{
+ podA1Addr: 19,
+ podA2Addr: 1,
+ },
+ },
+ {
+ // 30 10 30.75 FALSE
+ // 30 11 31.5 TRUE <-- 1 request (starting here)
+ // 31 11 32.25 TRUE <-- 2 requests
+ // 32 11 33 TRUE <-- 3 requests
+ // 33 11 33.75 FALSE <-- 4 requests
+ name: "4 more requests preferring pod-a-1",
+ model: modelA,
+ prefix: podA1Hash,
+ requestCount: 4,
+ // See the table above for the expected distribution.
+ expectedAddrCounts: map[string]int{
+ podA1Addr: 3,
+ podA2Addr: 1,
+ },
+ },
+ {
+ name: "with pod-a-1 near max load, requests preferring pod-a-2 should be distributed to pod-a-2",
+ model: modelA,
+ prefix: podA2Hash,
+ requestCount: 20,
+ expectedAddrCounts: map[string]int{
+ podA2Addr: 20,
+ },
+ },
+ {
+ name: "requests to model-b should be distributed to pod-b-1, as it is the only endpoint",
+ model: modelB,
+ // Use a hash that doesn't match any of the endpoints in model-b
+ // but does for model-a (to test hash-ring separation by model).
+ prefix: podA2Hash,
+ requestCount: 100_000,
+ expectedAddrCounts: map[string]int{
+ podB1Addr: 100_000,
+ },
+ },
+ },
+ },
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ manager := &LoadBalancer{
+ groups: map[string]*group{},
+ }
+
+ for model, endpoints := range c.modelEndpoints {
+ manager.getEndpoints(model).reconcileEndpoints(endpoints)
+ }
+
+ for modelName, inFlight := range c.initialInFlight {
+ for endpointName, count := range inFlight {
+ g := manager.getEndpoints(modelName)
+ ep := g.endpoints[endpointName]
+ g.addInFlight(ep.inFlight, count)
+ }
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+
+ doneFuncs := map[string][]func(){}
+ for _, step := range c.steps {
+ counts := map[string]int{}
+ for i := 0; i < step.requestCount; i++ {
+ // fmt.Println("request: ", step.name, "i: ", i)
+ addr, done, err := manager.AwaitBestAddress(ctx, &apiutils.Request{
+ Model: step.model,
+ Adapter: step.adapter,
+ Prefix: step.prefix,
+ LoadBalancing: c.loadBalancing,
+ })
+ require.NoError(t, err, "request: "+step.name)
+ doneFuncs[addr] = append(doneFuncs[addr], done)
+ counts[addr]++
+ }
+ if step.expectedAddrCounts != nil {
+ require.Equalf(t, step.expectedAddrCounts, counts, "request: %s", step.name)
+ }
+
+ for addr, count := range step.completeForAddrs {
+ for i := 0; i < count; i++ {
+ doneFuncs[addr][i]()
+ // remove the done function from the list
+ doneFuncs[addr] = doneFuncs[addr][1:]
+ }
+ }
+ }
+
+ for _, dones := range doneFuncs {
+ for _, done := range dones {
+ done()
+ }
+ }
+ })
+ }
+}
diff --git a/internal/manager/run.go b/internal/manager/run.go
index fb9a6dcb..3ea13877 100644
--- a/internal/manager/run.go
+++ b/internal/manager/run.go
@@ -33,13 +33,13 @@ import (
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"
kubeaiv1 "github.com/substratusai/kubeai/api/v1"
- "github.com/substratusai/kubeai/internal/endpoints"
"github.com/substratusai/kubeai/internal/leader"
+ "github.com/substratusai/kubeai/internal/loadbalancer"
"github.com/substratusai/kubeai/internal/messenger"
"github.com/substratusai/kubeai/internal/modelautoscaler"
+ "github.com/substratusai/kubeai/internal/modelclient"
"github.com/substratusai/kubeai/internal/modelcontroller"
"github.com/substratusai/kubeai/internal/modelproxy"
- "github.com/substratusai/kubeai/internal/modelscaler"
"github.com/substratusai/kubeai/internal/openaiserver"
"github.com/substratusai/kubeai/internal/vllmclient"
@@ -204,7 +204,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error {
cfg.LeaderElection.RetryPeriod.Duration,
)
- endpointResolver, err := endpoints.NewResolver(mgr)
+ loadBalancer, err := loadbalancer.New(mgr)
if err != nil {
return fmt.Errorf("unable to setup model resolver: %w", err)
}
@@ -239,7 +239,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error {
return fmt.Errorf("unable to set up ready check: %w", err)
}
- modelScaler := modelscaler.NewModelScaler(mgr.GetClient(), namespace)
+ modelClient := modelclient.NewModelClient(mgr.GetClient(), namespace)
metricsPort, err := parsePortFromAddr(cfg.MetricsAddr)
if err != nil {
@@ -250,8 +250,8 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error {
ctx,
k8sClient,
leaderElection,
- modelScaler,
- endpointResolver,
+ modelClient,
+ loadBalancer,
cfg.ModelAutoscaling,
metricsPort,
types.NamespacedName{Name: cfg.ModelAutoscaling.StateConfigMapName, Namespace: namespace},
@@ -261,7 +261,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error {
return fmt.Errorf("unable to create model autoscaler: %w", err)
}
- modelProxy := modelproxy.NewHandler(modelScaler, endpointResolver, 3, nil)
+ modelProxy := modelproxy.NewHandler(modelClient, loadBalancer, 3, nil)
openaiHandler := openaiserver.NewHandler(mgr.GetClient(), modelProxy)
mux := http.NewServeMux()
mux.Handle("/openai/", openaiHandler)
@@ -288,8 +288,8 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error {
stream.ResponsesURL,
stream.MaxHandlers,
cfg.Messaging.ErrorMaxBackoff.Duration,
- modelScaler,
- endpointResolver,
+ modelClient,
+ loadBalancer,
httpClient,
)
if err != nil {
diff --git a/internal/messenger/messenger.go b/internal/messenger/messenger.go
index ce6d51b3..7ea23460 100644
--- a/internal/messenger/messenger.go
+++ b/internal/messenger/messenger.go
@@ -13,6 +13,7 @@ import (
"sync"
"time"
+ v1 "github.com/substratusai/kubeai/api/v1"
"github.com/substratusai/kubeai/internal/apiutils"
"github.com/substratusai/kubeai/internal/metrics"
"go.opentelemetry.io/otel/attribute"
@@ -21,8 +22,8 @@ import (
)
type Messenger struct {
- modelScaler ModelScaler
- resolver EndpointResolver
+ modelClient ModelClient
+ loadBalancer LoadBalancer
HTTPC *http.Client
@@ -43,8 +44,8 @@ func NewMessenger(
responsesURL string,
maxHandlers int,
errorMaxBackoff time.Duration,
- modelScaler ModelScaler,
- resolver EndpointResolver,
+ modelClient ModelClient,
+ lb LoadBalancer,
httpClient *http.Client,
) (*Messenger, error) {
requests, err := pubsub.OpenSubscription(ctx, requestsURL)
@@ -58,8 +59,8 @@ func NewMessenger(
}
return &Messenger{
- modelScaler: modelScaler,
- resolver: resolver,
+ modelClient: modelClient,
+ loadBalancer: lb,
HTTPC: httpClient,
requestsURL: requestsURL,
requests: requests,
@@ -69,13 +70,13 @@ func NewMessenger(
}, nil
}
-type ModelScaler interface {
- LookupModel(ctx context.Context, model, adapter string, selectors []string) (bool, error)
+type ModelClient interface {
+ LookupModel(ctx context.Context, model, adapter string, selectors []string) (*v1.Model, error)
ScaleAtLeastOneReplica(ctx context.Context, model string) error
}
-type EndpointResolver interface {
- AwaitBestAddress(ctx context.Context, model, adapter string) (string, func(), error)
+type LoadBalancer interface {
+ AwaitBestAddress(ctx context.Context, req *apiutils.Request) (string, func(), error)
}
func (m *Messenger) Start(ctx context.Context) error {
@@ -192,71 +193,62 @@ func (m *Messenger) handleRequest(ctx context.Context, msg *pubsub.Message) {
}
}
*/
- req, err := parseRequest(ctx, msg)
+ mr, err := m.parseMsgRequest(ctx, msg)
if err != nil {
- m.sendResponse(req, m.jsonError("error parsing request: %v", err), http.StatusBadRequest)
+ if errors.Is(err, apiutils.ErrBadRequest) {
+ m.sendResponse(mr, m.jsonError("%v", err), http.StatusBadRequest)
+ } else if errors.Is(err, apiutils.ErrModelNotFound) {
+ m.sendResponse(mr, m.jsonError("%v", err), http.StatusNotFound)
+ } else {
+ m.sendResponse(mr, m.jsonError("parsing request: %v", err), http.StatusInternalServerError)
+ }
return
}
metricAttrs := metric.WithAttributeSet(attribute.NewSet(
- metrics.AttrRequestModel.String(req.model),
+ metrics.AttrRequestModel.String(mr.Model),
metrics.AttrRequestType.String(metrics.AttrRequestTypeMessage),
))
metrics.InferenceRequestsActive.Add(ctx, 1, metricAttrs)
defer metrics.InferenceRequestsActive.Add(ctx, -1, metricAttrs)
- modelExists, err := m.modelScaler.LookupModel(ctx, req.model, req.adapter, nil)
- if err != nil {
- m.sendResponse(req, m.jsonError("error checking if model exists: %v", err), http.StatusInternalServerError)
- return
- }
- if !modelExists {
- // Send a 400 response to the client, however it is possible the backend
- // will be deployed soon or another subscriber will handle it.
- m.sendResponse(req, m.jsonError("model not found: %s", req.model), http.StatusNotFound)
- return
- }
-
// Ensure the backend is scaled to at least one Pod.
- m.modelScaler.ScaleAtLeastOneReplica(ctx, req.model)
+ m.modelClient.ScaleAtLeastOneReplica(ctx, mr.Model)
log.Printf("Awaiting host for message %s", msg.LoggableID)
- host, completeFunc, err := m.resolver.AwaitBestAddress(ctx, req.model, req.adapter)
+ host, completeFunc, err := m.loadBalancer.AwaitBestAddress(ctx, mr.Request)
if err != nil {
- m.sendResponse(req, m.jsonError("error awaiting host for backend: %v", err), http.StatusBadGateway)
+ m.sendResponse(mr, m.jsonError("error awaiting host for backend: %v", err), http.StatusBadGateway)
return
}
defer completeFunc()
- url := fmt.Sprintf("http://%s%s", host, req.path)
+ url := fmt.Sprintf("http://%s%s", host, mr.path)
log.Printf("Sending request to backend for message %s: %s", msg.LoggableID, url)
- respPayload, respCode, err := m.sendBackendRequest(ctx, url, req.body)
+ respPayload, respCode, err := m.sendBackendRequest(ctx, url, mr.Body)
if err != nil {
- m.sendResponse(req, m.jsonError("error sending request to backend: %v", err), http.StatusBadGateway)
+ m.sendResponse(mr, m.jsonError("error sending request to backend: %v", err), http.StatusBadGateway)
return
}
- m.sendResponse(req, respPayload, respCode)
+ m.sendResponse(mr, respPayload, respCode)
}
func (m *Messenger) Stop(ctx context.Context) error {
return m.requests.Shutdown(ctx)
}
-type request struct {
- ctx context.Context
- msg *pubsub.Message
- metadata map[string]interface{}
- path string
- body json.RawMessage
- requestedModel string
- model string
- adapter string
+type msgRequest struct {
+ ctx context.Context
+ *apiutils.Request
+ msg *pubsub.Message
+ metadata map[string]interface{}
+ path string
}
-func parseRequest(ctx context.Context, msg *pubsub.Message) (*request, error) {
- req := &request{
+func (m *Messenger) parseMsgRequest(ctx context.Context, msg *pubsub.Message) (*msgRequest, error) {
+ req := &msgRequest{
ctx: ctx,
msg: msg,
}
@@ -280,34 +272,12 @@ func parseRequest(ctx context.Context, msg *pubsub.Message) (*request, error) {
req.metadata = payload.Metadata
req.path = path
- req.body = payload.Body
- var payloadBody map[string]interface{}
- if err := json.Unmarshal(payload.Body, &payloadBody); err != nil {
- return nil, fmt.Errorf("decoding: %w", err)
- }
- modelInf, ok := payloadBody["model"]
- if !ok {
- return nil, fmt.Errorf("missing '.body.model' field")
- }
- modelStr, ok := modelInf.(string)
- if !ok {
- return nil, fmt.Errorf("field '.body.model' should be a string")
- }
-
- req.requestedModel = modelStr
- req.model, req.adapter = apiutils.SplitModelAdapter(modelStr)
-
- // Assuming this is a vLLM request.
- // vLLM expects the adapter to be in the model field.
- if req.adapter != "" {
- payloadBody["model"] = req.adapter
- rewrittenBody, err := json.Marshal(payloadBody)
- if err != nil {
- return nil, fmt.Errorf("remarshalling: %w", err)
- }
- req.body = rewrittenBody
+ apiR, err := apiutils.ParseRequest(ctx, m.modelClient, bytes.NewReader(payload.Body), path, http.Header{})
+ if err != nil {
+ return req, err
}
+ req.Request = apiR
return req, nil
}
@@ -335,7 +305,7 @@ func (m *Messenger) sendBackendRequest(ctx context.Context, url string, body []b
return payload, resp.StatusCode, nil
}
-func (m *Messenger) sendResponse(req *request, body []byte, statusCode int) {
+func (m *Messenger) sendResponse(req *msgRequest, body []byte, statusCode int) {
log.Printf("Sending response to message: %v", req.msg.LoggableID)
response := struct {
diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go
index 4c9ba2d3..8c0b5612 100644
--- a/internal/metrics/metrics.go
+++ b/internal/metrics/metrics.go
@@ -1,6 +1,7 @@
package metrics
import (
+ "fmt"
"strings"
"go.opentelemetry.io/otel/attribute"
@@ -13,8 +14,10 @@ const (
// Metrics used to autoscale models:
var (
- InferenceRequestsActiveMetricName = "kubeai.inference.requests.active"
- InferenceRequestsActive metric.Int64UpDownCounter
+ InferenceRequestsActiveMetricName = "kubeai.inference.requests.active"
+ InferenceRequestsActive metric.Int64UpDownCounter
+ InferenceRequestsHashLookupIterationsMetricName = "kubeai.inference.requests.hash.lookup.iterations"
+ InferenceRequestsHashLookupIterations metric.Int64Histogram
)
// Attributes:
@@ -36,7 +39,14 @@ func Init(meter metric.Meter) error {
metric.WithDescription("The number of active requests by model"),
)
if err != nil {
- return err
+ return fmt.Errorf("%s: %w", InferenceRequestsActiveMetricName, err)
+ }
+ InferenceRequestsHashLookupIterations, err = meter.Int64Histogram(InferenceRequestsHashLookupIterationsMetricName,
+ metric.WithDescription("The number of vnodes considered while searching for the best endpoint for a request"),
+ metric.WithExplicitBucketBoundaries(1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024),
+ )
+ if err != nil {
+ return fmt.Errorf("%s: %w", InferenceRequestsHashLookupIterationsMetricName, err)
}
return nil
diff --git a/internal/modelautoscaler/autoscaler.go b/internal/modelautoscaler/autoscaler.go
index 6d4bf88b..52c5d05c 100644
--- a/internal/modelautoscaler/autoscaler.go
+++ b/internal/modelautoscaler/autoscaler.go
@@ -9,9 +9,9 @@ import (
"time"
"github.com/substratusai/kubeai/internal/config"
- "github.com/substratusai/kubeai/internal/endpoints"
"github.com/substratusai/kubeai/internal/leader"
- "github.com/substratusai/kubeai/internal/modelscaler"
+ "github.com/substratusai/kubeai/internal/loadbalancer"
+ "github.com/substratusai/kubeai/internal/modelclient"
"github.com/substratusai/kubeai/internal/movingaverage"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
@@ -21,8 +21,8 @@ func New(
ctx context.Context,
k8sClient client.Client,
leaderElection *leader.Election,
- scaler *modelscaler.ModelScaler,
- resolver *endpoints.Resolver,
+ modelClient *modelclient.ModelClient,
+ resolver *loadbalancer.LoadBalancer,
cfg config.ModelAutoscaling,
metricsPort int,
stateConfigMapRef types.NamespacedName,
@@ -31,7 +31,7 @@ func New(
a := &Autoscaler{
k8sClient: k8sClient,
leaderElection: leaderElection,
- scaler: scaler,
+ modelClient: modelClient,
resolver: resolver,
movingAvgByModel: map[string]*movingaverage.Simple{},
cfg: cfg,
@@ -78,8 +78,8 @@ type Autoscaler struct {
leaderElection *leader.Election
- scaler *modelscaler.ModelScaler
- resolver *endpoints.Resolver
+ modelClient *modelclient.ModelClient
+ resolver *loadbalancer.LoadBalancer
cfg config.ModelAutoscaling
@@ -107,7 +107,7 @@ func (a *Autoscaler) Start(ctx context.Context) {
// TODO: Remove hardcoded Service lookup by name "lingo".
- models, err := a.scaler.ListAllModels(ctx)
+ models, err := a.modelClient.ListAllModels(ctx)
if err != nil {
log.Printf("Failed to list models: %v", err)
continue
@@ -159,7 +159,7 @@ func (a *Autoscaler) Start(ctx context.Context) {
ceil := math.Ceil(normalized)
log.Printf("Calculated target replicas for model %q: ceil(%v/%v) = %v, current requests: sum(%v) = %v, history: %v",
m.Name, avgActiveRequests, *m.Spec.TargetRequests, ceil, activeRequests, activeRequestSum, avg.History())
- a.scaler.Scale(ctx, &m, int32(ceil), a.cfg.RequiredConsecutiveScaleDowns(*m.Spec.ScaleDownDelaySeconds))
+ a.modelClient.Scale(ctx, &m, int32(ceil), a.cfg.RequiredConsecutiveScaleDowns(*m.Spec.ScaleDownDelaySeconds))
nextModelState.Models[m.Name] = modelState{
AverageActiveRequests: avgActiveRequests,
diff --git a/internal/modelclient/client.go b/internal/modelclient/client.go
new file mode 100644
index 00000000..9d2d1c32
--- /dev/null
+++ b/internal/modelclient/client.go
@@ -0,0 +1,73 @@
+package modelclient
+
+import (
+ "context"
+ "fmt"
+ "sync"
+
+ kubeaiv1 "github.com/substratusai/kubeai/api/v1"
+ apierrors "k8s.io/apimachinery/pkg/api/errors"
+ "k8s.io/apimachinery/pkg/labels"
+ "k8s.io/apimachinery/pkg/types"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+)
+
+type ModelClient struct {
+ client client.Client
+ namespace string
+ consecutiveScaleDownsMtx sync.RWMutex
+ consecutiveScaleDowns map[string]int
+}
+
+func NewModelClient(client client.Client, namespace string) *ModelClient {
+ return &ModelClient{client: client, namespace: namespace, consecutiveScaleDowns: map[string]int{}}
+}
+
+// LookupModel checks if a model exists and matches the given label selectors.
+func (c *ModelClient) LookupModel(ctx context.Context, model, adapter string, labelSelectors []string) (*kubeaiv1.Model, error) {
+ m := &kubeaiv1.Model{}
+ if err := c.client.Get(ctx, types.NamespacedName{Name: model, Namespace: c.namespace}, m); err != nil {
+ if apierrors.IsNotFound(err) {
+ return nil, nil
+ }
+ return nil, err
+ }
+
+ modelLabels := m.GetLabels()
+ if modelLabels == nil {
+ modelLabels = map[string]string{}
+ }
+ for _, sel := range labelSelectors {
+ parsedSel, err := labels.Parse(sel)
+ if err != nil {
+ return nil, fmt.Errorf("parse label selector: %w", err)
+ }
+ if !parsedSel.Matches(labels.Set(modelLabels)) {
+ return nil, nil
+ }
+ }
+
+ if adapter != "" {
+ adapterFound := false
+ for _, a := range m.Spec.Adapters {
+ if a.Name == adapter {
+ adapterFound = true
+ break
+ }
+ }
+ if !adapterFound {
+ return nil, nil
+ }
+ }
+
+ return m, nil
+}
+
+func (s *ModelClient) ListAllModels(ctx context.Context) ([]kubeaiv1.Model, error) {
+ models := &kubeaiv1.ModelList{}
+ if err := s.client.List(ctx, models, client.InNamespace(s.namespace)); err != nil {
+ return nil, fmt.Errorf("list models: %w", err)
+ }
+
+ return models.Items, nil
+}
diff --git a/internal/modelclient/scale.go b/internal/modelclient/scale.go
new file mode 100644
index 00000000..f576da7c
--- /dev/null
+++ b/internal/modelclient/scale.go
@@ -0,0 +1,100 @@
+package modelclient
+
+import (
+ "context"
+ "fmt"
+ "log"
+
+ kubeaiv1 "github.com/substratusai/kubeai/api/v1"
+ autoscalingv1 "k8s.io/api/autoscaling/v1"
+ "k8s.io/apimachinery/pkg/types"
+ "sigs.k8s.io/controller-runtime/pkg/client"
+)
+
+func (c *ModelClient) ScaleAtLeastOneReplica(ctx context.Context, model string) error {
+ obj := &kubeaiv1.Model{}
+ if err := c.client.Get(ctx, types.NamespacedName{Namespace: c.namespace, Name: model}, obj); err != nil {
+ return fmt.Errorf("get scale: %w", err)
+ }
+
+ if obj.Spec.AutoscalingDisabled {
+ return nil
+ }
+
+ replicas := int32(0)
+ if obj.Spec.Replicas != nil {
+ replicas = *obj.Spec.Replicas
+ }
+
+ if replicas == 0 && !obj.Spec.AutoscalingDisabled {
+ scale := &autoscalingv1.Scale{
+ Spec: autoscalingv1.ScaleSpec{Replicas: 1},
+ }
+ if err := c.client.SubResource("scale").Update(ctx, obj, client.WithSubResourceBody(scale)); err != nil {
+ return fmt.Errorf("update scale: %w", err)
+ }
+ }
+
+ return nil
+}
+
+// Scale scales the model to the desired number of replicas, enforcing the min and max replica bounds.
+// Model should have .Spec defined before calling Scale().
+func (c *ModelClient) Scale(ctx context.Context, model *kubeaiv1.Model, replicas int32, requiredConsecutiveScaleDowns int) error {
+ //obj := &kubeaiv1.Model{}
+ //if err := s.client.Get(ctx, types.NamespacedName{Namespace: s.namespace, Name: model}, obj); err != nil {
+ // return fmt.Errorf("get scale: %w", err)
+ //}
+
+ replicas = enforceReplicaBounds(replicas, model)
+
+ var existingReplicas int32 = 0
+ if model.Spec.Replicas != nil {
+ existingReplicas = *model.Spec.Replicas
+ }
+
+ if existingReplicas > replicas {
+ // Scale down
+ c.consecutiveScaleDownsMtx.RLock()
+ consec := c.consecutiveScaleDowns[model.Name]
+ c.consecutiveScaleDownsMtx.RUnlock()
+ if consec < requiredConsecutiveScaleDowns {
+ log.Printf("model %s has %d consecutive scale downs (< %d), not scaling down yet", model.Name, consec, requiredConsecutiveScaleDowns)
+ c.consecutiveScaleDownsMtx.Lock()
+ c.consecutiveScaleDowns[model.Name]++
+ c.consecutiveScaleDownsMtx.Unlock()
+ return nil
+ }
+ } else {
+ // Scale up or constant scale.
+ c.consecutiveScaleDownsMtx.Lock()
+ c.consecutiveScaleDowns[model.Name] = 0
+ c.consecutiveScaleDownsMtx.Unlock()
+ }
+
+ if existingReplicas != replicas {
+ log.Printf("scaling model %s from %d to %d replicas", model.Name, existingReplicas, replicas)
+ scale := &autoscalingv1.Scale{
+ Spec: autoscalingv1.ScaleSpec{Replicas: replicas},
+ }
+ if err := c.client.SubResource("scale").Update(ctx, model, client.WithSubResourceBody(scale)); err != nil {
+ return fmt.Errorf("update scale: %w", err)
+ }
+ }
+
+ return nil
+}
+
+func enforceReplicaBounds(replicas int32, model *kubeaiv1.Model) int32 {
+ max := model.Spec.MaxReplicas
+ min := model.Spec.MinReplicas
+ if max != nil {
+ if replicas > *max {
+ return *max
+ }
+ }
+ if replicas < min {
+ return min
+ }
+ return replicas
+}
diff --git a/internal/modelproxy/handler.go b/internal/modelproxy/handler.go
index 5edaf4a6..d931e23e 100644
--- a/internal/modelproxy/handler.go
+++ b/internal/modelproxy/handler.go
@@ -8,40 +8,42 @@ import (
"net/http/httputil"
"net/url"
+ v1 "github.com/substratusai/kubeai/api/v1"
+ "github.com/substratusai/kubeai/internal/apiutils"
"github.com/substratusai/kubeai/internal/metrics"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
-type ModelScaler interface {
- LookupModel(ctx context.Context, model, adapter string, selectors []string) (bool, error)
+type ModelClient interface {
+ LookupModel(ctx context.Context, model, adapter string, selectors []string) (*v1.Model, error)
ScaleAtLeastOneReplica(ctx context.Context, model string) error
}
-type EndpointResolver interface {
- AwaitBestAddress(ctx context.Context, model, adapter string) (string, func(), error)
+type LoadBalancer interface {
+ AwaitBestAddress(ctx context.Context, req *apiutils.Request) (string, func(), error)
}
// Handler serves http requests for end-clients.
// It is also responsible for triggering scale-from-zero.
type Handler struct {
- modelScaler ModelScaler
- resolver EndpointResolver
- maxRetries int
- retryCodes map[int]struct{}
+ modelClient ModelClient
+ loadBalancer LoadBalancer
+ maxRetries int
+ retryCodes map[int]struct{}
}
func NewHandler(
- modelScaler ModelScaler,
- resolver EndpointResolver,
+ modelClient ModelClient,
+ loadBalancer LoadBalancer,
maxRetries int,
retryCodes map[int]struct{},
) *Handler {
return &Handler{
- modelScaler: modelScaler,
- resolver: resolver,
- maxRetries: maxRetries,
- retryCodes: retryCodes,
+ modelClient: modelClient,
+ loadBalancer: loadBalancer,
+ maxRetries: maxRetries,
+ retryCodes: retryCodes,
}
}
@@ -57,35 +59,30 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Proxy", "lingo")
- pr := newProxyRequest(r)
-
// TODO: Only parse model for paths that would have a model.
- if err := pr.parse(); err != nil {
- pr.sendErrorResponse(w, http.StatusBadRequest, "unable to parse model: %v", err)
+ pr, err := h.parseProxyRequest(r)
+ if err != nil {
+ if errors.Is(err, apiutils.ErrBadRequest) {
+ pr.sendErrorResponse(w, http.StatusBadRequest, "%v", err)
+ } else if errors.Is(err, apiutils.ErrModelNotFound) {
+ pr.sendErrorResponse(w, http.StatusNotFound, "%v", err)
+ } else {
+ pr.sendErrorResponse(w, http.StatusInternalServerError, "parsing request: %v", err)
+ }
return
}
- log.Println("model:", pr.model, "adapter:", pr.adapter)
+ log.Println("model:", pr.Model, "adapter:", pr.Adapter)
metricAttrs := metric.WithAttributeSet(attribute.NewSet(
- metrics.AttrRequestModel.String(pr.requestedModel),
+ metrics.AttrRequestModel.String(pr.RequestedModel),
metrics.AttrRequestType.String(metrics.AttrRequestTypeHTTP),
))
- metrics.InferenceRequestsActive.Add(pr.r.Context(), 1, metricAttrs)
- defer metrics.InferenceRequestsActive.Add(pr.r.Context(), -1, metricAttrs)
-
- modelExists, err := h.modelScaler.LookupModel(r.Context(), pr.model, pr.adapter, pr.selectors)
- if err != nil {
- pr.sendErrorResponse(w, http.StatusInternalServerError, "unable to resolve model: %v", err)
- return
- }
- if !modelExists {
- pr.sendErrorResponse(w, http.StatusNotFound, "model not found: %v", pr.requestedModel)
- return
- }
+ metrics.InferenceRequestsActive.Add(pr.http.Context(), 1, metricAttrs)
+ defer metrics.InferenceRequestsActive.Add(pr.http.Context(), -1, metricAttrs)
// Ensure the backend is scaled to at least one Pod.
- if err := h.modelScaler.ScaleAtLeastOneReplica(r.Context(), pr.model); err != nil {
+ if err := h.modelClient.ScaleAtLeastOneReplica(r.Context(), pr.Model); err != nil {
pr.sendErrorResponse(w, http.StatusInternalServerError, "unable to scale model: %v", err)
return
}
@@ -98,9 +95,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {}
func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) {
- log.Printf("Waiting for host: %v", pr.id)
+ log.Printf("Waiting for host: %v", pr.ID)
- addr, decrementInflight, err := h.resolver.AwaitBestAddress(pr.r.Context(), pr.model, pr.adapter)
+ addr, decrementInflight, err := h.loadBalancer.AwaitBestAddress(pr.http.Context(), pr.Request)
if err != nil {
switch {
case errors.Is(err, context.Canceled):
@@ -148,7 +145,7 @@ func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) {
if err != nil && r.Context().Err() == nil && pr.attempt < h.maxRetries {
pr.attempt++
- log.Printf("Retrying request (%v/%v): %v: %v", pr.attempt, h.maxRetries, pr.id, err)
+ log.Printf("Retrying request (%v/%v): %v: %v", pr.attempt, h.maxRetries, pr.ID, err)
h.proxyHTTP(w, pr)
return
}
@@ -158,7 +155,7 @@ func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) {
}
}
- log.Printf("Proxying request to ip %v: %v\n", addr, pr.id)
+ log.Printf("Proxying request to ip %v: %v\n", addr, pr.ID)
proxy.ServeHTTP(w, pr.httpRequest())
}
diff --git a/internal/modelproxy/handler_test.go b/internal/modelproxy/handler_test.go
index 5028ad92..f8579efc 100644
--- a/internal/modelproxy/handler_test.go
+++ b/internal/modelproxy/handler_test.go
@@ -13,8 +13,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ v1 "github.com/substratusai/kubeai/api/v1"
"github.com/substratusai/kubeai/internal/apiutils"
"github.com/substratusai/kubeai/internal/metrics/metricstest"
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
func TestHandler(t *testing.T) {
@@ -58,13 +60,13 @@ func TestHandler(t *testing.T) {
"no model": {
reqBody: "{}",
expCode: http.StatusBadRequest,
- expBody: `{"error":"unable to parse model: reading model from body: missing 'model' field"}` + "\n",
+ expBody: `{"error":"bad request: reading model from body: missing 'model' field"}` + "\n",
expBackendRequestCount: 0,
},
"model not found": {
reqBody: `{"model":"does-not-exist"}`,
expCode: http.StatusNotFound,
- expBody: `{"error":"model not found: does-not-exist"}` + "\n",
+ expBody: fmt.Sprintf(`{"error":%q}`, `model not found: "does-not-exist"`) + "\n",
expBackendRequestCount: 0,
},
"happy 200 model in body": {
@@ -93,7 +95,7 @@ func TestHandler(t *testing.T) {
"404 model+adapter in body but missing adapter": {
reqBody: fmt.Sprintf(`{"model":%q}`, apiutils.MergeModelAdapter(model1, "no-such-adapter")),
expCode: http.StatusNotFound,
- expBody: fmt.Sprintf(`{"error":"model not found: %s"}`, apiutils.MergeModelAdapter(model1, "no-such-adapter")) + "\n",
+ expBody: fmt.Sprintf(`{"error":%q}`, `model not found: "`+apiutils.MergeModelAdapter(model1, "no-such-adapter")+`"`) + "\n",
},
"happy 200 only model in form data": {
reqHeaders: map[string]string{"Content-Type": "multipart/form-data; boundary=12345"},
@@ -276,27 +278,29 @@ type testModelInterface struct {
models map[string]testMockModel
}
-func (t *testModelInterface) LookupModel(ctx context.Context, model, adapter string, selector []string) (bool, error) {
+func (t *testModelInterface) LookupModel(ctx context.Context, model, adapter string, selector []string) (*v1.Model, error) {
m, ok := t.models[model]
if ok {
if adapter == "" {
- return true, nil
+ return &v1.Model{ObjectMeta: metav1.ObjectMeta{Name: model}}, nil
}
if m.adapters == nil {
- return false, nil
+ return nil, nil
+ }
+ if m.adapters[adapter] {
+ return &v1.Model{ObjectMeta: metav1.ObjectMeta{Name: model}}, nil
}
- return m.adapters[adapter], nil
}
- return false, nil
+ return nil, nil
}
func (t *testModelInterface) ScaleAtLeastOneReplica(ctx context.Context, model string) error {
return nil
}
-func (t *testModelInterface) AwaitBestAddress(ctx context.Context, model, adapter string) (string, func(), error) {
+func (t *testModelInterface) AwaitBestAddress(ctx context.Context, req *apiutils.Request) (string, func(), error) {
t.hostRequestCount++
- t.requestedModel = model
- t.requestedAdapter = adapter
+ t.requestedModel = req.Model
+ t.requestedAdapter = req.Adapter
return t.address, func() {}, nil
}
diff --git a/internal/modelproxy/request.go b/internal/modelproxy/request.go
index f7ae912c..7d29402f 100644
--- a/internal/modelproxy/request.go
+++ b/internal/modelproxy/request.go
@@ -6,164 +6,37 @@ import (
"fmt"
"io"
"log"
- "mime"
- "mime/multipart"
"net/http"
- "github.com/google/uuid"
"github.com/substratusai/kubeai/internal/apiutils"
)
// proxyRequest keeps track of the state of a request that is to be proxied.
type proxyRequest struct {
+ *apiutils.Request
+
// r is the original request. It is stored here so that is can be cloned
// and sent to the backend while preserving the original request body.
- r *http.Request
- // body will be stored here if the request body needed to be read
- // in order to determine the model.
- body []byte
-
- selectors []string
-
- id string
- status int
- requestedModel string
- model string
- adapter string
- attempt int
+ http *http.Request
+ status int
+ attempt int
}
-func newProxyRequest(r *http.Request) *proxyRequest {
+func (h *Handler) parseProxyRequest(r *http.Request) (*proxyRequest, error) {
pr := &proxyRequest{
- r: r,
- id: uuid.New().String(),
+ http: r,
status: http.StatusOK,
}
- return pr
-}
-
-// parse attempts to determine the model from the request.
-// It first checks the "X-Model" header, and if that is not set, it
-// attempts to unmarshal the request body as JSON and extract the
-// .model field.
-func (pr *proxyRequest) parse() error {
- pr.selectors = pr.r.Header.Values("X-Label-Selector")
-
- // Parse media type (with params - which are used for multipart form data)
- var (
- contentType = pr.r.Header.Get("Content-Type")
- mediaType string
- mediaParams map[string]string
- )
- if contentType == "" {
- mediaType = "application/json"
- mediaParams = map[string]string{}
- } else {
- var err error
- mediaType, mediaParams, err = mime.ParseMediaType(contentType)
- if err != nil {
- return fmt.Errorf("parse media type: %w", err)
- }
- }
-
- switch mediaType {
- // Multipart form data is used for endpoints that accept file uploads:
- case "multipart/form-data":
- boundary := mediaParams["boundary"]
- if boundary == "" {
- return fmt.Errorf("no boundary specified in multipart form data")
- }
-
- var buf bytes.Buffer
- mw := multipart.NewWriter(&buf)
- // Keep the same boundary as the initial request (probably not necessary)
- mw.SetBoundary(boundary)
-
- // Iterate over the parts of the multipart form data:
- // - If the part is named "model", save the value to the proxy request.
- // - Otherwise, just copy the part to the new multipart writer.
- mr := multipart.NewReader(pr.r.Body, boundary)
- for {
- p, err := mr.NextPart()
- if err == io.EOF {
- break
- }
- if err != nil {
- return fmt.Errorf("interating over multipart form: %w", err)
- }
-
- if p.FormName() == "model" {
- value, err := io.ReadAll(p)
- if err != nil {
- return fmt.Errorf("reading multipart form value: %w", err)
- }
- pr.model, pr.adapter = apiutils.SplitModelAdapter(string(value))
- pr.requestedModel = string(value)
- // WORKAROUND ALERT:
- // Omit the "model" field from the proxy request to avoid FasterWhisper validation issues:
- // See https://github.com/fedirz/faster-whisper-server/issues/71
- continue
- }
-
- // Copy the part to the new multipart writer.
- pp, err := mw.CreatePart(p.Header)
- if err != nil {
- return fmt.Errorf("creating part: %w", err)
- }
- if _, err := io.Copy(pp, p); err != nil {
- return fmt.Errorf("copying part: %w", err)
- }
- }
-
- // Fully write to buffer.
- if err := mw.Close(); err != nil {
- return fmt.Errorf("closing multipart writer: %w", err)
- }
- pr.body = buf.Bytes()
- // Set a new content length based on the new body - which had the "model" field removed.
- pr.r.ContentLength = int64(len(pr.body))
-
- // Assume "application/json":
- default:
- if err := pr.readModelFromBody(pr.r.Body); err != nil {
- return fmt.Errorf("reading model from body: %w", err)
- }
- }
-
- return nil
-}
-
-func (pr *proxyRequest) readModelFromBody(r io.ReadCloser) error {
- var payload map[string]interface{}
- if err := json.NewDecoder(r).Decode(&payload); err != nil {
- return fmt.Errorf("decoding: %w", err)
- }
- modelInf, ok := payload["model"]
- if !ok {
- return fmt.Errorf("missing 'model' field")
- }
- modelStr, ok := modelInf.(string)
- if !ok {
- return fmt.Errorf("field 'model' should be a string")
- }
-
- pr.requestedModel = modelStr
- pr.model, pr.adapter = apiutils.SplitModelAdapter(modelStr)
-
- if pr.adapter != "" {
- // vLLM expects the adapter to be in the model field.
- payload["model"] = pr.adapter
- }
-
- body, err := json.Marshal(payload)
+ apiReq, err := apiutils.ParseRequest(r.Context(), h.modelClient, r.Body, r.URL.Path, r.Header)
if err != nil {
- return fmt.Errorf("remarshalling: %w", err)
+ return pr, err
}
- pr.body = body
- pr.r.ContentLength = int64(len(pr.body))
+ // The content length might have changed after the body was read and rewritten.
+ r.ContentLength = apiReq.ContentLength
+ pr.Request = apiReq
- return nil
+ return pr, nil
}
// sendErrorResponse sends an error response to the client and
@@ -198,9 +71,9 @@ func (pr *proxyRequest) setStatus(w http.ResponseWriter, code int) {
// request, preserving the original request body even if it was already
// read (i.e. if the body was inspected to determine the model).
func (pr *proxyRequest) httpRequest() *http.Request {
- clone := pr.r.Clone(pr.r.Context())
- if pr.body != nil {
- clone.Body = io.NopCloser(bytes.NewReader(pr.body))
+ clone := pr.http.Clone(pr.http.Context())
+ if pr.Body != nil {
+ clone.Body = io.NopCloser(bytes.NewReader(pr.Body))
}
return clone
}
diff --git a/internal/modelscaler/scaler.go b/internal/modelscaler/scaler.go
deleted file mode 100644
index fdb44294..00000000
--- a/internal/modelscaler/scaler.go
+++ /dev/null
@@ -1,163 +0,0 @@
-package modelscaler
-
-import (
- "context"
- "fmt"
- "log"
- "sync"
-
- kubeaiv1 "github.com/substratusai/kubeai/api/v1"
- autoscalingv1 "k8s.io/api/autoscaling/v1"
- apierrors "k8s.io/apimachinery/pkg/api/errors"
- "k8s.io/apimachinery/pkg/labels"
- "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/controller-runtime/pkg/client"
-)
-
-type ModelScaler struct {
- client client.Client
- namespace string
- consecutiveScaleDownsMtx sync.RWMutex
- consecutiveScaleDowns map[string]int
-}
-
-func NewModelScaler(client client.Client, namespace string) *ModelScaler {
- return &ModelScaler{client: client, namespace: namespace, consecutiveScaleDowns: map[string]int{}}
-}
-
-// LookupModel checks if a model exists and matches the given label selectors.
-func (s *ModelScaler) LookupModel(ctx context.Context, model, adapter string, labelSelectors []string) (bool, error) {
- m := &kubeaiv1.Model{}
- if err := s.client.Get(ctx, types.NamespacedName{Name: model, Namespace: s.namespace}, m); err != nil {
- if apierrors.IsNotFound(err) {
- return false, nil
- }
- return false, err
- }
-
- modelLabels := m.GetLabels()
- if modelLabels == nil {
- modelLabels = map[string]string{}
- }
- for _, sel := range labelSelectors {
- parsedSel, err := labels.Parse(sel)
- if err != nil {
- return false, fmt.Errorf("parse label selector: %w", err)
- }
- if !parsedSel.Matches(labels.Set(modelLabels)) {
- return false, nil
- }
- }
-
- if adapter != "" {
- adapterFound := false
- for _, a := range m.Spec.Adapters {
- if a.Name == adapter {
- adapterFound = true
- break
- }
- }
- if !adapterFound {
- return false, nil
- }
- }
-
- return true, nil
-}
-
-func (s *ModelScaler) ListAllModels(ctx context.Context) ([]kubeaiv1.Model, error) {
- models := &kubeaiv1.ModelList{}
- if err := s.client.List(ctx, models, client.InNamespace(s.namespace)); err != nil {
- return nil, fmt.Errorf("list models: %w", err)
- }
-
- return models.Items, nil
-}
-
-func (s *ModelScaler) ScaleAtLeastOneReplica(ctx context.Context, model string) error {
- obj := &kubeaiv1.Model{}
- if err := s.client.Get(ctx, types.NamespacedName{Namespace: s.namespace, Name: model}, obj); err != nil {
- return fmt.Errorf("get scale: %w", err)
- }
-
- if obj.Spec.AutoscalingDisabled {
- return nil
- }
-
- replicas := int32(0)
- if obj.Spec.Replicas != nil {
- replicas = *obj.Spec.Replicas
- }
-
- if replicas == 0 && !obj.Spec.AutoscalingDisabled {
- scale := &autoscalingv1.Scale{
- Spec: autoscalingv1.ScaleSpec{Replicas: 1},
- }
- if err := s.client.SubResource("scale").Update(ctx, obj, client.WithSubResourceBody(scale)); err != nil {
- return fmt.Errorf("update scale: %w", err)
- }
- }
-
- return nil
-}
-
-// Scale scales the model to the desired number of replicas, enforcing the min and max replica bounds.
-// Model should have .Spec defined before calling Scale().
-func (s *ModelScaler) Scale(ctx context.Context, model *kubeaiv1.Model, replicas int32, requiredConsecutiveScaleDowns int) error {
- //obj := &kubeaiv1.Model{}
- //if err := s.client.Get(ctx, types.NamespacedName{Namespace: s.namespace, Name: model}, obj); err != nil {
- // return fmt.Errorf("get scale: %w", err)
- //}
-
- replicas = enforceReplicaBounds(replicas, model)
-
- var existingReplicas int32 = 0
- if model.Spec.Replicas != nil {
- existingReplicas = *model.Spec.Replicas
- }
-
- if existingReplicas > replicas {
- // Scale down
- s.consecutiveScaleDownsMtx.RLock()
- consec := s.consecutiveScaleDowns[model.Name]
- s.consecutiveScaleDownsMtx.RUnlock()
- if consec < requiredConsecutiveScaleDowns {
- log.Printf("model %s has %d consecutive scale downs (< %d), not scaling down yet", model.Name, consec, requiredConsecutiveScaleDowns)
- s.consecutiveScaleDownsMtx.Lock()
- s.consecutiveScaleDowns[model.Name]++
- s.consecutiveScaleDownsMtx.Unlock()
- return nil
- }
- } else {
- // Scale up or constant scale.
- s.consecutiveScaleDownsMtx.Lock()
- s.consecutiveScaleDowns[model.Name] = 0
- s.consecutiveScaleDownsMtx.Unlock()
- }
-
- if existingReplicas != replicas {
- log.Printf("scaling model %s from %d to %d replicas", model.Name, existingReplicas, replicas)
- scale := &autoscalingv1.Scale{
- Spec: autoscalingv1.ScaleSpec{Replicas: replicas},
- }
- if err := s.client.SubResource("scale").Update(ctx, model, client.WithSubResourceBody(scale)); err != nil {
- return fmt.Errorf("update scale: %w", err)
- }
- }
-
- return nil
-}
-
-func enforceReplicaBounds(replicas int32, model *kubeaiv1.Model) int32 {
- max := model.Spec.MaxReplicas
- min := model.Spec.MinReplicas
- if max != nil {
- if replicas > *max {
- return *max
- }
- }
- if replicas < min {
- return min
- }
- return replicas
-}
diff --git a/test/integration/messenger_test.go b/test/integration/messenger_test.go
index a3e6485a..b26e6119 100644
--- a/test/integration/messenger_test.go
+++ b/test/integration/messenger_test.go
@@ -81,7 +81,7 @@ func TestMessenger(t *testing.T) {
shouldReceiveResponseMessage(t, m.Name, "a")
sendRequestMessage(t, "/v1/completions", "non-existant-model", "b")
- shouldReceiveResponseErrMessage(t, http.StatusNotFound, "model not found: non-existant-model", "b")
+ shouldReceiveResponseErrMessage(t, http.StatusNotFound, "model not found: \"non-existant-model\"", "b")
}
func shouldReceiveResponseErrMessage(t *testing.T, statusCode int, message string, id string) {