summaryrefslogtreecommitdiff
path: root/rag/rag_test.go
blob: 4944007e79442f308b8603eb55409a67557c993e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package rag

import (
	"testing"
)

func TestDetectPhrases(t *testing.T) {
	tests := []struct {
		query  string
		expect []string
	}{
		{
			query:  "bald prophet and two she bears",
			expect: []string{"bald prophet", "two she", "two she bears", "she bears"},
		},
		{
			query:  "she bears",
			expect: []string{"she bears"},
		},
		{
			query:  "the quick brown fox",
			expect: []string{"quick brown", "quick brown fox", "brown fox"},
		},
		{
			query:  "in the house", // stop words
			expect: []string{},     // "in" and "the" are stop words
		},
		{
			query:  "a", // short
			expect: []string{},
		},
	}

	for _, tt := range tests {
		got := detectPhrases(tt.query)
		if len(got) != len(tt.expect) {
			t.Errorf("detectPhrases(%q) = %v, want %v", tt.query, got, tt.expect)
			continue
		}
		for i := range got {
			if got[i] != tt.expect[i] {
				t.Errorf("detectPhrases(%q) = %v, want %v", tt.query, got, tt.expect)
				break
			}
		}
	}
}

func TestCountPhraseMatches(t *testing.T) {
	tests := []struct {
		text   string
		query  string
		expect int
	}{
		{
			text:   "two she bears came out of the wood",
			query:  "she bears",
			expect: 1,
		},
		{
			text:   "bald head and she bears",
			query:  "bald prophet and two she bears",
			expect: 1, // only "she bears" matches
		},
		{
			text:   "no match here",
			query:  "she bears",
			expect: 0,
		},
		{
			text:   "she bears and bald prophet",
			query:  "bald prophet she bears",
			expect: 2, // "she bears" and "bald prophet"
		},
	}

	for _, tt := range tests {
		got := countPhraseMatches(tt.text, tt.query)
		if got != tt.expect {
			t.Errorf("countPhraseMatches(%q, %q) = %d, want %d", tt.text, tt.query, got, tt.expect)
		}
	}
}

func TestAreSlugsAdjacent(t *testing.T) {
	tests := []struct {
		slug1  string
		slug2  string
		expect bool
	}{
		{
			slug1:  "kjv_bible.epub_1786_0",
			slug2:  "kjv_bible.epub_1787_0",
			expect: true,
		},
		{
			slug1:  "kjv_bible.epub_1787_0",
			slug2:  "kjv_bible.epub_1786_0",
			expect: true,
		},
		{
			slug1:  "kjv_bible.epub_1786_0",
			slug2:  "kjv_bible.epub_1788_0",
			expect: false,
		},
		{
			slug1:  "otherfile.txt_1_0",
			slug2:  "kjv_bible.epub_1786_0",
			expect: false,
		},
		{
			slug1:  "file_1_0",
			slug2:  "file_1_1",
			expect: true,
		},
		{
			slug1:  "file_1_0",
			slug2:  "file_2_0", // different batch
			expect: true,       // sequential batches with same chunk index are adjacent
		},
	}

	for _, tt := range tests {
		got := areSlugsAdjacent(tt.slug1, tt.slug2)
		if got != tt.expect {
			t.Errorf("areSlugsAdjacent(%q, %q) = %v, want %v", tt.slug1, tt.slug2, got, tt.expect)
		}
	}
}

func TestParseSlugIndices(t *testing.T) {
	tests := []struct {
		slug      string
		wantBatch int
		wantChunk int
		wantOk    bool
	}{
		{"kjv_bible.epub_1786_0", 1786, 0, true},
		{"file_1_5", 1, 5, true},
		{"no_underscore", 0, 0, false},
		{"file_abc_def", 0, 0, false},
		{"file_123_456_extra", 456, 0, false}, // regex matches last two numbers
	}

	for _, tt := range tests {
		batch, chunk, ok := parseSlugIndices(tt.slug)
		if ok != tt.wantOk {
			t.Errorf("parseSlugIndices(%q) ok = %v, want %v", tt.slug, ok, tt.wantOk)
			continue
		}
		if ok && (batch != tt.wantBatch || chunk != tt.wantChunk) {
			t.Errorf("parseSlugIndices(%q) = (%d, %d), want (%d, %d)", tt.slug, batch, chunk, tt.wantBatch, tt.wantChunk)
		}
	}
}