summaryrefslogtreecommitdiff
path: root/extra/stt.go
diff options
context:
space:
mode:
Diffstat (limited to 'extra/stt.go')
-rw-r--r--extra/stt.go48
1 files changed, 23 insertions, 25 deletions
diff --git a/extra/stt.go b/extra/stt.go
index c1dcba6..3e6e032 100644
--- a/extra/stt.go
+++ b/extra/stt.go
@@ -9,7 +9,7 @@ import (
"log/slog"
"mime/multipart"
"net/http"
- "time"
+ "strings"
"github.com/gordonklaus/portaudio"
)
@@ -25,22 +25,20 @@ type StreamCloser interface {
}
type WhisperSTT struct {
- logger *slog.Logger
- ServerURL string
- SampleRate int
- RawBuffer *bytes.Buffer
- WavBuffer *bytes.Buffer
- streamer StreamCloser
- recording bool
+ logger *slog.Logger
+ ServerURL string
+ SampleRate int
+ AudioBuffer *bytes.Buffer
+ streamer StreamCloser
+ recording bool
}
func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate int) *WhisperSTT {
return &WhisperSTT{
- logger: logger,
- ServerURL: serverURL,
- SampleRate: sampleRate,
- RawBuffer: new(bytes.Buffer),
- WavBuffer: new(bytes.Buffer),
+ logger: logger,
+ ServerURL: serverURL,
+ SampleRate: sampleRate,
+ AudioBuffer: new(bytes.Buffer),
}
}
@@ -54,17 +52,14 @@ func (stt *WhisperSTT) StartRecording() error {
func (stt *WhisperSTT) StopRecording() (string, error) {
stt.recording = false
- time.Sleep(time.Millisecond * 200) // this is not the way
// wait loop to finish?
- if stt.RawBuffer == nil {
- err := errors.New("unexpected nil RawBuffer")
+ if stt.AudioBuffer == nil {
+ err := errors.New("unexpected nil AudioBuffer")
stt.logger.Error(err.Error())
return "", err
}
// Create WAV header first
- stt.writeWavHeader(stt.WavBuffer, len(stt.RawBuffer.Bytes())) // Write initial header with 0 size
- stt.WavBuffer.Write(stt.RawBuffer.Bytes())
- body := &bytes.Buffer{} // third buffer?
+ body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// Add audio file part
part, err := writer.CreateFormFile("file", "recording.wav")
@@ -72,11 +67,15 @@ func (stt *WhisperSTT) StopRecording() (string, error) {
stt.logger.Error("fn: StopRecording", "error", err)
return "", err
}
- _, err = io.Copy(part, stt.WavBuffer)
- if err != nil {
+ // Stream directly to multipart writer: header + raw data
+ dataSize := stt.AudioBuffer.Len()
+ stt.writeWavHeader(part, dataSize)
+ if _, err := io.Copy(part, stt.AudioBuffer); err != nil {
stt.logger.Error("fn: StopRecording", "error", err)
return "", err
}
+ // Reset buffer for next recording
+ stt.AudioBuffer.Reset()
// Add response format field
err = writer.WriteField("response_format", "text")
if err != nil {
@@ -95,13 +94,12 @@ func (stt *WhisperSTT) StopRecording() (string, error) {
}
defer resp.Body.Close()
// Read and print response
- responseText, err := io.ReadAll(resp.Body)
+ responseTextBytes, err := io.ReadAll(resp.Body)
if err != nil {
stt.logger.Error("fn: StopRecording", "error", err)
return "", err
}
- stt.logger.Info("got transcript", "text", string(responseText))
- return string(responseText), nil
+ return strings.TrimRight(string(responseTextBytes), "\n"), nil
}
func (stt *WhisperSTT) writeWavHeader(w io.Writer, dataSize int) {
@@ -149,7 +147,7 @@ func (stt *WhisperSTT) microphoneStream(sampleRate int) error {
stt.logger.Error("reading stream", "error", err)
return
}
- if err := binary.Write(stt.RawBuffer, binary.LittleEndian, in); err != nil {
+ if err := binary.Write(stt.AudioBuffer, binary.LittleEndian, in); err != nil {
stt.logger.Error("writing to buffer", "error", err)
return
}