diff options
-rw-r--r-- | bot.go | 50 | ||||
-rw-r--r-- | extra/stt.go | 233 | ||||
-rw-r--r-- | go.mod | 3 | ||||
-rw-r--r-- | go.sum | 4 | ||||
-rw-r--r-- | main.go | 2 | ||||
-rw-r--r-- | tui.go | 52 |
6 files changed, 165 insertions, 179 deletions
@@ -25,32 +25,8 @@ import ( "github.com/rivo/tview" ) -var httpClient = &http.Client{} - -func createClient(connectTimeout time.Duration) *http.Client { - // Custom transport with connection timeout - transport := &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - // Create a dialer with connection timeout - dialer := &net.Dialer{ - Timeout: connectTimeout, - KeepAlive: 30 * time.Second, // Optional - } - return dialer.DialContext(ctx, network, addr) - }, - // Other transport settings (optional) - TLSHandshakeTimeout: connectTimeout, - ResponseHeaderTimeout: connectTimeout, - } - - // Client with no overall timeout (or set to streaming-safe duration) - return &http.Client{ - Transport: transport, - Timeout: 0, // No overall timeout (for streaming) - } -} - var ( + httpClient = &http.Client{} cluedoState *extra.CluedoRoundInfo // Current game state playerOrder []string // Turn order tracking cfg *config.Config @@ -68,6 +44,7 @@ var ( ragger *rag.RAG chunkParser ChunkParser orator extra.Orator + asr extra.STT defaultLCPProps = map[string]float32{ "temperature": 0.8, "dry_multiplier": 0.0, @@ -76,6 +53,28 @@ var ( } ) +func createClient(connectTimeout time.Duration) *http.Client { + // Custom transport with connection timeout + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // Create a dialer with connection timeout + dialer := &net.Dialer{ + Timeout: connectTimeout, + KeepAlive: 30 * time.Second, // Optional + } + return dialer.DialContext(ctx, network, addr) + }, + // Other transport settings (optional) + TLSHandshakeTimeout: connectTimeout, + ResponseHeaderTimeout: connectTimeout, + } + // Client with no overall timeout (or set to streaming-safe duration) + return &http.Client{ + Transport: transport, + Timeout: 0, // No overall timeout (for streaming) + } +} + func fetchModelName() *models.LLMModels { // TODO: to config api := "http://localhost:8080/v1/models" @@ -525,6 +524,7 @@ func init() { httpClient = createClient(time.Second * 15) // TODO: check config for orator orator = extra.InitOrator(logger, "http://localhost:8880/v1/audio/speech") + asr = extra.NewWhisperSTT(logger, "http://localhost:8081/inference", 44100) // go runModelNameTicker(time.Second * 120) // tempLoad() } diff --git a/extra/stt.go b/extra/stt.go index 6456488..c1dcba6 100644 --- a/extra/stt.go +++ b/extra/stt.go @@ -2,18 +2,16 @@ package extra import ( "bytes" - "encoding/json" + "encoding/binary" "errors" "fmt" "io" "log/slog" + "mime/multipart" "net/http" - "os" - "os/signal" + "time" - "github.com/MarkKremer/microphone/v2" - "github.com/gopxl/beep/v2" - "github.com/gopxl/beep/v2/wav" + "github.com/gordonklaus/portaudio" ) type STT interface { @@ -22,167 +20,140 @@ type STT interface { IsRecording() bool } +type StreamCloser interface { + Close() error +} + type WhisperSTT struct { logger *slog.Logger ServerURL string - SampleRate beep.SampleRate - Buffer *bytes.Buffer - streamer beep.StreamCloser + SampleRate int + RawBuffer *bytes.Buffer + WavBuffer *bytes.Buffer + streamer StreamCloser recording bool } -type writeseeker struct { - buf []byte - pos int -} - -func (m *writeseeker) Write(p []byte) (n int, err error) { - minCap := m.pos + len(p) - if minCap > cap(m.buf) { // Make sure buf has enough capacity: - buf2 := make([]byte, len(m.buf), minCap+len(p)) // add some extra - copy(buf2, m.buf) - m.buf = buf2 - } - if minCap > len(m.buf) { - m.buf = m.buf[:minCap] - } - copy(m.buf[m.pos:], p) - m.pos += len(p) - return len(p), nil -} - -func (m *writeseeker) Seek(offset int64, whence int) (int64, error) { - newPos, offs := 0, int(offset) - switch whence { - case io.SeekStart: - newPos = offs - case io.SeekCurrent: - newPos = m.pos + offs - case io.SeekEnd: - newPos = len(m.buf) + offs - } - if newPos < 0 { - return 0, errors.New("negative result pos") - } - m.pos = newPos - return int64(newPos), nil -} - -// Reader returns an io.Reader. Use it, for example, with io.Copy, to copy the content of the WriterSeeker buffer to an io.Writer -func (ws *writeseeker) Reader() io.Reader { - return bytes.NewReader(ws.buf) -} - -func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate beep.SampleRate) *WhisperSTT { +func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate int) *WhisperSTT { return &WhisperSTT{ logger: logger, ServerURL: serverURL, SampleRate: sampleRate, - Buffer: new(bytes.Buffer), + RawBuffer: new(bytes.Buffer), + WavBuffer: new(bytes.Buffer), } } func (stt *WhisperSTT) StartRecording() error { - stream, err := microphoneStream(stt.SampleRate) - if err != nil { + if err := stt.microphoneStream(stt.SampleRate); err != nil { return fmt.Errorf("failed to init microphone: %w", err) } - - stt.streamer = stream stt.recording = true - - go stt.capture() return nil } -func (stt *WhisperSTT) capture() { - sink := beep.NewBuffer(beep.Format{ - SampleRate: stt.SampleRate, - NumChannels: 1, - Precision: 2, - }) - - // Append the streamer to the buffer and encode as WAV - sink.Append(stt.streamer) - - // Encode the captured audio to WAV format using beep's WAV encoder - // var wavBuf bytes.Buffer - var wavBuf writeseeker - if err := wav.Encode(&wavBuf, sink.Streamer(0, sink.Len()), beep.Format{ - SampleRate: stt.SampleRate, - NumChannels: 1, - Precision: 2, - }); err != nil { - stt.logger.Error("failed to encode WAV", "error", err) - } - r := wavBuf.Reader() - // stt.Buffer = &wavBuf - if _, err := io.Copy(stt.Buffer, r); err != nil { - stt.logger.Error("failed to encode WAV", "error", err) - } -} - func (stt *WhisperSTT) StopRecording() (string, error) { - if !stt.recording { - return "", nil - } - - stt.streamer.Close() stt.recording = false - - // Send to Whisper.cpp server - req, err := http.NewRequest("POST", stt.ServerURL, stt.Buffer) + time.Sleep(time.Millisecond * 200) // this is not the way + // wait loop to finish? + if stt.RawBuffer == nil { + err := errors.New("unexpected nil RawBuffer") + stt.logger.Error(err.Error()) + return "", err + } + // Create WAV header first + stt.writeWavHeader(stt.WavBuffer, len(stt.RawBuffer.Bytes())) // Write initial header with 0 size + stt.WavBuffer.Write(stt.RawBuffer.Bytes()) + body := &bytes.Buffer{} // third buffer? + writer := multipart.NewWriter(body) + // Add audio file part + part, err := writer.CreateFormFile("file", "recording.wav") if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + stt.logger.Error("fn: StopRecording", "error", err) + return "", err } - req.Header.Set("Content-Type", "audio/wav") - - resp, err := http.DefaultClient.Do(req) + _, err = io.Copy(part, stt.WavBuffer) if err != nil { - return "", fmt.Errorf("transcription request failed: %w", err) + stt.logger.Error("fn: StopRecording", "error", err) + return "", err } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + // Add response format field + err = writer.WriteField("response_format", "text") + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err } - - var result struct { - Text string `json:"text"` + if writer.Close() != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return "", fmt.Errorf("failed to decode response: %w", err) + // Send request + resp, err := http.Post("http://localhost:8081/inference", writer.FormDataContentType(), body) + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err } + defer resp.Body.Close() + // Read and print response + responseText, err := io.ReadAll(resp.Body) + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + stt.logger.Info("got transcript", "text", string(responseText)) + return string(responseText), nil +} - return result.Text, nil +func (stt *WhisperSTT) writeWavHeader(w io.Writer, dataSize int) { + header := make([]byte, 44) + copy(header[0:4], "RIFF") + binary.LittleEndian.PutUint32(header[4:8], uint32(36+dataSize)) + copy(header[8:12], "WAVE") + copy(header[12:16], "fmt ") + binary.LittleEndian.PutUint32(header[16:20], 16) + binary.LittleEndian.PutUint16(header[20:22], 1) + binary.LittleEndian.PutUint16(header[22:24], 1) + binary.LittleEndian.PutUint32(header[24:28], uint32(stt.SampleRate)) + binary.LittleEndian.PutUint32(header[28:32], uint32(stt.SampleRate)*1*(16/8)) + binary.LittleEndian.PutUint16(header[32:34], 1*(16/8)) + binary.LittleEndian.PutUint16(header[34:36], 16) + copy(header[36:40], "data") + binary.LittleEndian.PutUint32(header[40:44], uint32(dataSize)) + w.Write(header) } func (stt *WhisperSTT) IsRecording() bool { return stt.recording } -func microphoneStream(sr beep.SampleRate) (beep.StreamCloser, error) { - if err := microphone.Init(); err != nil { - return nil, fmt.Errorf("microphone init failed: %w", err) +func (stt *WhisperSTT) microphoneStream(sampleRate int) error { + if err := portaudio.Initialize(); err != nil { + return fmt.Errorf("portaudio init failed: %w", err) } - - stream, _, err := microphone.OpenDefaultStream(sr, 1) // 1 channel mono + in := make([]int16, 64) + stream, err := portaudio.OpenDefaultStream(1, 0, float64(sampleRate), len(in), in) if err != nil { - microphone.Terminate() - return nil, fmt.Errorf("failed to open microphone: %w", err) - } - - // Handle OS signals to clean up - sig := make(chan os.Signal, 1) - signal.Notify(sig, os.Interrupt, os.Kill) - go func() { - <-sig - stream.Stop() - stream.Close() - microphone.Terminate() - os.Exit(1) - }() - - stream.Start() - return stream, nil + portaudio.Terminate() + return fmt.Errorf("failed to open microphone: %w", err) + } + go func(stream *portaudio.Stream) { + if err := stream.Start(); err != nil { + stt.logger.Error("microphoneStream", "error", err) + return + } + for { + if !stt.IsRecording() { + return + } + if err := stream.Read(); err != nil { + stt.logger.Error("reading stream", "error", err) + return + } + if err := binary.Write(stt.RawBuffer, binary.LittleEndian, in); err != nil { + stt.logger.Error("writing to buffer", "error", err) + return + } + } + }(stream) + return nil } @@ -4,11 +4,11 @@ go 1.23.2 require ( github.com/BurntSushi/toml v1.4.0 - github.com/MarkKremer/microphone/v2 v2.0.1 github.com/asg017/sqlite-vec-go-bindings v0.1.6 github.com/gdamore/tcell/v2 v2.7.4 github.com/glebarez/go-sqlite v1.22.0 github.com/gopxl/beep/v2 v2.1.0 + github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 github.com/jmoiron/sqlx v1.4.0 github.com/ncruces/go-sqlite3 v0.21.3 github.com/neurosnap/sentences v1.1.2 @@ -21,7 +21,6 @@ require ( github.com/ebitengine/purego v0.7.1 // indirect github.com/gdamore/encoding v1.0.0 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 // indirect github.com/hajimehoshi/go-mp3 v0.3.4 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -2,8 +2,6 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= -github.com/MarkKremer/microphone/v2 v2.0.1 h1:PWI0MgBu3Nd9CSxdnIjwol8qshstNfywERIMOLD03Zk= -github.com/MarkKremer/microphone/v2 v2.0.1/go.mod h1:IdM74GKdsZAWVbkgX8xLGAdd4ytzBt7uk5F0brfTZRM= github.com/asg017/sqlite-vec-go-bindings v0.1.6 h1:Nx0jAzyS38XpkKznJ9xQjFXz2X9tI7KqjwVxV8RNoww= github.com/asg017/sqlite-vec-go-bindings v0.1.6/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -51,8 +49,6 @@ github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= github.com/neurosnap/sentences v1.1.2 h1:iphYOzx/XckXeBiLIUBkPu2EKMJ+6jDbz/sLJZ7ZoUw= github.com/neurosnap/sentences v1.1.2/go.mod h1:/pwU4E9XNL21ygMIkOIllv/SMy2ujHwpf8GQPu1YPbQ= -github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e h1:s2RNOM/IGdY0Y6qfTeUKhDawdHDpK9RGBdx80qN4Ttw= -github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBdnFKj15wFbf94Rwfq4m30eAcyY9V/IyKAGQFtqkW0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -12,7 +12,7 @@ var ( botRespMode = false editMode = false selectedIndex = int(-1) - indexLine = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | RAGEnabled: [orange:-:b]%v[-:-:-] (F11) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | ThinkUse: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p)" + indexLine = "F12 to show keys help | bot resp mode: [orange:-:b]%v[-:-:-] (F6) | char: [orange:-:b]%s[-:-:-] (ctrl+s) | chat: [orange:-:b]%s[-:-:-] (F1) | RAGEnabled: [orange:-:b]%v[-:-:-] (F11) | toolUseAdviced: [orange:-:b]%v[-:-:-] (ctrl+k) | model: [orange:-:b]%s[-:-:-] (ctrl+l)\nAPI_URL: [orange:-:b]%s[-:-:-] (ctrl+v) | ThinkUse: [orange:-:b]%v[-:-:-] (ctrl+p) | Log Level: [orange:-:b]%v[-:-:-] (ctrl+p) | Recording: [orange:-:b]%v[-:-:-] (ctrl+r)" focusSwitcher = map[tview.Primitive]tview.Primitive{} ) @@ -139,7 +139,7 @@ func colorText() { } func updateStatusLine() { - position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, chatBody.Model, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level())) + position.SetText(fmt.Sprintf(indexLine, botRespMode, cfg.AssistantRole, activeChatName, cfg.RAGEnabled, cfg.ToolUse, chatBody.Model, cfg.CurrentAPI, cfg.ThinkUse, logLevel.Level(), asr.IsRecording())) } func initSysCards() ([]string, error) { @@ -666,24 +666,44 @@ func init() { pages.AddPage(imgPage, imgView, true, true) return nil } - if event.Key() == tcell.KeyCtrlR && cfg.HFToken != "" { - // rag load - // menu of the text files from defined rag directory - files, err := os.ReadDir(cfg.RAGDir) - if err != nil { - logger.Error("failed to read dir", "dir", cfg.RAGDir, "error", err) + // if event.Key() == tcell.KeyCtrlR && cfg.HFToken != "" { + // // rag load + // // menu of the text files from defined rag directory + // files, err := os.ReadDir(cfg.RAGDir) + // if err != nil { + // logger.Error("failed to read dir", "dir", cfg.RAGDir, "error", err) + // return nil + // } + // fileList := []string{} + // for _, f := range files { + // if f.IsDir() { + // continue + // } + // fileList = append(fileList, f.Name()) + // } + // chatRAGTable := makeRAGTable(fileList) + // pages.AddPage(RAGPage, chatRAGTable, true, true) + // return nil + // } + if event.Key() == tcell.KeyCtrlR { + defer updateStatusLine() + if asr.IsRecording() { + userSpeech, err := asr.StopRecording() + if err != nil { + logger.Error("failed to inference user speech", "error", err) + return nil + } + if userSpeech != "" { + textArea.SetText(userSpeech, true) + } else { + logger.Warn("empty user speech") + } return nil } - fileList := []string{} - for _, f := range files { - if f.IsDir() { - continue - } - fileList = append(fileList, f.Name()) + if err := asr.StartRecording(); err != nil { + logger.Error("failed to start recording user speech", "error", err) + return nil } - chatRAGTable := makeRAGTable(fileList) - pages.AddPage(RAGPage, chatRAGTable, true, true) - return nil } if event.Key() == tcell.KeyCtrlW { // INFO: continue bot/text message |