summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.golangci.yml9
-rw-r--r--bot.go6
-rw-r--r--llm.go27
-rw-r--r--models/models.go7
4 files changed, 45 insertions, 4 deletions
diff --git a/.golangci.yml b/.golangci.yml
index 2c7e552..ce57300 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -1,6 +1,7 @@
version: "2"
run:
- concurrency: 2
+ timeout: 1m
+ concurrency: 4
tests: false
linters:
default: none
@@ -14,7 +15,13 @@ linters:
- prealloc
- staticcheck
- unused
+ - gocritic
+ - unconvert
+ - wastedassign
settings:
+ gocritic:
+ enabled-tags:
+ - performance
funlen:
lines: 80
statements: 50
diff --git a/bot.go b/bot.go
index bbb3f65..c6c1e77 100644
--- a/bot.go
+++ b/bot.go
@@ -682,8 +682,10 @@ func sendMsgToLLM(body io.Reader) {
answerText = strings.ReplaceAll(chunk.Chunk, "\n\n", "\n")
// Accumulate text to check for stop strings that might span across chunks
// check if chunk is in stopstrings => stop
- if slices.Contains(stopStrings, answerText) {
- logger.Debug("Stop string detected and handled", "stop_string", answerText)
+ // this check is needed only for openrouter /v1/completion, since it does not respect stop slice
+ if chunkParser.GetAPIType() == models.APITypeCompletion &&
+ slices.Contains(stopStrings, answerText) {
+ logger.Debug("stop string detected on client side for completion endpoint", "stop_string", answerText)
streamDone <- true
}
chunkChan <- answerText
diff --git a/llm.go b/llm.go
index 95de1d8..b2cd5e2 100644
--- a/llm.go
+++ b/llm.go
@@ -78,6 +78,7 @@ type ChunkParser interface {
ParseChunk([]byte) (*models.TextChunk, error)
FormMsg(msg, role string, cont bool) (io.Reader, error)
GetToken() string
+ GetAPIType() models.APIType
}
func choseChunkParser() {
@@ -127,6 +128,10 @@ type OpenRouterChat struct {
Model string
}
+func (lcp LCPCompletion) GetAPIType() models.APIType {
+ return models.APITypeCompletion
+}
+
func (lcp LCPCompletion) GetToken() string {
return ""
}
@@ -233,7 +238,11 @@ func (lcp LCPCompletion) ParseChunk(data []byte) (*models.TextChunk, error) {
return resp, nil
}
-func (op LCPChat) GetToken() string {
+func (lcp LCPChat) GetAPIType() models.APIType {
+ return models.APITypeChat
+}
+
+func (lcp LCPChat) GetToken() string {
return ""
}
@@ -371,6 +380,10 @@ func (op LCPChat) FormMsg(msg, role string, resume bool) (io.Reader, error) {
}
// deepseek
+func (ds DeepSeekerCompletion) GetAPIType() models.APIType {
+ return models.APITypeCompletion
+}
+
func (ds DeepSeekerCompletion) ParseChunk(data []byte) (*models.TextChunk, error) {
llmchunk := models.DSCompletionResp{}
if err := json.Unmarshal(data, &llmchunk); err != nil {
@@ -453,6 +466,10 @@ func (ds DeepSeekerCompletion) FormMsg(msg, role string, resume bool) (io.Reader
return bytes.NewReader(data), nil
}
+func (ds DeepSeekerChat) GetAPIType() models.APIType {
+ return models.APITypeChat
+}
+
func (ds DeepSeekerChat) ParseChunk(data []byte) (*models.TextChunk, error) {
llmchunk := models.DSChatStreamResp{}
if err := json.Unmarshal(data, &llmchunk); err != nil {
@@ -539,6 +556,10 @@ func (ds DeepSeekerChat) FormMsg(msg, role string, resume bool) (io.Reader, erro
}
// openrouter
+func (or OpenRouterCompletion) GetAPIType() models.APIType {
+ return models.APITypeCompletion
+}
+
func (or OpenRouterCompletion) ParseChunk(data []byte) (*models.TextChunk, error) {
llmchunk := models.OpenRouterCompletionResp{}
if err := json.Unmarshal(data, &llmchunk); err != nil {
@@ -618,6 +639,10 @@ func (or OpenRouterCompletion) FormMsg(msg, role string, resume bool) (io.Reader
}
// chat
+func (or OpenRouterChat) GetAPIType() models.APIType {
+ return models.APITypeChat
+}
+
func (or OpenRouterChat) ParseChunk(data []byte) (*models.TextChunk, error) {
llmchunk := models.OpenRouterChatResp{}
if err := json.Unmarshal(data, &llmchunk); err != nil {
diff --git a/models/models.go b/models/models.go
index e99832a..4133a7c 100644
--- a/models/models.go
+++ b/models/models.go
@@ -558,3 +558,10 @@ type ChatRoundReq struct {
Regen bool
Resume bool
}
+
+type APIType int
+
+const (
+ APITypeChat APIType = iota
+ APITypeCompletion
+)