diff options
Diffstat (limited to 'extra')
| -rw-r--r-- | extra/cluedo.go | 73 | ||||
| -rw-r--r-- | extra/cluedo_test.go | 50 | ||||
| -rw-r--r-- | extra/stt.go | 199 | ||||
| -rw-r--r-- | extra/tts.go | 212 | ||||
| -rw-r--r-- | extra/twentyq.go | 11 | ||||
| -rw-r--r-- | extra/vad.go | 1 | ||||
| -rw-r--r-- | extra/websearch.go | 13 | ||||
| -rw-r--r-- | extra/whisper_binary.go | 318 |
8 files changed, 877 insertions, 0 deletions
diff --git a/extra/cluedo.go b/extra/cluedo.go new file mode 100644 index 0000000..1ef11cc --- /dev/null +++ b/extra/cluedo.go @@ -0,0 +1,73 @@ +package extra + +import ( + "math/rand" + "strings" +) + +var ( + rooms = []string{"HALL", "LOUNGE", "DINING ROOM", "KITCHEN", "BALLROOM", "CONSERVATORY", "BILLIARD ROOM", "LIBRARY", "STUDY"} + weapons = []string{"CANDLESTICK", "DAGGER", "LEAD PIPE", "REVOLVER", "ROPE", "SPANNER"} + people = []string{"Miss Scarlett", "Colonel Mustard", "Mrs. White", "Reverend Green", "Mrs. Peacock", "Professor Plum"} +) + +type MurderTrifecta struct { + Murderer string + Weapon string + Room string +} + +type CluedoRoundInfo struct { + Answer MurderTrifecta + PlayersCards map[string][]string +} + +func (c *CluedoRoundInfo) GetPlayerCards(player string) string { + // maybe format it a little + return "cards of " + player + "are " + strings.Join(c.PlayersCards[player], ",") +} + +func CluedoPrepCards(playerOrder []string) *CluedoRoundInfo { + res := &CluedoRoundInfo{} + // Select murder components + trifecta := MurderTrifecta{ + Murderer: people[rand.Intn(len(people))], + Weapon: weapons[rand.Intn(len(weapons))], + Room: rooms[rand.Intn(len(rooms))], + } + // Collect non-murder cards + var notInvolved []string + for _, room := range rooms { + if room != trifecta.Room { + notInvolved = append(notInvolved, room) + } + } + for _, weapon := range weapons { + if weapon != trifecta.Weapon { + notInvolved = append(notInvolved, weapon) + } + } + for _, person := range people { + if person != trifecta.Murderer { + notInvolved = append(notInvolved, person) + } + } + // Shuffle and distribute cards + rand.Shuffle(len(notInvolved), func(i, j int) { + notInvolved[i], notInvolved[j] = notInvolved[j], notInvolved[i] + }) + players := map[string][]string{} + cardsPerPlayer := len(notInvolved) / len(playerOrder) + // playerOrder := []string{"{{user}}", "{{char}}", "{{char2}}"} + for i, player := range playerOrder { + start := i * cardsPerPlayer + end := (i + 1) * cardsPerPlayer + if end > len(notInvolved) { + end = len(notInvolved) + } + players[player] = notInvolved[start:end] + } + res.Answer = trifecta + res.PlayersCards = players + return res +} diff --git a/extra/cluedo_test.go b/extra/cluedo_test.go new file mode 100644 index 0000000..e7a53b1 --- /dev/null +++ b/extra/cluedo_test.go @@ -0,0 +1,50 @@ +package extra + +import ( + "testing" +) + +func TestPrepCards(t *testing.T) { + // Run the function to get the murder combination and player cards + roundInfo := CluedoPrepCards([]string{"{{user}}", "{{char}}", "{{char2}}"}) + // Create a map to track all distributed cards + distributedCards := make(map[string]bool) + // Check that the murder combination cards are not distributed to players + murderCards := []string{roundInfo.Answer.Murderer, roundInfo.Answer.Weapon, roundInfo.Answer.Room} + for _, card := range murderCards { + if distributedCards[card] { + t.Errorf("Murder card %s was distributed to a player", card) + } + } + // Check each player's cards + for player, cards := range roundInfo.PlayersCards { + for _, card := range cards { + // Ensure the card is not part of the murder combination + for _, murderCard := range murderCards { + if card == murderCard { + t.Errorf("Player %s has a murder card: %s", player, card) + } + } + // Ensure the card is unique and not already distributed + if distributedCards[card] { + t.Errorf("Card %s is duplicated in player %s's hand", card, player) + } + distributedCards[card] = true + } + } + // Verify that all non-murder cards are distributed + allCards := append(append([]string{}, rooms...), weapons...) + allCards = append(allCards, people...) + for _, card := range allCards { + isMurderCard := false + for _, murderCard := range murderCards { + if card == murderCard { + isMurderCard = true + break + } + } + if !isMurderCard && !distributedCards[card] { + t.Errorf("Card %s was not distributed to any player", card) + } + } +} diff --git a/extra/stt.go b/extra/stt.go new file mode 100644 index 0000000..e33a94d --- /dev/null +++ b/extra/stt.go @@ -0,0 +1,199 @@ +package extra + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "gf-lt/config" + "io" + "log/slog" + "mime/multipart" + "net/http" + "regexp" + "strings" + "syscall" + + "github.com/gordonklaus/portaudio" +) + +var specialRE = regexp.MustCompile(`\[.*?\]`) + +type STT interface { + StartRecording() error + StopRecording() (string, error) + IsRecording() bool +} + +type StreamCloser interface { + Close() error +} + +func NewSTT(logger *slog.Logger, cfg *config.Config) STT { + switch cfg.STT_TYPE { + case "WHISPER_BINARY": + logger.Debug("stt init, chosen whisper binary") + return NewWhisperBinary(logger, cfg) + case "WHISPER_SERVER": + logger.Debug("stt init, chosen whisper server") + return NewWhisperServer(logger, cfg) + } + return NewWhisperServer(logger, cfg) +} + +type WhisperServer struct { + logger *slog.Logger + ServerURL string + SampleRate int + AudioBuffer *bytes.Buffer + recording bool +} + +func NewWhisperServer(logger *slog.Logger, cfg *config.Config) *WhisperServer { + return &WhisperServer{ + logger: logger, + ServerURL: cfg.STT_URL, + SampleRate: cfg.STT_SR, + AudioBuffer: new(bytes.Buffer), + } +} + +func (stt *WhisperServer) StartRecording() error { + if err := stt.microphoneStream(stt.SampleRate); err != nil { + return fmt.Errorf("failed to init microphone: %w", err) + } + stt.recording = true + return nil +} + +func (stt *WhisperServer) StopRecording() (string, error) { + stt.recording = false + // wait loop to finish? + if stt.AudioBuffer == nil { + err := errors.New("unexpected nil AudioBuffer") + stt.logger.Error(err.Error()) + return "", err + } + // Create WAV header first + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + // Add audio file part + part, err := writer.CreateFormFile("file", "recording.wav") + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + // Stream directly to multipart writer: header + raw data + dataSize := stt.AudioBuffer.Len() + stt.writeWavHeader(part, dataSize) + if _, err := io.Copy(part, stt.AudioBuffer); err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + // Reset buffer for next recording + stt.AudioBuffer.Reset() + // Add response format field + err = writer.WriteField("response_format", "text") + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + if writer.Close() != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + // Send request + resp, err := http.Post(stt.ServerURL, writer.FormDataContentType(), body) //nolint:noctx + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + defer resp.Body.Close() + // Read and print response + responseTextBytes, err := io.ReadAll(resp.Body) + if err != nil { + stt.logger.Error("fn: StopRecording", "error", err) + return "", err + } + resptext := strings.TrimRight(string(responseTextBytes), "\n") + // in case there are special tokens like [_BEG_] + resptext = specialRE.ReplaceAllString(resptext, "") + return strings.TrimSpace(strings.ReplaceAll(resptext, "\n ", "\n")), nil +} + +func (stt *WhisperServer) writeWavHeader(w io.Writer, dataSize int) { + header := make([]byte, 44) + copy(header[0:4], "RIFF") + binary.LittleEndian.PutUint32(header[4:8], uint32(36+dataSize)) + copy(header[8:12], "WAVE") + copy(header[12:16], "fmt ") + binary.LittleEndian.PutUint32(header[16:20], 16) + binary.LittleEndian.PutUint16(header[20:22], 1) + binary.LittleEndian.PutUint16(header[22:24], 1) + binary.LittleEndian.PutUint32(header[24:28], uint32(stt.SampleRate)) + binary.LittleEndian.PutUint32(header[28:32], uint32(stt.SampleRate)*1*(16/8)) + binary.LittleEndian.PutUint16(header[32:34], 1*(16/8)) + binary.LittleEndian.PutUint16(header[34:36], 16) + copy(header[36:40], "data") + binary.LittleEndian.PutUint32(header[40:44], uint32(dataSize)) + if _, err := w.Write(header); err != nil { + stt.logger.Error("writeWavHeader", "error", err) + } +} + +func (stt *WhisperServer) IsRecording() bool { + return stt.recording +} + +func (stt *WhisperServer) microphoneStream(sampleRate int) error { + // Temporarily redirect stderr to suppress ALSA warnings during PortAudio init + origStderr, errDup := syscall.Dup(syscall.Stderr) + if errDup != nil { + return fmt.Errorf("failed to dup stderr: %w", errDup) + } + nullFD, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) + if err != nil { + _ = syscall.Close(origStderr) // Close the dup'd fd if open fails + return fmt.Errorf("failed to open /dev/null: %w", err) + } + // redirect stderr + _ = syscall.Dup2(nullFD, syscall.Stderr) + // Initialize PortAudio (this is where ALSA warnings occur) + defer func() { + // Restore stderr + _ = syscall.Dup2(origStderr, syscall.Stderr) + _ = syscall.Close(origStderr) + _ = syscall.Close(nullFD) + }() + if err := portaudio.Initialize(); err != nil { + return fmt.Errorf("portaudio init failed: %w", err) + } + in := make([]int16, 64) + stream, err := portaudio.OpenDefaultStream(1, 0, float64(sampleRate), len(in), in) + if err != nil { + if paErr := portaudio.Terminate(); paErr != nil { + return fmt.Errorf("failed to open microphone: %w; terminate error: %w", err, paErr) + } + return fmt.Errorf("failed to open microphone: %w", err) + } + go func(stream *portaudio.Stream) { + if err := stream.Start(); err != nil { + stt.logger.Error("microphoneStream", "error", err) + return + } + for { + if !stt.IsRecording() { + return + } + if err := stream.Read(); err != nil { + stt.logger.Error("reading stream", "error", err) + return + } + if err := binary.Write(stt.AudioBuffer, binary.LittleEndian, in); err != nil { + stt.logger.Error("writing to buffer", "error", err) + return + } + } + }(stream) + return nil +} diff --git a/extra/tts.go b/extra/tts.go new file mode 100644 index 0000000..c6f373a --- /dev/null +++ b/extra/tts.go @@ -0,0 +1,212 @@ +package extra + +import ( + "bytes" + "encoding/json" + "fmt" + "gf-lt/config" + "gf-lt/models" + "io" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/gopxl/beep/v2" + "github.com/gopxl/beep/v2/mp3" + "github.com/gopxl/beep/v2/speaker" + "github.com/neurosnap/sentences/english" +) + +var ( + TTSTextChan = make(chan string, 10000) + TTSFlushChan = make(chan bool, 1) + TTSDoneChan = make(chan bool, 1) + // endsWithPunctuation = regexp.MustCompile(`[;.!?]$`) +) + +type Orator interface { + Speak(text string) error + Stop() + // pause and resume? + GetLogger() *slog.Logger +} + +// impl https://github.com/remsky/Kokoro-FastAPI +type KokoroOrator struct { + logger *slog.Logger + URL string + Format models.AudioFormat + Stream bool + Speed float32 + Language string + Voice string + currentStream *beep.Ctrl // Added for playback control + textBuffer strings.Builder + // textBuffer bytes.Buffer +} + +func (o *KokoroOrator) stoproutine() { + <-TTSDoneChan + o.logger.Debug("orator got done signal") + o.Stop() + // drain the channel + for len(TTSTextChan) > 0 { + <-TTSTextChan + } +} + +func (o *KokoroOrator) readroutine() { + tokenizer, _ := english.NewSentenceTokenizer(nil) + // var sentenceBuf bytes.Buffer + // var remainder strings.Builder + for { + select { + case chunk := <-TTSTextChan: + // sentenceBuf.WriteString(chunk) + // text := sentenceBuf.String() + _, err := o.textBuffer.WriteString(chunk) + if err != nil { + o.logger.Warn("failed to write to stringbuilder", "error", err) + continue + } + text := o.textBuffer.String() + sentences := tokenizer.Tokenize(text) + o.logger.Debug("adding chunk", "chunk", chunk, "text", text, "sen-len", len(sentences)) + for i, sentence := range sentences { + if i == len(sentences)-1 { // last sentence + o.textBuffer.Reset() + _, err := o.textBuffer.WriteString(sentence.Text) + if err != nil { + o.logger.Warn("failed to write to stringbuilder", "error", err) + continue + } + continue // if only one (often incomplete) sentence; wait for next chunk + } + o.logger.Debug("calling Speak with sentence", "sent", sentence.Text) + if err := o.Speak(sentence.Text); err != nil { + o.logger.Error("tts failed", "sentence", sentence.Text, "error", err) + } + } + case <-TTSFlushChan: + o.logger.Debug("got flushchan signal start") + // lln is done get the whole message out + if len(TTSTextChan) > 0 { // otherwise might get stuck + for chunk := range TTSTextChan { + _, err := o.textBuffer.WriteString(chunk) + if err != nil { + o.logger.Warn("failed to write to stringbuilder", "error", err) + continue + } + if len(TTSTextChan) == 0 { + break + } + } + } + // INFO: if there is a lot of text it will take some time to make with tts at once + // to avoid this pause, it might be better to keep splitting on sentences + // but keepinig in mind that remainder could be ommited by tokenizer + // Flush remaining text + remaining := o.textBuffer.String() + o.textBuffer.Reset() + if remaining != "" { + o.logger.Debug("calling Speak with remainder", "rem", remaining) + if err := o.Speak(remaining); err != nil { + o.logger.Error("tts failed", "sentence", remaining, "error", err) + } + } + } + } +} + +func NewOrator(log *slog.Logger, cfg *config.Config) Orator { + orator := &KokoroOrator{ + logger: log, + URL: cfg.TTS_URL, + Format: models.AFMP3, + Stream: false, + Speed: cfg.TTS_SPEED, + Language: "a", + Voice: "af_bella(1)+af_sky(1)", + } + go orator.readroutine() + go orator.stoproutine() + return orator +} + +func (o *KokoroOrator) GetLogger() *slog.Logger { + return o.logger +} + +func (o *KokoroOrator) requestSound(text string) (io.ReadCloser, error) { + payload := map[string]interface{}{ + "input": text, + "voice": o.Voice, + "response_format": o.Format, + "download_format": o.Format, + "stream": o.Stream, + "speed": o.Speed, + // "return_download_link": true, + "lang_code": o.Language, + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + req, err := http.NewRequest("POST", o.URL, bytes.NewBuffer(payloadBytes)) //nolint:noctx + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("accept", "application/json") + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + return resp.Body, nil +} + +func (o *KokoroOrator) Speak(text string) error { + o.logger.Debug("fn: Speak is called", "text-len", len(text)) + body, err := o.requestSound(text) + if err != nil { + o.logger.Error("request failed", "error", err) + return fmt.Errorf("request failed: %w", err) + } + defer body.Close() + // Decode the mp3 audio from response body + streamer, format, err := mp3.Decode(body) + if err != nil { + o.logger.Error("mp3 decode failed", "error", err) + return fmt.Errorf("mp3 decode failed: %w", err) + } + defer streamer.Close() + // here it spams with errors that speaker cannot be initialized more than once, but how would we deal with many audio records then? + if err := speaker.Init(format.SampleRate, format.SampleRate.N(time.Second/10)); err != nil { + o.logger.Debug("failed to init speaker", "error", err) + } + done := make(chan bool) + // Create controllable stream and store reference + o.currentStream = &beep.Ctrl{Streamer: beep.Seq(streamer, beep.Callback(func() { + close(done) + o.currentStream = nil + })), Paused: false} + speaker.Play(o.currentStream) + <-done // we hang in this routine; + return nil +} + +func (o *KokoroOrator) Stop() { + // speaker.Clear() + o.logger.Debug("attempted to stop orator", "orator", o) + speaker.Lock() + defer speaker.Unlock() + if o.currentStream != nil { + // o.currentStream.Paused = true + o.currentStream.Streamer = nil + } +} diff --git a/extra/twentyq.go b/extra/twentyq.go new file mode 100644 index 0000000..30c08cc --- /dev/null +++ b/extra/twentyq.go @@ -0,0 +1,11 @@ +package extra + +import "math/rand" + +var ( + chars = []string{"Shrek", "Garfield", "Jack the Ripper"} +) + +func GetRandomChar() string { + return chars[rand.Intn(len(chars))] +} diff --git a/extra/vad.go b/extra/vad.go new file mode 100644 index 0000000..2a9e238 --- /dev/null +++ b/extra/vad.go @@ -0,0 +1 @@ +package extra diff --git a/extra/websearch.go b/extra/websearch.go new file mode 100644 index 0000000..99bc1b6 --- /dev/null +++ b/extra/websearch.go @@ -0,0 +1,13 @@ +package extra + +import "github.com/GrailFinder/searchagent/searcher" + +var WebSearcher searcher.WebSurfer + +func init() { + sa, err := searcher.NewWebSurfer(searcher.SearcherTypeScraper, "") + if err != nil { + panic("failed to init seachagent; error: " + err.Error()) + } + WebSearcher = sa +} diff --git a/extra/whisper_binary.go b/extra/whisper_binary.go new file mode 100644 index 0000000..a016a30 --- /dev/null +++ b/extra/whisper_binary.go @@ -0,0 +1,318 @@ +package extra + +import ( + "bytes" + "context" + "errors" + "fmt" + "gf-lt/config" + "io" + "log/slog" + "os" + "os/exec" + "strings" + "sync" + "syscall" + + "github.com/gordonklaus/portaudio" +) + +type WhisperBinary struct { + logger *slog.Logger + whisperPath string + modelPath string + lang string + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex + recording bool + audioBuffer []int16 +} + +func NewWhisperBinary(logger *slog.Logger, cfg *config.Config) *WhisperBinary { + ctx, cancel := context.WithCancel(context.Background()) + // Set ALSA error handler first + return &WhisperBinary{ + logger: logger, + whisperPath: cfg.WhisperBinaryPath, + modelPath: cfg.WhisperModelPath, + lang: cfg.STT_LANG, + ctx: ctx, + cancel: cancel, + } +} + +func (w *WhisperBinary) StartRecording() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.recording { + return errors.New("recording is already in progress") + } + // Temporarily redirect stderr to suppress ALSA warnings during PortAudio init + origStderr, errDup := syscall.Dup(syscall.Stderr) + if errDup != nil { + return fmt.Errorf("failed to dup stderr: %w", errDup) + } + nullFD, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0) + if err != nil { + _ = syscall.Close(origStderr) // Close the dup'd fd if open fails + return fmt.Errorf("failed to open /dev/null: %w", err) + } + // redirect stderr + _ = syscall.Dup2(nullFD, syscall.Stderr) + // Initialize PortAudio (this is where ALSA warnings occur) + portaudioErr := portaudio.Initialize() + defer func() { + // Restore stderr + _ = syscall.Dup2(origStderr, syscall.Stderr) + _ = syscall.Close(origStderr) + _ = syscall.Close(nullFD) + }() + if portaudioErr != nil { + return fmt.Errorf("portaudio init failed: %w", portaudioErr) + } + // Initialize audio buffer + w.audioBuffer = make([]int16, 0) + in := make([]int16, 1024) // buffer size + stream, err := portaudio.OpenDefaultStream(1, 0, 16000.0, len(in), in) + if err != nil { + if paErr := portaudio.Terminate(); paErr != nil { + return fmt.Errorf("failed to open microphone: %w; terminate error: %w", err, paErr) + } + return fmt.Errorf("failed to open microphone: %w", err) + } + go w.recordAudio(stream, in) + w.recording = true + w.logger.Debug("Recording started") + return nil +} + +func (w *WhisperBinary) recordAudio(stream *portaudio.Stream, in []int16) { + defer func() { + w.logger.Debug("recordAudio defer function called") + _ = stream.Stop() // Stop the stream + _ = portaudio.Terminate() // ignoring error as we're shutting down + w.logger.Debug("recordAudio terminated") + }() + w.logger.Debug("Starting audio stream") + if err := stream.Start(); err != nil { + w.logger.Error("Failed to start audio stream", "error", err) + return + } + w.logger.Debug("Audio stream started, entering recording loop") + for { + select { + case <-w.ctx.Done(): + w.logger.Debug("Context done, exiting recording loop") + return + default: + // Check recording status with minimal lock time + w.mu.Lock() + recording := w.recording + w.mu.Unlock() + + if !recording { + w.logger.Debug("Recording flag is false, exiting recording loop") + return + } + if err := stream.Read(); err != nil { + w.logger.Error("Error reading from stream", "error", err) + return + } + // Append samples to buffer - only acquire lock when necessary + w.mu.Lock() + if w.audioBuffer == nil { + w.audioBuffer = make([]int16, 0) + } + // Make a copy of the input buffer to avoid overwriting + tempBuffer := make([]int16, len(in)) + copy(tempBuffer, in) + w.audioBuffer = append(w.audioBuffer, tempBuffer...) + w.mu.Unlock() + } + } +} + +func (w *WhisperBinary) StopRecording() (string, error) { + w.logger.Debug("StopRecording called") + w.mu.Lock() + if !w.recording { + w.mu.Unlock() + return "", errors.New("not currently recording") + } + w.logger.Debug("Setting recording to false and cancelling context") + w.recording = false + w.cancel() // This will stop the recording goroutine + w.mu.Unlock() + // // Small delay to allow the recording goroutine to react to context cancellation + // time.Sleep(20 * time.Millisecond) + // Save the recorded audio to a temporary file + tempFile, err := w.saveAudioToTempFile() + if err != nil { + w.logger.Error("Error saving audio to temp file", "error", err) + return "", fmt.Errorf("failed to save audio to temp file: %w", err) + } + w.logger.Debug("Saved audio to temp file", "file", tempFile) + // Run the whisper binary with a separate context to avoid cancellation during transcription + cmd := exec.Command(w.whisperPath, "-m", w.modelPath, "-l", w.lang, tempFile, "2>/dev/null") + var outBuf bytes.Buffer + cmd.Stdout = &outBuf + // Redirect stderr to suppress ALSA warnings and other stderr output + cmd.Stderr = io.Discard // Suppress stderr output from whisper binary + w.logger.Debug("Running whisper binary command") + if err := cmd.Run(); err != nil { + // Clean up audio buffer + w.mu.Lock() + w.audioBuffer = nil + w.mu.Unlock() + // Since we're suppressing stderr, we'll just log that the command failed + w.logger.Error("Error running whisper binary", "error", err) + return "", fmt.Errorf("whisper binary failed: %w", err) + } + result := outBuf.String() + w.logger.Debug("Whisper binary completed", "result", result) + // Clean up audio buffer + w.mu.Lock() + w.audioBuffer = nil + w.mu.Unlock() + // Clean up the temporary file after transcription + w.logger.Debug("StopRecording completed") + os.Remove(tempFile) + result = strings.TrimRight(result, "\n") + // in case there are special tokens like [_BEG_] + result = specialRE.ReplaceAllString(result, "") + return strings.TrimSpace(strings.ReplaceAll(result, "\n ", "\n")), nil +} + +// saveAudioToTempFile saves the recorded audio data to a temporary WAV file +func (w *WhisperBinary) saveAudioToTempFile() (string, error) { + w.logger.Debug("saveAudioToTempFile called") + // Create temporary WAV file + tempFile, err := os.CreateTemp("", "recording_*.wav") + if err != nil { + w.logger.Error("Failed to create temp file", "error", err) + return "", fmt.Errorf("failed to create temp file: %w", err) + } + w.logger.Debug("Created temp file", "file", tempFile.Name()) + defer tempFile.Close() + + // Write WAV header and data + w.logger.Debug("About to write WAV file", "file", tempFile.Name()) + err = w.writeWAVFile(tempFile.Name()) + if err != nil { + w.logger.Error("Error writing WAV file", "error", err) + return "", fmt.Errorf("failed to write WAV file: %w", err) + } + w.logger.Debug("WAV file written successfully", "file", tempFile.Name()) + + return tempFile.Name(), nil +} + +// writeWAVFile creates a WAV file from the recorded audio data +func (w *WhisperBinary) writeWAVFile(filename string) error { + w.logger.Debug("writeWAVFile called", "filename", filename) + // Open file for writing + file, err := os.Create(filename) + if err != nil { + w.logger.Error("Error creating file", "error", err) + return err + } + defer file.Close() + + w.logger.Debug("About to acquire mutex in writeWAVFile") + w.mu.Lock() + w.logger.Debug("Locked mutex, copying audio buffer") + audioData := make([]int16, len(w.audioBuffer)) + copy(audioData, w.audioBuffer) + w.mu.Unlock() + w.logger.Debug("Unlocked mutex", "audio_data_length", len(audioData)) + + if len(audioData) == 0 { + w.logger.Warn("No audio data to write") + return errors.New("no audio data to write") + } + + // Calculate data size (number of samples * size of int16) + dataSize := len(audioData) * 2 // 2 bytes per int16 sample + w.logger.Debug("Calculated data size", "size", dataSize) + + // Write WAV header with the correct data size + header := w.createWAVHeader(16000, 1, 16, dataSize) + _, err = file.Write(header) + if err != nil { + w.logger.Error("Error writing WAV header", "error", err) + return err + } + w.logger.Debug("WAV header written successfully") + + // Write audio data + w.logger.Debug("About to write audio data samples") + for i, sample := range audioData { + // Write little-endian 16-bit sample + _, err := file.Write([]byte{byte(sample), byte(sample >> 8)}) + if err != nil { + w.logger.Error("Error writing sample", "index", i, "error", err) + return err + } + // Log progress every 10000 samples to avoid too much output + if i%10000 == 0 { + w.logger.Debug("Written samples", "count", i) + } + } + w.logger.Debug("All audio data written successfully") + + return nil +} + +// createWAVHeader creates a WAV file header +func (w *WhisperBinary) createWAVHeader(sampleRate, channels, bitsPerSample int, dataSize int) []byte { + header := make([]byte, 44) + copy(header[0:4], "RIFF") + // Total file size will be updated later + copy(header[8:12], "WAVE") + copy(header[12:16], "fmt ") + // fmt chunk size (16 for PCM) + header[16] = 16 + header[17] = 0 + header[18] = 0 + header[19] = 0 + // Audio format (1 = PCM) + header[20] = 1 + header[21] = 0 + // Number of channels + header[22] = byte(channels) + header[23] = 0 + // Sample rate + header[24] = byte(sampleRate) + header[25] = byte(sampleRate >> 8) + header[26] = byte(sampleRate >> 16) + header[27] = byte(sampleRate >> 24) + // Byte rate + byteRate := sampleRate * channels * bitsPerSample / 8 + header[28] = byte(byteRate) + header[29] = byte(byteRate >> 8) + header[30] = byte(byteRate >> 16) + header[31] = byte(byteRate >> 24) + // Block align + blockAlign := channels * bitsPerSample / 8 + header[32] = byte(blockAlign) + header[33] = 0 + // Bits per sample + header[34] = byte(bitsPerSample) + header[35] = 0 + // "data" subchunk + copy(header[36:40], "data") + // Data size + header[40] = byte(dataSize) + header[41] = byte(dataSize >> 8) + header[42] = byte(dataSize >> 16) + header[43] = byte(dataSize >> 24) + + return header +} + +func (w *WhisperBinary) IsRecording() bool { + w.mu.Lock() + defer w.mu.Unlock() + return w.recording +} |
