summaryrefslogtreecommitdiff
path: root/extra
diff options
context:
space:
mode:
Diffstat (limited to 'extra')
-rw-r--r--extra/cluedo.go73
-rw-r--r--extra/cluedo_test.go50
-rw-r--r--extra/stt.go199
-rw-r--r--extra/tts.go212
-rw-r--r--extra/twentyq.go11
-rw-r--r--extra/vad.go1
-rw-r--r--extra/websearch.go13
-rw-r--r--extra/whisper_binary.go318
8 files changed, 877 insertions, 0 deletions
diff --git a/extra/cluedo.go b/extra/cluedo.go
new file mode 100644
index 0000000..1ef11cc
--- /dev/null
+++ b/extra/cluedo.go
@@ -0,0 +1,73 @@
+package extra
+
+import (
+ "math/rand"
+ "strings"
+)
+
+var (
+ rooms = []string{"HALL", "LOUNGE", "DINING ROOM", "KITCHEN", "BALLROOM", "CONSERVATORY", "BILLIARD ROOM", "LIBRARY", "STUDY"}
+ weapons = []string{"CANDLESTICK", "DAGGER", "LEAD PIPE", "REVOLVER", "ROPE", "SPANNER"}
+ people = []string{"Miss Scarlett", "Colonel Mustard", "Mrs. White", "Reverend Green", "Mrs. Peacock", "Professor Plum"}
+)
+
+type MurderTrifecta struct {
+ Murderer string
+ Weapon string
+ Room string
+}
+
+type CluedoRoundInfo struct {
+ Answer MurderTrifecta
+ PlayersCards map[string][]string
+}
+
+func (c *CluedoRoundInfo) GetPlayerCards(player string) string {
+ // maybe format it a little
+ return "cards of " + player + "are " + strings.Join(c.PlayersCards[player], ",")
+}
+
+func CluedoPrepCards(playerOrder []string) *CluedoRoundInfo {
+ res := &CluedoRoundInfo{}
+ // Select murder components
+ trifecta := MurderTrifecta{
+ Murderer: people[rand.Intn(len(people))],
+ Weapon: weapons[rand.Intn(len(weapons))],
+ Room: rooms[rand.Intn(len(rooms))],
+ }
+ // Collect non-murder cards
+ var notInvolved []string
+ for _, room := range rooms {
+ if room != trifecta.Room {
+ notInvolved = append(notInvolved, room)
+ }
+ }
+ for _, weapon := range weapons {
+ if weapon != trifecta.Weapon {
+ notInvolved = append(notInvolved, weapon)
+ }
+ }
+ for _, person := range people {
+ if person != trifecta.Murderer {
+ notInvolved = append(notInvolved, person)
+ }
+ }
+ // Shuffle and distribute cards
+ rand.Shuffle(len(notInvolved), func(i, j int) {
+ notInvolved[i], notInvolved[j] = notInvolved[j], notInvolved[i]
+ })
+ players := map[string][]string{}
+ cardsPerPlayer := len(notInvolved) / len(playerOrder)
+ // playerOrder := []string{"{{user}}", "{{char}}", "{{char2}}"}
+ for i, player := range playerOrder {
+ start := i * cardsPerPlayer
+ end := (i + 1) * cardsPerPlayer
+ if end > len(notInvolved) {
+ end = len(notInvolved)
+ }
+ players[player] = notInvolved[start:end]
+ }
+ res.Answer = trifecta
+ res.PlayersCards = players
+ return res
+}
diff --git a/extra/cluedo_test.go b/extra/cluedo_test.go
new file mode 100644
index 0000000..e7a53b1
--- /dev/null
+++ b/extra/cluedo_test.go
@@ -0,0 +1,50 @@
+package extra
+
+import (
+ "testing"
+)
+
+func TestPrepCards(t *testing.T) {
+ // Run the function to get the murder combination and player cards
+ roundInfo := CluedoPrepCards([]string{"{{user}}", "{{char}}", "{{char2}}"})
+ // Create a map to track all distributed cards
+ distributedCards := make(map[string]bool)
+ // Check that the murder combination cards are not distributed to players
+ murderCards := []string{roundInfo.Answer.Murderer, roundInfo.Answer.Weapon, roundInfo.Answer.Room}
+ for _, card := range murderCards {
+ if distributedCards[card] {
+ t.Errorf("Murder card %s was distributed to a player", card)
+ }
+ }
+ // Check each player's cards
+ for player, cards := range roundInfo.PlayersCards {
+ for _, card := range cards {
+ // Ensure the card is not part of the murder combination
+ for _, murderCard := range murderCards {
+ if card == murderCard {
+ t.Errorf("Player %s has a murder card: %s", player, card)
+ }
+ }
+ // Ensure the card is unique and not already distributed
+ if distributedCards[card] {
+ t.Errorf("Card %s is duplicated in player %s's hand", card, player)
+ }
+ distributedCards[card] = true
+ }
+ }
+ // Verify that all non-murder cards are distributed
+ allCards := append(append([]string{}, rooms...), weapons...)
+ allCards = append(allCards, people...)
+ for _, card := range allCards {
+ isMurderCard := false
+ for _, murderCard := range murderCards {
+ if card == murderCard {
+ isMurderCard = true
+ break
+ }
+ }
+ if !isMurderCard && !distributedCards[card] {
+ t.Errorf("Card %s was not distributed to any player", card)
+ }
+ }
+}
diff --git a/extra/stt.go b/extra/stt.go
new file mode 100644
index 0000000..e33a94d
--- /dev/null
+++ b/extra/stt.go
@@ -0,0 +1,199 @@
+package extra
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "gf-lt/config"
+ "io"
+ "log/slog"
+ "mime/multipart"
+ "net/http"
+ "regexp"
+ "strings"
+ "syscall"
+
+ "github.com/gordonklaus/portaudio"
+)
+
+var specialRE = regexp.MustCompile(`\[.*?\]`)
+
+type STT interface {
+ StartRecording() error
+ StopRecording() (string, error)
+ IsRecording() bool
+}
+
+type StreamCloser interface {
+ Close() error
+}
+
+func NewSTT(logger *slog.Logger, cfg *config.Config) STT {
+ switch cfg.STT_TYPE {
+ case "WHISPER_BINARY":
+ logger.Debug("stt init, chosen whisper binary")
+ return NewWhisperBinary(logger, cfg)
+ case "WHISPER_SERVER":
+ logger.Debug("stt init, chosen whisper server")
+ return NewWhisperServer(logger, cfg)
+ }
+ return NewWhisperServer(logger, cfg)
+}
+
+type WhisperServer struct {
+ logger *slog.Logger
+ ServerURL string
+ SampleRate int
+ AudioBuffer *bytes.Buffer
+ recording bool
+}
+
+func NewWhisperServer(logger *slog.Logger, cfg *config.Config) *WhisperServer {
+ return &WhisperServer{
+ logger: logger,
+ ServerURL: cfg.STT_URL,
+ SampleRate: cfg.STT_SR,
+ AudioBuffer: new(bytes.Buffer),
+ }
+}
+
+func (stt *WhisperServer) StartRecording() error {
+ if err := stt.microphoneStream(stt.SampleRate); err != nil {
+ return fmt.Errorf("failed to init microphone: %w", err)
+ }
+ stt.recording = true
+ return nil
+}
+
+func (stt *WhisperServer) StopRecording() (string, error) {
+ stt.recording = false
+ // wait loop to finish?
+ if stt.AudioBuffer == nil {
+ err := errors.New("unexpected nil AudioBuffer")
+ stt.logger.Error(err.Error())
+ return "", err
+ }
+ // Create WAV header first
+ body := &bytes.Buffer{}
+ writer := multipart.NewWriter(body)
+ // Add audio file part
+ part, err := writer.CreateFormFile("file", "recording.wav")
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ // Stream directly to multipart writer: header + raw data
+ dataSize := stt.AudioBuffer.Len()
+ stt.writeWavHeader(part, dataSize)
+ if _, err := io.Copy(part, stt.AudioBuffer); err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ // Reset buffer for next recording
+ stt.AudioBuffer.Reset()
+ // Add response format field
+ err = writer.WriteField("response_format", "text")
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ if writer.Close() != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ // Send request
+ resp, err := http.Post(stt.ServerURL, writer.FormDataContentType(), body) //nolint:noctx
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ defer resp.Body.Close()
+ // Read and print response
+ responseTextBytes, err := io.ReadAll(resp.Body)
+ if err != nil {
+ stt.logger.Error("fn: StopRecording", "error", err)
+ return "", err
+ }
+ resptext := strings.TrimRight(string(responseTextBytes), "\n")
+ // in case there are special tokens like [_BEG_]
+ resptext = specialRE.ReplaceAllString(resptext, "")
+ return strings.TrimSpace(strings.ReplaceAll(resptext, "\n ", "\n")), nil
+}
+
+func (stt *WhisperServer) writeWavHeader(w io.Writer, dataSize int) {
+ header := make([]byte, 44)
+ copy(header[0:4], "RIFF")
+ binary.LittleEndian.PutUint32(header[4:8], uint32(36+dataSize))
+ copy(header[8:12], "WAVE")
+ copy(header[12:16], "fmt ")
+ binary.LittleEndian.PutUint32(header[16:20], 16)
+ binary.LittleEndian.PutUint16(header[20:22], 1)
+ binary.LittleEndian.PutUint16(header[22:24], 1)
+ binary.LittleEndian.PutUint32(header[24:28], uint32(stt.SampleRate))
+ binary.LittleEndian.PutUint32(header[28:32], uint32(stt.SampleRate)*1*(16/8))
+ binary.LittleEndian.PutUint16(header[32:34], 1*(16/8))
+ binary.LittleEndian.PutUint16(header[34:36], 16)
+ copy(header[36:40], "data")
+ binary.LittleEndian.PutUint32(header[40:44], uint32(dataSize))
+ if _, err := w.Write(header); err != nil {
+ stt.logger.Error("writeWavHeader", "error", err)
+ }
+}
+
+func (stt *WhisperServer) IsRecording() bool {
+ return stt.recording
+}
+
+func (stt *WhisperServer) microphoneStream(sampleRate int) error {
+ // Temporarily redirect stderr to suppress ALSA warnings during PortAudio init
+ origStderr, errDup := syscall.Dup(syscall.Stderr)
+ if errDup != nil {
+ return fmt.Errorf("failed to dup stderr: %w", errDup)
+ }
+ nullFD, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0)
+ if err != nil {
+ _ = syscall.Close(origStderr) // Close the dup'd fd if open fails
+ return fmt.Errorf("failed to open /dev/null: %w", err)
+ }
+ // redirect stderr
+ _ = syscall.Dup2(nullFD, syscall.Stderr)
+ // Initialize PortAudio (this is where ALSA warnings occur)
+ defer func() {
+ // Restore stderr
+ _ = syscall.Dup2(origStderr, syscall.Stderr)
+ _ = syscall.Close(origStderr)
+ _ = syscall.Close(nullFD)
+ }()
+ if err := portaudio.Initialize(); err != nil {
+ return fmt.Errorf("portaudio init failed: %w", err)
+ }
+ in := make([]int16, 64)
+ stream, err := portaudio.OpenDefaultStream(1, 0, float64(sampleRate), len(in), in)
+ if err != nil {
+ if paErr := portaudio.Terminate(); paErr != nil {
+ return fmt.Errorf("failed to open microphone: %w; terminate error: %w", err, paErr)
+ }
+ return fmt.Errorf("failed to open microphone: %w", err)
+ }
+ go func(stream *portaudio.Stream) {
+ if err := stream.Start(); err != nil {
+ stt.logger.Error("microphoneStream", "error", err)
+ return
+ }
+ for {
+ if !stt.IsRecording() {
+ return
+ }
+ if err := stream.Read(); err != nil {
+ stt.logger.Error("reading stream", "error", err)
+ return
+ }
+ if err := binary.Write(stt.AudioBuffer, binary.LittleEndian, in); err != nil {
+ stt.logger.Error("writing to buffer", "error", err)
+ return
+ }
+ }
+ }(stream)
+ return nil
+}
diff --git a/extra/tts.go b/extra/tts.go
new file mode 100644
index 0000000..c6f373a
--- /dev/null
+++ b/extra/tts.go
@@ -0,0 +1,212 @@
+package extra
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "gf-lt/config"
+ "gf-lt/models"
+ "io"
+ "log/slog"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/gopxl/beep/v2"
+ "github.com/gopxl/beep/v2/mp3"
+ "github.com/gopxl/beep/v2/speaker"
+ "github.com/neurosnap/sentences/english"
+)
+
+var (
+ TTSTextChan = make(chan string, 10000)
+ TTSFlushChan = make(chan bool, 1)
+ TTSDoneChan = make(chan bool, 1)
+ // endsWithPunctuation = regexp.MustCompile(`[;.!?]$`)
+)
+
+type Orator interface {
+ Speak(text string) error
+ Stop()
+ // pause and resume?
+ GetLogger() *slog.Logger
+}
+
+// impl https://github.com/remsky/Kokoro-FastAPI
+type KokoroOrator struct {
+ logger *slog.Logger
+ URL string
+ Format models.AudioFormat
+ Stream bool
+ Speed float32
+ Language string
+ Voice string
+ currentStream *beep.Ctrl // Added for playback control
+ textBuffer strings.Builder
+ // textBuffer bytes.Buffer
+}
+
+func (o *KokoroOrator) stoproutine() {
+ <-TTSDoneChan
+ o.logger.Debug("orator got done signal")
+ o.Stop()
+ // drain the channel
+ for len(TTSTextChan) > 0 {
+ <-TTSTextChan
+ }
+}
+
+func (o *KokoroOrator) readroutine() {
+ tokenizer, _ := english.NewSentenceTokenizer(nil)
+ // var sentenceBuf bytes.Buffer
+ // var remainder strings.Builder
+ for {
+ select {
+ case chunk := <-TTSTextChan:
+ // sentenceBuf.WriteString(chunk)
+ // text := sentenceBuf.String()
+ _, err := o.textBuffer.WriteString(chunk)
+ if err != nil {
+ o.logger.Warn("failed to write to stringbuilder", "error", err)
+ continue
+ }
+ text := o.textBuffer.String()
+ sentences := tokenizer.Tokenize(text)
+ o.logger.Debug("adding chunk", "chunk", chunk, "text", text, "sen-len", len(sentences))
+ for i, sentence := range sentences {
+ if i == len(sentences)-1 { // last sentence
+ o.textBuffer.Reset()
+ _, err := o.textBuffer.WriteString(sentence.Text)
+ if err != nil {
+ o.logger.Warn("failed to write to stringbuilder", "error", err)
+ continue
+ }
+ continue // if only one (often incomplete) sentence; wait for next chunk
+ }
+ o.logger.Debug("calling Speak with sentence", "sent", sentence.Text)
+ if err := o.Speak(sentence.Text); err != nil {
+ o.logger.Error("tts failed", "sentence", sentence.Text, "error", err)
+ }
+ }
+ case <-TTSFlushChan:
+ o.logger.Debug("got flushchan signal start")
+ // lln is done get the whole message out
+ if len(TTSTextChan) > 0 { // otherwise might get stuck
+ for chunk := range TTSTextChan {
+ _, err := o.textBuffer.WriteString(chunk)
+ if err != nil {
+ o.logger.Warn("failed to write to stringbuilder", "error", err)
+ continue
+ }
+ if len(TTSTextChan) == 0 {
+ break
+ }
+ }
+ }
+ // INFO: if there is a lot of text it will take some time to make with tts at once
+ // to avoid this pause, it might be better to keep splitting on sentences
+ // but keepinig in mind that remainder could be ommited by tokenizer
+ // Flush remaining text
+ remaining := o.textBuffer.String()
+ o.textBuffer.Reset()
+ if remaining != "" {
+ o.logger.Debug("calling Speak with remainder", "rem", remaining)
+ if err := o.Speak(remaining); err != nil {
+ o.logger.Error("tts failed", "sentence", remaining, "error", err)
+ }
+ }
+ }
+ }
+}
+
+func NewOrator(log *slog.Logger, cfg *config.Config) Orator {
+ orator := &KokoroOrator{
+ logger: log,
+ URL: cfg.TTS_URL,
+ Format: models.AFMP3,
+ Stream: false,
+ Speed: cfg.TTS_SPEED,
+ Language: "a",
+ Voice: "af_bella(1)+af_sky(1)",
+ }
+ go orator.readroutine()
+ go orator.stoproutine()
+ return orator
+}
+
+func (o *KokoroOrator) GetLogger() *slog.Logger {
+ return o.logger
+}
+
+func (o *KokoroOrator) requestSound(text string) (io.ReadCloser, error) {
+ payload := map[string]interface{}{
+ "input": text,
+ "voice": o.Voice,
+ "response_format": o.Format,
+ "download_format": o.Format,
+ "stream": o.Stream,
+ "speed": o.Speed,
+ // "return_download_link": true,
+ "lang_code": o.Language,
+ }
+ payloadBytes, err := json.Marshal(payload)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal payload: %w", err)
+ }
+ req, err := http.NewRequest("POST", o.URL, bytes.NewBuffer(payloadBytes)) //nolint:noctx
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+ req.Header.Set("accept", "application/json")
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("request failed: %w", err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ defer resp.Body.Close()
+ return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+ return resp.Body, nil
+}
+
+func (o *KokoroOrator) Speak(text string) error {
+ o.logger.Debug("fn: Speak is called", "text-len", len(text))
+ body, err := o.requestSound(text)
+ if err != nil {
+ o.logger.Error("request failed", "error", err)
+ return fmt.Errorf("request failed: %w", err)
+ }
+ defer body.Close()
+ // Decode the mp3 audio from response body
+ streamer, format, err := mp3.Decode(body)
+ if err != nil {
+ o.logger.Error("mp3 decode failed", "error", err)
+ return fmt.Errorf("mp3 decode failed: %w", err)
+ }
+ defer streamer.Close()
+ // here it spams with errors that speaker cannot be initialized more than once, but how would we deal with many audio records then?
+ if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil {
+ o.logger.Debug("failed to init speaker", "error", err)
+ }
+ done := make(chan bool)
+ // Create controllable stream and store reference
+ o.currentStream = &beep.Ctrl{Streamer: beep.Seq(streamer, beep.Callback(func() {
+ close(done)
+ o.currentStream = nil
+ })), Paused: false}
+ speaker.Play(o.currentStream)
+ <-done // we hang in this routine;
+ return nil
+}
+
+func (o *KokoroOrator) Stop() {
+ // speaker.Clear()
+ o.logger.Debug("attempted to stop orator", "orator", o)
+ speaker.Lock()
+ defer speaker.Unlock()
+ if o.currentStream != nil {
+ // o.currentStream.Paused = true
+ o.currentStream.Streamer = nil
+ }
+}
diff --git a/extra/twentyq.go b/extra/twentyq.go
new file mode 100644
index 0000000..30c08cc
--- /dev/null
+++ b/extra/twentyq.go
@@ -0,0 +1,11 @@
+package extra
+
+import "math/rand"
+
+var (
+ chars = []string{"Shrek", "Garfield", "Jack the Ripper"}
+)
+
+func GetRandomChar() string {
+ return chars[rand.Intn(len(chars))]
+}
diff --git a/extra/vad.go b/extra/vad.go
new file mode 100644
index 0000000..2a9e238
--- /dev/null
+++ b/extra/vad.go
@@ -0,0 +1 @@
+package extra
diff --git a/extra/websearch.go b/extra/websearch.go
new file mode 100644
index 0000000..99bc1b6
--- /dev/null
+++ b/extra/websearch.go
@@ -0,0 +1,13 @@
+package extra
+
+import "github.com/GrailFinder/searchagent/searcher"
+
+var WebSearcher searcher.WebSurfer
+
+func init() {
+ sa, err := searcher.NewWebSurfer(searcher.SearcherTypeScraper, "")
+ if err != nil {
+ panic("failed to init seachagent; error: " + err.Error())
+ }
+ WebSearcher = sa
+}
diff --git a/extra/whisper_binary.go b/extra/whisper_binary.go
new file mode 100644
index 0000000..a016a30
--- /dev/null
+++ b/extra/whisper_binary.go
@@ -0,0 +1,318 @@
+package extra
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "gf-lt/config"
+ "io"
+ "log/slog"
+ "os"
+ "os/exec"
+ "strings"
+ "sync"
+ "syscall"
+
+ "github.com/gordonklaus/portaudio"
+)
+
+type WhisperBinary struct {
+ logger *slog.Logger
+ whisperPath string
+ modelPath string
+ lang string
+ ctx context.Context
+ cancel context.CancelFunc
+ mu sync.Mutex
+ recording bool
+ audioBuffer []int16
+}
+
+func NewWhisperBinary(logger *slog.Logger, cfg *config.Config) *WhisperBinary {
+ ctx, cancel := context.WithCancel(context.Background())
+ // Set ALSA error handler first
+ return &WhisperBinary{
+ logger: logger,
+ whisperPath: cfg.WhisperBinaryPath,
+ modelPath: cfg.WhisperModelPath,
+ lang: cfg.STT_LANG,
+ ctx: ctx,
+ cancel: cancel,
+ }
+}
+
+func (w *WhisperBinary) StartRecording() error {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.recording {
+ return errors.New("recording is already in progress")
+ }
+ // Temporarily redirect stderr to suppress ALSA warnings during PortAudio init
+ origStderr, errDup := syscall.Dup(syscall.Stderr)
+ if errDup != nil {
+ return fmt.Errorf("failed to dup stderr: %w", errDup)
+ }
+ nullFD, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0)
+ if err != nil {
+ _ = syscall.Close(origStderr) // Close the dup'd fd if open fails
+ return fmt.Errorf("failed to open /dev/null: %w", err)
+ }
+ // redirect stderr
+ _ = syscall.Dup2(nullFD, syscall.Stderr)
+ // Initialize PortAudio (this is where ALSA warnings occur)
+ portaudioErr := portaudio.Initialize()
+ defer func() {
+ // Restore stderr
+ _ = syscall.Dup2(origStderr, syscall.Stderr)
+ _ = syscall.Close(origStderr)
+ _ = syscall.Close(nullFD)
+ }()
+ if portaudioErr != nil {
+ return fmt.Errorf("portaudio init failed: %w", portaudioErr)
+ }
+ // Initialize audio buffer
+ w.audioBuffer = make([]int16, 0)
+ in := make([]int16, 1024) // buffer size
+ stream, err := portaudio.OpenDefaultStream(1, 0, 16000.0, len(in), in)
+ if err != nil {
+ if paErr := portaudio.Terminate(); paErr != nil {
+ return fmt.Errorf("failed to open microphone: %w; terminate error: %w", err, paErr)
+ }
+ return fmt.Errorf("failed to open microphone: %w", err)
+ }
+ go w.recordAudio(stream, in)
+ w.recording = true
+ w.logger.Debug("Recording started")
+ return nil
+}
+
+func (w *WhisperBinary) recordAudio(stream *portaudio.Stream, in []int16) {
+ defer func() {
+ w.logger.Debug("recordAudio defer function called")
+ _ = stream.Stop() // Stop the stream
+ _ = portaudio.Terminate() // ignoring error as we're shutting down
+ w.logger.Debug("recordAudio terminated")
+ }()
+ w.logger.Debug("Starting audio stream")
+ if err := stream.Start(); err != nil {
+ w.logger.Error("Failed to start audio stream", "error", err)
+ return
+ }
+ w.logger.Debug("Audio stream started, entering recording loop")
+ for {
+ select {
+ case <-w.ctx.Done():
+ w.logger.Debug("Context done, exiting recording loop")
+ return
+ default:
+ // Check recording status with minimal lock time
+ w.mu.Lock()
+ recording := w.recording
+ w.mu.Unlock()
+
+ if !recording {
+ w.logger.Debug("Recording flag is false, exiting recording loop")
+ return
+ }
+ if err := stream.Read(); err != nil {
+ w.logger.Error("Error reading from stream", "error", err)
+ return
+ }
+ // Append samples to buffer - only acquire lock when necessary
+ w.mu.Lock()
+ if w.audioBuffer == nil {
+ w.audioBuffer = make([]int16, 0)
+ }
+ // Make a copy of the input buffer to avoid overwriting
+ tempBuffer := make([]int16, len(in))
+ copy(tempBuffer, in)
+ w.audioBuffer = append(w.audioBuffer, tempBuffer...)
+ w.mu.Unlock()
+ }
+ }
+}
+
+func (w *WhisperBinary) StopRecording() (string, error) {
+ w.logger.Debug("StopRecording called")
+ w.mu.Lock()
+ if !w.recording {
+ w.mu.Unlock()
+ return "", errors.New("not currently recording")
+ }
+ w.logger.Debug("Setting recording to false and cancelling context")
+ w.recording = false
+ w.cancel() // This will stop the recording goroutine
+ w.mu.Unlock()
+ // // Small delay to allow the recording goroutine to react to context cancellation
+ // time.Sleep(20 * time.Millisecond)
+ // Save the recorded audio to a temporary file
+ tempFile, err := w.saveAudioToTempFile()
+ if err != nil {
+ w.logger.Error("Error saving audio to temp file", "error", err)
+ return "", fmt.Errorf("failed to save audio to temp file: %w", err)
+ }
+ w.logger.Debug("Saved audio to temp file", "file", tempFile)
+ // Run the whisper binary with a separate context to avoid cancellation during transcription
+ cmd := exec.Command(w.whisperPath, "-m", w.modelPath, "-l", w.lang, tempFile, "2>/dev/null")
+ var outBuf bytes.Buffer
+ cmd.Stdout = &outBuf
+ // Redirect stderr to suppress ALSA warnings and other stderr output
+ cmd.Stderr = io.Discard // Suppress stderr output from whisper binary
+ w.logger.Debug("Running whisper binary command")
+ if err := cmd.Run(); err != nil {
+ // Clean up audio buffer
+ w.mu.Lock()
+ w.audioBuffer = nil
+ w.mu.Unlock()
+ // Since we're suppressing stderr, we'll just log that the command failed
+ w.logger.Error("Error running whisper binary", "error", err)
+ return "", fmt.Errorf("whisper binary failed: %w", err)
+ }
+ result := outBuf.String()
+ w.logger.Debug("Whisper binary completed", "result", result)
+ // Clean up audio buffer
+ w.mu.Lock()
+ w.audioBuffer = nil
+ w.mu.Unlock()
+ // Clean up the temporary file after transcription
+ w.logger.Debug("StopRecording completed")
+ os.Remove(tempFile)
+ result = strings.TrimRight(result, "\n")
+ // in case there are special tokens like [_BEG_]
+ result = specialRE.ReplaceAllString(result, "")
+ return strings.TrimSpace(strings.ReplaceAll(result, "\n ", "\n")), nil
+}
+
+// saveAudioToTempFile saves the recorded audio data to a temporary WAV file
+func (w *WhisperBinary) saveAudioToTempFile() (string, error) {
+ w.logger.Debug("saveAudioToTempFile called")
+ // Create temporary WAV file
+ tempFile, err := os.CreateTemp("", "recording_*.wav")
+ if err != nil {
+ w.logger.Error("Failed to create temp file", "error", err)
+ return "", fmt.Errorf("failed to create temp file: %w", err)
+ }
+ w.logger.Debug("Created temp file", "file", tempFile.Name())
+ defer tempFile.Close()
+
+ // Write WAV header and data
+ w.logger.Debug("About to write WAV file", "file", tempFile.Name())
+ err = w.writeWAVFile(tempFile.Name())
+ if err != nil {
+ w.logger.Error("Error writing WAV file", "error", err)
+ return "", fmt.Errorf("failed to write WAV file: %w", err)
+ }
+ w.logger.Debug("WAV file written successfully", "file", tempFile.Name())
+
+ return tempFile.Name(), nil
+}
+
+// writeWAVFile creates a WAV file from the recorded audio data
+func (w *WhisperBinary) writeWAVFile(filename string) error {
+ w.logger.Debug("writeWAVFile called", "filename", filename)
+ // Open file for writing
+ file, err := os.Create(filename)
+ if err != nil {
+ w.logger.Error("Error creating file", "error", err)
+ return err
+ }
+ defer file.Close()
+
+ w.logger.Debug("About to acquire mutex in writeWAVFile")
+ w.mu.Lock()
+ w.logger.Debug("Locked mutex, copying audio buffer")
+ audioData := make([]int16, len(w.audioBuffer))
+ copy(audioData, w.audioBuffer)
+ w.mu.Unlock()
+ w.logger.Debug("Unlocked mutex", "audio_data_length", len(audioData))
+
+ if len(audioData) == 0 {
+ w.logger.Warn("No audio data to write")
+ return errors.New("no audio data to write")
+ }
+
+ // Calculate data size (number of samples * size of int16)
+ dataSize := len(audioData) * 2 // 2 bytes per int16 sample
+ w.logger.Debug("Calculated data size", "size", dataSize)
+
+ // Write WAV header with the correct data size
+ header := w.createWAVHeader(16000, 1, 16, dataSize)
+ _, err = file.Write(header)
+ if err != nil {
+ w.logger.Error("Error writing WAV header", "error", err)
+ return err
+ }
+ w.logger.Debug("WAV header written successfully")
+
+ // Write audio data
+ w.logger.Debug("About to write audio data samples")
+ for i, sample := range audioData {
+ // Write little-endian 16-bit sample
+ _, err := file.Write([]byte{byte(sample), byte(sample >> 8)})
+ if err != nil {
+ w.logger.Error("Error writing sample", "index", i, "error", err)
+ return err
+ }
+ // Log progress every 10000 samples to avoid too much output
+ if i%10000 == 0 {
+ w.logger.Debug("Written samples", "count", i)
+ }
+ }
+ w.logger.Debug("All audio data written successfully")
+
+ return nil
+}
+
+// createWAVHeader creates a WAV file header
+func (w *WhisperBinary) createWAVHeader(sampleRate, channels, bitsPerSample int, dataSize int) []byte {
+ header := make([]byte, 44)
+ copy(header[0:4], "RIFF")
+ // Total file size will be updated later
+ copy(header[8:12], "WAVE")
+ copy(header[12:16], "fmt ")
+ // fmt chunk size (16 for PCM)
+ header[16] = 16
+ header[17] = 0
+ header[18] = 0
+ header[19] = 0
+ // Audio format (1 = PCM)
+ header[20] = 1
+ header[21] = 0
+ // Number of channels
+ header[22] = byte(channels)
+ header[23] = 0
+ // Sample rate
+ header[24] = byte(sampleRate)
+ header[25] = byte(sampleRate >> 8)
+ header[26] = byte(sampleRate >> 16)
+ header[27] = byte(sampleRate >> 24)
+ // Byte rate
+ byteRate := sampleRate * channels * bitsPerSample / 8
+ header[28] = byte(byteRate)
+ header[29] = byte(byteRate >> 8)
+ header[30] = byte(byteRate >> 16)
+ header[31] = byte(byteRate >> 24)
+ // Block align
+ blockAlign := channels * bitsPerSample / 8
+ header[32] = byte(blockAlign)
+ header[33] = 0
+ // Bits per sample
+ header[34] = byte(bitsPerSample)
+ header[35] = 0
+ // "data" subchunk
+ copy(header[36:40], "data")
+ // Data size
+ header[40] = byte(dataSize)
+ header[41] = byte(dataSize >> 8)
+ header[42] = byte(dataSize >> 16)
+ header[43] = byte(dataSize >> 24)
+
+ return header
+}
+
+func (w *WhisperBinary) IsRecording() bool {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.recording
+}