Skip to content

Commit

Permalink
go : add temperature options (#2417)
Browse files Browse the repository at this point in the history
* Fixed go cuda bindings building

* Added note to go bindings Readme to build using cuda support

* Added temperature bindings for Go

---------

Co-authored-by: Binozo <entwickler@binozoworks.de>
  • Loading branch information
Binozo and Binozo authored Sep 20, 2024
1 parent bea43e0 commit 34972db
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
12 changes: 12 additions & 0 deletions bindings/go/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ func (p *Params) SetEntropyThold(t float32) {
p.entropy_thold = C.float(t)
}

func (p *Params) SetTemperature(t float32) {
p.temperature = C.float(t)
}

// Sets the fallback temperature incrementation
// Pass -1.0 to disable this feature
func (p *Params) SetTemperatureFallback(t float32) {
p.temperature_inc = C.float(t)
}

// Set initial prompt
func (p *Params) SetInitialPrompt(prompt string) {
p.initial_prompt = C.CString(prompt)
Expand Down Expand Up @@ -162,6 +172,8 @@ func (p *Params) String() string {
str += fmt.Sprintf(" audio_ctx=%d", p.audio_ctx)
str += fmt.Sprintf(" initial_prompt=%s", C.GoString(p.initial_prompt))
str += fmt.Sprintf(" entropy_thold=%f", p.entropy_thold)
str += fmt.Sprintf(" temperature=%f", p.temperature)
str += fmt.Sprintf(" temperature_inc=%f", p.temperature_inc)
str += fmt.Sprintf(" beam_size=%d", p.beam_search.beam_size)
if p.translate {
str += " translate"
Expand Down
11 changes: 11 additions & 0 deletions bindings/go/pkg/whisper/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ func (context *context) SetEntropyThold(t float32) {
context.params.SetEntropyThold(t)
}

// Set Temperature
func (context *context) SetTemperature(t float32) {
context.params.SetTemperature(t)
}

// Set the fallback temperature incrementation
// Pass -1.0 to disable this feature
func (context *context) SetTemperatureFallback(t float32) {
context.params.SetTemperatureFallback(t)
}

// Set initial prompt
func (context *context) SetInitialPrompt(prompt string) {
context.params.SetInitialPrompt(prompt)
Expand Down
30 changes: 16 additions & 14 deletions bindings/go/pkg/whisper/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,22 @@ type Context interface {
IsMultilingual() bool // Return true if the model is multilingual.
Language() string // Get language

SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSplitOnWord(bool) // Set split on word flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetTokenTimestamps(bool) // Set token timestamps flag
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
SetAudioCtx(uint) // Set audio encoder context
SetMaxContext(n int) // Set maximum number of text context tokens to store
SetBeamSize(n int) // Set Beam Size
SetEntropyThold(t float32) // Set Entropy threshold
SetInitialPrompt(prompt string) // Set initial prompt
SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use
SetSplitOnWord(bool) // Set split on word flag
SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold
SetMaxSegmentLength(uint) // Set max segment length in characters
SetTokenTimestamps(bool) // Set token timestamps flag
SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit)
SetAudioCtx(uint) // Set audio encoder context
SetMaxContext(n int) // Set maximum number of text context tokens to store
SetBeamSize(n int) // Set Beam Size
SetEntropyThold(t float32) // Set Entropy threshold
SetInitialPrompt(prompt string) // Set initial prompt
SetTemperature(t float32) // Set temperature
SetTemperatureFallback(t float32) // Set temperature incrementation

// Process mono audio data and return any errors.
// If defined, newly generated segments are passed to the
Expand Down

0 comments on commit 34972db

Please sign in to comment.