diff options
Diffstat (limited to 'extra/stt.go')
-rw-r--r-- | extra/stt.go | 48 |
1 files changed, 23 insertions, 25 deletions
diff --git a/extra/stt.go b/extra/stt.go index c1dcba6..3e6e032 100644 --- a/extra/stt.go +++ b/extra/stt.go @@ -9,7 +9,7 @@ import ( "log/slog" "mime/multipart" "net/http" - "time" + "strings" "github.com/gordonklaus/portaudio" ) @@ -25,22 +25,20 @@ type StreamCloser interface { } type WhisperSTT struct { - logger *slog.Logger - ServerURL string - SampleRate int - RawBuffer *bytes.Buffer - WavBuffer *bytes.Buffer - streamer StreamCloser - recording bool + logger *slog.Logger + ServerURL string + SampleRate int + AudioBuffer *bytes.Buffer + streamer StreamCloser + recording bool } func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate int) *WhisperSTT { return &WhisperSTT{ - logger: logger, - ServerURL: serverURL, - SampleRate: sampleRate, - RawBuffer: new(bytes.Buffer), - WavBuffer: new(bytes.Buffer), + logger: logger, + ServerURL: serverURL, + SampleRate: sampleRate, + AudioBuffer: new(bytes.Buffer), } } @@ -54,17 +52,14 @@ func (stt *WhisperSTT) StartRecording() error { func (stt *WhisperSTT) StopRecording() (string, error) { stt.recording = false - time.Sleep(time.Millisecond * 200) // this is not the way // wait loop to finish? - if stt.RawBuffer == nil { - err := errors.New("unexpected nil RawBuffer") + if stt.AudioBuffer == nil { + err := errors.New("unexpected nil AudioBuffer") stt.logger.Error(err.Error()) return "", err } // Create WAV header first - stt.writeWavHeader(stt.WavBuffer, len(stt.RawBuffer.Bytes())) // Write initial header with 0 size - stt.WavBuffer.Write(stt.RawBuffer.Bytes()) - body := &bytes.Buffer{} // third buffer? + body := &bytes.Buffer{} writer := multipart.NewWriter(body) // Add audio file part part, err := writer.CreateFormFile("file", "recording.wav") @@ -72,11 +67,15 @@ func (stt *WhisperSTT) StopRecording() (string, error) { stt.logger.Error("fn: StopRecording", "error", err) return "", err } - _, err = io.Copy(part, stt.WavBuffer) - if err != nil { + // Stream directly to multipart writer: header + raw data + dataSize := stt.AudioBuffer.Len() + stt.writeWavHeader(part, dataSize) + if _, err := io.Copy(part, stt.AudioBuffer); err != nil { stt.logger.Error("fn: StopRecording", "error", err) return "", err } + // Reset buffer for next recording + stt.AudioBuffer.Reset() // Add response format field err = writer.WriteField("response_format", "text") if err != nil { @@ -95,13 +94,12 @@ func (stt *WhisperSTT) StopRecording() (string, error) { } defer resp.Body.Close() // Read and print response - responseText, err := io.ReadAll(resp.Body) + responseTextBytes, err := io.ReadAll(resp.Body) if err != nil { stt.logger.Error("fn: StopRecording", "error", err) return "", err } - stt.logger.Info("got transcript", "text", string(responseText)) - return string(responseText), nil + return strings.TrimRight(string(responseTextBytes), "\n"), nil } func (stt *WhisperSTT) writeWavHeader(w io.Writer, dataSize int) { @@ -149,7 +147,7 @@ func (stt *WhisperSTT) microphoneStream(sampleRate int) error { stt.logger.Error("reading stream", "error", err) return } - if err := binary.Write(stt.RawBuffer, binary.LittleEndian, in); err != nil { + if err := binary.Write(stt.AudioBuffer, binary.LittleEndian, in); err != nil { stt.logger.Error("writing to buffer", "error", err) return } |