diff options
Diffstat (limited to 'extra')
| -rw-r--r-- | extra/google_tts.go | 218 | ||||
| -rw-r--r-- | extra/kokoro.go | 259 | ||||
| -rw-r--r-- | extra/stt.go | 70 | ||||
| -rw-r--r-- | extra/tts.go | 69 | ||||
| -rw-r--r-- | extra/tts_test.go | 40 | ||||
| -rw-r--r-- | extra/whisper_binary.go | 176 | ||||
| -rw-r--r-- | extra/whisper_server.go | 156 |
7 files changed, 988 insertions, 0 deletions
diff --git a/extra/google_tts.go b/extra/google_tts.go new file mode 100644 index 0000000..782075d --- /dev/null +++ b/extra/google_tts.go @@ -0,0 +1,218 @@ +//go:build extra +// +build extra + +package extra + +import ( + "fmt" + "gf-lt/models" + "io" + "log/slog" + "os/exec" + "strings" + "sync" + + google_translate_tts "github.com/GrailFinder/google-translate-tts" + "github.com/neurosnap/sentences/english" +) + +type GoogleTranslateOrator struct { + logger *slog.Logger + mu sync.Mutex + speech *google_translate_tts.Speech + // fields for playback control + cmd *exec.Cmd + cmdMu sync.Mutex + stopCh chan struct{} + // text buffer and interrupt flag + textBuffer strings.Builder + interrupt bool + Speed float32 +} + +func (o *GoogleTranslateOrator) stoproutine() { + for { + <-TTSDoneChan + o.logger.Debug("orator got done signal") + o.Stop() + for len(TTSTextChan) > 0 { + <-TTSTextChan + } + o.mu.Lock() + o.textBuffer.Reset() + o.interrupt = true + o.mu.Unlock() + } +} + +func (o *GoogleTranslateOrator) readroutine() { + tokenizer, _ := english.NewSentenceTokenizer(nil) + for { + select { + case chunk := <-TTSTextChan: + o.mu.Lock() + o.interrupt = false + _, err := o.textBuffer.WriteString(chunk) + if err != nil { + o.logger.Warn("failed to write to stringbuilder", "error", err) + o.mu.Unlock() + continue + } + text := o.textBuffer.String() + sentences := tokenizer.Tokenize(text) + o.logger.Debug("adding chunk", "chunk", chunk, "text", text, "sen-len", len(sentences)) + if len(sentences) <= 1 { + o.mu.Unlock() + continue + } + completeSentences := sentences[:len(sentences)-1] + remaining := sentences[len(sentences)-1].Text + o.textBuffer.Reset() + o.textBuffer.WriteString(remaining) + o.mu.Unlock() + for _, sentence := range completeSentences { + o.mu.Lock() + interrupted := o.interrupt + o.mu.Unlock() + if interrupted { + return + } + cleanedText := models.CleanText(sentence.Text) + if cleanedText == "" { + continue + } + o.logger.Debug("calling Speak with sentence", "sent", cleanedText) + if err := o.Speak(cleanedText); err != nil { + o.logger.Error("tts failed", "sentence", cleanedText, "error", err) + } + } + case <-TTSFlushChan: + o.logger.Debug("got flushchan signal start") + // lln is done get the whole message out + if len(TTSTextChan) > 0 { // otherwise might get stuck + for chunk := range TTSTextChan { + o.mu.Lock() + _, err := o.textBuffer.WriteString(chunk) + o.mu.Unlock() + if err != nil { + o.logger.Warn("failed to write to stringbuilder", "error", err) + continue + } + if len(TTSTextChan) == 0 { + break + } + } + } + o.mu.Lock() + remaining := o.textBuffer.String() + remaining = models.CleanText(remaining) + o.textBuffer.Reset() + o.mu.Unlock() + if remaining == "" { + continue + } + o.logger.Debug("calling Speak with remainder", "rem", remaining) + sentencesRem := tokenizer.Tokenize(remaining) + for _, rs := range sentencesRem { // to avoid dumping large volume of text + o.mu.Lock() + interrupt := o.interrupt + o.mu.Unlock() + if interrupt { + break + } + if err := o.Speak(rs.Text); err != nil { + o.logger.Error("tts failed", "sentence", rs.Text, "error", err) + } + } + } + } +} + +func (o *GoogleTranslateOrator) GetLogger() *slog.Logger { + return o.logger +} + +func (o *GoogleTranslateOrator) Speak(text string) error { + o.logger.Debug("fn: Speak is called", "text-len", len(text)) + // Generate MP3 data directly as an io.Reader + reader, err := o.speech.GenerateSpeech(text) + if err != nil { + return fmt.Errorf("generate speech failed: %w", err) + } + // Wrap in io.NopCloser since GenerateSpeech returns io.Reader (no close needed) + body := io.NopCloser(reader) + defer body.Close() + // Build ffplay command with optional speed filter + args := []string{"-nodisp", "-autoexit"} + if o.Speed > 0.1 && o.Speed != 1.0 { + // atempo range is 0.5 to 2.0; you might clamp it here + args = append(args, "-af", fmt.Sprintf("atempo=%.2f", o.Speed)) + } + args = append(args, "-i", "pipe:0") + cmd := exec.Command("ffplay", args...) + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to get stdin pipe: %w", err) + } + o.cmdMu.Lock() + o.cmd = cmd + o.stopCh = make(chan struct{}) + o.cmdMu.Unlock() + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start ffplay: %w", err) + } + copyErr := make(chan error, 1) + go func() { + _, err := io.Copy(stdin, body) + stdin.Close() + copyErr <- err + }() + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + select { + case <-o.stopCh: + if o.cmd != nil && o.cmd.Process != nil { + o.cmd.Process.Kill() + } + <-done + return nil + case copyErrVal := <-copyErr: + if copyErrVal != nil { + if o.cmd != nil && o.cmd.Process != nil { + o.cmd.Process.Kill() + } + <-done + return copyErrVal + } + return <-done + case err := <-done: + return err + } +} + +func (o *GoogleTranslateOrator) Stop() { + o.cmdMu.Lock() + defer o.cmdMu.Unlock() + // Signal any running Speak to stop + if o.stopCh != nil { + select { + case <-o.stopCh: // already closed + default: + close(o.stopCh) + } + o.stopCh = nil + } + // Kill the external player process if it's still running + if o.cmd != nil && o.cmd.Process != nil { + o.cmd.Process.Kill() + o.cmd.Wait() // clean up zombie process + o.cmd = nil + } + // Also reset text buffer and interrupt flag (with o.mu) + o.mu.Lock() + o.textBuffer.Reset() + o.interrupt = true + o.mu.Unlock() +} diff --git a/extra/kokoro.go b/extra/kokoro.go new file mode 100644 index 0000000..e3ca047 --- /dev/null +++ b/extra/kokoro.go @@ -0,0 +1,259 @@ +//go:build extra +// +build extra + +package extra + +import ( + "bytes" + "encoding/json" + "fmt" + "gf-lt/models" + "io" + "log/slog" + "net/http" + "os/exec" + "strings" + "sync" + + "github.com/neurosnap/sentences/english" +) + +type KokoroOrator struct { + logger *slog.Logger + mu sync.Mutex + URL string + Format models.AudioFormat + Stream bool + Speed float32 + Language string + Voice string + // fields for playback control + cmd *exec.Cmd + cmdMu sync.Mutex + stopCh chan struct{} + // textBuffer, interrupt etc. remain the same + textBuffer strings.Builder + interrupt bool +} + +func (o *KokoroOrator) GetLogger() *slog.Logger { + return o.logger +} + +func (o *KokoroOrator) Speak(text string) error { + o.logger.Debug("fn: Speak is called", "text-len", len(text)) + body, err := o.requestSound(text) + if err != nil { + return fmt.Errorf("request failed: %w", err) + } + defer body.Close() + cmd := exec.Command("ffplay", "-nodisp", "-autoexit", "-i", "pipe:0") + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to get stdin pipe: %w", err) + } + o.cmdMu.Lock() + o.cmd = cmd + o.stopCh = make(chan struct{}) + o.cmdMu.Unlock() + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start ffplay: %w", err) + } + // Copy audio in background + copyErr := make(chan error, 1) + go func() { + _, err := io.Copy(stdin, body) + stdin.Close() + copyErr <- err + }() + // Wait for player in background + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + // Wait for BOTH copy and player, but ensure we block until done + select { + case <-o.stopCh: + // Stop requested: kill player and wait for it to exit + if o.cmd != nil && o.cmd.Process != nil { + o.cmd.Process.Kill() + } + <-done // Wait for process to actually exit + return nil + case copyErrVal := <-copyErr: + if copyErrVal != nil { + // Copy failed: kill player and wait + if o.cmd != nil && o.cmd.Process != nil { + o.cmd.Process.Kill() + } + <-done + return copyErrVal + } + // Copy succeeded, now wait for playback to complete + return <-done + case err := <-done: + // Playback finished normally (copy must have succeeded or player would have exited early) + return err + } +} +func (o *KokoroOrator) requestSound(text string) (io.ReadCloser, error) { + if o.URL == "" { + return nil, fmt.Errorf("TTS URL is empty") + } + payload := map[string]interface{}{ + "input": text, + "voice": o.Voice, + "response_format": o.Format, + "download_format": o.Format, + "stream": o.Stream, + "speed": o.Speed, + // "return_download_link": true, + "lang_code": o.Language, + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + req, err := http.NewRequest("POST", o.URL, bytes.NewBuffer(payloadBytes)) //nolint:noctx + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("accept", "application/json") + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + return resp.Body, nil +} + +func (o *KokoroOrator) stoproutine() { + for { + <-TTSDoneChan + o.logger.Debug("orator got done signal") + // 1. Stop any ongoing playback (kills external player, closes stopCh) + o.Stop() + // 2. Drain any pending text chunks + for len(TTSTextChan) > 0 { + <-TTSTextChan + } + // 3. Reset internal state + o.mu.Lock() + o.textBuffer.Reset() + o.interrupt = true + o.mu.Unlock() + } +} + +func (o *KokoroOrator) Stop() { + o.cmdMu.Lock() + defer o.cmdMu.Unlock() + // Signal any running Speak to stop + if o.stopCh != nil { + select { + case <-o.stopCh: // already closed + default: + close(o.stopCh) + } + o.stopCh = nil + } + // Kill the external player process if it's still running + if o.cmd != nil && o.cmd.Process != nil { + o.cmd.Process.Kill() + o.cmd.Wait() // clean up zombie process + o.cmd = nil + } + // Also reset text buffer and interrupt flag (with o.mu) + o.mu.Lock() + o.textBuffer.Reset() + o.interrupt = true + o.mu.Unlock() +} + +func (o *KokoroOrator) readroutine() { + tokenizer, _ := english.NewSentenceTokenizer(nil) + for { + select { + case chunk := <-TTSTextChan: + o.mu.Lock() + o.interrupt = false + _, err := o.textBuffer.WriteString(chunk) + if err != nil { + o.logger.Warn("failed to write to stringbuilder", "error", err) + o.mu.Unlock() + continue + } + text := o.textBuffer.String() + sentences := tokenizer.Tokenize(text) + o.logger.Debug("adding chunk", "chunk", chunk, "text", text, "sen-len", len(sentences)) + if len(sentences) <= 1 { + o.mu.Unlock() + continue + } + completeSentences := sentences[:len(sentences)-1] + remaining := sentences[len(sentences)-1].Text + o.textBuffer.Reset() + o.textBuffer.WriteString(remaining) + o.mu.Unlock() + for _, sentence := range completeSentences { + o.mu.Lock() + interrupted := o.interrupt + o.mu.Unlock() + if interrupted { + return + } + cleanedText := models.CleanText(sentence.Text) + if cleanedText == "" { + continue + } + o.logger.Debug("calling Speak with sentence", "sent", cleanedText) + if err := o.Speak(cleanedText); err != nil { + o.logger.Error("tts failed", "sentence", cleanedText, "error", err) + } + } + case <-TTSFlushChan: + o.logger.Debug("got flushchan signal start") + // lln is done get the whole message out + if len(TTSTextChan) > 0 { // otherwise might get stuck + for chunk := range TTSTextChan { + o.mu.Lock() + _, err := o.textBuffer.WriteString(chunk) + o.mu.Unlock() + if err != nil { + o.logger.Warn("failed to write to stringbuilder", "error", err) + continue + } + if len(TTSTextChan) == 0 { + break + } + } + } + // flush remaining text + o.mu.Lock() + remaining := o.textBuffer.String() + remaining = models.CleanText(remaining) + o.textBuffer.Reset() + o.mu.Unlock() + if remaining == "" { + continue + } + o.logger.Debug("calling Speak with remainder", "rem", remaining) + sentencesRem := tokenizer.Tokenize(remaining) + for _, rs := range sentencesRem { // to avoid dumping large volume of text + o.mu.Lock() + interrupt := o.interrupt + o.mu.Unlock() + if interrupt { + break + } + if err := o.Speak(rs.Text); err != nil { + o.logger.Error("tts failed", "sentence", rs, "error", err) + } + } + } + } +} diff --git a/extra/stt.go b/extra/stt.go new file mode 100644 index 0000000..7bbf2fd --- /dev/null +++ b/extra/stt.go @@ -0,0 +1,70 @@ +//go:build extra +// +build extra + +package extra + +import ( + "bytes" + "encoding/binary" + "gf-lt/config" + "io" + "log/slog" + "regexp" +) + +var specialRE = regexp.MustCompile(`\[.*?\]`) + +type STT interface { + StartRecording() error + StopRecording() (string, error) + IsRecording() bool +} + +type StreamCloser interface { + Close() error +} + +func NewSTT(logger *slog.Logger, cfg *config.Config) STT { + switch cfg.STT_TYPE { + case "WHISPER_BINARY": + logger.Debug("stt init, chosen whisper binary") + return NewWhisperBinary(logger, cfg) + case "WHISPER_SERVER": + logger.Debug("stt init, chosen whisper server") + return NewWhisperServer(logger, cfg) + } + return NewWhisperServer(logger, cfg) +} + +func NewWhisperServer(logger *slog.Logger, cfg *config.Config) *WhisperServer { + return &WhisperServer{ + logger: logger, + ServerURL: cfg.STT_URL, + SampleRate: cfg.STT_SR, + AudioBuffer: new(bytes.Buffer), + } +} + +func (stt *WhisperServer) writeWavHeader(w io.Writer, dataSize int) { + header := make([]byte, 44) + copy(header[0:4], "RIFF") + binary.LittleEndian.PutUint32(header[4:8], uint32(36+dataSize)) + copy(header[8:12], "WAVE") + copy(header[12:16], "fmt ") + binary.LittleEndian.PutUint32(header[16:20], 16) + binary.LittleEndian.PutUint16(header[20:22], 1) + binary.LittleEndian.PutUint16(header[22:24], 1) + binary.LittleEndian.PutUint32(header[24:28], uint32(stt.SampleRate)) + binary.LittleEndian.PutUint32(header[28:32], uint32(stt.SampleRate)*1*(16/8)) + binary.LittleEndian.PutUint16(header[32:34], 1*(16/8)) + binary.LittleEndian.PutUint16(header[34:36], 16) + copy(header[36:40], "data") + binary.LittleEndian.PutUint32(header[40:44], uint32(dataSize)) + if _, err := w.Write(header); err != nil { + stt.logger.Error("writeWavHeader", "error", err) + } +} + +func (stt *WhisperServer) IsRecording() bool { + return stt.recording +} diff --git a/extra/tts.go b/extra/tts.go new file mode 100644 index 0000000..2ddb0ae --- /dev/null +++ b/extra/tts.go @@ -0,0 +1,69 @@ +//go:build extra +// +build extra + +package extra + +import ( + "gf-lt/config" + "gf-lt/models" + "log/slog" + "os" + "strings" + + google_translate_tts "github.com/GrailFinder/google-translate-tts" +) + +var ( + TTSTextChan = make(chan string, 10000) + TTSFlushChan = make(chan bool, 1) + TTSDoneChan = make(chan bool, 1) + // endsWithPunctuation = regexp.MustCompile(`[;.!?]$`) +) + +type Orator interface { + Speak(text string) error + Stop() + // pause and resume? + GetLogger() *slog.Logger +} + +func NewOrator(log *slog.Logger, cfg *config.Config) Orator { + provider := cfg.TTS_PROVIDER + if provider == "" { + provider = "google" // does not require local setup + } + switch strings.ToLower(provider) { + case "kokoro": // kokoro + orator := &KokoroOrator{ + logger: log, + URL: cfg.TTS_URL, + Format: models.AFMP3, + Stream: false, + Speed: cfg.TTS_SPEED, + Language: "a", + Voice: "af_bella(1)+af_sky(1)", + } + go orator.readroutine() + go orator.stoproutine() + return orator + default: + language := cfg.TTS_LANGUAGE + if language == "" { + language = "en" + } + speech := &google_translate_tts.Speech{ + Folder: os.TempDir() + "/gf-lt-tts", // Temporary directory for caching + Language: language, + Proxy: "", // Proxy not supported + Speed: cfg.TTS_SPEED, + } + orator := &GoogleTranslateOrator{ + logger: log, + speech: speech, + Speed: cfg.TTS_SPEED, + } + go orator.readroutine() + go orator.stoproutine() + return orator + } +} diff --git a/extra/tts_test.go b/extra/tts_test.go new file mode 100644 index 0000000..a21d9b8 --- /dev/null +++ b/extra/tts_test.go @@ -0,0 +1,40 @@ +//go:build extra +// +build extra + +package extra + +import ( + "testing" +) + +func TestCleanText(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"Hello world", "Hello world"}, + {"**Bold text**", "Bold text"}, + {"*Italic text*", "Italic text"}, + {"# Header", "Header"}, + {"_Underlined text_", "Underlined text"}, + {"~Strikethrough text~", "Strikethrough text"}, + {"`Code text`", "Code text"}, + {"[Link text](url)", "Link text(url)"}, + {"Mixed *markdown* and #headers#!", "Mixed markdown and headers"}, + {"<html>tags</html>", "tags"}, + {"|---|", ""}, // Table separator + {"|====|", ""}, // Table separator with equals + {"| - - - |", ""}, // Table separator with spaced dashes + {"| cell1 | cell2 |", "cell1 cell2"}, // Table row with content + {" Trailing spaces ", "Trailing spaces"}, + {"", ""}, + {"***", ""}, + } + + for _, test := range tests { + result := cleanText(test.input) + if result != test.expected { + t.Errorf("cleanText(%q) = %q; expected %q", test.input, result, test.expected) + } + } +}
\ No newline at end of file diff --git a/extra/whisper_binary.go b/extra/whisper_binary.go new file mode 100644 index 0000000..1c35952 --- /dev/null +++ b/extra/whisper_binary.go @@ -0,0 +1,176 @@ +//go:build extra +// +build extra + +package extra + +import ( + "bytes" + "context" + "errors" + "fmt" + "gf-lt/config" + "log/slog" + "os" + "os/exec" + "strings" + "sync" + "syscall" + "time" +) + +type WhisperBinary struct { + logger *slog.Logger + whisperPath string + modelPath string + lang string + // Per-recording fields (protected by mu) + mu sync.Mutex + recording bool + tempFile string + ctx context.Context + cancel context.CancelFunc + cmd *exec.Cmd + cmdMu sync.Mutex +} + +func (w *WhisperBinary) StartRecording() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.recording { + return errors.New("recording is already in progress") + } + // Fresh context for this recording + ctx, cancel := context.WithCancel(context.Background()) + w.ctx = ctx + w.cancel = cancel + // Create temporary file + tempFile, err := os.CreateTemp("", "recording_*.wav") + if err != nil { + cancel() + return fmt.Errorf("failed to create temp file: %w", err) + } + tempFile.Close() + w.tempFile = tempFile.Name() + // ffmpeg command: capture from default microphone, write WAV + args := []string{ + "-f", "alsa", // or "pulse" if preferred + "-i", "default", + "-acodec", "pcm_s16le", + "-ar", "16000", + "-ac", "1", + "-y", // overwrite output file + w.tempFile, + } + cmd := exec.CommandContext(w.ctx, "ffmpeg", args...) + // Capture stderr for debugging (optional, but useful for diagnosing) + stderr, err := cmd.StderrPipe() + if err != nil { + cancel() + os.Remove(w.tempFile) + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + go func() { + buf := make([]byte, 1024) + for { + n, err := stderr.Read(buf) + if n > 0 { + w.logger.Debug("ffmpeg stderr", "output", string(buf[:n])) + } + if err != nil { + break + } + } + }() + w.cmdMu.Lock() + w.cmd = cmd + w.cmdMu.Unlock() + if err := cmd.Start(); err != nil { + cancel() + os.Remove(w.tempFile) + return fmt.Errorf("failed to start ffmpeg: %w", err) + } + w.recording = true + w.logger.Debug("Recording started", "file", w.tempFile) + return nil +} + +func (w *WhisperBinary) StopRecording() (string, error) { + w.mu.Lock() + defer w.mu.Unlock() + if !w.recording { + return "", errors.New("not currently recording") + } + w.recording = false + // Gracefully stop ffmpeg + w.cmdMu.Lock() + if w.cmd != nil && w.cmd.Process != nil { + w.logger.Debug("Sending SIGTERM to ffmpeg") + w.cmd.Process.Signal(syscall.SIGTERM) + // Wait for process to exit (up to 2 seconds) + done := make(chan error, 1) + go func() { + done <- w.cmd.Wait() + }() + select { + case <-done: + w.logger.Debug("ffmpeg exited after SIGTERM") + case <-time.After(2 * time.Second): + w.logger.Warn("ffmpeg did not exit, sending SIGKILL") + w.cmd.Process.Kill() + <-done + } + } + w.cmdMu.Unlock() + // Cancel context (already done, but for cleanliness) + if w.cancel != nil { + w.cancel() + } + // Validate temp file + if w.tempFile == "" { + return "", errors.New("no recording file") + } + defer os.Remove(w.tempFile) + info, err := os.Stat(w.tempFile) + if err != nil { + return "", fmt.Errorf("failed to stat temp file: %w", err) + } + if info.Size() < 44 { // WAV header is 44 bytes + // Log ffmpeg stderr? Already captured in debug logs. + return "", fmt.Errorf("recording file too small (%d bytes), possibly no audio captured", info.Size()) + } + // Run whisper.cpp binary + cmd := exec.Command(w.whisperPath, "-m", w.modelPath, "-l", w.lang, w.tempFile) + var outBuf, errBuf bytes.Buffer + cmd.Stdout = &outBuf + cmd.Stderr = &errBuf + if err := cmd.Run(); err != nil { + w.logger.Error("whisper binary failed", + "error", err, + "stderr", errBuf.String(), + "file_size", info.Size()) + return "", fmt.Errorf("whisper binary failed: %w (stderr: %s)", err, errBuf.String()) + } + result := strings.TrimRight(outBuf.String(), "\n") + result = specialRE.ReplaceAllString(result, "") + return strings.TrimSpace(strings.ReplaceAll(result, "\n ", "\n")), nil +} + +// IsRecording returns true if a recording is in progress. +func (w *WhisperBinary) IsRecording() bool { + w.mu.Lock() + defer w.mu.Unlock() + return w.recording +} + +func NewWhisperBinary(logger *slog.Logger, cfg *config.Config) *WhisperBinary { + ctx, cancel := context.WithCancel(context.Background()) + // Set ALSA error handler first + return &WhisperBinary{ + logger: logger, + whisperPath: cfg.WhisperBinaryPath, + modelPath: cfg.WhisperModelPath, + lang: cfg.STT_LANG, + ctx: ctx, + cancel: cancel, + } +} diff --git a/extra/whisper_server.go b/extra/whisper_server.go new file mode 100644 index 0000000..7532f4a --- /dev/null +++ b/extra/whisper_server.go @@ -0,0 +1,156 @@ +//go:build extra +// +build extra + +package extra + +import ( + "bytes" + "errors" + "fmt" + "io" + "log/slog" + "mime/multipart" + "net/http" + "os/exec" + "strings" + "sync" +) + +type WhisperServer struct { + logger *slog.Logger + ServerURL string + SampleRate int + AudioBuffer *bytes.Buffer + recording bool // protected by mu + mu sync.Mutex // protects recording & AudioBuffer + cmd *exec.Cmd // protected by cmdMu + stopCh chan struct{} // protected by cmdMu + cmdMu sync.Mutex // protects cmd and stopCh +} + +func (stt *WhisperServer) StartRecording() error { + stt.mu.Lock() + defer stt.mu.Unlock() + if stt.recording { + return nil + } + // Build ffmpeg command for microphone capture + args := []string{ + "-f", "alsa", + "-i", "default", + "-acodec", "pcm_s16le", + "-ar", fmt.Sprint(stt.SampleRate), + "-ac", "1", + "-f", "s16le", + "-", + } + cmd := exec.Command("ffmpeg", args...) + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to get stdout pipe: %w", err) + } + stt.cmdMu.Lock() + stt.cmd = cmd + stt.stopCh = make(chan struct{}) + stt.cmdMu.Unlock() + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start ffmpeg: %w", err) + } + stt.recording = true + stt.AudioBuffer.Reset() + // Read PCM data in goroutine + go func() { + buf := make([]byte, 4096) + for { + select { + case <-stt.stopCh: + return + default: + n, err := stdout.Read(buf) + if n > 0 { + stt.mu.Lock() + stt.AudioBuffer.Write(buf[:n]) + stt.mu.Unlock() + } + if err != nil { + if err != io.EOF { + stt.logger.Error("recording read error", "error", err) + } + return + } + } + } + }() + return nil +} + +func (stt *WhisperServer) StopRecording() (string, error) { + stt.mu.Lock() + defer stt.mu.Unlock() + if !stt.recording { + return "", errors.New("not recording") + } + stt.recording = false + // Stop ffmpeg + stt.cmdMu.Lock() + if stt.cmd != nil && stt.cmd.Process != nil { + stt.cmd.Process.Kill() + stt.cmd.Wait() + } + close(stt.stopCh) + stt.cmdMu.Unlock() + // Rest of StopRecording unchanged (WAV header + HTTP upload) + // ... + stt.recording = false + // wait loop to finish? + if stt.AudioBuffer == nil { + err := errors.New("unexpected nil AudioBuffer") + stt.logger.Error(err.Error()) + return "", err + } + // Create WAV header first + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + // Add audio file part + part, err := writer.CreateFormFile("file", "recording.wav") + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + // Stream directly to multipart writer: header + raw data + dataSize := stt.AudioBuffer.Len() + stt.writeWavHeader(part, dataSize) + if _, err := io.Copy(part, stt.AudioBuffer); err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + // Reset buffer for next recording + stt.AudioBuffer.Reset() + // Add response format field + err = writer.WriteField("response_format", "text") + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + if writer.Close() != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + // Send request + resp, err := http.Post(stt.ServerURL, writer.FormDataContentType(), body) //nolint:noctx + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + defer resp.Body.Close() + // Read and print response + responseTextBytes, err := io.ReadAll(resp.Body) + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + resptext := strings.TrimRight(string(responseTextBytes), "\n") + // in case there are special tokens like [_BEG_] + resptext = specialRE.ReplaceAllString(resptext, "") + return strings.TrimSpace(strings.ReplaceAll(resptext, "\n ", "\n")), nil +} |
