From 441225ede8ef959058de4933e83870eb30a55ecf Mon Sep 17 00:00:00 2001 From: Grail Finder Date: Sun, 18 May 2025 14:32:54 +0300 Subject: Refactor: cleanup stt mess, config use --- bot.go | 16 +++++++++------- config.example.toml | 3 +++ config/config.go | 3 +++ extra/stt.go | 48 +++++++++++++++++++++++------------------------- tui.go | 7 +++++-- 5 files changed, 43 insertions(+), 34 deletions(-) diff --git a/bot.go b/bot.go index 979ed2b..716bbd4 100644 --- a/bot.go +++ b/bot.go @@ -149,7 +149,7 @@ func sendMsgToLLM(body io.Reader) { // resp, err := httpClient.Post(cfg.CurrentAPI, "application/json", body) resp, err := httpClient.Do(req) if err != nil { - logger.Error("llamacpp api", "error", err, "body", string(bodyBytes)) + logger.Error("llamacpp api", "error", err) if err := notifyUser("error", "apicall failed:"+err.Error()); err != nil { logger.Error("failed to notify", "error", err) } @@ -498,6 +498,7 @@ func init() { // logLevel.Set(slog.LevelInfo) logger = slog.New(slog.NewTextHandler(logfile, &slog.HandlerOptions{Level: logLevel})) + // TODO: rename and/or put in cfg store = storage.NewProviderSQL("test.db", logger) if store == nil { os.Exit(1) @@ -511,7 +512,7 @@ func init() { } lastChat := loadOldChatOrGetNew() chatBody = &models.ChatBody{ - Model: "modl_name", + Model: "modelname", Stream: true, Messages: lastChat, } @@ -522,9 +523,10 @@ func init() { } choseChunkParser() httpClient = createClient(time.Second * 15) - // TODO: check config for orator - orator = extra.InitOrator(logger, "http://localhost:8880/v1/audio/speech") - asr = extra.NewWhisperSTT(logger, "http://localhost:8081/inference", 44100) - // go runModelNameTicker(time.Second * 120) - // tempLoad() + if cfg.TTS_ENABLED { + orator = extra.InitOrator(logger, cfg.TTS_URL) + } + if cfg.STT_ENABLED { + asr = extra.NewWhisperSTT(logger, cfg.STT_URL, 16000) + } } diff --git a/config.example.toml b/config.example.toml index 846dbba..16409f7 100644 --- a/config.example.toml +++ b/config.example.toml @@ -15,3 +15,6 @@ RAGWorkers = 5 # extra tts TTS_ENABLED = false TTS_URL = "http://localhost:8880/v1/audio/speech" +# extra stt +STT_ENABLED = false +STT_URL = "http://localhost:8081/inference" diff --git a/config/config.go b/config/config.go index 5e00dba..ccae96d 100644 --- a/config/config.go +++ b/config/config.go @@ -42,6 +42,9 @@ type Config struct { // TTS TTS_URL string `toml:"TTS_URL"` TTS_ENABLED bool `toml:"TTS_ENABLED"` + // STT + STT_URL string `toml:"STT_URL"` + STT_ENABLED bool `toml:"STT_ENABLED"` } func LoadConfigOrDefault(fn string) *Config { diff --git a/extra/stt.go b/extra/stt.go index c1dcba6..3e6e032 100644 --- a/extra/stt.go +++ b/extra/stt.go @@ -9,7 +9,7 @@ import ( "log/slog" "mime/multipart" "net/http" - "time" + "strings" "github.com/gordonklaus/portaudio" ) @@ -25,22 +25,20 @@ type StreamCloser interface { } type WhisperSTT struct { - logger *slog.Logger - ServerURL string - SampleRate int - RawBuffer *bytes.Buffer - WavBuffer *bytes.Buffer - streamer StreamCloser - recording bool + logger *slog.Logger + ServerURL string + SampleRate int + AudioBuffer *bytes.Buffer + streamer StreamCloser + recording bool } func NewWhisperSTT(logger *slog.Logger, serverURL string, sampleRate int) *WhisperSTT { return &WhisperSTT{ - logger: logger, - ServerURL: serverURL, - SampleRate: sampleRate, - RawBuffer: new(bytes.Buffer), - WavBuffer: new(bytes.Buffer), + logger: logger, + ServerURL: serverURL, + SampleRate: sampleRate, + AudioBuffer: new(bytes.Buffer), } } @@ -54,17 +52,14 @@ func (stt *WhisperSTT) StartRecording() error { func (stt *WhisperSTT) StopRecording() (string, error) { stt.recording = false - time.Sleep(time.Millisecond * 200) // this is not the way // wait loop to finish? - if stt.RawBuffer == nil { - err := errors.New("unexpected nil RawBuffer") + if stt.AudioBuffer == nil { + err := errors.New("unexpected nil AudioBuffer") stt.logger.Error(err.Error()) return "", err } // Create WAV header first - stt.writeWavHeader(stt.WavBuffer, len(stt.RawBuffer.Bytes())) // Write initial header with 0 size - stt.WavBuffer.Write(stt.RawBuffer.Bytes()) - body := &bytes.Buffer{} // third buffer? + body := &bytes.Buffer{} writer := multipart.NewWriter(body) // Add audio file part part, err := writer.CreateFormFile("file", "recording.wav") @@ -72,11 +67,15 @@ func (stt *WhisperSTT) StopRecording() (string, error) { stt.logger.Error("fn: StopRecording", "error", err) return "", err } - _, err = io.Copy(part, stt.WavBuffer) - if err != nil { + // Stream directly to multipart writer: header + raw data + dataSize := stt.AudioBuffer.Len() + stt.writeWavHeader(part, dataSize) + if _, err := io.Copy(part, stt.AudioBuffer); err != nil { stt.logger.Error("fn: StopRecording", "error", err) return "", err } + // Reset buffer for next recording + stt.AudioBuffer.Reset() // Add response format field err = writer.WriteField("response_format", "text") if err != nil { @@ -95,13 +94,12 @@ func (stt *WhisperSTT) StopRecording() (string, error) { } defer resp.Body.Close() // Read and print response - responseText, err := io.ReadAll(resp.Body) + responseTextBytes, err := io.ReadAll(resp.Body) if err != nil { stt.logger.Error("fn: StopRecording", "error", err) return "", err } - stt.logger.Info("got transcript", "text", string(responseText)) - return string(responseText), nil + return strings.TrimRight(string(responseTextBytes), "\n"), nil } func (stt *WhisperSTT) writeWavHeader(w io.Writer, dataSize int) { @@ -149,7 +147,7 @@ func (stt *WhisperSTT) microphoneStream(sampleRate int) error { stt.logger.Error("reading stream", "error", err) return } - if err := binary.Write(stt.RawBuffer, binary.LittleEndian, in); err != nil { + if err := binary.Write(stt.AudioBuffer, binary.LittleEndian, in); err != nil { stt.logger.Error("writing to buffer", "error", err) return } diff --git a/tui.go b/tui.go index 07d988c..a2e3ded 100644 --- a/tui.go +++ b/tui.go @@ -666,6 +666,7 @@ func init() { pages.AddPage(imgPage, imgView, true, true) return nil } + // TODO: move to menu or table // if event.Key() == tcell.KeyCtrlR && cfg.HFToken != "" { // // rag load // // menu of the text files from defined rag directory @@ -685,7 +686,7 @@ func init() { // pages.AddPage(RAGPage, chatRAGTable, true, true) // return nil // } - if event.Key() == tcell.KeyCtrlR { + if event.Key() == tcell.KeyCtrlR && cfg.STT_ENABLED { defer updateStatusLine() if asr.IsRecording() { userSpeech, err := asr.StopRecording() @@ -694,7 +695,9 @@ func init() { return nil } if userSpeech != "" { - textArea.SetText(userSpeech, true) + // append indtead of replacing + prevText := textArea.GetText() + textArea.SetText(prevText+userSpeech, true) } else { logger.Warn("empty user speech") } -- cgit v1.2.3