summaryrefslogtreecommitdiff
path: root/agent/agent.go
blob: 30e30e3b0473c6ba206767490079673c0bea85c5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
package agent

// Agent defines an interface for processing tool outputs.
// An Agent can clean, summarize, or otherwise transform raw tool outputs
// before they are presented to the main LLM.
type Agent interface {
	// Process takes the original tool arguments and the raw output from the tool,
	// and returns a cleaned/summarized version suitable for the main LLM context.
	Process(args map[string]string, rawOutput []byte) []byte
}

// registry holds mapping from tool names to agents.
var registry = make(map[string]Agent)

// Register adds an agent for a specific tool name.
// If an agent already exists for the tool, it will be replaced.
func Register(toolName string, a Agent) {
	registry[toolName] = a
}

// Get returns the agent for a tool name, or nil if none is registered.
func Get(toolName string) Agent {
	return registry[toolName]
}

// FormatterAgent is a simple agent that applies formatting functions.
type FormatterAgent struct {
	formatFunc func([]byte) (string, error)
}

// NewFormatterAgent creates a FormatterAgent that uses the given formatting function.
func NewFormatterAgent(formatFunc func([]byte) (string, error)) *FormatterAgent {
	return &FormatterAgent{formatFunc: formatFunc}
}

// Process applies the formatting function to raw output.
func (a *FormatterAgent) Process(args map[string]string, rawOutput []byte) []byte {
	if a.formatFunc == nil {
		return rawOutput
	}
	formatted, err := a.formatFunc(rawOutput)
	if err != nil {
		// On error, return raw output with a warning prefix
		return []byte("[formatting failed, showing raw output]\n" + string(rawOutput))
	}
	return []byte(formatted)
}

// DefaultFormatter returns a FormatterAgent that uses the appropriate formatting
// based on tool name.
func DefaultFormatter(toolName string) Agent {
	switch toolName {
	case "websearch":
		return NewFormatterAgent(FormatSearchResults)
	case "read_url":
		return NewFormatterAgent(FormatWebPageContent)
	default:
		return nil
	}
}