diff --git a/src/wkok/openai_clojure/sse.clj b/src/wkok/openai_clojure/sse.clj index 8415e56..d40f9f7 100644 --- a/src/wkok/openai_clojure/sse.clj +++ b/src/wkok/openai_clojure/sse.clj @@ -3,8 +3,11 @@ [hato.client :as http] [clojure.core.async :as a] [clojure.string :as string] - [cheshire.core :as json]) - (:import (java.io InputStream))) + [cheshire.core :as json] + [clojure.core.async.impl.protocols :as impl]) + (:import (java.io InputStream) + (clojure.lang Counted) + (java.util LinkedList))) (def event-mask (re-pattern (str "(?s).+?\n\n"))) @@ -26,13 +29,39 @@ (-> (subs raw-event data-idx) (json/parse-string true))))) +(deftype InfiniteBuffer [^LinkedList buf] + impl/UnblockingBuffer + impl/Buffer + (full? [_this] + false) + (remove! [_this] + (.removeLast buf)) + (add!* [this itm] + (.addFirst buf itm) + this) + (close-buf! [_this]) + Counted + (count [_this] + (.size buf))) + +(defn infinite-buffer [] + (InfiniteBuffer. (LinkedList.))) + (defn calc-buffer-size - "Buffer size should be at least equal to max_tokens - or 16 (the default in openai as of 2023-02-19) - plus the [DONE] terminator" - [{:keys [max_tokens] - :or {max_tokens 16}}] - (inc max_tokens)) + "- Use stream_buffer_len if provided. + - Otherwise, buffer size should be at least equal to max_tokens + plus the [DONE] terminator if it is provided. + - Else fallbacks on ##Inf and use an infinite-buffer instead" + [{:keys [stream_buffer_len max_tokens]}] + (or stream_buffer_len + (when max_tokens (inc max_tokens)) + ##Inf)) + +(defn make-buffer [params] + (let [size (calc-buffer-size params)] + (if (= size ##Inf) + (infinite-buffer) + (a/sliding-buffer size)))) (defn sse-events "Returns a core.async channel with events as clojure data structures. @@ -41,8 +70,7 @@ (let [event-stream ^InputStream (:body (http/request (merge request params {:as :stream}))) - buffer-size (calc-buffer-size params) - events (a/chan (a/sliding-buffer buffer-size) (map parse-event))] + events (a/chan (make-buffer params) (map parse-event))] (a/thread (loop [byte-coll []] (let [byte-arr (byte-array (max 1 (.available event-stream)))