Skip to content

Commit

Permalink
Fix unmanaged self retain missing corrisponding release (caused memor…
Browse files Browse the repository at this point in the history
…y leak)
  • Loading branch information
exPHAT committed May 15, 2023
1 parent d2231f4 commit 92af8ea
Showing 1 changed file with 31 additions and 18 deletions.
49 changes: 31 additions & 18 deletions Sources/SwiftWhisper/Whisper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import whisper_cpp

public class Whisper {
private let whisperContext: OpaquePointer
private var unmanagedSelf: Unmanaged<Whisper>?

public var delegate: WhisperDelegate?
public var params: WhisperParams
Expand All @@ -14,17 +15,13 @@ public class Whisper {
public init(fromFileURL fileURL: URL, withParams params: WhisperParams = .default) {
self.whisperContext = fileURL.relativePath.withCString { whisper_init_from_file($0) }
self.params = params

prepareCallbacks()
}

public init(fromData data: Data, withParams params: WhisperParams = .default) {
var copy = data // Need to copy memory so we can gaurentee exclusive ownership over pointer

self.whisperContext = copy.withUnsafeMutableBytes { whisper_init_from_buffer($0.baseAddress!, data.count) }
self.params = params

prepareCallbacks()
}

deinit {
Expand All @@ -38,8 +35,11 @@ public class Whisper {

We can unwrap that and obtain a copy of self inside the callback.
*/
params.new_segment_callback_user_data = Unmanaged.passRetained(self).toOpaque()
params.encoder_begin_callback_user_data = Unmanaged.passRetained(self).toOpaque()
cleanupCallbacks()
let unmanagedSelf = Unmanaged.passRetained(self)
self.unmanagedSelf = unmanagedSelf
params.new_segment_callback_user_data = unmanagedSelf.toOpaque()
params.encoder_begin_callback_user_data = unmanagedSelf.toOpaque()

// swiftlint:disable line_length
params.new_segment_callback = { (ctx: OpaquePointer?, _: OpaquePointer?, newSegmentCount: Int32, userData: UnsafeMutableRawPointer?) in
Expand Down Expand Up @@ -94,32 +94,45 @@ public class Whisper {
}
}

private func cleanupCallbacks() {
guard let unmanagedSelf else { return }

unmanagedSelf.release()
self.unmanagedSelf = nil
}

public func transcribe(audioFrames: [Float], completionHandler: @escaping (Result<[Segment], Error>) -> Void) {
prepareCallbacks()

let wrappedCompletionHandler: (Result<[Segment], Error>) -> Void = { result in
self.cleanupCallbacks()
completionHandler(result)
}

guard !inProgress else {
completionHandler(.failure(WhisperError.instanceBusy))
wrappedCompletionHandler(.failure(WhisperError.instanceBusy))
return
}
guard audioFrames.count > 0 else {
completionHandler(.failure(WhisperError.invalidFrames))
wrappedCompletionHandler(.failure(WhisperError.invalidFrames))
return
}

inProgress = true
frameCount = audioFrames.count

DispatchQueue.global(qos: .userInitiated).async { [unowned self] in

whisper_full(whisperContext, params.whisperParams, audioFrames, Int32(audioFrames.count))
DispatchQueue.global(qos: .userInitiated).async {
whisper_full(self.whisperContext, self.params.whisperParams, audioFrames, Int32(audioFrames.count))

let segmentCount = whisper_full_n_segments(whisperContext)
let segmentCount = whisper_full_n_segments(self.whisperContext)

var segments: [Segment] = []
segments.reserveCapacity(Int(segmentCount))

for index in 0..<segmentCount {
guard let text = whisper_full_get_segment_text(whisperContext, index) else { continue }
let startTime = whisper_full_get_segment_t0(whisperContext, index)
let endTime = whisper_full_get_segment_t1(whisperContext, index)
guard let text = whisper_full_get_segment_text(self.whisperContext, index) else { continue }
let startTime = whisper_full_get_segment_t0(self.whisperContext, index)
let endTime = whisper_full_get_segment_t1(self.whisperContext, index)

segments.append(
.init(
Expand All @@ -130,20 +143,20 @@ public class Whisper {
)
}

if let cancelCallback {
if let cancelCallback = self.cancelCallback {
DispatchQueue.main.async {
// Should cancel callback be called after delegate and completionHandler?
cancelCallback()

let error = WhisperError.cancelled

self.delegate?.whisper(self, didErrorWith: error)
completionHandler(.failure(error))
wrappedCompletionHandler(.failure(error))
}
} else {
DispatchQueue.main.async {
self.delegate?.whisper(self, didCompleteWithSegments: segments)
completionHandler(.success(segments))
wrappedCompletionHandler(.success(segments))
}
}

Expand Down

0 comments on commit 92af8ea

Please sign in to comment.