summaryrefslogtreecommitdiff
path: root/extra/stt.go
diff options
context:
space:
mode:
Diffstat (limited to 'extra/stt.go')
-rw-r--r--extra/stt.go233
1 files changed, 102 insertions, 131 deletions
diff --git a/extra/stt.go b/extra/stt.go
index 6456488..c1dcba6 100644
--- a/extra/stt.go
+++ b/extra/stt.go
@@ -2,18 +2,16 @@ package extra
import (
"bytes"
- "encoding/json"
+ "encoding/binary"
"errors"
"fmt"
"io"
"log/slog"
+ "mime/multipart"
"net/http"
- "os"
- "os/signal"
+ "time"
- "github.com/MarkKremer/microphone/v2"
- "github.com/gopxl/beep/v2"
- "github.com/gopxl/beep/v2/wav"
+ "github.com/gordonklaus/portaudio"
)
type STT interface {
@@ -22,167 +20,140 @@ type STT interface {
IsRecording() bool
}
+type StreamCloser interface {
+ Close() error
+}
+
type WhisperSTT struct {
logger *slog.Logger
ServerURL string
- SampleRate beep.SampleRate
- Buffer *bytes.Buffer
- streamer beep.StreamCloser
+ SampleRate int
+ RawBuffer *bytes.Buffer
+ WavBuffer *bytes.Buffer
+ streamer StreamCloser
recording bool
}
-type writeseeker struct {
- buf []byte
- pos int
-}
-
-func (m *writeseeker) Write(p []byte) (n int, err error) {
- minCap := m.pos + len(p)
- if minCap > cap(m.buf) { // Make sure buf has enough capacity:
- buf2 := make([]byte, len(m.buf), minCap+len(p)) // add some extra
- copy(buf2, m.buf)
- m.buf = buf2
- }
- if minCap > len(m.buf) {
- m.buf = m.buf[:minCap]
- }
- copy(m.buf[m.pos:], p)
- m.pos += len(p)
- return len(p), nil
-}
-
-func (m *writeseeker) Seek(offset int64, whence int) (int64, error) {
- newPos, offs := 0, int(offset)
- switch whence {
- case io.SeekStart:
- newPos = offs
- case io.SeekCurrent:
- newPos = m.pos + offs
- case io.SeekEnd:
- newPos = len(m.buf) + offs
- }
- if newPos < 0 {
- return 0, errors.New("negative result pos")
- }
- m.pos = newPos
- return int64(newPos), nil
-}
-
-// Reader returns an io.Reader. Use it, for example, with io.Copy, to copy the content of the WriterSeeker buffer to an io.Writer
-func (ws *writeseeker) Reader() io.Reader {
- return bytes.NewReader(ws.buf)
-}
-
-func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate beep.SampleRate) *WhisperSTT {
+func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate int) *WhisperSTT {
return &WhisperSTT{
logger: logger,
ServerURL: serverURL,
SampleRate: sampleRate,
- Buffer: new(bytes.Buffer),
+ RawBuffer: new(bytes.Buffer),
+ WavBuffer: new(bytes.Buffer),
}
}
func (stt *WhisperSTT) StartRecording() error {
- stream, err := microphoneStream(stt.SampleRate)
- if err != nil {
+ if err := stt.microphoneStream(stt.SampleRate); err != nil {
return fmt.Errorf("failed to init microphone: %w", err)
}
-
- stt.streamer = stream
stt.recording = true
-
- go stt.capture()
return nil
}
-func (stt *WhisperSTT) capture() {
- sink := beep.NewBuffer(beep.Format{
- SampleRate: stt.SampleRate,
- NumChannels: 1,
- Precision: 2,
- })
-
- // Append the streamer to the buffer and encode as WAV
- sink.Append(stt.streamer)
-
- // Encode the captured audio to WAV format using beep's WAV encoder
- // var wavBuf bytes.Buffer
- var wavBuf writeseeker
- if err := wav.Encode(&wavBuf, sink.Streamer(0, sink.Len()), beep.Format{
- SampleRate: stt.SampleRate,
- NumChannels: 1,
- Precision: 2,
- }); err != nil {
- stt.logger.Error("failed to encode WAV", "error", err)
- }
- r := wavBuf.Reader()
- // stt.Buffer = &wavBuf
- if _, err := io.Copy(stt.Buffer, r); err != nil {
- stt.logger.Error("failed to encode WAV", "error", err)
- }
-}
-
func (stt *WhisperSTT) StopRecording() (string, error) {
- if !stt.recording {
- return "", nil
- }
-
- stt.streamer.Close()
stt.recording = false
-
- // Send to Whisper.cpp server
- req, err := http.NewRequest("POST", stt.ServerURL, stt.Buffer)
+ time.Sleep(time.Millisecond * 200) // this is not the way
+ // wait loop to finish?
+ if stt.RawBuffer == nil {
+ err := errors.New("unexpected nil RawBuffer")
+ stt.logger.Error(err.Error())
+ return "", err
+ }
+ // Create WAV header first
+ stt.writeWavHeader(stt.WavBuffer, len(stt.RawBuffer.Bytes())) // Write initial header with 0 size
+ stt.WavBuffer.Write(stt.RawBuffer.Bytes())
+ body := &bytes.Buffer{} // third buffer?
+ writer := multipart.NewWriter(body)
+ // Add audio file part
+ part, err := writer.CreateFormFile("file", "recording.wav")
if err != nil {
- return "", fmt.Errorf("failed to create request: %w", err)
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
}
- req.Header.Set("Content-Type", "audio/wav")
-
- resp, err := http.DefaultClient.Do(req)
+ _, err = io.Copy(part, stt.WavBuffer)
if err != nil {
- return "", fmt.Errorf("transcription request failed: %w", err)
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
}
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ // Add response format field
+ err = writer.WriteField("response_format", "text")
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
}
-
- var result struct {
- Text string `json:"text"`
+ if writer.Close() != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
}
- if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
- return "", fmt.Errorf("failed to decode response: %w", err)
+ // Send request
+ resp, err := http.Post("http://localhost:8081/inference", writer.FormDataContentType(), body)
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
}
+ defer resp.Body.Close()
+ // Read and print response
+ responseText, err := io.ReadAll(resp.Body)
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ stt.logger.Info("got transcript", "text", string(responseText))
+ return string(responseText), nil
+}
- return result.Text, nil
+func (stt *WhisperSTT) writeWavHeader(w io.Writer, dataSize int) {
+ header := make([]byte, 44)
+ copy(header[0:4], "RIFF")
+ binary.LittleEndian.PutUint32(header[4:8], uint32(36+dataSize))
+ copy(header[8:12], "WAVE")
+ copy(header[12:16], "fmt ")
+ binary.LittleEndian.PutUint32(header[16:20], 16)
+ binary.LittleEndian.PutUint16(header[20:22], 1)
+ binary.LittleEndian.PutUint16(header[22:24], 1)
+ binary.LittleEndian.PutUint32(header[24:28], uint32(stt.SampleRate))
+ binary.LittleEndian.PutUint32(header[28:32], uint32(stt.SampleRate)*1*(16/8))
+ binary.LittleEndian.PutUint16(header[32:34], 1*(16/8))
+ binary.LittleEndian.PutUint16(header[34:36], 16)
+ copy(header[36:40], "data")
+ binary.LittleEndian.PutUint32(header[40:44], uint32(dataSize))
+ w.Write(header)
}
func (stt *WhisperSTT) IsRecording() bool {
return stt.recording
}
-func microphoneStream(sr beep.SampleRate) (beep.StreamCloser, error) {
- if err := microphone.Init(); err != nil {
- return nil, fmt.Errorf("microphone init failed: %w", err)
+func (stt *WhisperSTT) microphoneStream(sampleRate int) error {
+ if err := portaudio.Initialize(); err != nil {
+ return fmt.Errorf("portaudio init failed: %w", err)
}
-
- stream, _, err := microphone.OpenDefaultStream(sr, 1) // 1 channel mono
+ in := make([]int16, 64)
+ stream, err := portaudio.OpenDefaultStream(1, 0, float64(sampleRate), len(in), in)
if err != nil {
- microphone.Terminate()
- return nil, fmt.Errorf("failed to open microphone: %w", err)
- }
-
- // Handle OS signals to clean up
- sig := make(chan os.Signal, 1)
- signal.Notify(sig, os.Interrupt, os.Kill)
- go func() {
- <-sig
- stream.Stop()
- stream.Close()
- microphone.Terminate()
- os.Exit(1)
- }()
-
- stream.Start()
- return stream, nil
+ portaudio.Terminate()
+ return fmt.Errorf("failed to open microphone: %w", err)
+ }
+ go func(stream *portaudio.Stream) {
+ if err := stream.Start(); err != nil {
+ stt.logger.Error("microphoneStream", "error", err)
+ return
+ }
+ for {
+ if !stt.IsRecording() {
+ return
+ }
+ if err := stream.Read(); err != nil {
+ stt.logger.Error("reading stream", "error", err)
+ return
+ }
+ if err := binary.Write(stt.RawBuffer, binary.LittleEndian, in); err != nil {
+ stt.logger.Error("writing to buffer", "error", err)
+ return
+ }
+ }
+ }(stream)
+ return nil
}